first commit
This commit is contained in:
commit
de33dad47e
|
@ -0,0 +1,42 @@
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
.Python
|
||||||
|
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
ENV/
|
||||||
|
.env
|
||||||
|
config.env
|
||||||
|
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
Desktop.ini
|
||||||
|
|
||||||
|
storage/chroma/
|
||||||
|
!storage/chroma/.gitkeep
|
||||||
|
*.log
|
||||||
|
logs/
|
||||||
|
.temporary/
|
||||||
|
*.tmp
|
||||||
|
*.temp
|
||||||
|
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
data/local/
|
||||||
|
data/temp/
|
||||||
|
articles_konsol_pro/
|
|
@ -0,0 +1,19 @@
|
||||||
|
fastapi
|
||||||
|
uvicorn[standard]
|
||||||
|
pydantic
|
||||||
|
langchain
|
||||||
|
langchain-community
|
||||||
|
langgraph
|
||||||
|
chromadb
|
||||||
|
openai
|
||||||
|
google-generativeai
|
||||||
|
python-multipart
|
||||||
|
python-dotenv
|
||||||
|
pypdf
|
||||||
|
markdown
|
||||||
|
python-jose[cryptography]
|
||||||
|
passlib[bcrypt]
|
||||||
|
tiktoken
|
||||||
|
httpx
|
||||||
|
pytest
|
||||||
|
pytest-asyncio
|
|
@ -0,0 +1,30 @@
|
||||||
|
from typing import Literal
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from pydantic import Field, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
llm_provider: Literal["openai", "gemini"] = "openai"
|
||||||
|
llm_model: str = "gpt-4o-mini"
|
||||||
|
embedding_model: str = "text-embedding-3-large"
|
||||||
|
|
||||||
|
openai_api_key: str = Field(default="", json_schema_extra={"env": "OPENAI_API_KEY"})
|
||||||
|
gemini_api_key: str = Field(default="", json_schema_extra={"env": "GEMINI_API_KEY"})
|
||||||
|
api_secret_key: str = Field(
|
||||||
|
default="secret", json_schema_extra={"env": "API_SECRET_KEY"}
|
||||||
|
)
|
||||||
|
|
||||||
|
chroma_persist_dir: str = "./storage/chroma"
|
||||||
|
top_k: int = 30
|
||||||
|
top_n_context: int = 6
|
||||||
|
max_tokens_completion: int = 1024
|
||||||
|
langgraph_tracing: bool = False
|
||||||
|
service_lang: str = "ru"
|
||||||
|
sales_rep_name: str = "Команда Консоль.Про"
|
||||||
|
chunk_size: int = 500
|
||||||
|
chunk_overlap: int = 100
|
||||||
|
|
||||||
|
model_config = ConfigDict(env_file=".env")
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
|
@ -0,0 +1,107 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from src.app.routers import generate, ingest, health
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="AI Email Assistant",
|
||||||
|
description="Персонализированная генерация холодных писем с использованием RAG и LangGraph",
|
||||||
|
version="1.0.0",
|
||||||
|
docs_url="/docs",
|
||||||
|
redoc_url="/redoc",
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def logging_middleware(request: Request, call_next):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
request_id = str(time.time()).replace(".", "")
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
process_time = time.time() - start_time
|
||||||
|
|
||||||
|
log_data = {
|
||||||
|
"request_id": request_id,
|
||||||
|
"method": request.method,
|
||||||
|
"url": str(request.url),
|
||||||
|
"status_code": response.status_code,
|
||||||
|
"process_time": round(process_time, 4),
|
||||||
|
"client_ip": request.client.host if request.client else "unknown",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(json.dumps(log_data))
|
||||||
|
|
||||||
|
response.headers["X-Request-ID"] = request_id
|
||||||
|
response.headers["X-Process-Time"] = str(round(process_time, 4))
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
logger.error(f"Global exception handler caught: {exc}", exc_info=True)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={
|
||||||
|
"error": "Internal server error",
|
||||||
|
"code": "INTERNAL_ERROR",
|
||||||
|
"detail": (
|
||||||
|
str(exc) if settings.llm_provider else "An unexpected error occurred"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app.include_router(health.router)
|
||||||
|
app.include_router(generate.router)
|
||||||
|
app.include_router(ingest.router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {
|
||||||
|
"service": "AI Email Assistant",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "Персонализированная генерация холодных писем",
|
||||||
|
"endpoints": {
|
||||||
|
"health": "/healthz",
|
||||||
|
"readiness": "/readiness",
|
||||||
|
"generate": "/api/v1/generate_email",
|
||||||
|
"admin": "/api/v1/admin/",
|
||||||
|
"docs": "/docs",
|
||||||
|
},
|
||||||
|
"status": "operational",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"src.app.main:app", host="0.0.0.0", port=8000, reload=True, log_level="info"
|
||||||
|
)
|
|
@ -0,0 +1,110 @@
|
||||||
|
import time
|
||||||
|
from fastapi import APIRouter, HTTPException, Depends
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from src.models.lead import LeadInput
|
||||||
|
from src.models.email import EmailResponse, ErrorResponse
|
||||||
|
from src.graph.build_graph import email_generation_graph
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1", tags=["generation"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generate_email", response_model=EmailResponse)
|
||||||
|
async def generate_email(lead_data: LeadInput):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
request = lead_data.model_dump()
|
||||||
|
|
||||||
|
initial_state = {
|
||||||
|
"raw_input": request,
|
||||||
|
"lead_input": None,
|
||||||
|
"lead_model": None,
|
||||||
|
"lead_features": None,
|
||||||
|
"retrieval_query": None,
|
||||||
|
"retrieved_chunks": None,
|
||||||
|
"ranked_context": None,
|
||||||
|
"prompt_payload": None,
|
||||||
|
"llm_output": None,
|
||||||
|
"email_draft": None,
|
||||||
|
"email_clean": None,
|
||||||
|
"email_response": None,
|
||||||
|
"error": None,
|
||||||
|
"error_code": None,
|
||||||
|
"trace_meta": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
result_state = email_generation_graph.invoke(initial_state)
|
||||||
|
|
||||||
|
if result_state.get("error"):
|
||||||
|
error_code = result_state.get("error_code", "UNKNOWN_ERROR")
|
||||||
|
error_message = result_state.get("error", "Unknown error occurred")
|
||||||
|
trace_meta = result_state.get("trace_meta", {})
|
||||||
|
|
||||||
|
if error_code == "VALIDATION_ERROR":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": error_message,
|
||||||
|
"code": error_code,
|
||||||
|
"trace": trace_meta,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif error_code == "NO_RESULTS":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"error": "No relevant knowledge found for this lead profile",
|
||||||
|
"code": error_code,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif error_code in ["LLM_ERROR", "OPENAI_ERROR", "GEMINI_ERROR"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail={
|
||||||
|
"error": "External service error",
|
||||||
|
"code": error_code,
|
||||||
|
"details": error_message,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": error_message,
|
||||||
|
"code": error_code,
|
||||||
|
"trace": trace_meta,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
email_response = result_state.get("email_response")
|
||||||
|
if not email_response:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "Failed to generate email response",
|
||||||
|
"code": "MISSING_RESPONSE",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return email_response
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={"error": f"Unexpected error: {str(e)}", "code": "INTERNAL_ERROR"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status")
|
||||||
|
async def get_status():
|
||||||
|
return {
|
||||||
|
"status": "operational",
|
||||||
|
"service": "email-generation",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"provider": settings.llm_provider,
|
||||||
|
"model": settings.llm_model,
|
||||||
|
}
|
|
@ -0,0 +1,100 @@
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
router = APIRouter(tags=["health"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/healthz")
|
||||||
|
async def health_check():
|
||||||
|
try:
|
||||||
|
from src.services.chroma_store import ChromaStore
|
||||||
|
from src.services.embeddings import EmbeddingService
|
||||||
|
|
||||||
|
chroma_store = ChromaStore()
|
||||||
|
doc_count = chroma_store.get_count()
|
||||||
|
|
||||||
|
embedding_service = EmbeddingService()
|
||||||
|
|
||||||
|
health_status = {
|
||||||
|
"status": "healthy",
|
||||||
|
"service": "ai-email-assistant",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"timestamp": None,
|
||||||
|
"components": {
|
||||||
|
"knowledge_base": {
|
||||||
|
"status": "healthy" if doc_count > 0 else "warning",
|
||||||
|
"documents": doc_count,
|
||||||
|
"persist_dir": settings.chroma_persist_dir,
|
||||||
|
},
|
||||||
|
"embedding_service": {
|
||||||
|
"status": "healthy",
|
||||||
|
"provider": settings.llm_provider,
|
||||||
|
"model": settings.embedding_model,
|
||||||
|
},
|
||||||
|
"llm_service": {
|
||||||
|
"status": "healthy",
|
||||||
|
"provider": settings.llm_provider,
|
||||||
|
"model": settings.llm_model,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"configuration": {
|
||||||
|
"top_k": settings.top_k,
|
||||||
|
"top_n_context": settings.top_n_context,
|
||||||
|
"max_tokens": settings.max_tokens_completion,
|
||||||
|
"language": settings.service_lang,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
health_status["timestamp"] = datetime.datetime.utcnow().isoformat() + "Z"
|
||||||
|
|
||||||
|
if doc_count == 0:
|
||||||
|
health_status["status"] = "degraded"
|
||||||
|
health_status["warnings"] = ["No documents in knowledge base"]
|
||||||
|
|
||||||
|
return health_status
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"service": "ai-email-assistant",
|
||||||
|
"error": str(e),
|
||||||
|
"timestamp": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/readiness")
|
||||||
|
async def readiness_check():
|
||||||
|
try:
|
||||||
|
from src.services.chroma_store import ChromaStore
|
||||||
|
|
||||||
|
chroma_store = ChromaStore()
|
||||||
|
doc_count = chroma_store.get_count()
|
||||||
|
|
||||||
|
if doc_count == 0:
|
||||||
|
return {
|
||||||
|
"ready": False,
|
||||||
|
"reason": "Knowledge base is empty - run ingestion first",
|
||||||
|
}
|
||||||
|
|
||||||
|
api_key_configured = bool(
|
||||||
|
settings.openai_api_key
|
||||||
|
if settings.llm_provider == "openai"
|
||||||
|
else settings.gemini_api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
if not api_key_configured:
|
||||||
|
return {
|
||||||
|
"ready": False,
|
||||||
|
"reason": f"API key not configured for {settings.llm_provider}",
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ready": True,
|
||||||
|
"documents": doc_count,
|
||||||
|
"provider": settings.llm_provider,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"ready": False, "reason": f"Service check failed: {str(e)}"}
|
|
@ -0,0 +1,108 @@
|
||||||
|
import os
|
||||||
|
from fastapi import APIRouter, HTTPException, Depends, Header
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/admin", tags=["administration"])
|
||||||
|
|
||||||
|
|
||||||
|
def verify_admin_token(authorization: Optional[str] = Header(None)):
|
||||||
|
if not authorization:
|
||||||
|
raise HTTPException(status_code=401, detail="Authorization header required")
|
||||||
|
|
||||||
|
if not authorization.startswith("Bearer "):
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid authorization format")
|
||||||
|
|
||||||
|
token = authorization.replace("Bearer ", "")
|
||||||
|
if token != settings.api_secret_key:
|
||||||
|
raise HTTPException(status_code=403, detail="Invalid admin token")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/ingest")
|
||||||
|
async def trigger_ingest(recreate: bool = False, _: bool = Depends(verify_admin_token)):
|
||||||
|
try:
|
||||||
|
from src.ingest.ingest_cli import run_ingest
|
||||||
|
|
||||||
|
data_dir = "articles_konsol_pro"
|
||||||
|
platform_file = "data/platform_overview.md"
|
||||||
|
persist_dir = settings.chroma_persist_dir
|
||||||
|
|
||||||
|
if not os.path.exists(data_dir):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Data directory not found: {data_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = run_ingest(
|
||||||
|
data_dir=data_dir,
|
||||||
|
platform_md=platform_file,
|
||||||
|
persist_dir=persist_dir,
|
||||||
|
recreate=recreate,
|
||||||
|
chunk_size=settings.chunk_size,
|
||||||
|
chunk_overlap=settings.chunk_overlap,
|
||||||
|
embedding_model=settings.embedding_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": "Knowledge base updated successfully",
|
||||||
|
"details": result,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={"error": f"Ingest failed: {str(e)}", "code": "INGEST_ERROR"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/knowledge-base/stats")
|
||||||
|
async def get_knowledge_base_stats(_: bool = Depends(verify_admin_token)):
|
||||||
|
try:
|
||||||
|
from src.services.chroma_store import ChromaStore
|
||||||
|
|
||||||
|
chroma_store = ChromaStore()
|
||||||
|
doc_count = chroma_store.get_count()
|
||||||
|
|
||||||
|
collection = chroma_store.get_or_create_collection()
|
||||||
|
|
||||||
|
sample_docs = collection.peek(limit=5)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_documents": doc_count,
|
||||||
|
"collection_name": chroma_store.collection_name,
|
||||||
|
"persist_directory": chroma_store.persist_directory,
|
||||||
|
"sample_metadata": [
|
||||||
|
doc.get("industry", []) if doc else []
|
||||||
|
for doc in (sample_docs.get("metadatas", []) or [])
|
||||||
|
],
|
||||||
|
"embedding_model": settings.embedding_model,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={"error": f"Failed to get stats: {str(e)}", "code": "STATS_ERROR"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/knowledge-base")
|
||||||
|
async def clear_knowledge_base(_: bool = Depends(verify_admin_token)):
|
||||||
|
try:
|
||||||
|
from src.services.chroma_store import ChromaStore
|
||||||
|
|
||||||
|
chroma_store = ChromaStore()
|
||||||
|
chroma_store.delete_collection()
|
||||||
|
|
||||||
|
return {"status": "success", "message": "Knowledge base cleared successfully"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": f"Failed to clear knowledge base: {str(e)}",
|
||||||
|
"code": "CLEAR_ERROR",
|
||||||
|
},
|
||||||
|
)
|
|
@ -0,0 +1,82 @@
|
||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from src.graph.state import EmailGenerationState
|
||||||
|
from src.graph.nodes.input_validation import input_validation_node
|
||||||
|
from src.graph.nodes.feature_extract import feature_extract_node
|
||||||
|
from src.graph.nodes.build_query import build_query_node
|
||||||
|
from src.graph.nodes.vector_search import vector_search_node
|
||||||
|
from src.graph.nodes.context_rank import context_rank_node
|
||||||
|
from src.graph.nodes.prompt_build import prompt_build_node
|
||||||
|
from src.graph.nodes.llm_generate import llm_generate_node
|
||||||
|
from src.graph.nodes.parse_output import parse_output_node
|
||||||
|
from src.graph.nodes.guardrails import guardrails_node
|
||||||
|
from src.graph.nodes.return_result import return_result_node
|
||||||
|
|
||||||
|
|
||||||
|
def should_continue(state: EmailGenerationState) -> str:
|
||||||
|
if state.get("error"):
|
||||||
|
return "error"
|
||||||
|
return "continue"
|
||||||
|
|
||||||
|
|
||||||
|
def create_email_graph() -> StateGraph:
|
||||||
|
workflow = StateGraph(EmailGenerationState)
|
||||||
|
|
||||||
|
workflow.add_node("input_validation", input_validation_node)
|
||||||
|
workflow.add_node("feature_extract", feature_extract_node)
|
||||||
|
workflow.add_node("build_query", build_query_node)
|
||||||
|
workflow.add_node("vector_search", vector_search_node)
|
||||||
|
workflow.add_node("context_rank", context_rank_node)
|
||||||
|
workflow.add_node("prompt_build", prompt_build_node)
|
||||||
|
workflow.add_node("llm_generate", llm_generate_node)
|
||||||
|
workflow.add_node("parse_output", parse_output_node)
|
||||||
|
workflow.add_node("guardrails", guardrails_node)
|
||||||
|
workflow.add_node("return_result", return_result_node)
|
||||||
|
|
||||||
|
workflow.set_entry_point("input_validation")
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"input_validation",
|
||||||
|
should_continue,
|
||||||
|
{"continue": "feature_extract", "error": END},
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"feature_extract", should_continue, {"continue": "build_query", "error": END}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"build_query", should_continue, {"continue": "vector_search", "error": END}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"vector_search", should_continue, {"continue": "context_rank", "error": END}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"context_rank", should_continue, {"continue": "prompt_build", "error": END}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"prompt_build", should_continue, {"continue": "llm_generate", "error": END}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"llm_generate", should_continue, {"continue": "parse_output", "error": END}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"parse_output", should_continue, {"continue": "guardrails", "error": END}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"guardrails", should_continue, {"continue": "return_result", "error": END}
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.add_edge("return_result", END)
|
||||||
|
|
||||||
|
return workflow.compile()
|
||||||
|
|
||||||
|
|
||||||
|
email_generation_graph = create_email_graph()
|
|
@ -0,0 +1,22 @@
|
||||||
|
from src.graph.state import EmailGenerationState
|
||||||
|
from src.services.retrieval import RetrievalService
|
||||||
|
|
||||||
|
|
||||||
|
def build_query_node(state: EmailGenerationState) -> EmailGenerationState:
|
||||||
|
try:
|
||||||
|
lead_features = state.get("lead_features")
|
||||||
|
if not lead_features:
|
||||||
|
state["error"] = "Lead features are required for query building"
|
||||||
|
state["error_code"] = "MISSING_FEATURES"
|
||||||
|
return state
|
||||||
|
|
||||||
|
retrieval_service = RetrievalService(None, None)
|
||||||
|
retrieval_query = retrieval_service.build_retrieval_query(lead_features)
|
||||||
|
|
||||||
|
state["retrieval_query"] = retrieval_query
|
||||||
|
return state
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
state["error"] = f"Query building error: {str(e)}"
|
||||||
|
state["error_code"] = "QUERY_BUILD_ERROR"
|
||||||
|
return state
|
|
@ -0,0 +1,36 @@
|
||||||
|
from src.graph.state import EmailGenerationState
|
||||||
|
from src.services.chroma_store import ChromaStore
|
||||||
|
from src.services.embeddings import EmbeddingService
|
||||||
|
from src.services.retrieval import RetrievalService
|
||||||
|
|
||||||
|
|
||||||
|
def context_rank_node(state: EmailGenerationState) -> EmailGenerationState:
|
||||||
|
try:
|
||||||
|
retrieved_chunks = state.get("retrieved_chunks")
|
||||||
|
lead_features = state.get("lead_features")
|
||||||
|
|
||||||
|
if not retrieved_chunks:
|
||||||
|
state["error"] = "Retrieved chunks are required for ranking"
|
||||||
|
state["error_code"] = "MISSING_CHUNKS"
|
||||||
|
return state
|
||||||
|
|
||||||
|
if not lead_features:
|
||||||
|
state["error"] = "Lead features are required for ranking"
|
||||||
|
state["error_code"] = "MISSING_FEATURES"
|
||||||
|
return state
|
||||||
|
|
||||||
|
chroma_store = ChromaStore()
|
||||||
|
embedding_service = EmbeddingService()
|
||||||
|
retrieval_service = RetrievalService(chroma_store, embedding_service)
|
||||||
|
|
||||||
|
ranked_chunks = retrieval_service.rank_chunks(retrieved_chunks, lead_features)
|
||||||
|
|
||||||
|
ranked_context = retrieval_service.build_context(ranked_chunks)
|
||||||
|
|
||||||
|
state["ranked_context"] = ranked_context
|
||||||
|
return state
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
state["error"] = f"Context ranking error: {str(e)}"
|
||||||
|
state["error_code"] = "RANKING_ERROR"
|
||||||
|
return state
|
|
@ -0,0 +1,126 @@
|
||||||
|
from typing import List
|
||||||
|
from src.graph.state import EmailGenerationState
|
||||||
|
from src.models.lead import LeadModel, LeadFeatures
|
||||||
|
|
||||||
|
|
||||||
|
def feature_extract_node(state: EmailGenerationState) -> EmailGenerationState:
|
||||||
|
try:
|
||||||
|
lead_input = state.get("lead_input")
|
||||||
|
if not lead_input:
|
||||||
|
state["error"] = "Lead input is required for feature extraction"
|
||||||
|
state["error_code"] = "MISSING_INPUT"
|
||||||
|
return state
|
||||||
|
|
||||||
|
lead_model = LeadModel.from_input(lead_input)
|
||||||
|
|
||||||
|
role_category = lead_model.role_category or "other"
|
||||||
|
industry_tag = lead_model.industry_tag or "other"
|
||||||
|
|
||||||
|
pain_points = _get_pain_points(industry_tag, role_category)
|
||||||
|
key_benefits = _get_key_benefits(industry_tag, role_category)
|
||||||
|
search_keywords = _get_search_keywords(industry_tag, role_category)
|
||||||
|
|
||||||
|
lead_features = LeadFeatures(
|
||||||
|
role_category=role_category,
|
||||||
|
industry_tag=industry_tag,
|
||||||
|
pain_points=pain_points,
|
||||||
|
key_benefits=key_benefits,
|
||||||
|
search_keywords=search_keywords,
|
||||||
|
)
|
||||||
|
|
||||||
|
state["lead_model"] = lead_model
|
||||||
|
state["lead_features"] = lead_features
|
||||||
|
return state
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
state["error"] = f"Feature extraction error: {str(e)}"
|
||||||
|
state["error_code"] = "FEATURE_EXTRACTION_ERROR"
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pain_points(industry_tag: str, role_category: str) -> List[str]:
|
||||||
|
industry_pain_points = {
|
||||||
|
"marketing_agency": [
|
||||||
|
"быстрое подключение подрядчиков",
|
||||||
|
"массовые выплаты",
|
||||||
|
"отчетность",
|
||||||
|
],
|
||||||
|
"logistics": ["управление исполнителями", "документооборот", "география"],
|
||||||
|
"software": ["работа с фрилансерами", "ИП", "интеграции"],
|
||||||
|
"retail": ["сезонность", "временные сотрудники", "масштабирование"],
|
||||||
|
"consulting": ["привлечение экспертов", "проектная работа"],
|
||||||
|
"construction": ["подрядчики", "документы", "сроки"],
|
||||||
|
"other": ["автоматизация", "выплаты", "документооборот"],
|
||||||
|
}
|
||||||
|
|
||||||
|
role_pain_points = {
|
||||||
|
"tech": ["интеграции", "автоматизация", "API"],
|
||||||
|
"finance": ["выплаты", "отчетность", "налоги"],
|
||||||
|
"ops": ["процессы", "управление", "координация"],
|
||||||
|
"hr": ["онбординг", "документы", "персонал"],
|
||||||
|
"ceo": ["эффективность", "масштабирование", "затраты"],
|
||||||
|
"sales": ["процессы продаж", "клиенты"],
|
||||||
|
"marketing": ["кампании", "подрядчики"],
|
||||||
|
"other": ["общие бизнес-процессы"],
|
||||||
|
}
|
||||||
|
|
||||||
|
pain_points = industry_pain_points.get(industry_tag, industry_pain_points["other"])
|
||||||
|
pain_points.extend(role_pain_points.get(role_category, role_pain_points["other"]))
|
||||||
|
|
||||||
|
return list(set(pain_points))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_key_benefits(industry_tag: str, role_category: str) -> List[str]:
|
||||||
|
benefits = [
|
||||||
|
"быстрое подключение исполнителей",
|
||||||
|
"автоматические выплаты",
|
||||||
|
"сбор документов",
|
||||||
|
"снижение ошибок",
|
||||||
|
"экономия времени",
|
||||||
|
]
|
||||||
|
|
||||||
|
if industry_tag == "marketing_agency":
|
||||||
|
benefits.extend(["масштабирование команды", "управление проектами"])
|
||||||
|
elif industry_tag == "logistics":
|
||||||
|
benefits.extend(["географическое покрытие", "отслеживание"])
|
||||||
|
elif industry_tag == "software":
|
||||||
|
benefits.extend(["API интеграции", "техническая поддержка"])
|
||||||
|
|
||||||
|
if role_category == "tech":
|
||||||
|
benefits.extend(["техническая интеграция", "автоматизация"])
|
||||||
|
elif role_category == "finance":
|
||||||
|
benefits.extend(["финансовая отчетность", "налоговое планирование"])
|
||||||
|
elif role_category == "ops":
|
||||||
|
benefits.extend(["операционная эффективность", "процессы"])
|
||||||
|
|
||||||
|
return list(set(benefits))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_search_keywords(industry_tag: str, role_category: str) -> List[str]:
|
||||||
|
keywords = ["самозанятые", "исполнители", "автоматизация", "выплаты"]
|
||||||
|
|
||||||
|
industry_keywords = {
|
||||||
|
"marketing_agency": ["маркетинг", "агентство", "подрядчики", "креатив"],
|
||||||
|
"logistics": ["логистика", "доставка", "перевозки", "склад"],
|
||||||
|
"software": ["разработка", "IT", "программирование", "фриланс"],
|
||||||
|
"retail": ["розница", "торговля", "продажи", "сезон"],
|
||||||
|
"consulting": ["консалтинг", "эксперты", "проекты"],
|
||||||
|
"construction": ["строительство", "подрядчики", "объекты"],
|
||||||
|
"other": ["бизнес", "процессы", "команда"],
|
||||||
|
}
|
||||||
|
|
||||||
|
role_keywords = {
|
||||||
|
"tech": ["технический", "IT", "разработка", "интеграция"],
|
||||||
|
"finance": ["финансы", "бухгалтерия", "учет", "налоги"],
|
||||||
|
"ops": ["операции", "управление", "процессы"],
|
||||||
|
"hr": ["персонал", "HR", "найм", "кадры"],
|
||||||
|
"ceo": ["руководство", "стратегия", "развитие"],
|
||||||
|
"sales": ["продажи", "клиенты", "сделки"],
|
||||||
|
"marketing": ["маркетинг", "реклама", "продвижение"],
|
||||||
|
"other": ["общие", "универсальные"],
|
||||||
|
}
|
||||||
|
|
||||||
|
keywords.extend(industry_keywords.get(industry_tag, industry_keywords["other"]))
|
||||||
|
keywords.extend(role_keywords.get(role_category, role_keywords["other"]))
|
||||||
|
|
||||||
|
return list(set(keywords))
|
|
@ -0,0 +1,201 @@
|
||||||
|
import re
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from src.graph.state import EmailGenerationState
|
||||||
|
from src.models.email import EmailDraftClean
|
||||||
|
from src.models.errors import GuardrailsError
|
||||||
|
|
||||||
|
|
||||||
|
def guardrails_node(state: EmailGenerationState) -> EmailGenerationState:
|
||||||
|
try:
|
||||||
|
email_draft = state.get("email_draft")
|
||||||
|
if not email_draft:
|
||||||
|
state["error"] = "Email draft is required for guardrails check"
|
||||||
|
state["error_code"] = "MISSING_DRAFT"
|
||||||
|
return state
|
||||||
|
|
||||||
|
violations = []
|
||||||
|
|
||||||
|
pii_violations = _check_pii_leakage(
|
||||||
|
email_draft.subject + " " + email_draft.body
|
||||||
|
)
|
||||||
|
if pii_violations:
|
||||||
|
violations.extend(pii_violations)
|
||||||
|
|
||||||
|
false_claims = _check_false_claims(email_draft.body, email_draft.used_chunks)
|
||||||
|
if false_claims:
|
||||||
|
violations.extend(false_claims)
|
||||||
|
|
||||||
|
tone_violations = _check_tone_appropriateness(email_draft.body)
|
||||||
|
if tone_violations:
|
||||||
|
violations.extend(tone_violations)
|
||||||
|
|
||||||
|
if violations:
|
||||||
|
cleaned_subject, cleaned_body = _apply_fixes(
|
||||||
|
email_draft.subject, email_draft.body, violations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cleaned_subject = email_draft.subject
|
||||||
|
cleaned_body = email_draft.body
|
||||||
|
|
||||||
|
llm_output = state.get("llm_output")
|
||||||
|
lead_model = state.get("lead_model")
|
||||||
|
ranked_context = state.get("ranked_context")
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
"locale": lead_model.locale if lead_model else "ru",
|
||||||
|
"lead_normalized": lead_model.dict() if lead_model else {},
|
||||||
|
"used_chunks": email_draft.used_chunks,
|
||||||
|
"model": llm_output.model if llm_output else "unknown",
|
||||||
|
"tokens_prompt": llm_output.tokens_prompt if llm_output else 0,
|
||||||
|
"tokens_completion": llm_output.tokens_completion if llm_output else 0,
|
||||||
|
"guardrails_violations": len(violations),
|
||||||
|
"context_chunks_used": len(ranked_context.chunks) if ranked_context else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
email_clean = EmailDraftClean(
|
||||||
|
subject=cleaned_subject, body=cleaned_body, meta=meta
|
||||||
|
)
|
||||||
|
|
||||||
|
state["email_clean"] = email_clean
|
||||||
|
return state
|
||||||
|
|
||||||
|
except GuardrailsError as e:
|
||||||
|
state["error"] = e.message
|
||||||
|
state["error_code"] = "GUARDRAILS_ERROR"
|
||||||
|
state["trace_meta"] = {"violation_type": e.violation_type, "details": e.details}
|
||||||
|
return state
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
state["error"] = f"Guardrails check error: {str(e)}"
|
||||||
|
state["error_code"] = "GUARDRAILS_ERROR"
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def _check_pii_leakage(text: str) -> List[Dict[str, Any]]:
|
||||||
|
violations = []
|
||||||
|
|
||||||
|
email_pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
|
||||||
|
phone_pattern = r"\b(?:\+7|8)?[\s\-]?\(?[0-9]{3}\)?[\s\-]?[0-9]{3}[\s\-]?[0-9]{2}[\s\-]?[0-9]{2}\b"
|
||||||
|
|
||||||
|
if re.search(email_pattern, text):
|
||||||
|
violations.append(
|
||||||
|
{
|
||||||
|
"type": "pii_email",
|
||||||
|
"message": "Email address detected in content",
|
||||||
|
"severity": "medium",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if re.search(phone_pattern, text):
|
||||||
|
violations.append(
|
||||||
|
{
|
||||||
|
"type": "pii_phone",
|
||||||
|
"message": "Phone number detected in content",
|
||||||
|
"severity": "medium",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return violations
|
||||||
|
|
||||||
|
|
||||||
|
def _check_false_claims(body: str, used_chunks: List[str]) -> List[Dict[str, Any]]:
|
||||||
|
violations = []
|
||||||
|
|
||||||
|
specific_numbers = re.findall(
|
||||||
|
r"\b\d+%|\b\d+\s*минут|\b\d+\s*дн[ейя]|\b\d+\s*раз", body
|
||||||
|
)
|
||||||
|
|
||||||
|
if specific_numbers and not used_chunks:
|
||||||
|
violations.append(
|
||||||
|
{
|
||||||
|
"type": "unverified_metrics",
|
||||||
|
"message": "Specific metrics mentioned without supporting context",
|
||||||
|
"severity": "high",
|
||||||
|
"numbers": specific_numbers,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
guarantee_words = ["гарантируем", "гарантия", "обещаем", "100% результат"]
|
||||||
|
for word in guarantee_words:
|
||||||
|
if word.lower() in body.lower():
|
||||||
|
violations.append(
|
||||||
|
{
|
||||||
|
"type": "false_guarantee",
|
||||||
|
"message": f"Strong guarantee detected: {word}",
|
||||||
|
"severity": "high",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return violations
|
||||||
|
|
||||||
|
|
||||||
|
def _check_tone_appropriateness(body: str) -> List[Dict[str, Any]]:
|
||||||
|
violations = []
|
||||||
|
|
||||||
|
aggressive_words = [
|
||||||
|
"срочно",
|
||||||
|
"немедленно",
|
||||||
|
"прямо сейчас",
|
||||||
|
"только сегодня",
|
||||||
|
"последний шанс",
|
||||||
|
]
|
||||||
|
for word in aggressive_words:
|
||||||
|
if word.lower() in body.lower():
|
||||||
|
violations.append(
|
||||||
|
{
|
||||||
|
"type": "aggressive_tone",
|
||||||
|
"message": f"Aggressive language detected: {word}",
|
||||||
|
"severity": "medium",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if body.count("!") > 2:
|
||||||
|
violations.append(
|
||||||
|
{
|
||||||
|
"type": "excessive_exclamation",
|
||||||
|
"message": "Too many exclamation marks",
|
||||||
|
"severity": "low",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return violations
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_fixes(
|
||||||
|
subject: str, body: str, violations: List[Dict[str, Any]]
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
fixed_subject = subject
|
||||||
|
fixed_body = body
|
||||||
|
|
||||||
|
for violation in violations:
|
||||||
|
if violation["type"] == "false_guarantee":
|
||||||
|
replacements = {
|
||||||
|
"гарантируем": "помогаем",
|
||||||
|
"гарантия": "возможность",
|
||||||
|
"обещаем": "стремимся",
|
||||||
|
"100% результат": "хорошие результаты",
|
||||||
|
}
|
||||||
|
for old, new in replacements.items():
|
||||||
|
fixed_body = re.sub(old, new, fixed_body, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
elif violation["type"] == "aggressive_tone":
|
||||||
|
replacements = {
|
||||||
|
"срочно": "",
|
||||||
|
"немедленно": "",
|
||||||
|
"прямо сейчас": "",
|
||||||
|
"только сегодня": "",
|
||||||
|
"последний шанс": "",
|
||||||
|
}
|
||||||
|
for old, new in replacements.items():
|
||||||
|
fixed_body = re.sub(old, new, fixed_body, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
elif violation["type"] == "excessive_exclamation":
|
||||||
|
fixed_body = re.sub(r"!{2,}", "!", fixed_body)
|
||||||
|
fixed_subject = re.sub(r"!{2,}", "!", fixed_subject)
|
||||||
|
|
||||||
|
elif violation["type"] == "unverified_metrics":
|
||||||
|
for number in violation.get("numbers", []):
|
||||||
|
replacement = f"в кейсах клиентов {number}"
|
||||||
|
fixed_body = fixed_body.replace(number, replacement)
|
||||||
|
|
||||||
|
return fixed_subject.strip(), fixed_body.strip()
|
|
@ -0,0 +1,43 @@
|
||||||
|
from typing import Dict, Any
|
||||||
|
from src.graph.state import EmailGenerationState
|
||||||
|
from src.models.lead import LeadInput, LeadModel
|
||||||
|
from src.models.errors import ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
def input_validation_node(state: EmailGenerationState) -> EmailGenerationState:
|
||||||
|
try:
|
||||||
|
raw_input = state.get("raw_input")
|
||||||
|
if not raw_input:
|
||||||
|
state["error"] = "Raw input is required"
|
||||||
|
state["error_code"] = "MISSING_INPUT"
|
||||||
|
return state
|
||||||
|
|
||||||
|
lead_input = LeadInput(**raw_input)
|
||||||
|
|
||||||
|
if not lead_input.contact.strip():
|
||||||
|
raise ValidationError("Contact cannot be empty")
|
||||||
|
|
||||||
|
if not lead_input.position.strip():
|
||||||
|
raise ValidationError("Position cannot be empty")
|
||||||
|
|
||||||
|
if not lead_input.company_name.strip():
|
||||||
|
raise ValidationError("Company name cannot be empty")
|
||||||
|
|
||||||
|
if not lead_input.segment.strip():
|
||||||
|
raise ValidationError("Segment cannot be empty")
|
||||||
|
|
||||||
|
lead_model = LeadModel.from_input(lead_input)
|
||||||
|
|
||||||
|
state["lead_input"] = lead_input
|
||||||
|
state["lead_model"] = lead_model
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
state["error"] = str(e)
|
||||||
|
state["error_code"] = "VALIDATION_ERROR"
|
||||||
|
return state
|
||||||
|
except Exception as e:
|
||||||
|
state["error"] = f"Input validation failed: {str(e)}"
|
||||||
|
state["error_code"] = "VALIDATION_FAILED"
|
||||||
|
return state
|
|
@ -0,0 +1,48 @@
|
||||||
|
from src.graph.state import EmailGenerationState
|
||||||
|
from src.llm_clients.openai_client import OpenAIClient
|
||||||
|
from src.llm_clients.gemini_client import GeminiClient
|
||||||
|
from src.models.errors import LLMError
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def llm_generate_node(state: EmailGenerationState) -> EmailGenerationState:
|
||||||
|
try:
|
||||||
|
prompt_payload = state.get("prompt_payload")
|
||||||
|
if not prompt_payload:
|
||||||
|
state["error"] = "Prompt payload is required for LLM generation"
|
||||||
|
state["error_code"] = "MISSING_PROMPT"
|
||||||
|
return state
|
||||||
|
|
||||||
|
if settings.llm_provider == "openai":
|
||||||
|
llm_client = OpenAIClient()
|
||||||
|
elif settings.llm_provider == "gemini":
|
||||||
|
llm_client = GeminiClient()
|
||||||
|
else:
|
||||||
|
state["error"] = f"Unsupported LLM provider: {settings.llm_provider}"
|
||||||
|
state["error_code"] = "UNSUPPORTED_PROVIDER"
|
||||||
|
return state
|
||||||
|
|
||||||
|
llm_output = llm_client.generate_completion(
|
||||||
|
system_prompt=prompt_payload.system_prompt,
|
||||||
|
user_prompt=prompt_payload.user_prompt,
|
||||||
|
max_tokens=prompt_payload.max_tokens,
|
||||||
|
temperature=prompt_payload.temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
state["llm_output"] = llm_output
|
||||||
|
return state
|
||||||
|
|
||||||
|
except LLMError as e:
|
||||||
|
state["error"] = e.message
|
||||||
|
state["error_code"] = "LLM_ERROR"
|
||||||
|
state["trace_meta"] = {
|
||||||
|
"provider": e.provider,
|
||||||
|
"model": e.model,
|
||||||
|
"details": e.details,
|
||||||
|
}
|
||||||
|
return state
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
state["error"] = f"LLM generation error: {str(e)}"
|
||||||
|
state["error_code"] = "GENERATION_ERROR"
|
||||||
|
return state
|
|
@ -0,0 +1,141 @@
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from src.graph.state import EmailGenerationState
|
||||||
|
from src.models.email import EmailDraft
|
||||||
|
from src.models.errors import ParseError
|
||||||
|
|
||||||
|
|
||||||
|
def parse_output_node(state: EmailGenerationState) -> EmailGenerationState:
|
||||||
|
try:
|
||||||
|
llm_output = state.get("llm_output")
|
||||||
|
if not llm_output:
|
||||||
|
state["error"] = "LLM output is required for parsing"
|
||||||
|
state["error_code"] = "MISSING_LLM_OUTPUT"
|
||||||
|
return state
|
||||||
|
|
||||||
|
content = llm_output.content.strip()
|
||||||
|
|
||||||
|
content = _clean_json_content(content)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_data = json.loads(content)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
parsed_data = _fallback_parse(content)
|
||||||
|
|
||||||
|
if not isinstance(parsed_data, dict):
|
||||||
|
raise ParseError("Response is not a JSON object", content)
|
||||||
|
|
||||||
|
subject = parsed_data.get("subject", "").strip()
|
||||||
|
body = parsed_data.get("body", "").strip()
|
||||||
|
short_reasoning = parsed_data.get("short_reasoning", "")
|
||||||
|
used_chunks = parsed_data.get("used_chunks", [])
|
||||||
|
|
||||||
|
if not subject:
|
||||||
|
raise ParseError("Subject is required", content)
|
||||||
|
|
||||||
|
if not body:
|
||||||
|
raise ParseError("Body is required", content)
|
||||||
|
|
||||||
|
subject = _validate_subject(subject)
|
||||||
|
body = _validate_body(body)
|
||||||
|
|
||||||
|
email_draft = EmailDraft(
|
||||||
|
subject=subject,
|
||||||
|
body=body,
|
||||||
|
short_reasoning=short_reasoning,
|
||||||
|
used_chunks=used_chunks if isinstance(used_chunks, list) else [],
|
||||||
|
)
|
||||||
|
|
||||||
|
state["email_draft"] = email_draft
|
||||||
|
return state
|
||||||
|
|
||||||
|
except ParseError as e:
|
||||||
|
state["error"] = e.message
|
||||||
|
state["error_code"] = "PARSE_ERROR"
|
||||||
|
state["trace_meta"] = {"raw_output": e.raw_output[:500], "details": e.details}
|
||||||
|
return state
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
state["error"] = f"Output parsing error: {str(e)}"
|
||||||
|
state["error_code"] = "PARSING_ERROR"
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_json_content(content: str) -> str:
|
||||||
|
content = re.sub(r"^```json\s*", "", content)
|
||||||
|
content = re.sub(r"\s*```$", "", content)
|
||||||
|
content = re.sub(r"^```\s*", "", content)
|
||||||
|
content = content.strip()
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_parse(content: str) -> dict:
|
||||||
|
lines = content.split("\n")
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
current_key = None
|
||||||
|
current_value = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if ":" in line and line.startswith('"') and line.count('"') >= 4:
|
||||||
|
if current_key:
|
||||||
|
result[current_key] = "\n".join(current_value)
|
||||||
|
|
||||||
|
parts = line.split(":", 1)
|
||||||
|
current_key = parts[0].strip('"').strip()
|
||||||
|
current_value = [parts[1].strip().strip(",").strip('"')]
|
||||||
|
elif current_key:
|
||||||
|
current_value.append(line.strip(",").strip('"'))
|
||||||
|
|
||||||
|
if current_key:
|
||||||
|
result[current_key] = "\n".join(current_value)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_subject(subject: str) -> str:
|
||||||
|
if len(subject) > 80:
|
||||||
|
words = subject.split()
|
||||||
|
truncated = []
|
||||||
|
char_count = 0
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
if char_count + len(word) + 1 <= 77:
|
||||||
|
truncated.append(word)
|
||||||
|
char_count += len(word) + 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
subject = " ".join(truncated) + "..."
|
||||||
|
|
||||||
|
spam_patterns = [
|
||||||
|
r"(!{2,})",
|
||||||
|
r"(СКИДКА|АКЦИЯ|СРОЧНО|БЕСПЛАТНО)",
|
||||||
|
r"(\$|\€|\₽)",
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in spam_patterns:
|
||||||
|
subject = re.sub(pattern, "", subject, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
return subject.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_body(body: str) -> str:
|
||||||
|
if len(body) > 2000:
|
||||||
|
body = body[:1950] + "..."
|
||||||
|
|
||||||
|
required_elements = {"greeting": False, "company_mention": False, "cta": False}
|
||||||
|
|
||||||
|
greetings = ["добрый день", "здравствуйте", "приветствую"]
|
||||||
|
if any(greeting in body.lower() for greeting in greetings):
|
||||||
|
required_elements["greeting"] = True
|
||||||
|
|
||||||
|
if "консоль" in body.lower():
|
||||||
|
required_elements["company_mention"] = True
|
||||||
|
|
||||||
|
cta_phrases = ["звонок", "демо", "встреча", "обсудить", "покажу"]
|
||||||
|
if any(phrase in body.lower() for phrase in cta_phrases):
|
||||||
|
required_elements["cta"] = True
|
||||||
|
|
||||||
|
return body.strip()
|
|
@ -0,0 +1,40 @@
|
||||||
|
from src.graph.state import EmailGenerationState
|
||||||
|
from src.services.prompt_templates import PromptBuilder
|
||||||
|
from src.models.email import PromptPayload
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_build_node(state: EmailGenerationState) -> EmailGenerationState:
|
||||||
|
try:
|
||||||
|
lead_model = state.get("lead_model")
|
||||||
|
ranked_context = state.get("ranked_context")
|
||||||
|
|
||||||
|
if not lead_model:
|
||||||
|
state["error"] = "Lead model is required for prompt building"
|
||||||
|
state["error_code"] = "MISSING_LEAD"
|
||||||
|
return state
|
||||||
|
|
||||||
|
if not ranked_context:
|
||||||
|
state["error"] = "Ranked context is required for prompt building"
|
||||||
|
state["error_code"] = "MISSING_CONTEXT"
|
||||||
|
return state
|
||||||
|
|
||||||
|
prompt_builder = PromptBuilder()
|
||||||
|
system_prompt, user_prompt = prompt_builder.build_prompt(
|
||||||
|
lead_model, ranked_context
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_payload = PromptPayload(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
max_tokens=settings.max_tokens_completion,
|
||||||
|
temperature=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
state["prompt_payload"] = prompt_payload
|
||||||
|
return state
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
state["error"] = f"Prompt building error: {str(e)}"
|
||||||
|
state["error_code"] = "PROMPT_BUILD_ERROR"
|
||||||
|
return state
|
|
@ -0,0 +1,23 @@
|
||||||
|
from src.graph.state import EmailGenerationState
|
||||||
|
from src.models.email import EmailResponse
|
||||||
|
|
||||||
|
|
||||||
|
def return_result_node(state: EmailGenerationState) -> EmailGenerationState:
|
||||||
|
try:
|
||||||
|
email_clean = state.get("email_clean")
|
||||||
|
if not email_clean:
|
||||||
|
state["error"] = "Clean email is required for result return"
|
||||||
|
state["error_code"] = "MISSING_CLEAN_EMAIL"
|
||||||
|
return state
|
||||||
|
|
||||||
|
email_response = EmailResponse(
|
||||||
|
subject=email_clean.subject, body=email_clean.body, meta=email_clean.meta
|
||||||
|
)
|
||||||
|
|
||||||
|
state["email_response"] = email_response
|
||||||
|
return state
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
state["error"] = f"Result return error: {str(e)}"
|
||||||
|
state["error_code"] = "RETURN_ERROR"
|
||||||
|
return state
|
|
@ -0,0 +1,39 @@
|
||||||
|
from src.graph.state import EmailGenerationState
|
||||||
|
from src.services.chroma_store import ChromaStore
|
||||||
|
from src.services.embeddings import EmbeddingService
|
||||||
|
from src.services.retrieval import RetrievalService
|
||||||
|
from src.models.errors import RetrievalError
|
||||||
|
|
||||||
|
|
||||||
|
def vector_search_node(state: EmailGenerationState) -> EmailGenerationState:
|
||||||
|
try:
|
||||||
|
retrieval_query = state.get("retrieval_query")
|
||||||
|
if not retrieval_query:
|
||||||
|
state["error"] = "Retrieval query is required for vector search"
|
||||||
|
state["error_code"] = "MISSING_QUERY"
|
||||||
|
return state
|
||||||
|
|
||||||
|
chroma_store = ChromaStore()
|
||||||
|
embedding_service = EmbeddingService()
|
||||||
|
retrieval_service = RetrievalService(chroma_store, embedding_service)
|
||||||
|
|
||||||
|
chunks = retrieval_service.search_relevant_chunks(retrieval_query)
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
state["error"] = "No relevant documents found"
|
||||||
|
state["error_code"] = "NO_RESULTS"
|
||||||
|
return state
|
||||||
|
|
||||||
|
state["retrieved_chunks"] = chunks
|
||||||
|
return state
|
||||||
|
|
||||||
|
except RetrievalError as e:
|
||||||
|
state["error"] = e.message
|
||||||
|
state["error_code"] = "RETRIEVAL_ERROR"
|
||||||
|
state["trace_meta"] = e.details
|
||||||
|
return state
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
state["error"] = f"Vector search error: {str(e)}"
|
||||||
|
state["error_code"] = "SEARCH_ERROR"
|
||||||
|
return state
|
|
@ -0,0 +1,30 @@
|
||||||
|
from typing import Optional, List, Dict, Any, TypedDict
|
||||||
|
from src.models.lead import LeadInput, LeadModel, LeadFeatures
|
||||||
|
from src.models.email import (
|
||||||
|
DocChunk,
|
||||||
|
RetrievalQuery,
|
||||||
|
RankedContext,
|
||||||
|
PromptPayload,
|
||||||
|
LLMRawOutput,
|
||||||
|
EmailDraft,
|
||||||
|
EmailDraftClean,
|
||||||
|
EmailResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmailGenerationState(TypedDict):
|
||||||
|
raw_input: Optional[Dict[str, Any]]
|
||||||
|
lead_input: Optional[LeadInput]
|
||||||
|
lead_model: Optional[LeadModel]
|
||||||
|
lead_features: Optional[LeadFeatures]
|
||||||
|
retrieval_query: Optional[RetrievalQuery]
|
||||||
|
retrieved_chunks: Optional[List[DocChunk]]
|
||||||
|
ranked_context: Optional[RankedContext]
|
||||||
|
prompt_payload: Optional[PromptPayload]
|
||||||
|
llm_output: Optional[LLMRawOutput]
|
||||||
|
email_draft: Optional[EmailDraft]
|
||||||
|
email_clean: Optional[EmailDraftClean]
|
||||||
|
email_response: Optional[EmailResponse]
|
||||||
|
error: Optional[str]
|
||||||
|
error_code: Optional[str]
|
||||||
|
trace_meta: Optional[Dict[str, Any]]
|
|
@ -0,0 +1,329 @@
|
||||||
|
import re
|
||||||
|
import tiktoken
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from src.ingest.loader import Document
|
||||||
|
from src.models.email import DocChunk
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkConfig:
|
||||||
|
chunk_size: int = 500
|
||||||
|
chunk_overlap: int = 100
|
||||||
|
min_chunk_size: int = 50
|
||||||
|
preserve_context: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentChunker:
|
||||||
|
def __init__(self, config: ChunkConfig = None):
|
||||||
|
self.config = config or ChunkConfig()
|
||||||
|
self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
def chunk_document(self, document: Document) -> List[DocChunk]:
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
sections = self._split_by_semantic_blocks(document.content)
|
||||||
|
|
||||||
|
if self.config.preserve_context:
|
||||||
|
sections = self._add_context_bridges(sections)
|
||||||
|
|
||||||
|
chunk_counter = 0
|
||||||
|
for section_title, section_content in sections:
|
||||||
|
section_chunks = self._chunk_section(
|
||||||
|
section_content, document, section_title, chunk_counter
|
||||||
|
)
|
||||||
|
chunks.extend(section_chunks)
|
||||||
|
chunk_counter += len(section_chunks)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def chunk_documents(self, documents: List[Document]) -> List[DocChunk]:
|
||||||
|
all_chunks = []
|
||||||
|
|
||||||
|
for document in documents:
|
||||||
|
doc_chunks = self.chunk_document(document)
|
||||||
|
all_chunks.extend(doc_chunks)
|
||||||
|
|
||||||
|
return all_chunks
|
||||||
|
|
||||||
|
def _add_context_bridges(
|
||||||
|
self, sections: List[tuple[str, str]]
|
||||||
|
) -> List[tuple[str, str]]:
|
||||||
|
if len(sections) <= 1:
|
||||||
|
return sections
|
||||||
|
|
||||||
|
enhanced_sections = []
|
||||||
|
|
||||||
|
for i, (title, content) in enumerate(sections):
|
||||||
|
enhanced_content = content
|
||||||
|
|
||||||
|
if i > 0:
|
||||||
|
prev_title, prev_content = sections[i - 1]
|
||||||
|
if self._should_add_context(prev_title, title):
|
||||||
|
context_snippet = self._extract_key_context(prev_content)
|
||||||
|
if context_snippet:
|
||||||
|
enhanced_content = f"Контекст: {context_snippet}\n\n{content}"
|
||||||
|
|
||||||
|
enhanced_sections.append((title, enhanced_content))
|
||||||
|
|
||||||
|
return enhanced_sections
|
||||||
|
|
||||||
|
def _should_add_context(self, prev_title: str, current_title: str) -> bool:
|
||||||
|
context_pairs = [
|
||||||
|
("Проблема", "Решение"),
|
||||||
|
("Решение", "Результат"),
|
||||||
|
("О клиенте", "Проблема"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for prev_type, curr_type in context_pairs:
|
||||||
|
if (
|
||||||
|
prev_type.lower() in prev_title.lower()
|
||||||
|
and curr_type.lower() in current_title.lower()
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _extract_key_context(self, content: str) -> str:
|
||||||
|
sentences = content.split(".")[:2]
|
||||||
|
context = ". ".join(sentences).strip()
|
||||||
|
if len(context) > 150:
|
||||||
|
context = context[:150] + "..."
|
||||||
|
return context
|
||||||
|
|
||||||
|
def _split_by_semantic_blocks(self, content: str) -> List[tuple[str, str]]:
|
||||||
|
sections = []
|
||||||
|
lines = content.split("\n")
|
||||||
|
|
||||||
|
current_section = ""
|
||||||
|
current_title = "Введение"
|
||||||
|
|
||||||
|
semantic_markers = [
|
||||||
|
("Проблема", ["проблема", "вызов", "трудност", "сложност"]),
|
||||||
|
("Решение", ["решение", "как", "что сделали", "подход"]),
|
||||||
|
("Результат", ["результат", "итог", "достижени", "эффект"]),
|
||||||
|
("О клиенте", ["о клиенте", "о компании", "клиент"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line_stripped = line.strip()
|
||||||
|
|
||||||
|
header_match = re.match(r"^(#{1,3})\s+(.+)$", line_stripped)
|
||||||
|
|
||||||
|
if header_match:
|
||||||
|
if current_section.strip():
|
||||||
|
sections.append((current_title, current_section.strip()))
|
||||||
|
|
||||||
|
new_title = header_match.group(2)
|
||||||
|
|
||||||
|
semantic_title = self._classify_section(new_title.lower())
|
||||||
|
current_title = semantic_title if semantic_title else new_title
|
||||||
|
current_section = ""
|
||||||
|
else:
|
||||||
|
current_section += line + "\n"
|
||||||
|
|
||||||
|
if current_section.strip():
|
||||||
|
sections.append((current_title, current_section.strip()))
|
||||||
|
|
||||||
|
if not sections:
|
||||||
|
sections = [("Основной контент", content)]
|
||||||
|
|
||||||
|
merged_sections = self._merge_small_sections(sections)
|
||||||
|
|
||||||
|
return merged_sections
|
||||||
|
|
||||||
|
def _classify_section(self, title_lower: str) -> str:
|
||||||
|
classifications = {
|
||||||
|
"Проблема": [
|
||||||
|
"проблема",
|
||||||
|
"вызов",
|
||||||
|
"трудност",
|
||||||
|
"сложност",
|
||||||
|
"боль",
|
||||||
|
"до внедрения",
|
||||||
|
"было",
|
||||||
|
"раньше",
|
||||||
|
],
|
||||||
|
"Решение": [
|
||||||
|
"решение",
|
||||||
|
"как",
|
||||||
|
"что сделали",
|
||||||
|
"подход",
|
||||||
|
"внедрение",
|
||||||
|
"переход",
|
||||||
|
"процесс",
|
||||||
|
"теперь",
|
||||||
|
],
|
||||||
|
"Результат": [
|
||||||
|
"результат",
|
||||||
|
"итог",
|
||||||
|
"достижени",
|
||||||
|
"эффект",
|
||||||
|
"выгода",
|
||||||
|
"метрики",
|
||||||
|
"экономия",
|
||||||
|
"улучшен",
|
||||||
|
],
|
||||||
|
"О клиенте": ["о клиенте", "о компании", "клиент", "заказчик", "компания"],
|
||||||
|
"Процесс": [
|
||||||
|
"как работает",
|
||||||
|
"алгоритм",
|
||||||
|
"этапы",
|
||||||
|
"процесс",
|
||||||
|
"схема",
|
||||||
|
"шаги",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
for semantic_name, keywords in classifications.items():
|
||||||
|
if any(keyword in title_lower for keyword in keywords):
|
||||||
|
return semantic_name
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _merge_small_sections(
|
||||||
|
self, sections: List[tuple[str, str]]
|
||||||
|
) -> List[tuple[str, str]]:
|
||||||
|
merged = []
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
while i < len(sections):
|
||||||
|
title, content = sections[i]
|
||||||
|
token_count = len(self.encoding.encode(content))
|
||||||
|
|
||||||
|
if token_count < self.config.min_chunk_size and i + 1 < len(sections):
|
||||||
|
next_title, next_content = sections[i + 1]
|
||||||
|
merged_content = f"{content}\n\n{next_content}"
|
||||||
|
merged_title = f"{title} / {next_title}"
|
||||||
|
merged.append((merged_title, merged_content))
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
merged.append((title, content))
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def _chunk_section(
|
||||||
|
self,
|
||||||
|
section_content: str,
|
||||||
|
document: Document,
|
||||||
|
section_title: str,
|
||||||
|
start_counter: int,
|
||||||
|
) -> List[DocChunk]:
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
tokens = self.encoding.encode(section_content)
|
||||||
|
total_tokens = len(tokens)
|
||||||
|
|
||||||
|
if total_tokens <= self.config.chunk_size:
|
||||||
|
chunk = self._create_chunk(
|
||||||
|
section_content, document, section_title, start_counter
|
||||||
|
)
|
||||||
|
return [chunk]
|
||||||
|
|
||||||
|
sentences = self._split_into_sentences(section_content)
|
||||||
|
|
||||||
|
current_chunk = ""
|
||||||
|
current_tokens = 0
|
||||||
|
chunk_idx = 0
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
sentence_tokens = len(self.encoding.encode(sentence))
|
||||||
|
|
||||||
|
if current_tokens + sentence_tokens > self.config.chunk_size:
|
||||||
|
if current_chunk.strip():
|
||||||
|
chunk = self._create_chunk(
|
||||||
|
current_chunk.strip(),
|
||||||
|
document,
|
||||||
|
section_title,
|
||||||
|
start_counter + chunk_idx,
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
chunk_idx += 1
|
||||||
|
|
||||||
|
overlap_text = self._get_overlap_text(
|
||||||
|
current_chunk, self.config.chunk_overlap
|
||||||
|
)
|
||||||
|
current_chunk = overlap_text + sentence
|
||||||
|
current_tokens = len(self.encoding.encode(current_chunk))
|
||||||
|
else:
|
||||||
|
current_chunk += sentence
|
||||||
|
current_tokens += sentence_tokens
|
||||||
|
|
||||||
|
if current_chunk.strip():
|
||||||
|
chunk = self._create_chunk(
|
||||||
|
current_chunk.strip(),
|
||||||
|
document,
|
||||||
|
section_title,
|
||||||
|
start_counter + chunk_idx,
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _split_into_sentences(self, text: str) -> List[str]:
|
||||||
|
sentence_endings = r"[.!?]+\s+"
|
||||||
|
sentences = re.split(sentence_endings, text)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for i, sentence in enumerate(sentences):
|
||||||
|
if sentence.strip():
|
||||||
|
if i < len(sentences) - 1:
|
||||||
|
sentence += ". "
|
||||||
|
result.append(sentence)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_overlap_text(self, text: str, overlap_tokens: int) -> str:
|
||||||
|
tokens = self.encoding.encode(text)
|
||||||
|
if len(tokens) <= overlap_tokens:
|
||||||
|
return text
|
||||||
|
|
||||||
|
overlap_tokens_slice = tokens[-overlap_tokens:]
|
||||||
|
overlap_text = self.encoding.decode(overlap_tokens_slice)
|
||||||
|
|
||||||
|
sentences = overlap_text.split(".")
|
||||||
|
if len(sentences) > 1:
|
||||||
|
return ". ".join(sentences[1:]) + ". "
|
||||||
|
|
||||||
|
return overlap_text + " "
|
||||||
|
|
||||||
|
def _create_chunk(
|
||||||
|
self, content: str, document: Document, section_title: str, chunk_index: int
|
||||||
|
) -> DocChunk:
|
||||||
|
chunk_id = f"{document.doc_id}#c{chunk_index}"
|
||||||
|
|
||||||
|
token_count = len(self.encoding.encode(content))
|
||||||
|
|
||||||
|
metadata = document.metadata.copy()
|
||||||
|
metadata.update(
|
||||||
|
{
|
||||||
|
"section_title": section_title,
|
||||||
|
"parent_doc_title": document.metadata.get("title", "Unknown"),
|
||||||
|
"chunk_index": chunk_index,
|
||||||
|
"doc_source": document.source,
|
||||||
|
"semantic_type": self._get_semantic_type(section_title),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return DocChunk(
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
parent_doc_id=document.doc_id,
|
||||||
|
content_text=content,
|
||||||
|
token_count=token_count,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_semantic_type(self, section_title: str) -> str:
|
||||||
|
title_lower = section_title.lower()
|
||||||
|
|
||||||
|
if any(word in title_lower for word in ["проблема", "вызов", "трудност"]):
|
||||||
|
return "problem"
|
||||||
|
elif any(word in title_lower for word in ["решение", "подход", "внедрение"]):
|
||||||
|
return "solution"
|
||||||
|
elif any(word in title_lower for word in ["результат", "итог", "эффект"]):
|
||||||
|
return "result"
|
||||||
|
elif any(word in title_lower for word in ["клиент", "компани"]):
|
||||||
|
return "client_info"
|
||||||
|
else:
|
||||||
|
return "general"
|
|
@ -0,0 +1,241 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import hashlib
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
sys.path.append(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.ingest.loader import MarkdownLoader, create_platform_overview
|
||||||
|
from src.ingest.chunker import DocumentChunker, ChunkConfig
|
||||||
|
from src.services.chroma_store import ChromaStore
|
||||||
|
from src.services.embeddings import EmbeddingService
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_hash(file_path: str) -> str:
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
return hashlib.md5(f.read()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def run_ingest(
|
||||||
|
data_dir: str,
|
||||||
|
platform_md: str = None,
|
||||||
|
persist_dir: str = None,
|
||||||
|
recreate: bool = False,
|
||||||
|
incremental: bool = True,
|
||||||
|
chunk_size: int = 500,
|
||||||
|
chunk_overlap: int = 100,
|
||||||
|
embedding_model: str = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
print(f"Starting ingestion process...")
|
||||||
|
print(f"Data directory: {data_dir}")
|
||||||
|
print(f"Persist directory: {persist_dir or settings.chroma_persist_dir}")
|
||||||
|
print(f"Recreate: {recreate}")
|
||||||
|
print(f"Incremental: {incremental}")
|
||||||
|
|
||||||
|
if not os.path.exists(data_dir):
|
||||||
|
raise FileNotFoundError(f"Data directory not found: {data_dir}")
|
||||||
|
|
||||||
|
chroma_store = ChromaStore(persist_dir)
|
||||||
|
|
||||||
|
if recreate:
|
||||||
|
print("Recreating collection...")
|
||||||
|
chroma_store.recreate_collection()
|
||||||
|
incremental = False
|
||||||
|
|
||||||
|
loader = MarkdownLoader()
|
||||||
|
|
||||||
|
documents_to_process = []
|
||||||
|
existing_doc_ids = set()
|
||||||
|
|
||||||
|
if incremental and not recreate:
|
||||||
|
existing_doc_ids = chroma_store.get_existing_doc_ids()
|
||||||
|
print(f"Found {len(existing_doc_ids)} existing documents in collection")
|
||||||
|
|
||||||
|
print("Loading markdown files...")
|
||||||
|
all_documents = loader.load_directory(data_dir)
|
||||||
|
|
||||||
|
new_docs = 0
|
||||||
|
updated_docs = 0
|
||||||
|
skipped_docs = 0
|
||||||
|
|
||||||
|
for doc in all_documents:
|
||||||
|
doc_id = doc.doc_id
|
||||||
|
|
||||||
|
if incremental and doc_id in existing_doc_ids:
|
||||||
|
if chroma_store.document_exists(doc_id):
|
||||||
|
print(f"Skipping existing document: {doc_id}")
|
||||||
|
skipped_docs += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if incremental and doc_id in existing_doc_ids:
|
||||||
|
print(f"Updating document: {doc_id}")
|
||||||
|
chroma_store.delete_document_chunks(doc_id)
|
||||||
|
updated_docs += 1
|
||||||
|
else:
|
||||||
|
new_docs += 1
|
||||||
|
|
||||||
|
documents_to_process.append(doc)
|
||||||
|
|
||||||
|
if platform_md and os.path.exists(platform_md):
|
||||||
|
platform_doc = loader.load_file(platform_md)
|
||||||
|
if not incremental or "platform_overview" not in existing_doc_ids:
|
||||||
|
documents_to_process.append(platform_doc)
|
||||||
|
print("Added platform overview from file")
|
||||||
|
else:
|
||||||
|
platform_doc = create_platform_overview()
|
||||||
|
if not incremental or "platform_overview" not in existing_doc_ids:
|
||||||
|
documents_to_process.append(platform_doc)
|
||||||
|
print("Added generated platform overview")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Processing {len(documents_to_process)} documents ({new_docs} new, {updated_docs} updated, {skipped_docs} skipped)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not documents_to_process:
|
||||||
|
print("No documents to process. Collection is up to date.")
|
||||||
|
return {
|
||||||
|
"documents_loaded": 0,
|
||||||
|
"chunks_created": 0,
|
||||||
|
"embeddings_generated": 0,
|
||||||
|
"total_in_collection": chroma_store.get_count(),
|
||||||
|
"new_documents": new_docs,
|
||||||
|
"updated_documents": updated_docs,
|
||||||
|
"skipped_documents": skipped_docs,
|
||||||
|
"embedding_model": embedding_model or settings.embedding_model,
|
||||||
|
"chunk_config": {"size": chunk_size, "overlap": chunk_overlap},
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_config = ChunkConfig(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
|
chunker = DocumentChunker(chunk_config)
|
||||||
|
|
||||||
|
print("Chunking documents...")
|
||||||
|
chunks = chunker.chunk_documents(documents_to_process)
|
||||||
|
print(f"Created {len(chunks)} chunks")
|
||||||
|
|
||||||
|
if chunks:
|
||||||
|
embedding_service = EmbeddingService()
|
||||||
|
|
||||||
|
print("Generating embeddings...")
|
||||||
|
texts = [chunk.content_text for chunk in chunks]
|
||||||
|
embeddings = embedding_service.embed_batch(texts)
|
||||||
|
print(f"Generated embeddings for {len(embeddings)} chunks")
|
||||||
|
|
||||||
|
print("Storing in ChromaDB...")
|
||||||
|
chroma_store.add_chunks(chunks, embeddings)
|
||||||
|
|
||||||
|
final_count = chroma_store.get_count()
|
||||||
|
print(f"Ingestion complete. Total documents in collection: {final_count}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"documents_loaded": len(documents_to_process),
|
||||||
|
"chunks_created": len(chunks),
|
||||||
|
"embeddings_generated": len(embeddings) if chunks else 0,
|
||||||
|
"total_in_collection": final_count,
|
||||||
|
"new_documents": new_docs,
|
||||||
|
"updated_documents": updated_docs,
|
||||||
|
"skipped_documents": skipped_docs,
|
||||||
|
"embedding_model": embedding_model or settings.embedding_model,
|
||||||
|
"chunk_config": {"size": chunk_size, "overlap": chunk_overlap},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_data_directory():
|
||||||
|
os.makedirs("data/articles_konsol_pro", exist_ok=True)
|
||||||
|
|
||||||
|
sample_content = """# Пример кейса
|
||||||
|
|
||||||
|
Компания FIVE (маркетинговое агентство) автоматизировала работу с самозанятыми через Консоль.Про.
|
||||||
|
|
||||||
|
## Результаты
|
||||||
|
- Сокращение времени онбординга с 2 дней до 15 минут
|
||||||
|
- Автоматизация выплат - теперь занимает минуты вместо часов
|
||||||
|
- Снижение ошибок на 95%
|
||||||
|
- Управление 200+ исполнителями одним сотрудником
|
||||||
|
|
||||||
|
## Внедрение
|
||||||
|
Технический директор Никита Помящий отмечает, что внедрение заняло всего 1 день.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open("data/articles_konsol_pro/sample_case.md", "w", encoding="utf-8") as f:
|
||||||
|
f.write(sample_content)
|
||||||
|
|
||||||
|
print("Created sample data directory and file")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Ingest knowledge base data into ChromaDB"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
default="data/articles_konsol_pro",
|
||||||
|
help="Directory containing markdown files",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--platform-md",
|
||||||
|
default="data/platform_overview.md",
|
||||||
|
help="Platform overview markdown file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--persist-dir", default=None, help="ChromaDB persist directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--recreate",
|
||||||
|
action="store_true",
|
||||||
|
help="Recreate the collection (delete existing)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-incremental",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable incremental updates (process all files)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-size", type=int, default=500, help="Chunk size in tokens"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=100, help="Chunk overlap in tokens"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model", default=None, help="Embedding model to use"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--create-sample",
|
||||||
|
action="store_true",
|
||||||
|
help="Create sample data directory and files",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.create_sample:
|
||||||
|
create_data_directory()
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = run_ingest(
|
||||||
|
data_dir=args.data_dir,
|
||||||
|
platform_md=args.platform_md,
|
||||||
|
persist_dir=args.persist_dir,
|
||||||
|
recreate=args.recreate,
|
||||||
|
incremental=not args.no_incremental,
|
||||||
|
chunk_size=args.chunk_size,
|
||||||
|
chunk_overlap=args.chunk_overlap,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nIngestion Summary:")
|
||||||
|
for key, value in result.items():
|
||||||
|
print(f" {key}: {value}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during ingestion: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,337 @@
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import markdown
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Document:
|
||||||
|
content: str
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
doc_id: str
|
||||||
|
source: str
|
||||||
|
|
||||||
|
|
||||||
|
class MarkdownLoader:
|
||||||
|
def __init__(self):
|
||||||
|
self.md = markdown.Markdown(extensions=["meta", "toc"])
|
||||||
|
|
||||||
|
def load_file(self, file_path: str) -> Document:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
html_content = self.md.convert(content)
|
||||||
|
|
||||||
|
title = self._extract_title(content)
|
||||||
|
doc_id = self._generate_doc_id(file_path)
|
||||||
|
|
||||||
|
metadata = self._extract_metadata(content, file_path)
|
||||||
|
metadata.update(
|
||||||
|
{
|
||||||
|
"title": title,
|
||||||
|
"doc_type": "case" if "case" in file_path.lower() else "info",
|
||||||
|
"source": file_path,
|
||||||
|
"file_size": len(content),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
clean_content = self._clean_content(content)
|
||||||
|
|
||||||
|
return Document(
|
||||||
|
content=clean_content, metadata=metadata, doc_id=doc_id, source=file_path
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_directory(self, dir_path: str) -> List[Document]:
|
||||||
|
documents = []
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(dir_path):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".md"):
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
try:
|
||||||
|
doc = self.load_file(file_path)
|
||||||
|
documents.append(doc)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading {file_path}: {e}")
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def _extract_title(self, content: str) -> str:
|
||||||
|
lines = content.strip().split("\n")
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith("# "):
|
||||||
|
title = line[2:].strip()
|
||||||
|
if title and not title.startswith("["):
|
||||||
|
return title
|
||||||
|
|
||||||
|
return "Untitled"
|
||||||
|
|
||||||
|
def _generate_doc_id(self, file_path: str) -> str:
|
||||||
|
filename = os.path.basename(file_path)
|
||||||
|
name_without_ext = os.path.splitext(filename)[0]
|
||||||
|
return name_without_ext.replace(" ", "_").replace("-", "_").lower()
|
||||||
|
|
||||||
|
def _extract_metadata(self, content: str, file_path: str) -> Dict[str, Any]:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
filename = os.path.basename(file_path).lower()
|
||||||
|
content_lower = content.lower()
|
||||||
|
|
||||||
|
industry_mapping = {
|
||||||
|
"маркетинг": "marketing_agency",
|
||||||
|
"агентство": "marketing_agency",
|
||||||
|
"реклам": "marketing_agency",
|
||||||
|
"блогер": "marketing_agency",
|
||||||
|
"mediar": "marketing_agency",
|
||||||
|
"büro": "marketing_agency",
|
||||||
|
"логист": "logistics",
|
||||||
|
"достав": "logistics",
|
||||||
|
"склад": "logistics",
|
||||||
|
"грузчик": "logistics",
|
||||||
|
"разраб": "software",
|
||||||
|
"програм": "software",
|
||||||
|
"progkids": "software",
|
||||||
|
"it": "software",
|
||||||
|
"диджитал": "software",
|
||||||
|
"строит": "construction",
|
||||||
|
"недвиж": "construction",
|
||||||
|
"этажи": "construction",
|
||||||
|
"рознич": "retail",
|
||||||
|
"торгов": "retail",
|
||||||
|
"консалт": "consulting",
|
||||||
|
"экобренд": "manufacturing",
|
||||||
|
"wonder": "manufacturing",
|
||||||
|
"производ": "manufacturing",
|
||||||
|
"колл-центр": "call_center",
|
||||||
|
"звонки": "call_center",
|
||||||
|
}
|
||||||
|
|
||||||
|
industries = []
|
||||||
|
for keyword, industry in industry_mapping.items():
|
||||||
|
if keyword in content_lower or keyword in filename:
|
||||||
|
if industry not in industries:
|
||||||
|
industries.append(industry)
|
||||||
|
|
||||||
|
if not industries:
|
||||||
|
industries = ["other"]
|
||||||
|
|
||||||
|
metadata["industry"] = industries
|
||||||
|
|
||||||
|
roles_mapping = {
|
||||||
|
"технический директор": "tech",
|
||||||
|
"техн": "tech",
|
||||||
|
"cto": "tech",
|
||||||
|
"операционный директор": "ops",
|
||||||
|
"директор": "ceo",
|
||||||
|
"руководи": "ceo",
|
||||||
|
"основатель": "ceo",
|
||||||
|
"фин": "finance",
|
||||||
|
"бухгалт": "finance",
|
||||||
|
"cfo": "finance",
|
||||||
|
"операц": "ops",
|
||||||
|
"coo": "ops",
|
||||||
|
"hr": "hr",
|
||||||
|
"кадр": "hr",
|
||||||
|
"маркет": "marketing",
|
||||||
|
"продаж": "sales",
|
||||||
|
"менеджер": "other",
|
||||||
|
}
|
||||||
|
|
||||||
|
roles = []
|
||||||
|
for keyword, role in roles_mapping.items():
|
||||||
|
if keyword in content_lower:
|
||||||
|
if role not in roles:
|
||||||
|
roles.append(role)
|
||||||
|
|
||||||
|
if not roles:
|
||||||
|
roles = ["other"]
|
||||||
|
|
||||||
|
metadata["roles_relevant"] = roles
|
||||||
|
|
||||||
|
metrics = self._extract_metrics(content)
|
||||||
|
if metrics:
|
||||||
|
metadata["metrics"] = metrics
|
||||||
|
|
||||||
|
metadata["language"] = "ru"
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
metadata["created_at"] = datetime.datetime.now().isoformat()
|
||||||
|
metadata["updated_at"] = datetime.datetime.now().isoformat()
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _extract_metrics(self, content: str) -> Dict[str, Any]:
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
time_patterns = [
|
||||||
|
(r"(\d+)\s*минут[ауы]?", "processing_minutes"),
|
||||||
|
(r"(\d+)\s*час[ауов]?", "processing_hours"),
|
||||||
|
(r"(\d+)\s*дн[ейяах]", "processing_days"),
|
||||||
|
(
|
||||||
|
r"с\s+(\d+)\s*дн[ейя]\s+до\s+(\d+)\s*минут",
|
||||||
|
"improvement_days_to_minutes",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"с\s+(\d+)\s*час[ауов]?\s+до\s+(\d+)\s*минут",
|
||||||
|
"improvement_hours_to_minutes",
|
||||||
|
),
|
||||||
|
(r"(\d+)\s*секунд", "processing_seconds"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, key in time_patterns:
|
||||||
|
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||||
|
if matches:
|
||||||
|
try:
|
||||||
|
if key.startswith("improvement_"):
|
||||||
|
if len(matches[0]) == 2:
|
||||||
|
metrics[f"{key}_before"] = int(matches[0][0])
|
||||||
|
metrics[f"{key}_after"] = int(matches[0][1])
|
||||||
|
else:
|
||||||
|
metrics[key] = int(matches[0])
|
||||||
|
else:
|
||||||
|
metrics[key] = int(matches[0])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
percentage_patterns = [
|
||||||
|
(r"(\d+)%\s*снижени", "error_reduction_pct"),
|
||||||
|
(r"снижение[^0-9]*(\d+)%", "error_reduction_pct"),
|
||||||
|
(r"(\d+)%\s*документ", "document_collection_pct"),
|
||||||
|
(r"(\d+)%\s*точност", "accuracy_pct"),
|
||||||
|
(r"увеличи[лв]\w*\s+в\s+(\d+)\s*раз", "growth_multiplier"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, key in percentage_patterns:
|
||||||
|
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||||
|
if matches:
|
||||||
|
try:
|
||||||
|
metrics[key] = int(matches[0])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
volume_patterns = [
|
||||||
|
(r"(\d+)\s*блогер", "bloggers_count"),
|
||||||
|
(r"(\d+)\s*исполнител", "contractors_count"),
|
||||||
|
(r"(\d+)\s*сотрудник", "employees_count"),
|
||||||
|
(r"бол[ьеее]+\s+(\d+)", "more_than_count"),
|
||||||
|
(r"свыше\s+(\d+)", "over_count"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, key in volume_patterns:
|
||||||
|
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||||
|
if matches:
|
||||||
|
try:
|
||||||
|
metrics[key] = int(matches[0])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def _clean_content(self, content: str) -> str:
|
||||||
|
content = re.sub(r"^\s*#+\s*", "", content, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
content = re.sub(r"\*\*(.*?)\*\*", r"\1", content)
|
||||||
|
content = re.sub(r"\*(.*?)\*", r"\1", content)
|
||||||
|
|
||||||
|
content = re.sub(r"\[([^\]]+)\]\([^\)]*\)", r"\1", content)
|
||||||
|
|
||||||
|
content = re.sub(r"^\s*[-*+]\s+", "• ", content, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
content = re.sub(
|
||||||
|
r"\d+\s+\d+\s+\[Комментировать\]\(\)\s+\d{2}\.\d{2}\.\d{2}", "", content
|
||||||
|
)
|
||||||
|
|
||||||
|
noise_patterns = [
|
||||||
|
r"Автор и редактор журнала Консоль",
|
||||||
|
r"Автор\s+\[.*?\]\(\)",
|
||||||
|
r"Поделиться",
|
||||||
|
r"Ваше мнение\?",
|
||||||
|
r"Отлично\s+Хорошо\s+Нормально\s+Плохо\s+Ужасно",
|
||||||
|
r"Сайт использует файлы cookie.*?Принять",
|
||||||
|
r"\[Политика конфиденциальности\]\(\)",
|
||||||
|
r"\[Пользовательское соглашение\]\(\)",
|
||||||
|
r"hello@konsol\.pro",
|
||||||
|
r"\+7 \(\d{3}\) \d{3}-\d{2}-\d{2}",
|
||||||
|
r"125047.*?дом \d+",
|
||||||
|
r"\[Разработка — SKDO\]\(\)",
|
||||||
|
r"\[Подключиться к Консоли\]\(\)",
|
||||||
|
r"\[Кейсы наших клиентов\]\(\)",
|
||||||
|
r"\[Делимся экспертизой\]\(\)",
|
||||||
|
r"^\s*\d+\s*$",
|
||||||
|
r"^\s*\[\d+\]\(\)\s*\d{2}\.\d{2}\.\d{2}\s*$",
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in noise_patterns:
|
||||||
|
content = re.sub(pattern, "", content, flags=re.MULTILINE | re.IGNORECASE)
|
||||||
|
|
||||||
|
related_articles_pattern = r"###\s+\[.*?\]\(\).*?(?=###|\Z)"
|
||||||
|
content = re.sub(related_articles_pattern, "", content, flags=re.DOTALL)
|
||||||
|
|
||||||
|
content = re.sub(r"\n{3,}", "\n\n", content)
|
||||||
|
|
||||||
|
lines = content.split("\n")
|
||||||
|
filtered_lines = []
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if line and not (line.startswith("[") and line.endswith("]()")):
|
||||||
|
if not re.match(r"^\d+\s*$", line):
|
||||||
|
if len(line) > 10 or line.startswith("•"):
|
||||||
|
filtered_lines.append(line)
|
||||||
|
|
||||||
|
content = "\n".join(filtered_lines)
|
||||||
|
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def create_platform_overview() -> Document:
|
||||||
|
content = """
|
||||||
|
Консоль.Про – платформа автоматизации работы с самозанятыми, ИП и физлицами.
|
||||||
|
|
||||||
|
Основные возможности:
|
||||||
|
• Подключение нового исполнителя за ~15 минут
|
||||||
|
• Выплаты в течение минут вместо часов
|
||||||
|
• Сбор 100% закрывающих документов
|
||||||
|
• Снижение ошибок до 95%
|
||||||
|
• Управление сотнями исполнителей одним сотрудником
|
||||||
|
• API интеграции для автоматизации процессов
|
||||||
|
• Автоматический сбор чеков и документов
|
||||||
|
• Снижение времени онбординга с 2 дней до ~20 минут
|
||||||
|
|
||||||
|
Платформа решает ключевые задачи бизнеса:
|
||||||
|
• Быстрое масштабирование команды исполнителей
|
||||||
|
• Автоматизация документооборота и выплат
|
||||||
|
• Снижение операционных затрат
|
||||||
|
• Обеспечение налогового соответствия
|
||||||
|
• Упрощение работы с подрядчиками
|
||||||
|
|
||||||
|
Внедрение платформы занимает около 1 дня.
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"title": "Платформа Консоль.Про - Обзор",
|
||||||
|
"doc_type": "platform_overview",
|
||||||
|
"source": "internal",
|
||||||
|
"industry": ["generic"],
|
||||||
|
"roles_relevant": ["tech", "finance", "ops", "ceo"],
|
||||||
|
"metrics": {
|
||||||
|
"onboarding_minutes": 15,
|
||||||
|
"onboarding_days_before": 2,
|
||||||
|
"onboarding_minutes_after": 20,
|
||||||
|
"error_reduction_pct": 95,
|
||||||
|
"document_collection_pct": 100,
|
||||||
|
"implementation_days": 1,
|
||||||
|
},
|
||||||
|
"language": "ru",
|
||||||
|
"created_at": "2024-01-01T00:00:00",
|
||||||
|
"updated_at": "2024-01-01T00:00:00",
|
||||||
|
}
|
||||||
|
|
||||||
|
return Document(
|
||||||
|
content=content.strip(),
|
||||||
|
metadata=metadata,
|
||||||
|
doc_id="platform_overview",
|
||||||
|
source="platform_overview.md",
|
||||||
|
)
|
|
@ -0,0 +1,20 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
from src.models.email import LLMRawOutput
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def generate_completion(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
**kwargs
|
||||||
|
) -> LLMRawOutput:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def count_tokens(self, text: str) -> int:
|
||||||
|
pass
|
|
@ -0,0 +1,56 @@
|
||||||
|
import google.generativeai as genai
|
||||||
|
import tiktoken
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from src.llm_clients.base import LLMClient
|
||||||
|
from src.models.email import LLMRawOutput
|
||||||
|
from src.models.errors import LLMError
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiClient(LLMClient):
|
||||||
|
def __init__(self):
|
||||||
|
if not settings.gemini_api_key:
|
||||||
|
raise ValueError("Gemini API key is required")
|
||||||
|
|
||||||
|
genai.configure(api_key=settings.gemini_api_key)
|
||||||
|
self.model = genai.GenerativeModel("gemini-pro")
|
||||||
|
self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
def generate_completion(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
**kwargs,
|
||||||
|
) -> LLMRawOutput:
|
||||||
|
try:
|
||||||
|
prompt = f"Системная инструкция: {system_prompt}\n\nЗапрос пользователя: {user_prompt}"
|
||||||
|
|
||||||
|
generation_config = genai.types.GenerationConfig(
|
||||||
|
max_output_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.model.generate_content(
|
||||||
|
prompt, generation_config=generation_config
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.text
|
||||||
|
|
||||||
|
prompt_tokens = self.count_tokens(prompt)
|
||||||
|
completion_tokens = self.count_tokens(content)
|
||||||
|
|
||||||
|
return LLMRawOutput(
|
||||||
|
content=content,
|
||||||
|
tokens_prompt=prompt_tokens,
|
||||||
|
tokens_completion=completion_tokens,
|
||||||
|
model="gemini-pro",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise LLMError(f"Gemini generation error: {str(e)}", "gemini", "gemini-pro")
|
||||||
|
|
||||||
|
def count_tokens(self, text: str) -> int:
|
||||||
|
return len(self.encoding.encode(text))
|
|
@ -0,0 +1,55 @@
|
||||||
|
import openai
|
||||||
|
import tiktoken
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from src.llm_clients.base import LLMClient
|
||||||
|
from src.models.email import LLMRawOutput
|
||||||
|
from src.models.errors import LLMError
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIClient(LLMClient):
|
||||||
|
def __init__(self):
|
||||||
|
if not settings.openai_api_key:
|
||||||
|
raise ValueError("OpenAI API key is required")
|
||||||
|
|
||||||
|
self.client = openai.OpenAI(api_key=settings.openai_api_key)
|
||||||
|
self.model = settings.llm_model
|
||||||
|
self.encoding = tiktoken.encoding_for_model(self.model)
|
||||||
|
|
||||||
|
def generate_completion(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
**kwargs,
|
||||||
|
) -> LLMRawOutput:
|
||||||
|
try:
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
return LLMRawOutput(
|
||||||
|
content=content,
|
||||||
|
tokens_prompt=response.usage.prompt_tokens,
|
||||||
|
tokens_completion=response.usage.completion_tokens,
|
||||||
|
model=self.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise LLMError(f"OpenAI generation error: {str(e)}", "openai", self.model)
|
||||||
|
|
||||||
|
def count_tokens(self, text: str) -> int:
|
||||||
|
return len(self.encoding.encode(text))
|
|
@ -0,0 +1,63 @@
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class DocChunk(BaseModel):
|
||||||
|
chunk_id: str
|
||||||
|
parent_doc_id: str
|
||||||
|
content_text: str
|
||||||
|
token_count: int
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
similarity_score: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalQuery(BaseModel):
|
||||||
|
text_query: str
|
||||||
|
vector_query: Optional[List[float]] = None
|
||||||
|
metadata_filters: Optional[Dict[str, Any]] = None
|
||||||
|
k: int = 30
|
||||||
|
|
||||||
|
|
||||||
|
class RankedContext(BaseModel):
|
||||||
|
chunks: List[DocChunk]
|
||||||
|
total_tokens: int
|
||||||
|
summary_bullets: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class PromptPayload(BaseModel):
|
||||||
|
system_prompt: str
|
||||||
|
user_prompt: str
|
||||||
|
max_tokens: int
|
||||||
|
temperature: float = 0.7
|
||||||
|
|
||||||
|
|
||||||
|
class LLMRawOutput(BaseModel):
|
||||||
|
content: str
|
||||||
|
tokens_prompt: int
|
||||||
|
tokens_completion: int
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
|
class EmailDraft(BaseModel):
|
||||||
|
subject: str = Field(max_length=80)
|
||||||
|
body: str = Field(max_length=2000)
|
||||||
|
short_reasoning: Optional[str] = None
|
||||||
|
used_chunks: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
class EmailDraftClean(BaseModel):
|
||||||
|
subject: str
|
||||||
|
body: str
|
||||||
|
meta: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class EmailResponse(BaseModel):
|
||||||
|
subject: str
|
||||||
|
body: str
|
||||||
|
meta: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
error: str
|
||||||
|
code: str
|
||||||
|
details: Optional[Dict[str, Any]] = None
|
|
@ -0,0 +1,53 @@
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(Exception):
|
||||||
|
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||||
|
self.message = message
|
||||||
|
self.details = details or {}
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalError(Exception):
|
||||||
|
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||||
|
self.message = message
|
||||||
|
self.details = details or {}
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMError(Exception):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
provider: str,
|
||||||
|
model: str,
|
||||||
|
details: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
self.provider = provider
|
||||||
|
self.model = model
|
||||||
|
self.details = details or {}
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class ParseError(Exception):
|
||||||
|
def __init__(
|
||||||
|
self, message: str, raw_output: str, details: Optional[Dict[str, Any]] = None
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
self.raw_output = raw_output
|
||||||
|
self.details = details or {}
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class GuardrailsError(Exception):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
violation_type: str,
|
||||||
|
details: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
self.violation_type = violation_type
|
||||||
|
self.details = details or {}
|
||||||
|
super().__init__(self.message)
|
|
@ -0,0 +1,105 @@
|
||||||
|
from typing import Literal, Optional
|
||||||
|
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||||
|
|
||||||
|
|
||||||
|
class LeadInput(BaseModel):
|
||||||
|
contact: str = Field(min_length=1, max_length=100)
|
||||||
|
position: str = Field(min_length=1, max_length=100)
|
||||||
|
company_name: str = Field(min_length=1, max_length=100)
|
||||||
|
segment: str = Field(min_length=1, max_length=100)
|
||||||
|
email: Optional[EmailStr] = None
|
||||||
|
locale: Literal["ru", "en"] = "ru"
|
||||||
|
notes: Optional[str] = Field(default=None, max_length=500)
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
|
||||||
|
class LeadModel(BaseModel):
|
||||||
|
contact_name: str
|
||||||
|
contact_first_name: Optional[str] = None
|
||||||
|
contact_last_name: Optional[str] = None
|
||||||
|
role_title: str
|
||||||
|
role_category: Optional[
|
||||||
|
Literal["tech", "finance", "ops", "hr", "ceo", "sales", "marketing", "other"]
|
||||||
|
] = None
|
||||||
|
company_name: str
|
||||||
|
industry_segment: Optional[str] = None
|
||||||
|
industry_tag: Optional[str] = None
|
||||||
|
email: Optional[EmailStr] = None
|
||||||
|
locale: Literal["ru", "en"] = "ru"
|
||||||
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_input(cls, lead_input: LeadInput) -> "LeadModel":
|
||||||
|
contact_parts = lead_input.contact.strip().split()
|
||||||
|
first_name = contact_parts[0] if contact_parts else ""
|
||||||
|
last_name = " ".join(contact_parts[1:]) if len(contact_parts) > 1 else None
|
||||||
|
|
||||||
|
role_category = cls._extract_role_category(lead_input.position)
|
||||||
|
industry_tag = cls._extract_industry_tag(lead_input.segment)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
contact_name=lead_input.contact,
|
||||||
|
contact_first_name=first_name,
|
||||||
|
contact_last_name=last_name,
|
||||||
|
role_title=lead_input.position,
|
||||||
|
role_category=role_category,
|
||||||
|
company_name=lead_input.company_name,
|
||||||
|
industry_segment=lead_input.segment,
|
||||||
|
industry_tag=industry_tag,
|
||||||
|
email=lead_input.email,
|
||||||
|
locale=lead_input.locale,
|
||||||
|
notes=lead_input.notes,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_role_category(role_title: str) -> Optional[str]:
|
||||||
|
role_lower = role_title.lower()
|
||||||
|
|
||||||
|
if any(
|
||||||
|
word in role_lower for word in ["техн", "cto", "it", "разраб", "программ"]
|
||||||
|
):
|
||||||
|
return "tech"
|
||||||
|
elif any(word in role_lower for word in ["фин", "бухг", "казнач", "cfo"]):
|
||||||
|
return "finance"
|
||||||
|
elif any(word in role_lower for word in ["операц", "coo", "управл"]):
|
||||||
|
return "ops"
|
||||||
|
elif any(word in role_lower for word in ["hr", "кадр", "персонал"]):
|
||||||
|
return "hr"
|
||||||
|
elif any(
|
||||||
|
word in role_lower for word in ["гендир", "ceo", "директор", "руковод"]
|
||||||
|
):
|
||||||
|
return "ceo"
|
||||||
|
elif any(word in role_lower for word in ["продаж", "коммерч", "sales"]):
|
||||||
|
return "sales"
|
||||||
|
elif any(word in role_lower for word in ["маркет", "реклам", "pr"]):
|
||||||
|
return "marketing"
|
||||||
|
else:
|
||||||
|
return "other"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_industry_tag(segment: str) -> Optional[str]:
|
||||||
|
segment_lower = segment.lower()
|
||||||
|
|
||||||
|
if any(word in segment_lower for word in ["маркет", "реклам", "агентс"]):
|
||||||
|
return "marketing_agency"
|
||||||
|
elif any(word in segment_lower for word in ["логист", "достав", "склад"]):
|
||||||
|
return "logistics"
|
||||||
|
elif any(word in segment_lower for word in ["разраб", "софт", "it", "програм"]):
|
||||||
|
return "software"
|
||||||
|
elif any(word in segment_lower for word in ["рознич", "торгов", "магазин"]):
|
||||||
|
return "retail"
|
||||||
|
elif any(word in segment_lower for word in ["консалт", "консульт"]):
|
||||||
|
return "consulting"
|
||||||
|
elif any(word in segment_lower for word in ["строит", "недвиж"]):
|
||||||
|
return "construction"
|
||||||
|
else:
|
||||||
|
return "other"
|
||||||
|
|
||||||
|
|
||||||
|
class LeadFeatures(BaseModel):
|
||||||
|
role_category: str
|
||||||
|
industry_tag: str
|
||||||
|
pain_points: list[str]
|
||||||
|
key_benefits: list[str]
|
||||||
|
search_keywords: list[str]
|
|
@ -0,0 +1,169 @@
|
||||||
|
import os
|
||||||
|
from typing import List, Dict, Any, Optional, Set
|
||||||
|
import chromadb
|
||||||
|
from chromadb.config import Settings as ChromaSettings
|
||||||
|
from chromadb.api.models.Collection import Collection
|
||||||
|
|
||||||
|
from src.models.email import DocChunk
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaStore:
|
||||||
|
def __init__(self, persist_directory: str = None):
|
||||||
|
self.persist_directory = persist_directory or settings.chroma_persist_dir
|
||||||
|
os.makedirs(self.persist_directory, exist_ok=True)
|
||||||
|
|
||||||
|
self.client = chromadb.PersistentClient(
|
||||||
|
path=self.persist_directory,
|
||||||
|
settings=ChromaSettings(anonymized_telemetry=False),
|
||||||
|
)
|
||||||
|
self.collection_name = "knowledge_base"
|
||||||
|
self.collection: Optional[Collection] = None
|
||||||
|
|
||||||
|
def get_or_create_collection(self) -> Collection:
|
||||||
|
if self.collection is None:
|
||||||
|
self.collection = self.client.get_or_create_collection(
|
||||||
|
name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
||||||
|
)
|
||||||
|
return self.collection
|
||||||
|
|
||||||
|
def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
sanitized[key] = ", ".join(str(item) for item in value)
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
import json
|
||||||
|
|
||||||
|
sanitized[key] = json.dumps(value, ensure_ascii=False)
|
||||||
|
elif isinstance(value, (str, int, float, bool)) or value is None:
|
||||||
|
sanitized[key] = value
|
||||||
|
else:
|
||||||
|
sanitized[key] = str(value)
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
def get_existing_doc_ids(self) -> Set[str]:
|
||||||
|
collection = self.get_or_create_collection()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = collection.get(include=["metadatas"])
|
||||||
|
if result and result.get("metadatas"):
|
||||||
|
doc_ids = set()
|
||||||
|
for metadata in result["metadatas"]:
|
||||||
|
if metadata and "source" in metadata:
|
||||||
|
source = metadata["source"]
|
||||||
|
filename = os.path.basename(source)
|
||||||
|
doc_id = (
|
||||||
|
os.path.splitext(filename)[0]
|
||||||
|
.replace(" ", "_")
|
||||||
|
.replace("-", "_")
|
||||||
|
.lower()
|
||||||
|
)
|
||||||
|
doc_ids.add(doc_id)
|
||||||
|
return doc_ids
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return set()
|
||||||
|
|
||||||
|
def document_exists(self, doc_id: str) -> bool:
|
||||||
|
collection = self.get_or_create_collection()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = collection.get(where={"parent_doc_id": doc_id}, limit=1)
|
||||||
|
return bool(result and result.get("ids"))
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_document_chunks(self, doc_id: str):
|
||||||
|
collection = self.get_or_create_collection()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = collection.get(
|
||||||
|
where={"parent_doc_id": doc_id}, include=["metadatas"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if result and result.get("ids"):
|
||||||
|
chunk_ids = result["ids"]
|
||||||
|
collection.delete(ids=chunk_ids)
|
||||||
|
print(f"Deleted {len(chunk_ids)} chunks for document {doc_id}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting chunks for {doc_id}: {e}")
|
||||||
|
|
||||||
|
def add_chunks(self, chunks: List[DocChunk], embeddings: List[List[float]]):
|
||||||
|
collection = self.get_or_create_collection()
|
||||||
|
|
||||||
|
ids = [chunk.chunk_id for chunk in chunks]
|
||||||
|
documents = [chunk.content_text for chunk in chunks]
|
||||||
|
metadatas = []
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
metadata = chunk.metadata.copy()
|
||||||
|
metadata.update(
|
||||||
|
{"parent_doc_id": chunk.parent_doc_id, "token_count": chunk.token_count}
|
||||||
|
)
|
||||||
|
sanitized_metadata = self._sanitize_metadata(metadata)
|
||||||
|
metadatas.append(sanitized_metadata)
|
||||||
|
|
||||||
|
collection.add(
|
||||||
|
ids=ids, embeddings=embeddings, documents=documents, metadatas=metadatas
|
||||||
|
)
|
||||||
|
|
||||||
|
def search_similar(
|
||||||
|
self,
|
||||||
|
query_embedding: List[float],
|
||||||
|
k: int = 30,
|
||||||
|
where: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> List[DocChunk]:
|
||||||
|
collection = self.get_or_create_collection()
|
||||||
|
|
||||||
|
results = collection.query(
|
||||||
|
query_embeddings=[query_embedding],
|
||||||
|
n_results=k,
|
||||||
|
where=where,
|
||||||
|
include=["documents", "metadatas", "distances"],
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
if results["ids"] and results["ids"][0]:
|
||||||
|
for i, chunk_id in enumerate(results["ids"][0]):
|
||||||
|
similarity_score = (
|
||||||
|
1.0 - results["distances"][0][i] if results["distances"] else None
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = results["metadatas"][0][i] if results["metadatas"] else {}
|
||||||
|
parent_doc_id = metadata.pop("parent_doc_id", "")
|
||||||
|
token_count = metadata.pop("token_count", 0)
|
||||||
|
|
||||||
|
chunk = DocChunk(
|
||||||
|
chunk_id=chunk_id,
|
||||||
|
parent_doc_id=parent_doc_id,
|
||||||
|
content_text=(
|
||||||
|
results["documents"][0][i] if results["documents"] else ""
|
||||||
|
),
|
||||||
|
token_count=token_count,
|
||||||
|
metadata=metadata,
|
||||||
|
similarity_score=similarity_score,
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def get_count(self) -> int:
|
||||||
|
collection = self.get_or_create_collection()
|
||||||
|
return collection.count()
|
||||||
|
|
||||||
|
def delete_collection(self):
|
||||||
|
try:
|
||||||
|
self.client.delete_collection(self.collection_name)
|
||||||
|
self.collection = None
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def recreate_collection(self):
|
||||||
|
self.delete_collection()
|
||||||
|
self.collection = self.client.create_collection(
|
||||||
|
name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
||||||
|
)
|
|
@ -0,0 +1,88 @@
|
||||||
|
from typing import List, Union
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import openai
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
from src.app.config import settings
|
||||||
|
from src.models.errors import LLMError
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProvider(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def embed_text(self, text: str) -> List[float]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||||
|
def __init__(self):
|
||||||
|
if not settings.openai_api_key:
|
||||||
|
raise ValueError("OpenAI API key is required")
|
||||||
|
openai.api_key = settings.openai_api_key
|
||||||
|
self.client = openai.OpenAI(api_key=settings.openai_api_key)
|
||||||
|
self.model = settings.embedding_model
|
||||||
|
|
||||||
|
def embed_text(self, text: str) -> List[float]:
|
||||||
|
try:
|
||||||
|
response = self.client.embeddings.create(model=self.model, input=text)
|
||||||
|
return response.data[0].embedding
|
||||||
|
except Exception as e:
|
||||||
|
raise LLMError(f"OpenAI embedding error: {str(e)}", "openai", self.model)
|
||||||
|
|
||||||
|
def embed_batch(self, texts: List[str], batch_size: int = 64) -> List[List[float]]:
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i : i + batch_size]
|
||||||
|
try:
|
||||||
|
response = self.client.embeddings.create(model=self.model, input=batch)
|
||||||
|
batch_embeddings = [data.embedding for data in response.data]
|
||||||
|
embeddings.extend(batch_embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
raise LLMError(
|
||||||
|
f"OpenAI batch embedding error: {str(e)}", "openai", self.model
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||||
|
def __init__(self):
|
||||||
|
if not settings.gemini_api_key:
|
||||||
|
raise ValueError("Gemini API key is required")
|
||||||
|
genai.configure(api_key=settings.gemini_api_key)
|
||||||
|
self.model = "models/embedding-001"
|
||||||
|
|
||||||
|
def embed_text(self, text: str) -> List[float]:
|
||||||
|
try:
|
||||||
|
result = genai.embed_content(
|
||||||
|
model=self.model, content=text, task_type="retrieval_document"
|
||||||
|
)
|
||||||
|
return result["embedding"]
|
||||||
|
except Exception as e:
|
||||||
|
raise LLMError(f"Gemini embedding error: {str(e)}", "gemini", self.model)
|
||||||
|
|
||||||
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
embeddings = []
|
||||||
|
for text in texts:
|
||||||
|
embeddings.append(self.embed_text(text))
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingService:
|
||||||
|
def __init__(self):
|
||||||
|
if settings.llm_provider == "openai":
|
||||||
|
self.provider = OpenAIEmbeddingProvider()
|
||||||
|
elif settings.llm_provider == "gemini":
|
||||||
|
self.provider = GeminiEmbeddingProvider()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported embedding provider: {settings.llm_provider}")
|
||||||
|
|
||||||
|
def embed_text(self, text: str) -> List[float]:
|
||||||
|
return self.provider.embed_text(text)
|
||||||
|
|
||||||
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
return self.provider.embed_batch(texts)
|
|
@ -0,0 +1,113 @@
|
||||||
|
from typing import Dict, Any
|
||||||
|
from src.models.lead import LeadModel
|
||||||
|
from src.models.email import RankedContext
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """Вы — AI-ассистент отдела продаж платформы «Консоль.Про». Ваша задача: на основе профиля лида и предоставленного контекста сформировать одно персонализированное холодное письмо (первое касание) на языке лида. Вы пишете кратко, уважительно, без давления. Вы показываете ценность через конкретные результаты клиентов и цифры. Вы не придумываете факты; используете только контекст. Если нет данных — говорите общими преимуществами платформы.
|
||||||
|
Формат ответа строго JSON: {"subject": str, "body": str, "short_reasoning": str, "used_chunks": [ids...]}. Без тройных кавычек, без Markdown."""
|
||||||
|
|
||||||
|
|
||||||
|
USER_PROMPT_TEMPLATE = """[ПРОФИЛЬ ЛИДА]
|
||||||
|
Имя: {contact_first_name}
|
||||||
|
Должность: {role_title} (категория: {role_category})
|
||||||
|
Компания: {company_name}
|
||||||
|
Индустрия: {industry_tag}
|
||||||
|
|
||||||
|
[КОНТЕКСТ ИЗ БАЗЫ ЗНАНИЙ]
|
||||||
|
{context_bullets}
|
||||||
|
|
||||||
|
[СТИЛЬ ПИСЬМА]
|
||||||
|
- Язык: {locale}
|
||||||
|
- Тон: деловой, лаконичный, дружелюбный.
|
||||||
|
- Письмо до ~1600 символов.
|
||||||
|
- 1 CTA: предложить 15-мин звонок/демо. Не навязчиво.
|
||||||
|
- Можно упомянуть, что внедрение платформы занимает ~1 день (если релевантно).
|
||||||
|
|
||||||
|
[ШАБЛОН СТРУКТУРЫ ТЕКСТА]
|
||||||
|
1. Приветствие по имени.
|
||||||
|
2. Короткий хук, отсылающий к индустрии/процессам с исполнителями.
|
||||||
|
3. Ценность Консоль.Про + 1–3 релевантные выгоды (цифры, если в контексте).
|
||||||
|
4. Мини-соцдоказательство (кейс, факт).
|
||||||
|
5. CTA.
|
||||||
|
6. Подпись: Имя менеджера ({sales_rep_name}) | Консоль.Про.
|
||||||
|
|
||||||
|
Напомню: ответ только JSON."""
|
||||||
|
|
||||||
|
|
||||||
|
class PromptBuilder:
|
||||||
|
def __init__(self):
|
||||||
|
self.system_prompt = SYSTEM_PROMPT
|
||||||
|
self.user_template = USER_PROMPT_TEMPLATE
|
||||||
|
|
||||||
|
def build_prompt(
|
||||||
|
self, lead: LeadModel, context: RankedContext, sales_rep_name: str = None
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
sales_rep = sales_rep_name or settings.sales_rep_name
|
||||||
|
|
||||||
|
context_bullets_text = (
|
||||||
|
"\n".join(context.summary_bullets)
|
||||||
|
if context.summary_bullets
|
||||||
|
else "Общая информация о платформе Консоль.Про для автоматизации работы с самозанятыми."
|
||||||
|
)
|
||||||
|
|
||||||
|
user_prompt = self.user_template.format(
|
||||||
|
contact_first_name=lead.contact_first_name or lead.contact_name,
|
||||||
|
role_title=lead.role_title,
|
||||||
|
role_category=lead.role_category or "специалист",
|
||||||
|
company_name=lead.company_name,
|
||||||
|
industry_tag=lead.industry_tag or lead.industry_segment,
|
||||||
|
context_bullets=context_bullets_text,
|
||||||
|
sales_rep_name=sales_rep,
|
||||||
|
locale=lead.locale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.system_prompt, user_prompt
|
||||||
|
|
||||||
|
def get_greeting(self, lead: LeadModel) -> str:
|
||||||
|
if lead.contact_first_name:
|
||||||
|
return f"{lead.contact_first_name}, добрый день!"
|
||||||
|
else:
|
||||||
|
return "Коллеги, добрый день!"
|
||||||
|
|
||||||
|
def get_industry_hook(self, industry_tag: str) -> str:
|
||||||
|
hooks = {
|
||||||
|
"marketing_agency": "В маркетинговых агентствах часто сложно быстро подключать десятки подрядчиков для проектов",
|
||||||
|
"logistics": "В логистических компаниях управление множественными исполнителями требует особого внимания к документообороту",
|
||||||
|
"software": "В IT-компаниях работа с фрилансерами и ИП — важная часть команды разработки",
|
||||||
|
"retail": "В розничной торговле сезонные пики требуют быстрого масштабирования команды исполнителей",
|
||||||
|
"consulting": "В консалтинговых компаниях привлечение экспертов под проекты — критичный процесс",
|
||||||
|
"construction": "В строительстве координация подрядчиков и документооборот — ключевые вызовы",
|
||||||
|
"other": "При работе с внешними исполнителями всегда актуальны вопросы автоматизации процессов",
|
||||||
|
}
|
||||||
|
return hooks.get(industry_tag, hooks["other"])
|
||||||
|
|
||||||
|
def get_key_benefits(self, industry_tag: str) -> list[str]:
|
||||||
|
benefits_map = {
|
||||||
|
"marketing_agency": [
|
||||||
|
"Подключение нового исполнителя за ~15 минут",
|
||||||
|
"Мгновенные массовые выплаты",
|
||||||
|
"Сбор 100% закрывающих документов",
|
||||||
|
],
|
||||||
|
"logistics": [
|
||||||
|
"Снижение времени онбординга с 2 дней до ~20 минут",
|
||||||
|
"Автоматический сбор чеков и документов",
|
||||||
|
"Управление сотнями исполнителей одним сотрудником",
|
||||||
|
],
|
||||||
|
"software": [
|
||||||
|
"API интеграции для автоматизации процессов",
|
||||||
|
"Работа с ИП и самозанятыми в одной системе",
|
||||||
|
"Снижение ошибок до 95%",
|
||||||
|
],
|
||||||
|
"retail": [
|
||||||
|
"Массовые выплаты в течение минут",
|
||||||
|
"Быстрое подключение сезонного персонала",
|
||||||
|
"Автоматизация документооборота",
|
||||||
|
],
|
||||||
|
"other": [
|
||||||
|
"Подключение исполнителя за ~15 минут",
|
||||||
|
"Выплаты в течение минут vs часы",
|
||||||
|
"Снижение ошибок до 95%",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return benefits_map.get(industry_tag, benefits_map["other"])
|
|
@ -0,0 +1,170 @@
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
import tiktoken
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
|
||||||
|
from src.models.email import DocChunk, RetrievalQuery, RankedContext
|
||||||
|
from src.models.lead import LeadModel, LeadFeatures
|
||||||
|
from src.services.chroma_store import ChromaStore
|
||||||
|
from src.services.embeddings import EmbeddingService
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalService:
|
||||||
|
def __init__(self, chroma_store: ChromaStore, embedding_service: EmbeddingService):
|
||||||
|
self.chroma_store = chroma_store
|
||||||
|
self.embedding_service = embedding_service
|
||||||
|
self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
self._embedding_cache = {}
|
||||||
|
self._search_cache = {}
|
||||||
|
self._cache_ttl = 3600
|
||||||
|
|
||||||
|
def _get_query_hash(self, text_query: str) -> str:
|
||||||
|
return hashlib.md5(text_query.encode()).hexdigest()
|
||||||
|
|
||||||
|
def _get_cached_embedding(self, text_query: str) -> Optional[List[float]]:
|
||||||
|
query_hash = self._get_query_hash(text_query)
|
||||||
|
cached = self._embedding_cache.get(query_hash)
|
||||||
|
|
||||||
|
if cached and time.time() - cached["timestamp"] < self._cache_ttl:
|
||||||
|
return cached["embedding"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _cache_embedding(self, text_query: str, embedding: List[float]):
|
||||||
|
query_hash = self._get_query_hash(text_query)
|
||||||
|
self._embedding_cache[query_hash] = {
|
||||||
|
"embedding": embedding,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(self._embedding_cache) > 1000:
|
||||||
|
oldest_key = min(
|
||||||
|
self._embedding_cache.keys(),
|
||||||
|
key=lambda k: self._embedding_cache[k]["timestamp"],
|
||||||
|
)
|
||||||
|
del self._embedding_cache[oldest_key]
|
||||||
|
|
||||||
|
def build_retrieval_query(self, lead_features: LeadFeatures) -> RetrievalQuery:
|
||||||
|
search_terms = [
|
||||||
|
"самозанятые",
|
||||||
|
"автоматизация",
|
||||||
|
"исполнители",
|
||||||
|
lead_features.industry_tag,
|
||||||
|
lead_features.role_category,
|
||||||
|
]
|
||||||
|
|
||||||
|
pain_points_map = {
|
||||||
|
"marketing_agency": ["онбординг", "выплаты", "подрядчики", "отчетность"],
|
||||||
|
"logistics": ["документооборот", "исполнители", "чеки", "география"],
|
||||||
|
"software": ["фриланс", "ИП", "API", "интеграции"],
|
||||||
|
"retail": ["сезонность", "временные", "массовые", "выплаты"],
|
||||||
|
"consulting": ["проекты", "экспертиза", "временные", "специалисты"],
|
||||||
|
"construction": ["подрядчики", "документы", "безопасность", "сроки"],
|
||||||
|
"other": ["онбординг", "выплаты", "документооборот"],
|
||||||
|
}
|
||||||
|
|
||||||
|
industry_terms = pain_points_map.get(
|
||||||
|
lead_features.industry_tag, pain_points_map["other"]
|
||||||
|
)
|
||||||
|
search_terms.extend(industry_terms)
|
||||||
|
|
||||||
|
text_query = " ".join(search_terms)
|
||||||
|
|
||||||
|
metadata_filters = {}
|
||||||
|
if lead_features.industry_tag != "other":
|
||||||
|
metadata_filters["$or"] = [
|
||||||
|
{"industry": {"$contains": lead_features.industry_tag}},
|
||||||
|
{"roles_relevant": {"$contains": lead_features.role_category}},
|
||||||
|
]
|
||||||
|
|
||||||
|
return RetrievalQuery(
|
||||||
|
text_query=text_query,
|
||||||
|
metadata_filters=metadata_filters if metadata_filters else None,
|
||||||
|
k=settings.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
def search_relevant_chunks(self, query: RetrievalQuery) -> List[DocChunk]:
|
||||||
|
cached_embedding = self._get_cached_embedding(query.text_query)
|
||||||
|
|
||||||
|
if cached_embedding:
|
||||||
|
query_embedding = cached_embedding
|
||||||
|
else:
|
||||||
|
query_embedding = self.embedding_service.embed_text(query.text_query)
|
||||||
|
self._cache_embedding(query.text_query, query_embedding)
|
||||||
|
|
||||||
|
chunks = self.chroma_store.search_similar(
|
||||||
|
query_embedding=query_embedding, k=query.k, where=query.metadata_filters
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def rank_chunks(
|
||||||
|
self, chunks: List[DocChunk], lead_features: LeadFeatures
|
||||||
|
) -> List[DocChunk]:
|
||||||
|
for chunk in chunks:
|
||||||
|
score = chunk.similarity_score or 0.0
|
||||||
|
|
||||||
|
industry_boost = 0.0
|
||||||
|
if lead_features.industry_tag in chunk.metadata.get("industry", []):
|
||||||
|
industry_boost = 0.2
|
||||||
|
|
||||||
|
role_boost = 0.0
|
||||||
|
if lead_features.role_category in chunk.metadata.get("roles_relevant", []):
|
||||||
|
role_boost = 0.1
|
||||||
|
|
||||||
|
metrics_boost = 0.0
|
||||||
|
if chunk.metadata.get("metrics"):
|
||||||
|
metrics_boost = 0.1
|
||||||
|
|
||||||
|
final_score = 0.6 * score + industry_boost + role_boost + metrics_boost
|
||||||
|
chunk.similarity_score = final_score
|
||||||
|
|
||||||
|
chunks.sort(key=lambda x: x.similarity_score or 0.0, reverse=True)
|
||||||
|
|
||||||
|
seen_content = set()
|
||||||
|
unique_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
content_hash = hash(chunk.content_text[:100])
|
||||||
|
if content_hash not in seen_content:
|
||||||
|
seen_content.add(content_hash)
|
||||||
|
unique_chunks.append(chunk)
|
||||||
|
|
||||||
|
return unique_chunks[: settings.top_n_context]
|
||||||
|
|
||||||
|
def build_context(
|
||||||
|
self, chunks: List[DocChunk], max_tokens: int = 3000
|
||||||
|
) -> RankedContext:
|
||||||
|
selected_chunks = []
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_tokens = len(self.encoding.encode(chunk.content_text))
|
||||||
|
if total_tokens + chunk_tokens <= max_tokens:
|
||||||
|
selected_chunks.append(chunk)
|
||||||
|
total_tokens += chunk_tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
summary_bullets = []
|
||||||
|
for chunk in selected_chunks:
|
||||||
|
content = chunk.content_text.strip()
|
||||||
|
if len(content) > 200:
|
||||||
|
content = content[:200] + "..."
|
||||||
|
|
||||||
|
metrics_info = ""
|
||||||
|
if chunk.metadata.get("metrics"):
|
||||||
|
metrics = chunk.metadata["metrics"]
|
||||||
|
metrics_parts = []
|
||||||
|
for key, value in metrics.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
metrics_parts.append(f"{key}: {value}")
|
||||||
|
if metrics_parts:
|
||||||
|
metrics_info = f" ({', '.join(metrics_parts)})"
|
||||||
|
|
||||||
|
summary_bullets.append(f"• {content}{metrics_info}")
|
||||||
|
|
||||||
|
return RankedContext(
|
||||||
|
chunks=selected_chunks,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
summary_bullets=summary_bullets,
|
||||||
|
)
|
|
@ -0,0 +1,126 @@
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.append(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.app.main import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_endpoint():
|
||||||
|
response = client.get("/healthz")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "status" in data
|
||||||
|
assert data["status"] == "healthy"
|
||||||
|
|
||||||
|
|
||||||
|
def test_readiness_endpoint():
|
||||||
|
response = client.get("/readiness")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_documentation():
|
||||||
|
response = client.get("/docs")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_root_redirect():
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_email_missing_fields():
|
||||||
|
incomplete_payload = {"contact": ""}
|
||||||
|
|
||||||
|
response = client.post("/api/v1/generate_email", json=incomplete_payload)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_email_basic_validation():
|
||||||
|
valid_payload = {"contact": "Test Person"}
|
||||||
|
|
||||||
|
response = client.post("/api/v1/generate_email", json=valid_payload)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_email_success():
|
||||||
|
payload = {
|
||||||
|
"contact": "Помящий Никита",
|
||||||
|
"position": "Технический директор",
|
||||||
|
"company_name": "FIVE",
|
||||||
|
"segment": "маркетинговое агентство",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/api/v1/generate_email", json=payload)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
assert "subject" in data
|
||||||
|
assert "body" in data
|
||||||
|
assert "meta" in data
|
||||||
|
else:
|
||||||
|
print(f"Response: {response.status_code}, {response.text}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_email_with_email():
|
||||||
|
payload = {
|
||||||
|
"contact": "Помящий Никита",
|
||||||
|
"position": "Технический директор",
|
||||||
|
"company_name": "FIVE",
|
||||||
|
"segment": "маркетинговое агентство",
|
||||||
|
"email": "nikita@five.agency",
|
||||||
|
"locale": "ru",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/api/v1/generate_email", json=payload)
|
||||||
|
assert response.status_code in [200, 500]
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_status():
|
||||||
|
response = client.get("/api/v1/status")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "operational"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_json():
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/generate_email",
|
||||||
|
content="invalid json".encode("utf-8"),
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_admin_endpoints_without_auth():
|
||||||
|
response = client.get("/api/v1/admin/knowledge-base/stats")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
response = client.post("/api/v1/admin/ingest")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_admin_endpoints_with_auth():
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {settings.api_secret_key}"}
|
||||||
|
|
||||||
|
response = client.get("/api/v1/admin/knowledge-base/stats", headers=headers)
|
||||||
|
assert response.status_code in [200, 500]
|
||||||
|
|
||||||
|
response = client.post("/api/v1/admin/ingest", headers=headers)
|
||||||
|
assert response.status_code in [200, 500]
|
||||||
|
|
||||||
|
|
||||||
|
def test_admin_endpoints_with_wrong_auth():
|
||||||
|
headers = {"Authorization": "Bearer wrong_token"}
|
||||||
|
|
||||||
|
response = client.get("/api/v1/admin/knowledge-base/stats", headers=headers)
|
||||||
|
assert response.status_code == 403
|
|
@ -0,0 +1,225 @@
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
sys.path.append(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.ingest.loader import MarkdownLoader, Document, create_platform_overview
|
||||||
|
from src.ingest.chunker import DocumentChunker, ChunkConfig
|
||||||
|
from src.models.lead import LeadModel
|
||||||
|
from src.services.chroma_store import ChromaStore
|
||||||
|
from src.services.embeddings import EmbeddingService
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarkdownLoader:
|
||||||
|
def test_create_platform_overview(self):
|
||||||
|
doc = create_platform_overview()
|
||||||
|
|
||||||
|
assert doc.doc_id == "platform_overview"
|
||||||
|
assert doc.metadata["title"] == "Платформа Консоль.Про - Обзор"
|
||||||
|
assert "generic" in doc.metadata["industry"]
|
||||||
|
assert len(doc.content) > 100
|
||||||
|
|
||||||
|
def test_load_file(self):
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="w", suffix=".md", delete=False, encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write(
|
||||||
|
"""# Тестовый кейс
|
||||||
|
|
||||||
|
Компания тестирует автоматизацию с самозанятыми.
|
||||||
|
|
||||||
|
## Результаты
|
||||||
|
- Сокращение времени на 20 минут
|
||||||
|
- Снижение ошибок на 85%
|
||||||
|
|
||||||
|
Технический директор доволен результатами.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
temp_file = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
loader = MarkdownLoader()
|
||||||
|
doc = loader.load_file(temp_file)
|
||||||
|
|
||||||
|
assert doc.content is not None
|
||||||
|
assert len(doc.content) > 0
|
||||||
|
assert "Тестовый кейс" in doc.metadata["title"]
|
||||||
|
assert doc.metadata["doc_type"] == "info"
|
||||||
|
assert "industry" in doc.metadata
|
||||||
|
assert "roles_relevant" in doc.metadata
|
||||||
|
assert "tech" in doc.metadata["roles_relevant"]
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
|
||||||
|
def test_extract_metadata(self):
|
||||||
|
content = """# Кейс маркетингового агентства
|
||||||
|
|
||||||
|
Технический директор внедрил автоматизацию.
|
||||||
|
Снижение ошибок на 95%.
|
||||||
|
Онбординг теперь занимает 15 минут."""
|
||||||
|
|
||||||
|
loader = MarkdownLoader()
|
||||||
|
metadata = loader._extract_metadata(content, "test.md")
|
||||||
|
|
||||||
|
assert "marketing_agency" in metadata["industry"]
|
||||||
|
assert "tech" in metadata["roles_relevant"]
|
||||||
|
assert "error_reduction_pct" in metadata["metrics"]
|
||||||
|
assert metadata["metrics"]["error_reduction_pct"] == 95
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentChunker:
|
||||||
|
def test_chunk_document(self):
|
||||||
|
content = """# Заголовок документа
|
||||||
|
|
||||||
|
Вводный текст документа.
|
||||||
|
|
||||||
|
## Проблема
|
||||||
|
Описание проблемы клиента.
|
||||||
|
Это было сложно.
|
||||||
|
|
||||||
|
## Решение
|
||||||
|
Как мы решили проблему.
|
||||||
|
Внедрили автоматизацию.
|
||||||
|
|
||||||
|
## Результат
|
||||||
|
Получили отличные результаты.
|
||||||
|
Снижение на 95%."""
|
||||||
|
|
||||||
|
document = Document(
|
||||||
|
content=content,
|
||||||
|
metadata={"title": "Test", "industry": ["tech"], "roles_relevant": ["ceo"]},
|
||||||
|
doc_id="test_doc",
|
||||||
|
source="test.md",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunker = DocumentChunker()
|
||||||
|
chunks = chunker.chunk_document(document)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(chunk.chunk_id.startswith("test_doc#c") for chunk in chunks)
|
||||||
|
assert all(chunk.parent_doc_id == "test_doc" for chunk in chunks)
|
||||||
|
|
||||||
|
def test_split_by_semantic_blocks(self):
|
||||||
|
content = """# Заголовок 1
|
||||||
|
Контент секции 1
|
||||||
|
|
||||||
|
## Заголовок 2
|
||||||
|
Контент секции 2
|
||||||
|
|
||||||
|
### Заголовок 3
|
||||||
|
Контент секции 3"""
|
||||||
|
|
||||||
|
chunker = DocumentChunker()
|
||||||
|
sections = chunker._split_by_semantic_blocks(content)
|
||||||
|
|
||||||
|
assert len(sections) >= 2
|
||||||
|
assert any("Заголовок" in section[0] for section in sections)
|
||||||
|
|
||||||
|
def test_classify_section(self):
|
||||||
|
chunker = DocumentChunker()
|
||||||
|
|
||||||
|
assert chunker._classify_section("проблема клиента") == "Проблема"
|
||||||
|
assert chunker._classify_section("как мы решили") == "Решение"
|
||||||
|
assert chunker._classify_section("результаты внедрения") == "Результат"
|
||||||
|
assert chunker._classify_section("о компании") == "О клиенте"
|
||||||
|
|
||||||
|
def test_chunk_config(self):
|
||||||
|
config = ChunkConfig(chunk_size=300, chunk_overlap=50)
|
||||||
|
chunker = DocumentChunker(config)
|
||||||
|
|
||||||
|
assert chunker.config.chunk_size == 300
|
||||||
|
assert chunker.config.chunk_overlap == 50
|
||||||
|
|
||||||
|
|
||||||
|
class TestLeadModelExtraction:
|
||||||
|
def test_role_category_extraction(self):
|
||||||
|
from src.models.lead import LeadModel
|
||||||
|
|
||||||
|
assert LeadModel._extract_role_category("Технический директор") == "tech"
|
||||||
|
assert LeadModel._extract_role_category("Финансовый директор") == "finance"
|
||||||
|
assert LeadModel._extract_role_category("Генеральный директор") == "ceo"
|
||||||
|
assert LeadModel._extract_role_category("HR менеджер") == "hr"
|
||||||
|
|
||||||
|
def test_industry_tag_extraction(self):
|
||||||
|
from src.models.lead import LeadModel
|
||||||
|
|
||||||
|
assert (
|
||||||
|
LeadModel._extract_industry_tag("маркетинговое агентство")
|
||||||
|
== "marketing_agency"
|
||||||
|
)
|
||||||
|
assert LeadModel._extract_industry_tag("IT компания") == "software"
|
||||||
|
assert (
|
||||||
|
LeadModel._extract_industry_tag("строительная компания") == "construction"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegration:
|
||||||
|
def test_full_ingest_flow(self):
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="w", suffix=".md", delete=False, encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write(
|
||||||
|
"""# Тестовый кейс компании
|
||||||
|
|
||||||
|
Компания внедрила автоматизацию.
|
||||||
|
|
||||||
|
## Результаты
|
||||||
|
Сокращение времени с 2 дней до 30 минут.
|
||||||
|
Снижение ошибок на 90%.
|
||||||
|
|
||||||
|
Технический директор доволен."""
|
||||||
|
)
|
||||||
|
temp_file = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
loader = MarkdownLoader()
|
||||||
|
doc = loader.load_file(temp_file)
|
||||||
|
|
||||||
|
chunker = DocumentChunker()
|
||||||
|
chunks = chunker.chunk_document(doc)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(chunk.token_count > 0 for chunk in chunks)
|
||||||
|
assert all(chunk.content_text.strip() for chunk in chunks)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChromaStore:
|
||||||
|
@patch("chromadb.PersistentClient")
|
||||||
|
def test_init(self, mock_client):
|
||||||
|
store = ChromaStore()
|
||||||
|
assert store.collection_name == "knowledge_base"
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
|
||||||
|
@patch("chromadb.PersistentClient")
|
||||||
|
def test_get_existing_doc_ids(self, mock_client):
|
||||||
|
mock_collection = Mock()
|
||||||
|
mock_collection.get.return_value = {
|
||||||
|
"metadatas": [{"source": "test1.md"}, {"source": "test2.md"}]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_client.return_value.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
store = ChromaStore()
|
||||||
|
doc_ids = store.get_existing_doc_ids()
|
||||||
|
|
||||||
|
assert isinstance(doc_ids, set)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingService:
|
||||||
|
@patch("src.services.embeddings.OpenAIEmbeddingProvider")
|
||||||
|
def test_embedding_service_init(self, mock_provider):
|
||||||
|
with patch("src.app.config.settings") as mock_settings:
|
||||||
|
mock_settings.llm_provider = "openai"
|
||||||
|
|
||||||
|
service = EmbeddingService()
|
||||||
|
mock_provider.assert_called_once()
|
Loading…
Reference in New Issue