commit de33dad47eea9beaf50856f5c7ddcc8b45e96913 Author: itqop Date: Fri Jul 18 20:22:57 2025 +0300 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a51fddd --- /dev/null +++ b/.gitignore @@ -0,0 +1,42 @@ +__pycache__/ +*.py[cod] +*$py.class +*.so +*.egg-info/ +dist/ +build/ +.Python + +.venv/ +venv/ +env/ +ENV/ +.env +config.env + +.vscode/ +.idea/ +*.swp +*.swo + +.DS_Store +Thumbs.db +Desktop.ini + +storage/chroma/ +!storage/chroma/.gitkeep +*.log +logs/ +.temporary/ +*.tmp +*.temp + +.pytest_cache/ +.coverage +htmlcov/ +.tox/ +.ruff_cache/ + +data/local/ +data/temp/ +articles_konsol_pro/ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..faa29e9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +fastapi +uvicorn[standard] +pydantic +langchain +langchain-community +langgraph +chromadb +openai +google-generativeai +python-multipart +python-dotenv +pypdf +markdown +python-jose[cryptography] +passlib[bcrypt] +tiktoken +httpx +pytest +pytest-asyncio diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/__init__.py b/src/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/config.py b/src/app/config.py new file mode 100644 index 0000000..4784333 --- /dev/null +++ b/src/app/config.py @@ -0,0 +1,30 @@ +from typing import Literal +from pydantic_settings import BaseSettings +from pydantic import Field, ConfigDict + + +class Settings(BaseSettings): + llm_provider: Literal["openai", "gemini"] = "openai" + llm_model: str = "gpt-4o-mini" + embedding_model: str = "text-embedding-3-large" + + openai_api_key: str = Field(default="", json_schema_extra={"env": "OPENAI_API_KEY"}) + gemini_api_key: str = Field(default="", json_schema_extra={"env": "GEMINI_API_KEY"}) + api_secret_key: str = Field( + default="secret", json_schema_extra={"env": "API_SECRET_KEY"} + ) + + chroma_persist_dir: str = "./storage/chroma" + top_k: int = 30 + top_n_context: int = 6 + max_tokens_completion: int = 1024 + langgraph_tracing: bool = False + service_lang: str = "ru" + sales_rep_name: str = "Команда Консоль.Про" + chunk_size: int = 500 + chunk_overlap: int = 100 + + model_config = ConfigDict(env_file=".env") + + +settings = Settings() diff --git a/src/app/main.py b/src/app/main.py new file mode 100644 index 0000000..4ba6d20 --- /dev/null +++ b/src/app/main.py @@ -0,0 +1,107 @@ +import os +import sys +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +import time +import json +import logging + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from src.app.routers import generate, ingest, health +from src.app.config import settings + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +app = FastAPI( + title="AI Email Assistant", + description="Персонализированная генерация холодных писем с использованием RAG и LangGraph", + version="1.0.0", + docs_url="/docs", + redoc_url="/redoc", +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.middleware("http") +async def logging_middleware(request: Request, call_next): + start_time = time.time() + + request_id = str(time.time()).replace(".", "") + + response = await call_next(request) + + process_time = time.time() - start_time + + log_data = { + "request_id": request_id, + "method": request.method, + "url": str(request.url), + "status_code": response.status_code, + "process_time": round(process_time, 4), + "client_ip": request.client.host if request.client else "unknown", + } + + logger.info(json.dumps(log_data)) + + response.headers["X-Request-ID"] = request_id + response.headers["X-Process-Time"] = str(round(process_time, 4)) + + return response + + +@app.exception_handler(Exception) +async def global_exception_handler(request: Request, exc: Exception): + logger.error(f"Global exception handler caught: {exc}", exc_info=True) + + return JSONResponse( + status_code=500, + content={ + "error": "Internal server error", + "code": "INTERNAL_ERROR", + "detail": ( + str(exc) if settings.llm_provider else "An unexpected error occurred" + ), + }, + ) + + +app.include_router(health.router) +app.include_router(generate.router) +app.include_router(ingest.router) + + +@app.get("/") +async def root(): + return { + "service": "AI Email Assistant", + "version": "1.0.0", + "description": "Персонализированная генерация холодных писем", + "endpoints": { + "health": "/healthz", + "readiness": "/readiness", + "generate": "/api/v1/generate_email", + "admin": "/api/v1/admin/", + "docs": "/docs", + }, + "status": "operational", + } + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "src.app.main:app", host="0.0.0.0", port=8000, reload=True, log_level="info" + ) diff --git a/src/app/routers/__init__.py b/src/app/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/routers/generate.py b/src/app/routers/generate.py new file mode 100644 index 0000000..e7cdc81 --- /dev/null +++ b/src/app/routers/generate.py @@ -0,0 +1,110 @@ +import time +from fastapi import APIRouter, HTTPException, Depends +from typing import Dict, Any + +from src.models.lead import LeadInput +from src.models.email import EmailResponse, ErrorResponse +from src.graph.build_graph import email_generation_graph +from src.app.config import settings + +router = APIRouter(prefix="/api/v1", tags=["generation"]) + + +@router.post("/generate_email", response_model=EmailResponse) +async def generate_email(lead_data: LeadInput): + start_time = time.time() + + try: + request = lead_data.model_dump() + + initial_state = { + "raw_input": request, + "lead_input": None, + "lead_model": None, + "lead_features": None, + "retrieval_query": None, + "retrieved_chunks": None, + "ranked_context": None, + "prompt_payload": None, + "llm_output": None, + "email_draft": None, + "email_clean": None, + "email_response": None, + "error": None, + "error_code": None, + "trace_meta": None, + } + + result_state = email_generation_graph.invoke(initial_state) + + if result_state.get("error"): + error_code = result_state.get("error_code", "UNKNOWN_ERROR") + error_message = result_state.get("error", "Unknown error occurred") + trace_meta = result_state.get("trace_meta", {}) + + if error_code == "VALIDATION_ERROR": + raise HTTPException( + status_code=400, + detail={ + "error": error_message, + "code": error_code, + "trace": trace_meta, + }, + ) + elif error_code == "NO_RESULTS": + raise HTTPException( + status_code=404, + detail={ + "error": "No relevant knowledge found for this lead profile", + "code": error_code, + }, + ) + elif error_code in ["LLM_ERROR", "OPENAI_ERROR", "GEMINI_ERROR"]: + raise HTTPException( + status_code=502, + detail={ + "error": "External service error", + "code": error_code, + "details": error_message, + }, + ) + else: + raise HTTPException( + status_code=500, + detail={ + "error": error_message, + "code": error_code, + "trace": trace_meta, + }, + ) + + email_response = result_state.get("email_response") + if not email_response: + raise HTTPException( + status_code=500, + detail={ + "error": "Failed to generate email response", + "code": "MISSING_RESPONSE", + }, + ) + + return email_response + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, + detail={"error": f"Unexpected error: {str(e)}", "code": "INTERNAL_ERROR"}, + ) + + +@router.get("/status") +async def get_status(): + return { + "status": "operational", + "service": "email-generation", + "version": "1.0.0", + "provider": settings.llm_provider, + "model": settings.llm_model, + } diff --git a/src/app/routers/health.py b/src/app/routers/health.py new file mode 100644 index 0000000..952918a --- /dev/null +++ b/src/app/routers/health.py @@ -0,0 +1,100 @@ +from fastapi import APIRouter +from src.app.config import settings + +router = APIRouter(tags=["health"]) + + +@router.get("/healthz") +async def health_check(): + try: + from src.services.chroma_store import ChromaStore + from src.services.embeddings import EmbeddingService + + chroma_store = ChromaStore() + doc_count = chroma_store.get_count() + + embedding_service = EmbeddingService() + + health_status = { + "status": "healthy", + "service": "ai-email-assistant", + "version": "1.0.0", + "timestamp": None, + "components": { + "knowledge_base": { + "status": "healthy" if doc_count > 0 else "warning", + "documents": doc_count, + "persist_dir": settings.chroma_persist_dir, + }, + "embedding_service": { + "status": "healthy", + "provider": settings.llm_provider, + "model": settings.embedding_model, + }, + "llm_service": { + "status": "healthy", + "provider": settings.llm_provider, + "model": settings.llm_model, + }, + }, + "configuration": { + "top_k": settings.top_k, + "top_n_context": settings.top_n_context, + "max_tokens": settings.max_tokens_completion, + "language": settings.service_lang, + }, + } + + import datetime + + health_status["timestamp"] = datetime.datetime.utcnow().isoformat() + "Z" + + if doc_count == 0: + health_status["status"] = "degraded" + health_status["warnings"] = ["No documents in knowledge base"] + + return health_status + + except Exception as e: + return { + "status": "unhealthy", + "service": "ai-email-assistant", + "error": str(e), + "timestamp": None, + } + + +@router.get("/readiness") +async def readiness_check(): + try: + from src.services.chroma_store import ChromaStore + + chroma_store = ChromaStore() + doc_count = chroma_store.get_count() + + if doc_count == 0: + return { + "ready": False, + "reason": "Knowledge base is empty - run ingestion first", + } + + api_key_configured = bool( + settings.openai_api_key + if settings.llm_provider == "openai" + else settings.gemini_api_key + ) + + if not api_key_configured: + return { + "ready": False, + "reason": f"API key not configured for {settings.llm_provider}", + } + + return { + "ready": True, + "documents": doc_count, + "provider": settings.llm_provider, + } + + except Exception as e: + return {"ready": False, "reason": f"Service check failed: {str(e)}"} diff --git a/src/app/routers/ingest.py b/src/app/routers/ingest.py new file mode 100644 index 0000000..5404782 --- /dev/null +++ b/src/app/routers/ingest.py @@ -0,0 +1,108 @@ +import os +from fastapi import APIRouter, HTTPException, Depends, Header +from typing import Optional + +from src.app.config import settings + +router = APIRouter(prefix="/api/v1/admin", tags=["administration"]) + + +def verify_admin_token(authorization: Optional[str] = Header(None)): + if not authorization: + raise HTTPException(status_code=401, detail="Authorization header required") + + if not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Invalid authorization format") + + token = authorization.replace("Bearer ", "") + if token != settings.api_secret_key: + raise HTTPException(status_code=403, detail="Invalid admin token") + + return True + + +@router.post("/ingest") +async def trigger_ingest(recreate: bool = False, _: bool = Depends(verify_admin_token)): + try: + from src.ingest.ingest_cli import run_ingest + + data_dir = "articles_konsol_pro" + platform_file = "data/platform_overview.md" + persist_dir = settings.chroma_persist_dir + + if not os.path.exists(data_dir): + raise HTTPException( + status_code=404, detail=f"Data directory not found: {data_dir}" + ) + + result = run_ingest( + data_dir=data_dir, + platform_md=platform_file, + persist_dir=persist_dir, + recreate=recreate, + chunk_size=settings.chunk_size, + chunk_overlap=settings.chunk_overlap, + embedding_model=settings.embedding_model, + ) + + return { + "status": "success", + "message": "Knowledge base updated successfully", + "details": result, + } + + except Exception as e: + raise HTTPException( + status_code=500, + detail={"error": f"Ingest failed: {str(e)}", "code": "INGEST_ERROR"}, + ) + + +@router.get("/knowledge-base/stats") +async def get_knowledge_base_stats(_: bool = Depends(verify_admin_token)): + try: + from src.services.chroma_store import ChromaStore + + chroma_store = ChromaStore() + doc_count = chroma_store.get_count() + + collection = chroma_store.get_or_create_collection() + + sample_docs = collection.peek(limit=5) + + return { + "total_documents": doc_count, + "collection_name": chroma_store.collection_name, + "persist_directory": chroma_store.persist_directory, + "sample_metadata": [ + doc.get("industry", []) if doc else [] + for doc in (sample_docs.get("metadatas", []) or []) + ], + "embedding_model": settings.embedding_model, + } + + except Exception as e: + raise HTTPException( + status_code=500, + detail={"error": f"Failed to get stats: {str(e)}", "code": "STATS_ERROR"}, + ) + + +@router.delete("/knowledge-base") +async def clear_knowledge_base(_: bool = Depends(verify_admin_token)): + try: + from src.services.chroma_store import ChromaStore + + chroma_store = ChromaStore() + chroma_store.delete_collection() + + return {"status": "success", "message": "Knowledge base cleared successfully"} + + except Exception as e: + raise HTTPException( + status_code=500, + detail={ + "error": f"Failed to clear knowledge base: {str(e)}", + "code": "CLEAR_ERROR", + }, + ) diff --git a/src/graph/__init__.py b/src/graph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/graph/build_graph.py b/src/graph/build_graph.py new file mode 100644 index 0000000..9c3a9bb --- /dev/null +++ b/src/graph/build_graph.py @@ -0,0 +1,82 @@ +from langgraph.graph import StateGraph, END +from typing import Dict, Any + +from src.graph.state import EmailGenerationState +from src.graph.nodes.input_validation import input_validation_node +from src.graph.nodes.feature_extract import feature_extract_node +from src.graph.nodes.build_query import build_query_node +from src.graph.nodes.vector_search import vector_search_node +from src.graph.nodes.context_rank import context_rank_node +from src.graph.nodes.prompt_build import prompt_build_node +from src.graph.nodes.llm_generate import llm_generate_node +from src.graph.nodes.parse_output import parse_output_node +from src.graph.nodes.guardrails import guardrails_node +from src.graph.nodes.return_result import return_result_node + + +def should_continue(state: EmailGenerationState) -> str: + if state.get("error"): + return "error" + return "continue" + + +def create_email_graph() -> StateGraph: + workflow = StateGraph(EmailGenerationState) + + workflow.add_node("input_validation", input_validation_node) + workflow.add_node("feature_extract", feature_extract_node) + workflow.add_node("build_query", build_query_node) + workflow.add_node("vector_search", vector_search_node) + workflow.add_node("context_rank", context_rank_node) + workflow.add_node("prompt_build", prompt_build_node) + workflow.add_node("llm_generate", llm_generate_node) + workflow.add_node("parse_output", parse_output_node) + workflow.add_node("guardrails", guardrails_node) + workflow.add_node("return_result", return_result_node) + + workflow.set_entry_point("input_validation") + + workflow.add_conditional_edges( + "input_validation", + should_continue, + {"continue": "feature_extract", "error": END}, + ) + + workflow.add_conditional_edges( + "feature_extract", should_continue, {"continue": "build_query", "error": END} + ) + + workflow.add_conditional_edges( + "build_query", should_continue, {"continue": "vector_search", "error": END} + ) + + workflow.add_conditional_edges( + "vector_search", should_continue, {"continue": "context_rank", "error": END} + ) + + workflow.add_conditional_edges( + "context_rank", should_continue, {"continue": "prompt_build", "error": END} + ) + + workflow.add_conditional_edges( + "prompt_build", should_continue, {"continue": "llm_generate", "error": END} + ) + + workflow.add_conditional_edges( + "llm_generate", should_continue, {"continue": "parse_output", "error": END} + ) + + workflow.add_conditional_edges( + "parse_output", should_continue, {"continue": "guardrails", "error": END} + ) + + workflow.add_conditional_edges( + "guardrails", should_continue, {"continue": "return_result", "error": END} + ) + + workflow.add_edge("return_result", END) + + return workflow.compile() + + +email_generation_graph = create_email_graph() diff --git a/src/graph/nodes/__init__.py b/src/graph/nodes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/graph/nodes/build_query.py b/src/graph/nodes/build_query.py new file mode 100644 index 0000000..c0687f6 --- /dev/null +++ b/src/graph/nodes/build_query.py @@ -0,0 +1,22 @@ +from src.graph.state import EmailGenerationState +from src.services.retrieval import RetrievalService + + +def build_query_node(state: EmailGenerationState) -> EmailGenerationState: + try: + lead_features = state.get("lead_features") + if not lead_features: + state["error"] = "Lead features are required for query building" + state["error_code"] = "MISSING_FEATURES" + return state + + retrieval_service = RetrievalService(None, None) + retrieval_query = retrieval_service.build_retrieval_query(lead_features) + + state["retrieval_query"] = retrieval_query + return state + + except Exception as e: + state["error"] = f"Query building error: {str(e)}" + state["error_code"] = "QUERY_BUILD_ERROR" + return state diff --git a/src/graph/nodes/context_rank.py b/src/graph/nodes/context_rank.py new file mode 100644 index 0000000..2a88d75 --- /dev/null +++ b/src/graph/nodes/context_rank.py @@ -0,0 +1,36 @@ +from src.graph.state import EmailGenerationState +from src.services.chroma_store import ChromaStore +from src.services.embeddings import EmbeddingService +from src.services.retrieval import RetrievalService + + +def context_rank_node(state: EmailGenerationState) -> EmailGenerationState: + try: + retrieved_chunks = state.get("retrieved_chunks") + lead_features = state.get("lead_features") + + if not retrieved_chunks: + state["error"] = "Retrieved chunks are required for ranking" + state["error_code"] = "MISSING_CHUNKS" + return state + + if not lead_features: + state["error"] = "Lead features are required for ranking" + state["error_code"] = "MISSING_FEATURES" + return state + + chroma_store = ChromaStore() + embedding_service = EmbeddingService() + retrieval_service = RetrievalService(chroma_store, embedding_service) + + ranked_chunks = retrieval_service.rank_chunks(retrieved_chunks, lead_features) + + ranked_context = retrieval_service.build_context(ranked_chunks) + + state["ranked_context"] = ranked_context + return state + + except Exception as e: + state["error"] = f"Context ranking error: {str(e)}" + state["error_code"] = "RANKING_ERROR" + return state diff --git a/src/graph/nodes/feature_extract.py b/src/graph/nodes/feature_extract.py new file mode 100644 index 0000000..d11ae80 --- /dev/null +++ b/src/graph/nodes/feature_extract.py @@ -0,0 +1,126 @@ +from typing import List +from src.graph.state import EmailGenerationState +from src.models.lead import LeadModel, LeadFeatures + + +def feature_extract_node(state: EmailGenerationState) -> EmailGenerationState: + try: + lead_input = state.get("lead_input") + if not lead_input: + state["error"] = "Lead input is required for feature extraction" + state["error_code"] = "MISSING_INPUT" + return state + + lead_model = LeadModel.from_input(lead_input) + + role_category = lead_model.role_category or "other" + industry_tag = lead_model.industry_tag or "other" + + pain_points = _get_pain_points(industry_tag, role_category) + key_benefits = _get_key_benefits(industry_tag, role_category) + search_keywords = _get_search_keywords(industry_tag, role_category) + + lead_features = LeadFeatures( + role_category=role_category, + industry_tag=industry_tag, + pain_points=pain_points, + key_benefits=key_benefits, + search_keywords=search_keywords, + ) + + state["lead_model"] = lead_model + state["lead_features"] = lead_features + return state + + except Exception as e: + state["error"] = f"Feature extraction error: {str(e)}" + state["error_code"] = "FEATURE_EXTRACTION_ERROR" + return state + + +def _get_pain_points(industry_tag: str, role_category: str) -> List[str]: + industry_pain_points = { + "marketing_agency": [ + "быстрое подключение подрядчиков", + "массовые выплаты", + "отчетность", + ], + "logistics": ["управление исполнителями", "документооборот", "география"], + "software": ["работа с фрилансерами", "ИП", "интеграции"], + "retail": ["сезонность", "временные сотрудники", "масштабирование"], + "consulting": ["привлечение экспертов", "проектная работа"], + "construction": ["подрядчики", "документы", "сроки"], + "other": ["автоматизация", "выплаты", "документооборот"], + } + + role_pain_points = { + "tech": ["интеграции", "автоматизация", "API"], + "finance": ["выплаты", "отчетность", "налоги"], + "ops": ["процессы", "управление", "координация"], + "hr": ["онбординг", "документы", "персонал"], + "ceo": ["эффективность", "масштабирование", "затраты"], + "sales": ["процессы продаж", "клиенты"], + "marketing": ["кампании", "подрядчики"], + "other": ["общие бизнес-процессы"], + } + + pain_points = industry_pain_points.get(industry_tag, industry_pain_points["other"]) + pain_points.extend(role_pain_points.get(role_category, role_pain_points["other"])) + + return list(set(pain_points)) + + +def _get_key_benefits(industry_tag: str, role_category: str) -> List[str]: + benefits = [ + "быстрое подключение исполнителей", + "автоматические выплаты", + "сбор документов", + "снижение ошибок", + "экономия времени", + ] + + if industry_tag == "marketing_agency": + benefits.extend(["масштабирование команды", "управление проектами"]) + elif industry_tag == "logistics": + benefits.extend(["географическое покрытие", "отслеживание"]) + elif industry_tag == "software": + benefits.extend(["API интеграции", "техническая поддержка"]) + + if role_category == "tech": + benefits.extend(["техническая интеграция", "автоматизация"]) + elif role_category == "finance": + benefits.extend(["финансовая отчетность", "налоговое планирование"]) + elif role_category == "ops": + benefits.extend(["операционная эффективность", "процессы"]) + + return list(set(benefits)) + + +def _get_search_keywords(industry_tag: str, role_category: str) -> List[str]: + keywords = ["самозанятые", "исполнители", "автоматизация", "выплаты"] + + industry_keywords = { + "marketing_agency": ["маркетинг", "агентство", "подрядчики", "креатив"], + "logistics": ["логистика", "доставка", "перевозки", "склад"], + "software": ["разработка", "IT", "программирование", "фриланс"], + "retail": ["розница", "торговля", "продажи", "сезон"], + "consulting": ["консалтинг", "эксперты", "проекты"], + "construction": ["строительство", "подрядчики", "объекты"], + "other": ["бизнес", "процессы", "команда"], + } + + role_keywords = { + "tech": ["технический", "IT", "разработка", "интеграция"], + "finance": ["финансы", "бухгалтерия", "учет", "налоги"], + "ops": ["операции", "управление", "процессы"], + "hr": ["персонал", "HR", "найм", "кадры"], + "ceo": ["руководство", "стратегия", "развитие"], + "sales": ["продажи", "клиенты", "сделки"], + "marketing": ["маркетинг", "реклама", "продвижение"], + "other": ["общие", "универсальные"], + } + + keywords.extend(industry_keywords.get(industry_tag, industry_keywords["other"])) + keywords.extend(role_keywords.get(role_category, role_keywords["other"])) + + return list(set(keywords)) diff --git a/src/graph/nodes/guardrails.py b/src/graph/nodes/guardrails.py new file mode 100644 index 0000000..eab5d43 --- /dev/null +++ b/src/graph/nodes/guardrails.py @@ -0,0 +1,201 @@ +import re +from typing import List, Dict, Any +from src.graph.state import EmailGenerationState +from src.models.email import EmailDraftClean +from src.models.errors import GuardrailsError + + +def guardrails_node(state: EmailGenerationState) -> EmailGenerationState: + try: + email_draft = state.get("email_draft") + if not email_draft: + state["error"] = "Email draft is required for guardrails check" + state["error_code"] = "MISSING_DRAFT" + return state + + violations = [] + + pii_violations = _check_pii_leakage( + email_draft.subject + " " + email_draft.body + ) + if pii_violations: + violations.extend(pii_violations) + + false_claims = _check_false_claims(email_draft.body, email_draft.used_chunks) + if false_claims: + violations.extend(false_claims) + + tone_violations = _check_tone_appropriateness(email_draft.body) + if tone_violations: + violations.extend(tone_violations) + + if violations: + cleaned_subject, cleaned_body = _apply_fixes( + email_draft.subject, email_draft.body, violations + ) + else: + cleaned_subject = email_draft.subject + cleaned_body = email_draft.body + + llm_output = state.get("llm_output") + lead_model = state.get("lead_model") + ranked_context = state.get("ranked_context") + + meta = { + "locale": lead_model.locale if lead_model else "ru", + "lead_normalized": lead_model.dict() if lead_model else {}, + "used_chunks": email_draft.used_chunks, + "model": llm_output.model if llm_output else "unknown", + "tokens_prompt": llm_output.tokens_prompt if llm_output else 0, + "tokens_completion": llm_output.tokens_completion if llm_output else 0, + "guardrails_violations": len(violations), + "context_chunks_used": len(ranked_context.chunks) if ranked_context else 0, + } + + email_clean = EmailDraftClean( + subject=cleaned_subject, body=cleaned_body, meta=meta + ) + + state["email_clean"] = email_clean + return state + + except GuardrailsError as e: + state["error"] = e.message + state["error_code"] = "GUARDRAILS_ERROR" + state["trace_meta"] = {"violation_type": e.violation_type, "details": e.details} + return state + + except Exception as e: + state["error"] = f"Guardrails check error: {str(e)}" + state["error_code"] = "GUARDRAILS_ERROR" + return state + + +def _check_pii_leakage(text: str) -> List[Dict[str, Any]]: + violations = [] + + email_pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" + phone_pattern = r"\b(?:\+7|8)?[\s\-]?\(?[0-9]{3}\)?[\s\-]?[0-9]{3}[\s\-]?[0-9]{2}[\s\-]?[0-9]{2}\b" + + if re.search(email_pattern, text): + violations.append( + { + "type": "pii_email", + "message": "Email address detected in content", + "severity": "medium", + } + ) + + if re.search(phone_pattern, text): + violations.append( + { + "type": "pii_phone", + "message": "Phone number detected in content", + "severity": "medium", + } + ) + + return violations + + +def _check_false_claims(body: str, used_chunks: List[str]) -> List[Dict[str, Any]]: + violations = [] + + specific_numbers = re.findall( + r"\b\d+%|\b\d+\s*минут|\b\d+\s*дн[ейя]|\b\d+\s*раз", body + ) + + if specific_numbers and not used_chunks: + violations.append( + { + "type": "unverified_metrics", + "message": "Specific metrics mentioned without supporting context", + "severity": "high", + "numbers": specific_numbers, + } + ) + + guarantee_words = ["гарантируем", "гарантия", "обещаем", "100% результат"] + for word in guarantee_words: + if word.lower() in body.lower(): + violations.append( + { + "type": "false_guarantee", + "message": f"Strong guarantee detected: {word}", + "severity": "high", + } + ) + + return violations + + +def _check_tone_appropriateness(body: str) -> List[Dict[str, Any]]: + violations = [] + + aggressive_words = [ + "срочно", + "немедленно", + "прямо сейчас", + "только сегодня", + "последний шанс", + ] + for word in aggressive_words: + if word.lower() in body.lower(): + violations.append( + { + "type": "aggressive_tone", + "message": f"Aggressive language detected: {word}", + "severity": "medium", + } + ) + + if body.count("!") > 2: + violations.append( + { + "type": "excessive_exclamation", + "message": "Too many exclamation marks", + "severity": "low", + } + ) + + return violations + + +def _apply_fixes( + subject: str, body: str, violations: List[Dict[str, Any]] +) -> tuple[str, str]: + fixed_subject = subject + fixed_body = body + + for violation in violations: + if violation["type"] == "false_guarantee": + replacements = { + "гарантируем": "помогаем", + "гарантия": "возможность", + "обещаем": "стремимся", + "100% результат": "хорошие результаты", + } + for old, new in replacements.items(): + fixed_body = re.sub(old, new, fixed_body, flags=re.IGNORECASE) + + elif violation["type"] == "aggressive_tone": + replacements = { + "срочно": "", + "немедленно": "", + "прямо сейчас": "", + "только сегодня": "", + "последний шанс": "", + } + for old, new in replacements.items(): + fixed_body = re.sub(old, new, fixed_body, flags=re.IGNORECASE) + + elif violation["type"] == "excessive_exclamation": + fixed_body = re.sub(r"!{2,}", "!", fixed_body) + fixed_subject = re.sub(r"!{2,}", "!", fixed_subject) + + elif violation["type"] == "unverified_metrics": + for number in violation.get("numbers", []): + replacement = f"в кейсах клиентов {number}" + fixed_body = fixed_body.replace(number, replacement) + + return fixed_subject.strip(), fixed_body.strip() diff --git a/src/graph/nodes/input_validation.py b/src/graph/nodes/input_validation.py new file mode 100644 index 0000000..15b56dc --- /dev/null +++ b/src/graph/nodes/input_validation.py @@ -0,0 +1,43 @@ +from typing import Dict, Any +from src.graph.state import EmailGenerationState +from src.models.lead import LeadInput, LeadModel +from src.models.errors import ValidationError + + +def input_validation_node(state: EmailGenerationState) -> EmailGenerationState: + try: + raw_input = state.get("raw_input") + if not raw_input: + state["error"] = "Raw input is required" + state["error_code"] = "MISSING_INPUT" + return state + + lead_input = LeadInput(**raw_input) + + if not lead_input.contact.strip(): + raise ValidationError("Contact cannot be empty") + + if not lead_input.position.strip(): + raise ValidationError("Position cannot be empty") + + if not lead_input.company_name.strip(): + raise ValidationError("Company name cannot be empty") + + if not lead_input.segment.strip(): + raise ValidationError("Segment cannot be empty") + + lead_model = LeadModel.from_input(lead_input) + + state["lead_input"] = lead_input + state["lead_model"] = lead_model + + return state + + except ValidationError as e: + state["error"] = str(e) + state["error_code"] = "VALIDATION_ERROR" + return state + except Exception as e: + state["error"] = f"Input validation failed: {str(e)}" + state["error_code"] = "VALIDATION_FAILED" + return state diff --git a/src/graph/nodes/llm_generate.py b/src/graph/nodes/llm_generate.py new file mode 100644 index 0000000..38c1175 --- /dev/null +++ b/src/graph/nodes/llm_generate.py @@ -0,0 +1,48 @@ +from src.graph.state import EmailGenerationState +from src.llm_clients.openai_client import OpenAIClient +from src.llm_clients.gemini_client import GeminiClient +from src.models.errors import LLMError +from src.app.config import settings + + +def llm_generate_node(state: EmailGenerationState) -> EmailGenerationState: + try: + prompt_payload = state.get("prompt_payload") + if not prompt_payload: + state["error"] = "Prompt payload is required for LLM generation" + state["error_code"] = "MISSING_PROMPT" + return state + + if settings.llm_provider == "openai": + llm_client = OpenAIClient() + elif settings.llm_provider == "gemini": + llm_client = GeminiClient() + else: + state["error"] = f"Unsupported LLM provider: {settings.llm_provider}" + state["error_code"] = "UNSUPPORTED_PROVIDER" + return state + + llm_output = llm_client.generate_completion( + system_prompt=prompt_payload.system_prompt, + user_prompt=prompt_payload.user_prompt, + max_tokens=prompt_payload.max_tokens, + temperature=prompt_payload.temperature, + ) + + state["llm_output"] = llm_output + return state + + except LLMError as e: + state["error"] = e.message + state["error_code"] = "LLM_ERROR" + state["trace_meta"] = { + "provider": e.provider, + "model": e.model, + "details": e.details, + } + return state + + except Exception as e: + state["error"] = f"LLM generation error: {str(e)}" + state["error_code"] = "GENERATION_ERROR" + return state diff --git a/src/graph/nodes/parse_output.py b/src/graph/nodes/parse_output.py new file mode 100644 index 0000000..24ea513 --- /dev/null +++ b/src/graph/nodes/parse_output.py @@ -0,0 +1,141 @@ +import json +import re +from src.graph.state import EmailGenerationState +from src.models.email import EmailDraft +from src.models.errors import ParseError + + +def parse_output_node(state: EmailGenerationState) -> EmailGenerationState: + try: + llm_output = state.get("llm_output") + if not llm_output: + state["error"] = "LLM output is required for parsing" + state["error_code"] = "MISSING_LLM_OUTPUT" + return state + + content = llm_output.content.strip() + + content = _clean_json_content(content) + + try: + parsed_data = json.loads(content) + except json.JSONDecodeError as e: + parsed_data = _fallback_parse(content) + + if not isinstance(parsed_data, dict): + raise ParseError("Response is not a JSON object", content) + + subject = parsed_data.get("subject", "").strip() + body = parsed_data.get("body", "").strip() + short_reasoning = parsed_data.get("short_reasoning", "") + used_chunks = parsed_data.get("used_chunks", []) + + if not subject: + raise ParseError("Subject is required", content) + + if not body: + raise ParseError("Body is required", content) + + subject = _validate_subject(subject) + body = _validate_body(body) + + email_draft = EmailDraft( + subject=subject, + body=body, + short_reasoning=short_reasoning, + used_chunks=used_chunks if isinstance(used_chunks, list) else [], + ) + + state["email_draft"] = email_draft + return state + + except ParseError as e: + state["error"] = e.message + state["error_code"] = "PARSE_ERROR" + state["trace_meta"] = {"raw_output": e.raw_output[:500], "details": e.details} + return state + + except Exception as e: + state["error"] = f"Output parsing error: {str(e)}" + state["error_code"] = "PARSING_ERROR" + return state + + +def _clean_json_content(content: str) -> str: + content = re.sub(r"^```json\s*", "", content) + content = re.sub(r"\s*```$", "", content) + content = re.sub(r"^```\s*", "", content) + content = content.strip() + return content + + +def _fallback_parse(content: str) -> dict: + lines = content.split("\n") + result = {} + + current_key = None + current_value = [] + + for line in lines: + line = line.strip() + if ":" in line and line.startswith('"') and line.count('"') >= 4: + if current_key: + result[current_key] = "\n".join(current_value) + + parts = line.split(":", 1) + current_key = parts[0].strip('"').strip() + current_value = [parts[1].strip().strip(",").strip('"')] + elif current_key: + current_value.append(line.strip(",").strip('"')) + + if current_key: + result[current_key] = "\n".join(current_value) + + return result + + +def _validate_subject(subject: str) -> str: + if len(subject) > 80: + words = subject.split() + truncated = [] + char_count = 0 + + for word in words: + if char_count + len(word) + 1 <= 77: + truncated.append(word) + char_count += len(word) + 1 + else: + break + + subject = " ".join(truncated) + "..." + + spam_patterns = [ + r"(!{2,})", + r"(СКИДКА|АКЦИЯ|СРОЧНО|БЕСПЛАТНО)", + r"(\$|\€|\₽)", + ] + + for pattern in spam_patterns: + subject = re.sub(pattern, "", subject, flags=re.IGNORECASE) + + return subject.strip() + + +def _validate_body(body: str) -> str: + if len(body) > 2000: + body = body[:1950] + "..." + + required_elements = {"greeting": False, "company_mention": False, "cta": False} + + greetings = ["добрый день", "здравствуйте", "приветствую"] + if any(greeting in body.lower() for greeting in greetings): + required_elements["greeting"] = True + + if "консоль" in body.lower(): + required_elements["company_mention"] = True + + cta_phrases = ["звонок", "демо", "встреча", "обсудить", "покажу"] + if any(phrase in body.lower() for phrase in cta_phrases): + required_elements["cta"] = True + + return body.strip() diff --git a/src/graph/nodes/prompt_build.py b/src/graph/nodes/prompt_build.py new file mode 100644 index 0000000..2c769ed --- /dev/null +++ b/src/graph/nodes/prompt_build.py @@ -0,0 +1,40 @@ +from src.graph.state import EmailGenerationState +from src.services.prompt_templates import PromptBuilder +from src.models.email import PromptPayload +from src.app.config import settings + + +def prompt_build_node(state: EmailGenerationState) -> EmailGenerationState: + try: + lead_model = state.get("lead_model") + ranked_context = state.get("ranked_context") + + if not lead_model: + state["error"] = "Lead model is required for prompt building" + state["error_code"] = "MISSING_LEAD" + return state + + if not ranked_context: + state["error"] = "Ranked context is required for prompt building" + state["error_code"] = "MISSING_CONTEXT" + return state + + prompt_builder = PromptBuilder() + system_prompt, user_prompt = prompt_builder.build_prompt( + lead_model, ranked_context + ) + + prompt_payload = PromptPayload( + system_prompt=system_prompt, + user_prompt=user_prompt, + max_tokens=settings.max_tokens_completion, + temperature=0.7, + ) + + state["prompt_payload"] = prompt_payload + return state + + except Exception as e: + state["error"] = f"Prompt building error: {str(e)}" + state["error_code"] = "PROMPT_BUILD_ERROR" + return state diff --git a/src/graph/nodes/return_result.py b/src/graph/nodes/return_result.py new file mode 100644 index 0000000..d1c25a4 --- /dev/null +++ b/src/graph/nodes/return_result.py @@ -0,0 +1,23 @@ +from src.graph.state import EmailGenerationState +from src.models.email import EmailResponse + + +def return_result_node(state: EmailGenerationState) -> EmailGenerationState: + try: + email_clean = state.get("email_clean") + if not email_clean: + state["error"] = "Clean email is required for result return" + state["error_code"] = "MISSING_CLEAN_EMAIL" + return state + + email_response = EmailResponse( + subject=email_clean.subject, body=email_clean.body, meta=email_clean.meta + ) + + state["email_response"] = email_response + return state + + except Exception as e: + state["error"] = f"Result return error: {str(e)}" + state["error_code"] = "RETURN_ERROR" + return state diff --git a/src/graph/nodes/vector_search.py b/src/graph/nodes/vector_search.py new file mode 100644 index 0000000..b759564 --- /dev/null +++ b/src/graph/nodes/vector_search.py @@ -0,0 +1,39 @@ +from src.graph.state import EmailGenerationState +from src.services.chroma_store import ChromaStore +from src.services.embeddings import EmbeddingService +from src.services.retrieval import RetrievalService +from src.models.errors import RetrievalError + + +def vector_search_node(state: EmailGenerationState) -> EmailGenerationState: + try: + retrieval_query = state.get("retrieval_query") + if not retrieval_query: + state["error"] = "Retrieval query is required for vector search" + state["error_code"] = "MISSING_QUERY" + return state + + chroma_store = ChromaStore() + embedding_service = EmbeddingService() + retrieval_service = RetrievalService(chroma_store, embedding_service) + + chunks = retrieval_service.search_relevant_chunks(retrieval_query) + + if not chunks: + state["error"] = "No relevant documents found" + state["error_code"] = "NO_RESULTS" + return state + + state["retrieved_chunks"] = chunks + return state + + except RetrievalError as e: + state["error"] = e.message + state["error_code"] = "RETRIEVAL_ERROR" + state["trace_meta"] = e.details + return state + + except Exception as e: + state["error"] = f"Vector search error: {str(e)}" + state["error_code"] = "SEARCH_ERROR" + return state diff --git a/src/graph/state.py b/src/graph/state.py new file mode 100644 index 0000000..78eaa9f --- /dev/null +++ b/src/graph/state.py @@ -0,0 +1,30 @@ +from typing import Optional, List, Dict, Any, TypedDict +from src.models.lead import LeadInput, LeadModel, LeadFeatures +from src.models.email import ( + DocChunk, + RetrievalQuery, + RankedContext, + PromptPayload, + LLMRawOutput, + EmailDraft, + EmailDraftClean, + EmailResponse, +) + + +class EmailGenerationState(TypedDict): + raw_input: Optional[Dict[str, Any]] + lead_input: Optional[LeadInput] + lead_model: Optional[LeadModel] + lead_features: Optional[LeadFeatures] + retrieval_query: Optional[RetrievalQuery] + retrieved_chunks: Optional[List[DocChunk]] + ranked_context: Optional[RankedContext] + prompt_payload: Optional[PromptPayload] + llm_output: Optional[LLMRawOutput] + email_draft: Optional[EmailDraft] + email_clean: Optional[EmailDraftClean] + email_response: Optional[EmailResponse] + error: Optional[str] + error_code: Optional[str] + trace_meta: Optional[Dict[str, Any]] diff --git a/src/ingest/__init__.py b/src/ingest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ingest/chunker.py b/src/ingest/chunker.py new file mode 100644 index 0000000..8e78b6c --- /dev/null +++ b/src/ingest/chunker.py @@ -0,0 +1,329 @@ +import re +import tiktoken +from typing import List, Dict, Any +from dataclasses import dataclass + +from src.ingest.loader import Document +from src.models.email import DocChunk + + +@dataclass +class ChunkConfig: + chunk_size: int = 500 + chunk_overlap: int = 100 + min_chunk_size: int = 50 + preserve_context: bool = True + + +class DocumentChunker: + def __init__(self, config: ChunkConfig = None): + self.config = config or ChunkConfig() + self.encoding = tiktoken.get_encoding("cl100k_base") + + def chunk_document(self, document: Document) -> List[DocChunk]: + chunks = [] + + sections = self._split_by_semantic_blocks(document.content) + + if self.config.preserve_context: + sections = self._add_context_bridges(sections) + + chunk_counter = 0 + for section_title, section_content in sections: + section_chunks = self._chunk_section( + section_content, document, section_title, chunk_counter + ) + chunks.extend(section_chunks) + chunk_counter += len(section_chunks) + + return chunks + + def chunk_documents(self, documents: List[Document]) -> List[DocChunk]: + all_chunks = [] + + for document in documents: + doc_chunks = self.chunk_document(document) + all_chunks.extend(doc_chunks) + + return all_chunks + + def _add_context_bridges( + self, sections: List[tuple[str, str]] + ) -> List[tuple[str, str]]: + if len(sections) <= 1: + return sections + + enhanced_sections = [] + + for i, (title, content) in enumerate(sections): + enhanced_content = content + + if i > 0: + prev_title, prev_content = sections[i - 1] + if self._should_add_context(prev_title, title): + context_snippet = self._extract_key_context(prev_content) + if context_snippet: + enhanced_content = f"Контекст: {context_snippet}\n\n{content}" + + enhanced_sections.append((title, enhanced_content)) + + return enhanced_sections + + def _should_add_context(self, prev_title: str, current_title: str) -> bool: + context_pairs = [ + ("Проблема", "Решение"), + ("Решение", "Результат"), + ("О клиенте", "Проблема"), + ] + + for prev_type, curr_type in context_pairs: + if ( + prev_type.lower() in prev_title.lower() + and curr_type.lower() in current_title.lower() + ): + return True + return False + + def _extract_key_context(self, content: str) -> str: + sentences = content.split(".")[:2] + context = ". ".join(sentences).strip() + if len(context) > 150: + context = context[:150] + "..." + return context + + def _split_by_semantic_blocks(self, content: str) -> List[tuple[str, str]]: + sections = [] + lines = content.split("\n") + + current_section = "" + current_title = "Введение" + + semantic_markers = [ + ("Проблема", ["проблема", "вызов", "трудност", "сложност"]), + ("Решение", ["решение", "как", "что сделали", "подход"]), + ("Результат", ["результат", "итог", "достижени", "эффект"]), + ("О клиенте", ["о клиенте", "о компании", "клиент"]), + ] + + for line in lines: + line_stripped = line.strip() + + header_match = re.match(r"^(#{1,3})\s+(.+)$", line_stripped) + + if header_match: + if current_section.strip(): + sections.append((current_title, current_section.strip())) + + new_title = header_match.group(2) + + semantic_title = self._classify_section(new_title.lower()) + current_title = semantic_title if semantic_title else new_title + current_section = "" + else: + current_section += line + "\n" + + if current_section.strip(): + sections.append((current_title, current_section.strip())) + + if not sections: + sections = [("Основной контент", content)] + + merged_sections = self._merge_small_sections(sections) + + return merged_sections + + def _classify_section(self, title_lower: str) -> str: + classifications = { + "Проблема": [ + "проблема", + "вызов", + "трудност", + "сложност", + "боль", + "до внедрения", + "было", + "раньше", + ], + "Решение": [ + "решение", + "как", + "что сделали", + "подход", + "внедрение", + "переход", + "процесс", + "теперь", + ], + "Результат": [ + "результат", + "итог", + "достижени", + "эффект", + "выгода", + "метрики", + "экономия", + "улучшен", + ], + "О клиенте": ["о клиенте", "о компании", "клиент", "заказчик", "компания"], + "Процесс": [ + "как работает", + "алгоритм", + "этапы", + "процесс", + "схема", + "шаги", + ], + } + + for semantic_name, keywords in classifications.items(): + if any(keyword in title_lower for keyword in keywords): + return semantic_name + + return None + + def _merge_small_sections( + self, sections: List[tuple[str, str]] + ) -> List[tuple[str, str]]: + merged = [] + i = 0 + + while i < len(sections): + title, content = sections[i] + token_count = len(self.encoding.encode(content)) + + if token_count < self.config.min_chunk_size and i + 1 < len(sections): + next_title, next_content = sections[i + 1] + merged_content = f"{content}\n\n{next_content}" + merged_title = f"{title} / {next_title}" + merged.append((merged_title, merged_content)) + i += 2 + else: + merged.append((title, content)) + i += 1 + + return merged + + def _chunk_section( + self, + section_content: str, + document: Document, + section_title: str, + start_counter: int, + ) -> List[DocChunk]: + chunks = [] + + tokens = self.encoding.encode(section_content) + total_tokens = len(tokens) + + if total_tokens <= self.config.chunk_size: + chunk = self._create_chunk( + section_content, document, section_title, start_counter + ) + return [chunk] + + sentences = self._split_into_sentences(section_content) + + current_chunk = "" + current_tokens = 0 + chunk_idx = 0 + + for sentence in sentences: + sentence_tokens = len(self.encoding.encode(sentence)) + + if current_tokens + sentence_tokens > self.config.chunk_size: + if current_chunk.strip(): + chunk = self._create_chunk( + current_chunk.strip(), + document, + section_title, + start_counter + chunk_idx, + ) + chunks.append(chunk) + chunk_idx += 1 + + overlap_text = self._get_overlap_text( + current_chunk, self.config.chunk_overlap + ) + current_chunk = overlap_text + sentence + current_tokens = len(self.encoding.encode(current_chunk)) + else: + current_chunk += sentence + current_tokens += sentence_tokens + + if current_chunk.strip(): + chunk = self._create_chunk( + current_chunk.strip(), + document, + section_title, + start_counter + chunk_idx, + ) + chunks.append(chunk) + + return chunks + + def _split_into_sentences(self, text: str) -> List[str]: + sentence_endings = r"[.!?]+\s+" + sentences = re.split(sentence_endings, text) + + result = [] + for i, sentence in enumerate(sentences): + if sentence.strip(): + if i < len(sentences) - 1: + sentence += ". " + result.append(sentence) + + return result + + def _get_overlap_text(self, text: str, overlap_tokens: int) -> str: + tokens = self.encoding.encode(text) + if len(tokens) <= overlap_tokens: + return text + + overlap_tokens_slice = tokens[-overlap_tokens:] + overlap_text = self.encoding.decode(overlap_tokens_slice) + + sentences = overlap_text.split(".") + if len(sentences) > 1: + return ". ".join(sentences[1:]) + ". " + + return overlap_text + " " + + def _create_chunk( + self, content: str, document: Document, section_title: str, chunk_index: int + ) -> DocChunk: + chunk_id = f"{document.doc_id}#c{chunk_index}" + + token_count = len(self.encoding.encode(content)) + + metadata = document.metadata.copy() + metadata.update( + { + "section_title": section_title, + "parent_doc_title": document.metadata.get("title", "Unknown"), + "chunk_index": chunk_index, + "doc_source": document.source, + "semantic_type": self._get_semantic_type(section_title), + } + ) + + return DocChunk( + chunk_id=chunk_id, + parent_doc_id=document.doc_id, + content_text=content, + token_count=token_count, + metadata=metadata, + ) + + def _get_semantic_type(self, section_title: str) -> str: + title_lower = section_title.lower() + + if any(word in title_lower for word in ["проблема", "вызов", "трудност"]): + return "problem" + elif any(word in title_lower for word in ["решение", "подход", "внедрение"]): + return "solution" + elif any(word in title_lower for word in ["результат", "итог", "эффект"]): + return "result" + elif any(word in title_lower for word in ["клиент", "компани"]): + return "client_info" + else: + return "general" diff --git a/src/ingest/ingest_cli.py b/src/ingest/ingest_cli.py new file mode 100644 index 0000000..f43c4eb --- /dev/null +++ b/src/ingest/ingest_cli.py @@ -0,0 +1,241 @@ +import os +import sys +import argparse +import hashlib +from typing import List, Dict, Any + +sys.path.append( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) + +from src.ingest.loader import MarkdownLoader, create_platform_overview +from src.ingest.chunker import DocumentChunker, ChunkConfig +from src.services.chroma_store import ChromaStore +from src.services.embeddings import EmbeddingService +from src.app.config import settings + + +def get_file_hash(file_path: str) -> str: + with open(file_path, "rb") as f: + return hashlib.md5(f.read()).hexdigest() + + +def run_ingest( + data_dir: str, + platform_md: str = None, + persist_dir: str = None, + recreate: bool = False, + incremental: bool = True, + chunk_size: int = 500, + chunk_overlap: int = 100, + embedding_model: str = None, +) -> Dict[str, Any]: + + print(f"Starting ingestion process...") + print(f"Data directory: {data_dir}") + print(f"Persist directory: {persist_dir or settings.chroma_persist_dir}") + print(f"Recreate: {recreate}") + print(f"Incremental: {incremental}") + + if not os.path.exists(data_dir): + raise FileNotFoundError(f"Data directory not found: {data_dir}") + + chroma_store = ChromaStore(persist_dir) + + if recreate: + print("Recreating collection...") + chroma_store.recreate_collection() + incremental = False + + loader = MarkdownLoader() + + documents_to_process = [] + existing_doc_ids = set() + + if incremental and not recreate: + existing_doc_ids = chroma_store.get_existing_doc_ids() + print(f"Found {len(existing_doc_ids)} existing documents in collection") + + print("Loading markdown files...") + all_documents = loader.load_directory(data_dir) + + new_docs = 0 + updated_docs = 0 + skipped_docs = 0 + + for doc in all_documents: + doc_id = doc.doc_id + + if incremental and doc_id in existing_doc_ids: + if chroma_store.document_exists(doc_id): + print(f"Skipping existing document: {doc_id}") + skipped_docs += 1 + continue + + if incremental and doc_id in existing_doc_ids: + print(f"Updating document: {doc_id}") + chroma_store.delete_document_chunks(doc_id) + updated_docs += 1 + else: + new_docs += 1 + + documents_to_process.append(doc) + + if platform_md and os.path.exists(platform_md): + platform_doc = loader.load_file(platform_md) + if not incremental or "platform_overview" not in existing_doc_ids: + documents_to_process.append(platform_doc) + print("Added platform overview from file") + else: + platform_doc = create_platform_overview() + if not incremental or "platform_overview" not in existing_doc_ids: + documents_to_process.append(platform_doc) + print("Added generated platform overview") + + print( + f"Processing {len(documents_to_process)} documents ({new_docs} new, {updated_docs} updated, {skipped_docs} skipped)" + ) + + if not documents_to_process: + print("No documents to process. Collection is up to date.") + return { + "documents_loaded": 0, + "chunks_created": 0, + "embeddings_generated": 0, + "total_in_collection": chroma_store.get_count(), + "new_documents": new_docs, + "updated_documents": updated_docs, + "skipped_documents": skipped_docs, + "embedding_model": embedding_model or settings.embedding_model, + "chunk_config": {"size": chunk_size, "overlap": chunk_overlap}, + } + + chunk_config = ChunkConfig(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + chunker = DocumentChunker(chunk_config) + + print("Chunking documents...") + chunks = chunker.chunk_documents(documents_to_process) + print(f"Created {len(chunks)} chunks") + + if chunks: + embedding_service = EmbeddingService() + + print("Generating embeddings...") + texts = [chunk.content_text for chunk in chunks] + embeddings = embedding_service.embed_batch(texts) + print(f"Generated embeddings for {len(embeddings)} chunks") + + print("Storing in ChromaDB...") + chroma_store.add_chunks(chunks, embeddings) + + final_count = chroma_store.get_count() + print(f"Ingestion complete. Total documents in collection: {final_count}") + + return { + "documents_loaded": len(documents_to_process), + "chunks_created": len(chunks), + "embeddings_generated": len(embeddings) if chunks else 0, + "total_in_collection": final_count, + "new_documents": new_docs, + "updated_documents": updated_docs, + "skipped_documents": skipped_docs, + "embedding_model": embedding_model or settings.embedding_model, + "chunk_config": {"size": chunk_size, "overlap": chunk_overlap}, + } + + +def create_data_directory(): + os.makedirs("data/articles_konsol_pro", exist_ok=True) + + sample_content = """# Пример кейса + +Компания FIVE (маркетинговое агентство) автоматизировала работу с самозанятыми через Консоль.Про. + +## Результаты +- Сокращение времени онбординга с 2 дней до 15 минут +- Автоматизация выплат - теперь занимает минуты вместо часов +- Снижение ошибок на 95% +- Управление 200+ исполнителями одним сотрудником + +## Внедрение +Технический директор Никита Помящий отмечает, что внедрение заняло всего 1 день. +""" + + with open("data/articles_konsol_pro/sample_case.md", "w", encoding="utf-8") as f: + f.write(sample_content) + + print("Created sample data directory and file") + + +def main(): + parser = argparse.ArgumentParser( + description="Ingest knowledge base data into ChromaDB" + ) + + parser.add_argument( + "--data-dir", + default="data/articles_konsol_pro", + help="Directory containing markdown files", + ) + parser.add_argument( + "--platform-md", + default="data/platform_overview.md", + help="Platform overview markdown file", + ) + parser.add_argument( + "--persist-dir", default=None, help="ChromaDB persist directory" + ) + parser.add_argument( + "--recreate", + action="store_true", + help="Recreate the collection (delete existing)", + ) + parser.add_argument( + "--no-incremental", + action="store_true", + help="Disable incremental updates (process all files)", + ) + parser.add_argument( + "--chunk-size", type=int, default=500, help="Chunk size in tokens" + ) + parser.add_argument( + "--chunk-overlap", type=int, default=100, help="Chunk overlap in tokens" + ) + parser.add_argument( + "--embedding-model", default=None, help="Embedding model to use" + ) + parser.add_argument( + "--create-sample", + action="store_true", + help="Create sample data directory and files", + ) + + args = parser.parse_args() + + if args.create_sample: + create_data_directory() + return + + try: + result = run_ingest( + data_dir=args.data_dir, + platform_md=args.platform_md, + persist_dir=args.persist_dir, + recreate=args.recreate, + incremental=not args.no_incremental, + chunk_size=args.chunk_size, + chunk_overlap=args.chunk_overlap, + embedding_model=args.embedding_model, + ) + + print("\nIngestion Summary:") + for key, value in result.items(): + print(f" {key}: {value}") + + except Exception as e: + print(f"Error during ingestion: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/ingest/loader.py b/src/ingest/loader.py new file mode 100644 index 0000000..1cec84a --- /dev/null +++ b/src/ingest/loader.py @@ -0,0 +1,337 @@ +import os +import re +from typing import List, Dict, Any +from dataclasses import dataclass +import markdown + + +@dataclass +class Document: + content: str + metadata: Dict[str, Any] + doc_id: str + source: str + + +class MarkdownLoader: + def __init__(self): + self.md = markdown.Markdown(extensions=["meta", "toc"]) + + def load_file(self, file_path: str) -> Document: + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + html_content = self.md.convert(content) + + title = self._extract_title(content) + doc_id = self._generate_doc_id(file_path) + + metadata = self._extract_metadata(content, file_path) + metadata.update( + { + "title": title, + "doc_type": "case" if "case" in file_path.lower() else "info", + "source": file_path, + "file_size": len(content), + } + ) + + clean_content = self._clean_content(content) + + return Document( + content=clean_content, metadata=metadata, doc_id=doc_id, source=file_path + ) + + def load_directory(self, dir_path: str) -> List[Document]: + documents = [] + + for root, dirs, files in os.walk(dir_path): + for file in files: + if file.endswith(".md"): + file_path = os.path.join(root, file) + try: + doc = self.load_file(file_path) + documents.append(doc) + except Exception as e: + print(f"Error loading {file_path}: {e}") + + return documents + + def _extract_title(self, content: str) -> str: + lines = content.strip().split("\n") + for line in lines: + line = line.strip() + if line.startswith("# "): + title = line[2:].strip() + if title and not title.startswith("["): + return title + + return "Untitled" + + def _generate_doc_id(self, file_path: str) -> str: + filename = os.path.basename(file_path) + name_without_ext = os.path.splitext(filename)[0] + return name_without_ext.replace(" ", "_").replace("-", "_").lower() + + def _extract_metadata(self, content: str, file_path: str) -> Dict[str, Any]: + metadata = {} + + filename = os.path.basename(file_path).lower() + content_lower = content.lower() + + industry_mapping = { + "маркетинг": "marketing_agency", + "агентство": "marketing_agency", + "реклам": "marketing_agency", + "блогер": "marketing_agency", + "mediar": "marketing_agency", + "büro": "marketing_agency", + "логист": "logistics", + "достав": "logistics", + "склад": "logistics", + "грузчик": "logistics", + "разраб": "software", + "програм": "software", + "progkids": "software", + "it": "software", + "диджитал": "software", + "строит": "construction", + "недвиж": "construction", + "этажи": "construction", + "рознич": "retail", + "торгов": "retail", + "консалт": "consulting", + "экобренд": "manufacturing", + "wonder": "manufacturing", + "производ": "manufacturing", + "колл-центр": "call_center", + "звонки": "call_center", + } + + industries = [] + for keyword, industry in industry_mapping.items(): + if keyword in content_lower or keyword in filename: + if industry not in industries: + industries.append(industry) + + if not industries: + industries = ["other"] + + metadata["industry"] = industries + + roles_mapping = { + "технический директор": "tech", + "техн": "tech", + "cto": "tech", + "операционный директор": "ops", + "директор": "ceo", + "руководи": "ceo", + "основатель": "ceo", + "фин": "finance", + "бухгалт": "finance", + "cfo": "finance", + "операц": "ops", + "coo": "ops", + "hr": "hr", + "кадр": "hr", + "маркет": "marketing", + "продаж": "sales", + "менеджер": "other", + } + + roles = [] + for keyword, role in roles_mapping.items(): + if keyword in content_lower: + if role not in roles: + roles.append(role) + + if not roles: + roles = ["other"] + + metadata["roles_relevant"] = roles + + metrics = self._extract_metrics(content) + if metrics: + metadata["metrics"] = metrics + + metadata["language"] = "ru" + + import datetime + + metadata["created_at"] = datetime.datetime.now().isoformat() + metadata["updated_at"] = datetime.datetime.now().isoformat() + + return metadata + + def _extract_metrics(self, content: str) -> Dict[str, Any]: + metrics = {} + + time_patterns = [ + (r"(\d+)\s*минут[ауы]?", "processing_minutes"), + (r"(\d+)\s*час[ауов]?", "processing_hours"), + (r"(\d+)\s*дн[ейяах]", "processing_days"), + ( + r"с\s+(\d+)\s*дн[ейя]\s+до\s+(\d+)\s*минут", + "improvement_days_to_minutes", + ), + ( + r"с\s+(\d+)\s*час[ауов]?\s+до\s+(\d+)\s*минут", + "improvement_hours_to_minutes", + ), + (r"(\d+)\s*секунд", "processing_seconds"), + ] + + for pattern, key in time_patterns: + matches = re.findall(pattern, content, re.IGNORECASE) + if matches: + try: + if key.startswith("improvement_"): + if len(matches[0]) == 2: + metrics[f"{key}_before"] = int(matches[0][0]) + metrics[f"{key}_after"] = int(matches[0][1]) + else: + metrics[key] = int(matches[0]) + else: + metrics[key] = int(matches[0]) + except (ValueError, IndexError): + pass + + percentage_patterns = [ + (r"(\d+)%\s*снижени", "error_reduction_pct"), + (r"снижение[^0-9]*(\d+)%", "error_reduction_pct"), + (r"(\d+)%\s*документ", "document_collection_pct"), + (r"(\d+)%\s*точност", "accuracy_pct"), + (r"увеличи[лв]\w*\s+в\s+(\d+)\s*раз", "growth_multiplier"), + ] + + for pattern, key in percentage_patterns: + matches = re.findall(pattern, content, re.IGNORECASE) + if matches: + try: + metrics[key] = int(matches[0]) + except ValueError: + pass + + volume_patterns = [ + (r"(\d+)\s*блогер", "bloggers_count"), + (r"(\d+)\s*исполнител", "contractors_count"), + (r"(\d+)\s*сотрудник", "employees_count"), + (r"бол[ьеее]+\s+(\d+)", "more_than_count"), + (r"свыше\s+(\d+)", "over_count"), + ] + + for pattern, key in volume_patterns: + matches = re.findall(pattern, content, re.IGNORECASE) + if matches: + try: + metrics[key] = int(matches[0]) + except ValueError: + pass + + return metrics + + def _clean_content(self, content: str) -> str: + content = re.sub(r"^\s*#+\s*", "", content, flags=re.MULTILINE) + + content = re.sub(r"\*\*(.*?)\*\*", r"\1", content) + content = re.sub(r"\*(.*?)\*", r"\1", content) + + content = re.sub(r"\[([^\]]+)\]\([^\)]*\)", r"\1", content) + + content = re.sub(r"^\s*[-*+]\s+", "• ", content, flags=re.MULTILINE) + + content = re.sub( + r"\d+\s+\d+\s+\[Комментировать\]\(\)\s+\d{2}\.\d{2}\.\d{2}", "", content + ) + + noise_patterns = [ + r"Автор и редактор журнала Консоль", + r"Автор\s+\[.*?\]\(\)", + r"Поделиться", + r"Ваше мнение\?", + r"Отлично\s+Хорошо\s+Нормально\s+Плохо\s+Ужасно", + r"Сайт использует файлы cookie.*?Принять", + r"\[Политика конфиденциальности\]\(\)", + r"\[Пользовательское соглашение\]\(\)", + r"hello@konsol\.pro", + r"\+7 \(\d{3}\) \d{3}-\d{2}-\d{2}", + r"125047.*?дом \d+", + r"\[Разработка — SKDO\]\(\)", + r"\[Подключиться к Консоли\]\(\)", + r"\[Кейсы наших клиентов\]\(\)", + r"\[Делимся экспертизой\]\(\)", + r"^\s*\d+\s*$", + r"^\s*\[\d+\]\(\)\s*\d{2}\.\d{2}\.\d{2}\s*$", + ] + + for pattern in noise_patterns: + content = re.sub(pattern, "", content, flags=re.MULTILINE | re.IGNORECASE) + + related_articles_pattern = r"###\s+\[.*?\]\(\).*?(?=###|\Z)" + content = re.sub(related_articles_pattern, "", content, flags=re.DOTALL) + + content = re.sub(r"\n{3,}", "\n\n", content) + + lines = content.split("\n") + filtered_lines = [] + for line in lines: + line = line.strip() + if line and not (line.startswith("[") and line.endswith("]()")): + if not re.match(r"^\d+\s*$", line): + if len(line) > 10 or line.startswith("•"): + filtered_lines.append(line) + + content = "\n".join(filtered_lines) + + return content.strip() + + +def create_platform_overview() -> Document: + content = """ +Консоль.Про – платформа автоматизации работы с самозанятыми, ИП и физлицами. + +Основные возможности: +• Подключение нового исполнителя за ~15 минут +• Выплаты в течение минут вместо часов +• Сбор 100% закрывающих документов +• Снижение ошибок до 95% +• Управление сотнями исполнителей одним сотрудником +• API интеграции для автоматизации процессов +• Автоматический сбор чеков и документов +• Снижение времени онбординга с 2 дней до ~20 минут + +Платформа решает ключевые задачи бизнеса: +• Быстрое масштабирование команды исполнителей +• Автоматизация документооборота и выплат +• Снижение операционных затрат +• Обеспечение налогового соответствия +• Упрощение работы с подрядчиками + +Внедрение платформы занимает около 1 дня. + """ + + metadata = { + "title": "Платформа Консоль.Про - Обзор", + "doc_type": "platform_overview", + "source": "internal", + "industry": ["generic"], + "roles_relevant": ["tech", "finance", "ops", "ceo"], + "metrics": { + "onboarding_minutes": 15, + "onboarding_days_before": 2, + "onboarding_minutes_after": 20, + "error_reduction_pct": 95, + "document_collection_pct": 100, + "implementation_days": 1, + }, + "language": "ru", + "created_at": "2024-01-01T00:00:00", + "updated_at": "2024-01-01T00:00:00", + } + + return Document( + content=content.strip(), + metadata=metadata, + doc_id="platform_overview", + source="platform_overview.md", + ) diff --git a/src/llm_clients/__init__.py b/src/llm_clients/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/llm_clients/base.py b/src/llm_clients/base.py new file mode 100644 index 0000000..5d05ae9 --- /dev/null +++ b/src/llm_clients/base.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Tuple +from src.models.email import LLMRawOutput + + +class LLMClient(ABC): + @abstractmethod + def generate_completion( + self, + system_prompt: str, + user_prompt: str, + max_tokens: int = 1024, + temperature: float = 0.7, + **kwargs + ) -> LLMRawOutput: + pass + + @abstractmethod + def count_tokens(self, text: str) -> int: + pass diff --git a/src/llm_clients/gemini_client.py b/src/llm_clients/gemini_client.py new file mode 100644 index 0000000..a39a6b9 --- /dev/null +++ b/src/llm_clients/gemini_client.py @@ -0,0 +1,56 @@ +import google.generativeai as genai +import tiktoken +from typing import Dict, Any + +from src.llm_clients.base import LLMClient +from src.models.email import LLMRawOutput +from src.models.errors import LLMError +from src.app.config import settings + + +class GeminiClient(LLMClient): + def __init__(self): + if not settings.gemini_api_key: + raise ValueError("Gemini API key is required") + + genai.configure(api_key=settings.gemini_api_key) + self.model = genai.GenerativeModel("gemini-pro") + self.encoding = tiktoken.get_encoding("cl100k_base") + + def generate_completion( + self, + system_prompt: str, + user_prompt: str, + max_tokens: int = 1024, + temperature: float = 0.7, + **kwargs, + ) -> LLMRawOutput: + try: + prompt = f"Системная инструкция: {system_prompt}\n\nЗапрос пользователя: {user_prompt}" + + generation_config = genai.types.GenerationConfig( + max_output_tokens=max_tokens, + temperature=temperature, + ) + + response = self.model.generate_content( + prompt, generation_config=generation_config + ) + + content = response.text + + prompt_tokens = self.count_tokens(prompt) + completion_tokens = self.count_tokens(content) + + return LLMRawOutput( + content=content, + tokens_prompt=prompt_tokens, + tokens_completion=completion_tokens, + model="gemini-pro", + ) + + except Exception as e: + raise LLMError(f"Gemini generation error: {str(e)}", "gemini", "gemini-pro") + + def count_tokens(self, text: str) -> int: + return len(self.encoding.encode(text)) diff --git a/src/llm_clients/openai_client.py b/src/llm_clients/openai_client.py new file mode 100644 index 0000000..949d60a --- /dev/null +++ b/src/llm_clients/openai_client.py @@ -0,0 +1,55 @@ +import openai +import tiktoken +from typing import Dict, Any + +from src.llm_clients.base import LLMClient +from src.models.email import LLMRawOutput +from src.models.errors import LLMError +from src.app.config import settings + + +class OpenAIClient(LLMClient): + def __init__(self): + if not settings.openai_api_key: + raise ValueError("OpenAI API key is required") + + self.client = openai.OpenAI(api_key=settings.openai_api_key) + self.model = settings.llm_model + self.encoding = tiktoken.encoding_for_model(self.model) + + def generate_completion( + self, + system_prompt: str, + user_prompt: str, + max_tokens: int = 1024, + temperature: float = 0.7, + **kwargs, + ) -> LLMRawOutput: + try: + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + response_format={"type": "json_object"}, + ) + + content = response.choices[0].message.content + + return LLMRawOutput( + content=content, + tokens_prompt=response.usage.prompt_tokens, + tokens_completion=response.usage.completion_tokens, + model=self.model, + ) + + except Exception as e: + raise LLMError(f"OpenAI generation error: {str(e)}", "openai", self.model) + + def count_tokens(self, text: str) -> int: + return len(self.encoding.encode(text)) diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/email.py b/src/models/email.py new file mode 100644 index 0000000..7160106 --- /dev/null +++ b/src/models/email.py @@ -0,0 +1,63 @@ +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field + + +class DocChunk(BaseModel): + chunk_id: str + parent_doc_id: str + content_text: str + token_count: int + metadata: Dict[str, Any] + similarity_score: Optional[float] = None + + +class RetrievalQuery(BaseModel): + text_query: str + vector_query: Optional[List[float]] = None + metadata_filters: Optional[Dict[str, Any]] = None + k: int = 30 + + +class RankedContext(BaseModel): + chunks: List[DocChunk] + total_tokens: int + summary_bullets: List[str] + + +class PromptPayload(BaseModel): + system_prompt: str + user_prompt: str + max_tokens: int + temperature: float = 0.7 + + +class LLMRawOutput(BaseModel): + content: str + tokens_prompt: int + tokens_completion: int + model: str + + +class EmailDraft(BaseModel): + subject: str = Field(max_length=80) + body: str = Field(max_length=2000) + short_reasoning: Optional[str] = None + used_chunks: List[str] = [] + + +class EmailDraftClean(BaseModel): + subject: str + body: str + meta: Dict[str, Any] + + +class EmailResponse(BaseModel): + subject: str + body: str + meta: Dict[str, Any] = Field(default_factory=dict) + + +class ErrorResponse(BaseModel): + error: str + code: str + details: Optional[Dict[str, Any]] = None diff --git a/src/models/errors.py b/src/models/errors.py new file mode 100644 index 0000000..4df55f6 --- /dev/null +++ b/src/models/errors.py @@ -0,0 +1,53 @@ +from typing import Optional, Dict, Any + + +class ValidationError(Exception): + def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): + self.message = message + self.details = details or {} + super().__init__(self.message) + + +class RetrievalError(Exception): + def __init__(self, message: str, details: Optional[Dict[str, Any]] = None): + self.message = message + self.details = details or {} + super().__init__(self.message) + + +class LLMError(Exception): + def __init__( + self, + message: str, + provider: str, + model: str, + details: Optional[Dict[str, Any]] = None, + ): + self.message = message + self.provider = provider + self.model = model + self.details = details or {} + super().__init__(self.message) + + +class ParseError(Exception): + def __init__( + self, message: str, raw_output: str, details: Optional[Dict[str, Any]] = None + ): + self.message = message + self.raw_output = raw_output + self.details = details or {} + super().__init__(self.message) + + +class GuardrailsError(Exception): + def __init__( + self, + message: str, + violation_type: str, + details: Optional[Dict[str, Any]] = None, + ): + self.message = message + self.violation_type = violation_type + self.details = details or {} + super().__init__(self.message) diff --git a/src/models/lead.py b/src/models/lead.py new file mode 100644 index 0000000..883a2c0 --- /dev/null +++ b/src/models/lead.py @@ -0,0 +1,105 @@ +from typing import Literal, Optional +from pydantic import BaseModel, ConfigDict, EmailStr, Field + + +class LeadInput(BaseModel): + contact: str = Field(min_length=1, max_length=100) + position: str = Field(min_length=1, max_length=100) + company_name: str = Field(min_length=1, max_length=100) + segment: str = Field(min_length=1, max_length=100) + email: Optional[EmailStr] = None + locale: Literal["ru", "en"] = "ru" + notes: Optional[str] = Field(default=None, max_length=500) + + model_config = ConfigDict(populate_by_name=True) + + +class LeadModel(BaseModel): + contact_name: str + contact_first_name: Optional[str] = None + contact_last_name: Optional[str] = None + role_title: str + role_category: Optional[ + Literal["tech", "finance", "ops", "hr", "ceo", "sales", "marketing", "other"] + ] = None + company_name: str + industry_segment: Optional[str] = None + industry_tag: Optional[str] = None + email: Optional[EmailStr] = None + locale: Literal["ru", "en"] = "ru" + notes: Optional[str] = None + + @classmethod + def from_input(cls, lead_input: LeadInput) -> "LeadModel": + contact_parts = lead_input.contact.strip().split() + first_name = contact_parts[0] if contact_parts else "" + last_name = " ".join(contact_parts[1:]) if len(contact_parts) > 1 else None + + role_category = cls._extract_role_category(lead_input.position) + industry_tag = cls._extract_industry_tag(lead_input.segment) + + return cls( + contact_name=lead_input.contact, + contact_first_name=first_name, + contact_last_name=last_name, + role_title=lead_input.position, + role_category=role_category, + company_name=lead_input.company_name, + industry_segment=lead_input.segment, + industry_tag=industry_tag, + email=lead_input.email, + locale=lead_input.locale, + notes=lead_input.notes, + ) + + @staticmethod + def _extract_role_category(role_title: str) -> Optional[str]: + role_lower = role_title.lower() + + if any( + word in role_lower for word in ["техн", "cto", "it", "разраб", "программ"] + ): + return "tech" + elif any(word in role_lower for word in ["фин", "бухг", "казнач", "cfo"]): + return "finance" + elif any(word in role_lower for word in ["операц", "coo", "управл"]): + return "ops" + elif any(word in role_lower for word in ["hr", "кадр", "персонал"]): + return "hr" + elif any( + word in role_lower for word in ["гендир", "ceo", "директор", "руковод"] + ): + return "ceo" + elif any(word in role_lower for word in ["продаж", "коммерч", "sales"]): + return "sales" + elif any(word in role_lower for word in ["маркет", "реклам", "pr"]): + return "marketing" + else: + return "other" + + @staticmethod + def _extract_industry_tag(segment: str) -> Optional[str]: + segment_lower = segment.lower() + + if any(word in segment_lower for word in ["маркет", "реклам", "агентс"]): + return "marketing_agency" + elif any(word in segment_lower for word in ["логист", "достав", "склад"]): + return "logistics" + elif any(word in segment_lower for word in ["разраб", "софт", "it", "програм"]): + return "software" + elif any(word in segment_lower for word in ["рознич", "торгов", "магазин"]): + return "retail" + elif any(word in segment_lower for word in ["консалт", "консульт"]): + return "consulting" + elif any(word in segment_lower for word in ["строит", "недвиж"]): + return "construction" + else: + return "other" + + +class LeadFeatures(BaseModel): + role_category: str + industry_tag: str + pain_points: list[str] + key_benefits: list[str] + search_keywords: list[str] diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/services/chroma_store.py b/src/services/chroma_store.py new file mode 100644 index 0000000..92abe9b --- /dev/null +++ b/src/services/chroma_store.py @@ -0,0 +1,169 @@ +import os +from typing import List, Dict, Any, Optional, Set +import chromadb +from chromadb.config import Settings as ChromaSettings +from chromadb.api.models.Collection import Collection + +from src.models.email import DocChunk +from src.app.config import settings + + +class ChromaStore: + def __init__(self, persist_directory: str = None): + self.persist_directory = persist_directory or settings.chroma_persist_dir + os.makedirs(self.persist_directory, exist_ok=True) + + self.client = chromadb.PersistentClient( + path=self.persist_directory, + settings=ChromaSettings(anonymized_telemetry=False), + ) + self.collection_name = "knowledge_base" + self.collection: Optional[Collection] = None + + def get_or_create_collection(self) -> Collection: + if self.collection is None: + self.collection = self.client.get_or_create_collection( + name=self.collection_name, metadata={"hnsw:space": "cosine"} + ) + return self.collection + + def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: + sanitized = {} + + for key, value in metadata.items(): + if isinstance(value, list): + sanitized[key] = ", ".join(str(item) for item in value) + elif isinstance(value, dict): + import json + + sanitized[key] = json.dumps(value, ensure_ascii=False) + elif isinstance(value, (str, int, float, bool)) or value is None: + sanitized[key] = value + else: + sanitized[key] = str(value) + + return sanitized + + def get_existing_doc_ids(self) -> Set[str]: + collection = self.get_or_create_collection() + + try: + result = collection.get(include=["metadatas"]) + if result and result.get("metadatas"): + doc_ids = set() + for metadata in result["metadatas"]: + if metadata and "source" in metadata: + source = metadata["source"] + filename = os.path.basename(source) + doc_id = ( + os.path.splitext(filename)[0] + .replace(" ", "_") + .replace("-", "_") + .lower() + ) + doc_ids.add(doc_id) + return doc_ids + except Exception: + pass + + return set() + + def document_exists(self, doc_id: str) -> bool: + collection = self.get_or_create_collection() + + try: + result = collection.get(where={"parent_doc_id": doc_id}, limit=1) + return bool(result and result.get("ids")) + except Exception: + return False + + def delete_document_chunks(self, doc_id: str): + collection = self.get_or_create_collection() + + try: + result = collection.get( + where={"parent_doc_id": doc_id}, include=["metadatas"] + ) + + if result and result.get("ids"): + chunk_ids = result["ids"] + collection.delete(ids=chunk_ids) + print(f"Deleted {len(chunk_ids)} chunks for document {doc_id}") + except Exception as e: + print(f"Error deleting chunks for {doc_id}: {e}") + + def add_chunks(self, chunks: List[DocChunk], embeddings: List[List[float]]): + collection = self.get_or_create_collection() + + ids = [chunk.chunk_id for chunk in chunks] + documents = [chunk.content_text for chunk in chunks] + metadatas = [] + + for chunk in chunks: + metadata = chunk.metadata.copy() + metadata.update( + {"parent_doc_id": chunk.parent_doc_id, "token_count": chunk.token_count} + ) + sanitized_metadata = self._sanitize_metadata(metadata) + metadatas.append(sanitized_metadata) + + collection.add( + ids=ids, embeddings=embeddings, documents=documents, metadatas=metadatas + ) + + def search_similar( + self, + query_embedding: List[float], + k: int = 30, + where: Optional[Dict[str, Any]] = None, + ) -> List[DocChunk]: + collection = self.get_or_create_collection() + + results = collection.query( + query_embeddings=[query_embedding], + n_results=k, + where=where, + include=["documents", "metadatas", "distances"], + ) + + chunks = [] + if results["ids"] and results["ids"][0]: + for i, chunk_id in enumerate(results["ids"][0]): + similarity_score = ( + 1.0 - results["distances"][0][i] if results["distances"] else None + ) + + metadata = results["metadatas"][0][i] if results["metadatas"] else {} + parent_doc_id = metadata.pop("parent_doc_id", "") + token_count = metadata.pop("token_count", 0) + + chunk = DocChunk( + chunk_id=chunk_id, + parent_doc_id=parent_doc_id, + content_text=( + results["documents"][0][i] if results["documents"] else "" + ), + token_count=token_count, + metadata=metadata, + similarity_score=similarity_score, + ) + chunks.append(chunk) + + return chunks + + def get_count(self) -> int: + collection = self.get_or_create_collection() + return collection.count() + + def delete_collection(self): + try: + self.client.delete_collection(self.collection_name) + self.collection = None + except Exception: + pass + + def recreate_collection(self): + self.delete_collection() + self.collection = self.client.create_collection( + name=self.collection_name, metadata={"hnsw:space": "cosine"} + ) diff --git a/src/services/embeddings.py b/src/services/embeddings.py new file mode 100644 index 0000000..0d5d28f --- /dev/null +++ b/src/services/embeddings.py @@ -0,0 +1,88 @@ +from typing import List, Union +from abc import ABC, abstractmethod +import openai +import google.generativeai as genai + +from src.app.config import settings +from src.models.errors import LLMError + + +class EmbeddingProvider(ABC): + @abstractmethod + def embed_text(self, text: str) -> List[float]: + pass + + @abstractmethod + def embed_batch(self, texts: List[str]) -> List[List[float]]: + pass + + +class OpenAIEmbeddingProvider(EmbeddingProvider): + def __init__(self): + if not settings.openai_api_key: + raise ValueError("OpenAI API key is required") + openai.api_key = settings.openai_api_key + self.client = openai.OpenAI(api_key=settings.openai_api_key) + self.model = settings.embedding_model + + def embed_text(self, text: str) -> List[float]: + try: + response = self.client.embeddings.create(model=self.model, input=text) + return response.data[0].embedding + except Exception as e: + raise LLMError(f"OpenAI embedding error: {str(e)}", "openai", self.model) + + def embed_batch(self, texts: List[str], batch_size: int = 64) -> List[List[float]]: + embeddings = [] + + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + try: + response = self.client.embeddings.create(model=self.model, input=batch) + batch_embeddings = [data.embedding for data in response.data] + embeddings.extend(batch_embeddings) + except Exception as e: + raise LLMError( + f"OpenAI batch embedding error: {str(e)}", "openai", self.model + ) + + return embeddings + + +class GeminiEmbeddingProvider(EmbeddingProvider): + def __init__(self): + if not settings.gemini_api_key: + raise ValueError("Gemini API key is required") + genai.configure(api_key=settings.gemini_api_key) + self.model = "models/embedding-001" + + def embed_text(self, text: str) -> List[float]: + try: + result = genai.embed_content( + model=self.model, content=text, task_type="retrieval_document" + ) + return result["embedding"] + except Exception as e: + raise LLMError(f"Gemini embedding error: {str(e)}", "gemini", self.model) + + def embed_batch(self, texts: List[str]) -> List[List[float]]: + embeddings = [] + for text in texts: + embeddings.append(self.embed_text(text)) + return embeddings + + +class EmbeddingService: + def __init__(self): + if settings.llm_provider == "openai": + self.provider = OpenAIEmbeddingProvider() + elif settings.llm_provider == "gemini": + self.provider = GeminiEmbeddingProvider() + else: + raise ValueError(f"Unsupported embedding provider: {settings.llm_provider}") + + def embed_text(self, text: str) -> List[float]: + return self.provider.embed_text(text) + + def embed_batch(self, texts: List[str]) -> List[List[float]]: + return self.provider.embed_batch(texts) diff --git a/src/services/prompt_templates.py b/src/services/prompt_templates.py new file mode 100644 index 0000000..fa265eb --- /dev/null +++ b/src/services/prompt_templates.py @@ -0,0 +1,113 @@ +from typing import Dict, Any +from src.models.lead import LeadModel +from src.models.email import RankedContext +from src.app.config import settings + + +SYSTEM_PROMPT = """Вы — AI-ассистент отдела продаж платформы «Консоль.Про». Ваша задача: на основе профиля лида и предоставленного контекста сформировать одно персонализированное холодное письмо (первое касание) на языке лида. Вы пишете кратко, уважительно, без давления. Вы показываете ценность через конкретные результаты клиентов и цифры. Вы не придумываете факты; используете только контекст. Если нет данных — говорите общими преимуществами платформы. +Формат ответа строго JSON: {"subject": str, "body": str, "short_reasoning": str, "used_chunks": [ids...]}. Без тройных кавычек, без Markdown.""" + + +USER_PROMPT_TEMPLATE = """[ПРОФИЛЬ ЛИДА] +Имя: {contact_first_name} +Должность: {role_title} (категория: {role_category}) +Компания: {company_name} +Индустрия: {industry_tag} + +[КОНТЕКСТ ИЗ БАЗЫ ЗНАНИЙ] +{context_bullets} + +[СТИЛЬ ПИСЬМА] +- Язык: {locale} +- Тон: деловой, лаконичный, дружелюбный. +- Письмо до ~1600 символов. +- 1 CTA: предложить 15-мин звонок/демо. Не навязчиво. +- Можно упомянуть, что внедрение платформы занимает ~1 день (если релевантно). + +[ШАБЛОН СТРУКТУРЫ ТЕКСТА] +1. Приветствие по имени. +2. Короткий хук, отсылающий к индустрии/процессам с исполнителями. +3. Ценность Консоль.Про + 1–3 релевантные выгоды (цифры, если в контексте). +4. Мини-соцдоказательство (кейс, факт). +5. CTA. +6. Подпись: Имя менеджера ({sales_rep_name}) | Консоль.Про. + +Напомню: ответ только JSON.""" + + +class PromptBuilder: + def __init__(self): + self.system_prompt = SYSTEM_PROMPT + self.user_template = USER_PROMPT_TEMPLATE + + def build_prompt( + self, lead: LeadModel, context: RankedContext, sales_rep_name: str = None + ) -> tuple[str, str]: + sales_rep = sales_rep_name or settings.sales_rep_name + + context_bullets_text = ( + "\n".join(context.summary_bullets) + if context.summary_bullets + else "Общая информация о платформе Консоль.Про для автоматизации работы с самозанятыми." + ) + + user_prompt = self.user_template.format( + contact_first_name=lead.contact_first_name or lead.contact_name, + role_title=lead.role_title, + role_category=lead.role_category or "специалист", + company_name=lead.company_name, + industry_tag=lead.industry_tag or lead.industry_segment, + context_bullets=context_bullets_text, + sales_rep_name=sales_rep, + locale=lead.locale, + ) + + return self.system_prompt, user_prompt + + def get_greeting(self, lead: LeadModel) -> str: + if lead.contact_first_name: + return f"{lead.contact_first_name}, добрый день!" + else: + return "Коллеги, добрый день!" + + def get_industry_hook(self, industry_tag: str) -> str: + hooks = { + "marketing_agency": "В маркетинговых агентствах часто сложно быстро подключать десятки подрядчиков для проектов", + "logistics": "В логистических компаниях управление множественными исполнителями требует особого внимания к документообороту", + "software": "В IT-компаниях работа с фрилансерами и ИП — важная часть команды разработки", + "retail": "В розничной торговле сезонные пики требуют быстрого масштабирования команды исполнителей", + "consulting": "В консалтинговых компаниях привлечение экспертов под проекты — критичный процесс", + "construction": "В строительстве координация подрядчиков и документооборот — ключевые вызовы", + "other": "При работе с внешними исполнителями всегда актуальны вопросы автоматизации процессов", + } + return hooks.get(industry_tag, hooks["other"]) + + def get_key_benefits(self, industry_tag: str) -> list[str]: + benefits_map = { + "marketing_agency": [ + "Подключение нового исполнителя за ~15 минут", + "Мгновенные массовые выплаты", + "Сбор 100% закрывающих документов", + ], + "logistics": [ + "Снижение времени онбординга с 2 дней до ~20 минут", + "Автоматический сбор чеков и документов", + "Управление сотнями исполнителей одним сотрудником", + ], + "software": [ + "API интеграции для автоматизации процессов", + "Работа с ИП и самозанятыми в одной системе", + "Снижение ошибок до 95%", + ], + "retail": [ + "Массовые выплаты в течение минут", + "Быстрое подключение сезонного персонала", + "Автоматизация документооборота", + ], + "other": [ + "Подключение исполнителя за ~15 минут", + "Выплаты в течение минут vs часы", + "Снижение ошибок до 95%", + ], + } + return benefits_map.get(industry_tag, benefits_map["other"]) diff --git a/src/services/retrieval.py b/src/services/retrieval.py new file mode 100644 index 0000000..3b3a145 --- /dev/null +++ b/src/services/retrieval.py @@ -0,0 +1,170 @@ +from typing import List, Dict, Any, Optional +import tiktoken +import hashlib +import time + +from src.models.email import DocChunk, RetrievalQuery, RankedContext +from src.models.lead import LeadModel, LeadFeatures +from src.services.chroma_store import ChromaStore +from src.services.embeddings import EmbeddingService +from src.app.config import settings + + +class RetrievalService: + def __init__(self, chroma_store: ChromaStore, embedding_service: EmbeddingService): + self.chroma_store = chroma_store + self.embedding_service = embedding_service + self.encoding = tiktoken.get_encoding("cl100k_base") + self._embedding_cache = {} + self._search_cache = {} + self._cache_ttl = 3600 + + def _get_query_hash(self, text_query: str) -> str: + return hashlib.md5(text_query.encode()).hexdigest() + + def _get_cached_embedding(self, text_query: str) -> Optional[List[float]]: + query_hash = self._get_query_hash(text_query) + cached = self._embedding_cache.get(query_hash) + + if cached and time.time() - cached["timestamp"] < self._cache_ttl: + return cached["embedding"] + return None + + def _cache_embedding(self, text_query: str, embedding: List[float]): + query_hash = self._get_query_hash(text_query) + self._embedding_cache[query_hash] = { + "embedding": embedding, + "timestamp": time.time(), + } + + if len(self._embedding_cache) > 1000: + oldest_key = min( + self._embedding_cache.keys(), + key=lambda k: self._embedding_cache[k]["timestamp"], + ) + del self._embedding_cache[oldest_key] + + def build_retrieval_query(self, lead_features: LeadFeatures) -> RetrievalQuery: + search_terms = [ + "самозанятые", + "автоматизация", + "исполнители", + lead_features.industry_tag, + lead_features.role_category, + ] + + pain_points_map = { + "marketing_agency": ["онбординг", "выплаты", "подрядчики", "отчетность"], + "logistics": ["документооборот", "исполнители", "чеки", "география"], + "software": ["фриланс", "ИП", "API", "интеграции"], + "retail": ["сезонность", "временные", "массовые", "выплаты"], + "consulting": ["проекты", "экспертиза", "временные", "специалисты"], + "construction": ["подрядчики", "документы", "безопасность", "сроки"], + "other": ["онбординг", "выплаты", "документооборот"], + } + + industry_terms = pain_points_map.get( + lead_features.industry_tag, pain_points_map["other"] + ) + search_terms.extend(industry_terms) + + text_query = " ".join(search_terms) + + metadata_filters = {} + if lead_features.industry_tag != "other": + metadata_filters["$or"] = [ + {"industry": {"$contains": lead_features.industry_tag}}, + {"roles_relevant": {"$contains": lead_features.role_category}}, + ] + + return RetrievalQuery( + text_query=text_query, + metadata_filters=metadata_filters if metadata_filters else None, + k=settings.top_k, + ) + + def search_relevant_chunks(self, query: RetrievalQuery) -> List[DocChunk]: + cached_embedding = self._get_cached_embedding(query.text_query) + + if cached_embedding: + query_embedding = cached_embedding + else: + query_embedding = self.embedding_service.embed_text(query.text_query) + self._cache_embedding(query.text_query, query_embedding) + + chunks = self.chroma_store.search_similar( + query_embedding=query_embedding, k=query.k, where=query.metadata_filters + ) + + return chunks + + def rank_chunks( + self, chunks: List[DocChunk], lead_features: LeadFeatures + ) -> List[DocChunk]: + for chunk in chunks: + score = chunk.similarity_score or 0.0 + + industry_boost = 0.0 + if lead_features.industry_tag in chunk.metadata.get("industry", []): + industry_boost = 0.2 + + role_boost = 0.0 + if lead_features.role_category in chunk.metadata.get("roles_relevant", []): + role_boost = 0.1 + + metrics_boost = 0.0 + if chunk.metadata.get("metrics"): + metrics_boost = 0.1 + + final_score = 0.6 * score + industry_boost + role_boost + metrics_boost + chunk.similarity_score = final_score + + chunks.sort(key=lambda x: x.similarity_score or 0.0, reverse=True) + + seen_content = set() + unique_chunks = [] + for chunk in chunks: + content_hash = hash(chunk.content_text[:100]) + if content_hash not in seen_content: + seen_content.add(content_hash) + unique_chunks.append(chunk) + + return unique_chunks[: settings.top_n_context] + + def build_context( + self, chunks: List[DocChunk], max_tokens: int = 3000 + ) -> RankedContext: + selected_chunks = [] + total_tokens = 0 + + for chunk in chunks: + chunk_tokens = len(self.encoding.encode(chunk.content_text)) + if total_tokens + chunk_tokens <= max_tokens: + selected_chunks.append(chunk) + total_tokens += chunk_tokens + else: + break + + summary_bullets = [] + for chunk in selected_chunks: + content = chunk.content_text.strip() + if len(content) > 200: + content = content[:200] + "..." + + metrics_info = "" + if chunk.metadata.get("metrics"): + metrics = chunk.metadata["metrics"] + metrics_parts = [] + for key, value in metrics.items(): + if isinstance(value, (int, float)): + metrics_parts.append(f"{key}: {value}") + if metrics_parts: + metrics_info = f" ({', '.join(metrics_parts)})" + + summary_bullets.append(f"• {content}{metrics_info}") + + return RankedContext( + chunks=selected_chunks, + total_tokens=total_tokens, + summary_bullets=summary_bullets, + ) diff --git a/src/tests/__init__.py b/src/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tests/test_api.py b/src/tests/test_api.py new file mode 100644 index 0000000..94b120a --- /dev/null +++ b/src/tests/test_api.py @@ -0,0 +1,126 @@ +import pytest +import json +from fastapi.testclient import TestClient +import sys +import os + +sys.path.append( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) + +from src.app.main import app + +client = TestClient(app) + + +def test_health_endpoint(): + response = client.get("/healthz") + assert response.status_code == 200 + data = response.json() + assert "status" in data + assert data["status"] == "healthy" + + +def test_readiness_endpoint(): + response = client.get("/readiness") + assert response.status_code == 200 + + +def test_api_documentation(): + response = client.get("/docs") + assert response.status_code == 200 + + +def test_root_redirect(): + response = client.get("/") + assert response.status_code == 200 + + +def test_generate_email_missing_fields(): + incomplete_payload = {"contact": ""} + + response = client.post("/api/v1/generate_email", json=incomplete_payload) + assert response.status_code == 422 + + +def test_generate_email_basic_validation(): + valid_payload = {"contact": "Test Person"} + + response = client.post("/api/v1/generate_email", json=valid_payload) + assert response.status_code == 422 + + +def test_generate_email_success(): + payload = { + "contact": "Помящий Никита", + "position": "Технический директор", + "company_name": "FIVE", + "segment": "маркетинговое агентство", + } + + response = client.post("/api/v1/generate_email", json=payload) + + if response.status_code == 200: + data = response.json() + assert "subject" in data + assert "body" in data + assert "meta" in data + else: + print(f"Response: {response.status_code}, {response.text}") + + +def test_generate_email_with_email(): + payload = { + "contact": "Помящий Никита", + "position": "Технический директор", + "company_name": "FIVE", + "segment": "маркетинговое агентство", + "email": "nikita@five.agency", + "locale": "ru", + } + + response = client.post("/api/v1/generate_email", json=payload) + assert response.status_code in [200, 500] + + +def test_api_status(): + response = client.get("/api/v1/status") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "operational" + + +def test_invalid_json(): + response = client.post( + "/api/v1/generate_email", + content="invalid json".encode("utf-8"), + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 422 + + +def test_admin_endpoints_without_auth(): + response = client.get("/api/v1/admin/knowledge-base/stats") + assert response.status_code == 401 + + response = client.post("/api/v1/admin/ingest") + assert response.status_code == 401 + + +def test_admin_endpoints_with_auth(): + from src.app.config import settings + + headers = {"Authorization": f"Bearer {settings.api_secret_key}"} + + response = client.get("/api/v1/admin/knowledge-base/stats", headers=headers) + assert response.status_code in [200, 500] + + response = client.post("/api/v1/admin/ingest", headers=headers) + assert response.status_code in [200, 500] + + +def test_admin_endpoints_with_wrong_auth(): + headers = {"Authorization": "Bearer wrong_token"} + + response = client.get("/api/v1/admin/knowledge-base/stats", headers=headers) + assert response.status_code == 403 diff --git a/src/tests/test_ingest.py b/src/tests/test_ingest.py new file mode 100644 index 0000000..46b7f5f --- /dev/null +++ b/src/tests/test_ingest.py @@ -0,0 +1,225 @@ +import pytest +import tempfile +import os +import shutil +import sys +from unittest.mock import Mock, patch + +sys.path.append( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) + +from src.ingest.loader import MarkdownLoader, Document, create_platform_overview +from src.ingest.chunker import DocumentChunker, ChunkConfig +from src.models.lead import LeadModel +from src.services.chroma_store import ChromaStore +from src.services.embeddings import EmbeddingService + + +class TestMarkdownLoader: + def test_create_platform_overview(self): + doc = create_platform_overview() + + assert doc.doc_id == "platform_overview" + assert doc.metadata["title"] == "Платформа Консоль.Про - Обзор" + assert "generic" in doc.metadata["industry"] + assert len(doc.content) > 100 + + def test_load_file(self): + with tempfile.NamedTemporaryFile( + mode="w", suffix=".md", delete=False, encoding="utf-8" + ) as f: + f.write( + """# Тестовый кейс + +Компания тестирует автоматизацию с самозанятыми. + +## Результаты +- Сокращение времени на 20 минут +- Снижение ошибок на 85% + +Технический директор доволен результатами. +""" + ) + temp_file = f.name + + try: + loader = MarkdownLoader() + doc = loader.load_file(temp_file) + + assert doc.content is not None + assert len(doc.content) > 0 + assert "Тестовый кейс" in doc.metadata["title"] + assert doc.metadata["doc_type"] == "info" + assert "industry" in doc.metadata + assert "roles_relevant" in doc.metadata + assert "tech" in doc.metadata["roles_relevant"] + + finally: + os.unlink(temp_file) + + def test_extract_metadata(self): + content = """# Кейс маркетингового агентства + +Технический директор внедрил автоматизацию. +Снижение ошибок на 95%. +Онбординг теперь занимает 15 минут.""" + + loader = MarkdownLoader() + metadata = loader._extract_metadata(content, "test.md") + + assert "marketing_agency" in metadata["industry"] + assert "tech" in metadata["roles_relevant"] + assert "error_reduction_pct" in metadata["metrics"] + assert metadata["metrics"]["error_reduction_pct"] == 95 + + +class TestDocumentChunker: + def test_chunk_document(self): + content = """# Заголовок документа + +Вводный текст документа. + +## Проблема +Описание проблемы клиента. +Это было сложно. + +## Решение +Как мы решили проблему. +Внедрили автоматизацию. + +## Результат +Получили отличные результаты. +Снижение на 95%.""" + + document = Document( + content=content, + metadata={"title": "Test", "industry": ["tech"], "roles_relevant": ["ceo"]}, + doc_id="test_doc", + source="test.md", + ) + + chunker = DocumentChunker() + chunks = chunker.chunk_document(document) + + assert len(chunks) > 0 + assert all(chunk.chunk_id.startswith("test_doc#c") for chunk in chunks) + assert all(chunk.parent_doc_id == "test_doc" for chunk in chunks) + + def test_split_by_semantic_blocks(self): + content = """# Заголовок 1 +Контент секции 1 + +## Заголовок 2 +Контент секции 2 + +### Заголовок 3 +Контент секции 3""" + + chunker = DocumentChunker() + sections = chunker._split_by_semantic_blocks(content) + + assert len(sections) >= 2 + assert any("Заголовок" in section[0] for section in sections) + + def test_classify_section(self): + chunker = DocumentChunker() + + assert chunker._classify_section("проблема клиента") == "Проблема" + assert chunker._classify_section("как мы решили") == "Решение" + assert chunker._classify_section("результаты внедрения") == "Результат" + assert chunker._classify_section("о компании") == "О клиенте" + + def test_chunk_config(self): + config = ChunkConfig(chunk_size=300, chunk_overlap=50) + chunker = DocumentChunker(config) + + assert chunker.config.chunk_size == 300 + assert chunker.config.chunk_overlap == 50 + + +class TestLeadModelExtraction: + def test_role_category_extraction(self): + from src.models.lead import LeadModel + + assert LeadModel._extract_role_category("Технический директор") == "tech" + assert LeadModel._extract_role_category("Финансовый директор") == "finance" + assert LeadModel._extract_role_category("Генеральный директор") == "ceo" + assert LeadModel._extract_role_category("HR менеджер") == "hr" + + def test_industry_tag_extraction(self): + from src.models.lead import LeadModel + + assert ( + LeadModel._extract_industry_tag("маркетинговое агентство") + == "marketing_agency" + ) + assert LeadModel._extract_industry_tag("IT компания") == "software" + assert ( + LeadModel._extract_industry_tag("строительная компания") == "construction" + ) + + +class TestIntegration: + def test_full_ingest_flow(self): + with tempfile.NamedTemporaryFile( + mode="w", suffix=".md", delete=False, encoding="utf-8" + ) as f: + f.write( + """# Тестовый кейс компании + +Компания внедрила автоматизацию. + +## Результаты +Сокращение времени с 2 дней до 30 минут. +Снижение ошибок на 90%. + +Технический директор доволен.""" + ) + temp_file = f.name + + try: + loader = MarkdownLoader() + doc = loader.load_file(temp_file) + + chunker = DocumentChunker() + chunks = chunker.chunk_document(doc) + + assert len(chunks) > 0 + assert all(chunk.token_count > 0 for chunk in chunks) + assert all(chunk.content_text.strip() for chunk in chunks) + + finally: + os.unlink(temp_file) + + +class TestChromaStore: + @patch("chromadb.PersistentClient") + def test_init(self, mock_client): + store = ChromaStore() + assert store.collection_name == "knowledge_base" + mock_client.assert_called_once() + + @patch("chromadb.PersistentClient") + def test_get_existing_doc_ids(self, mock_client): + mock_collection = Mock() + mock_collection.get.return_value = { + "metadatas": [{"source": "test1.md"}, {"source": "test2.md"}] + } + + mock_client.return_value.get_or_create_collection.return_value = mock_collection + + store = ChromaStore() + doc_ids = store.get_existing_doc_ids() + + assert isinstance(doc_ids, set) + + +class TestEmbeddingService: + @patch("src.services.embeddings.OpenAIEmbeddingProvider") + def test_embedding_service_init(self, mock_provider): + with patch("src.app.config.settings") as mock_settings: + mock_settings.llm_provider = "openai" + + service = EmbeddingService() + mock_provider.assert_called_once() diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29