ai-email-assistant/src/services/retrieval.py

169 lines
6.2 KiB
Python

from typing import List, Optional
import tiktoken
import hashlib
import time
from src.models.email import DocChunk, RetrievalQuery, RankedContext
from src.models.lead import LeadFeatures
from src.services.chroma_store import ChromaStore
from src.services.embeddings import EmbeddingService
from src.app.config import settings
class RetrievalService:
def __init__(self, chroma_store: ChromaStore, embedding_service: EmbeddingService):
self.chroma_store = chroma_store
self.embedding_service = embedding_service
self.encoding = tiktoken.get_encoding("cl100k_base")
self._embedding_cache = {}
self._search_cache = {}
self._cache_ttl = 3600
def _get_query_hash(self, text_query: str) -> str:
return hashlib.md5(text_query.encode()).hexdigest()
def _get_cached_embedding(self, text_query: str) -> Optional[List[float]]:
query_hash = self._get_query_hash(text_query)
cached = self._embedding_cache.get(query_hash)
if cached and time.time() - cached["timestamp"] < self._cache_ttl:
return cached["embedding"]
return None
def _cache_embedding(self, text_query: str, embedding: List[float]):
query_hash = self._get_query_hash(text_query)
self._embedding_cache[query_hash] = {
"embedding": embedding,
"timestamp": time.time(),
}
if len(self._embedding_cache) > 1000:
oldest_key = min(
self._embedding_cache.keys(),
key=lambda k: self._embedding_cache[k]["timestamp"],
)
del self._embedding_cache[oldest_key]
def build_retrieval_query(self, lead_features: LeadFeatures) -> RetrievalQuery:
search_terms = [
"самозанятые",
"автоматизация",
"исполнители",
lead_features.industry_tag,
lead_features.role_category,
]
pain_points_map = {
"marketing_agency": ["онбординг", "выплаты", "подрядчики", "отчетность"],
"logistics": ["документооборот", "исполнители", "чеки", "география"],
"software": ["фриланс", "ИП", "API", "интеграции"],
"retail": ["сезонность", "временные", "массовые", "выплаты"],
"consulting": ["проекты", "экспертиза", "временные", "специалисты"],
"construction": ["подрядчики", "документы", "безопасность", "сроки"],
"other": ["онбординг", "выплаты", "документооборот"],
}
industry_terms = pain_points_map.get(
lead_features.industry_tag, pain_points_map["other"]
)
search_terms.extend(industry_terms)
text_query = " ".join(search_terms)
metadata_filters = None
return RetrievalQuery(
text_query=text_query,
metadata_filters=metadata_filters if metadata_filters else None,
k=settings.top_k,
)
def search_relevant_chunks(self, query: RetrievalQuery) -> List[DocChunk]:
cached_embedding = self._get_cached_embedding(query.text_query)
if cached_embedding:
query_embedding = cached_embedding
else:
query_embedding = self.embedding_service.embed_text(query.text_query)
self._cache_embedding(query.text_query, query_embedding)
chunks = self.chroma_store.search_similar(
query_embedding=query_embedding, k=query.k, where=query.metadata_filters
)
return chunks
def rank_chunks(
self, chunks: List[DocChunk], lead_features: LeadFeatures
) -> List[DocChunk]:
for chunk in chunks:
score = chunk.similarity_score or 0.0
industry_boost = 0.0
if lead_features.industry_tag in chunk.metadata.get("industry", []):
industry_boost = 0.2
role_boost = 0.0
if lead_features.role_category in chunk.metadata.get("roles_relevant", []):
role_boost = 0.1
metrics_boost = 0.0
if chunk.metadata.get("metrics"):
metrics_boost = 0.1
final_score = 0.6 * score + industry_boost + role_boost + metrics_boost
chunk.similarity_score = final_score
chunks.sort(key=lambda x: x.similarity_score or 0.0, reverse=True)
seen_content = set()
unique_chunks = []
for chunk in chunks:
content_hash = hash(chunk.content_text[:100])
if content_hash not in seen_content:
seen_content.add(content_hash)
unique_chunks.append(chunk)
return unique_chunks[: settings.top_n_context]
def build_context(
self, chunks: List[DocChunk], max_tokens: int = 3000
) -> RankedContext:
selected_chunks = []
total_tokens = 0
for chunk in chunks:
chunk_tokens = len(self.encoding.encode(chunk.content_text))
if total_tokens + chunk_tokens <= max_tokens:
selected_chunks.append(chunk)
total_tokens += chunk_tokens
else:
break
summary_bullets = []
for chunk in selected_chunks:
content = chunk.content_text.strip()
if len(content) > 200:
content = content[:200] + "..."
metrics_info = ""
if chunk.metadata.get("metrics"):
metrics = chunk.metadata["metrics"]
metrics_parts = []
if isinstance(metrics, dict):
for key, value in metrics.items():
if isinstance(value, (int, float)):
metrics_parts.append(f"{key}: {value}")
elif isinstance(metrics, str):
metrics_parts.append(metrics)
if metrics_parts:
metrics_info = f" ({', '.join(metrics_parts)})"
summary_bullets.append(f"{content}{metrics_info}")
return RankedContext(
chunks=selected_chunks,
total_tokens=total_tokens,
summary_bullets=summary_bullets,
)