169 lines
6.2 KiB
Python
169 lines
6.2 KiB
Python
from typing import List, Optional
|
|
import tiktoken
|
|
import hashlib
|
|
import time
|
|
|
|
from src.models.email import DocChunk, RetrievalQuery, RankedContext
|
|
from src.models.lead import LeadFeatures
|
|
from src.services.chroma_store import ChromaStore
|
|
from src.services.embeddings import EmbeddingService
|
|
from src.app.config import settings
|
|
|
|
|
|
class RetrievalService:
|
|
def __init__(self, chroma_store: ChromaStore, embedding_service: EmbeddingService):
|
|
self.chroma_store = chroma_store
|
|
self.embedding_service = embedding_service
|
|
self.encoding = tiktoken.get_encoding("cl100k_base")
|
|
self._embedding_cache = {}
|
|
self._search_cache = {}
|
|
self._cache_ttl = 3600
|
|
|
|
def _get_query_hash(self, text_query: str) -> str:
|
|
return hashlib.md5(text_query.encode()).hexdigest()
|
|
|
|
def _get_cached_embedding(self, text_query: str) -> Optional[List[float]]:
|
|
query_hash = self._get_query_hash(text_query)
|
|
cached = self._embedding_cache.get(query_hash)
|
|
|
|
if cached and time.time() - cached["timestamp"] < self._cache_ttl:
|
|
return cached["embedding"]
|
|
return None
|
|
|
|
def _cache_embedding(self, text_query: str, embedding: List[float]):
|
|
query_hash = self._get_query_hash(text_query)
|
|
self._embedding_cache[query_hash] = {
|
|
"embedding": embedding,
|
|
"timestamp": time.time(),
|
|
}
|
|
|
|
if len(self._embedding_cache) > 1000:
|
|
oldest_key = min(
|
|
self._embedding_cache.keys(),
|
|
key=lambda k: self._embedding_cache[k]["timestamp"],
|
|
)
|
|
del self._embedding_cache[oldest_key]
|
|
|
|
def build_retrieval_query(self, lead_features: LeadFeatures) -> RetrievalQuery:
|
|
search_terms = [
|
|
"самозанятые",
|
|
"автоматизация",
|
|
"исполнители",
|
|
lead_features.industry_tag,
|
|
lead_features.role_category,
|
|
]
|
|
|
|
pain_points_map = {
|
|
"marketing_agency": ["онбординг", "выплаты", "подрядчики", "отчетность"],
|
|
"logistics": ["документооборот", "исполнители", "чеки", "география"],
|
|
"software": ["фриланс", "ИП", "API", "интеграции"],
|
|
"retail": ["сезонность", "временные", "массовые", "выплаты"],
|
|
"consulting": ["проекты", "экспертиза", "временные", "специалисты"],
|
|
"construction": ["подрядчики", "документы", "безопасность", "сроки"],
|
|
"other": ["онбординг", "выплаты", "документооборот"],
|
|
}
|
|
|
|
industry_terms = pain_points_map.get(
|
|
lead_features.industry_tag, pain_points_map["other"]
|
|
)
|
|
search_terms.extend(industry_terms)
|
|
|
|
text_query = " ".join(search_terms)
|
|
|
|
metadata_filters = None
|
|
|
|
return RetrievalQuery(
|
|
text_query=text_query,
|
|
metadata_filters=metadata_filters if metadata_filters else None,
|
|
k=settings.top_k,
|
|
)
|
|
|
|
def search_relevant_chunks(self, query: RetrievalQuery) -> List[DocChunk]:
|
|
cached_embedding = self._get_cached_embedding(query.text_query)
|
|
|
|
if cached_embedding:
|
|
query_embedding = cached_embedding
|
|
else:
|
|
query_embedding = self.embedding_service.embed_text(query.text_query)
|
|
self._cache_embedding(query.text_query, query_embedding)
|
|
|
|
chunks = self.chroma_store.search_similar(
|
|
query_embedding=query_embedding, k=query.k, where=query.metadata_filters
|
|
)
|
|
|
|
return chunks
|
|
|
|
def rank_chunks(
|
|
self, chunks: List[DocChunk], lead_features: LeadFeatures
|
|
) -> List[DocChunk]:
|
|
for chunk in chunks:
|
|
score = chunk.similarity_score or 0.0
|
|
|
|
industry_boost = 0.0
|
|
if lead_features.industry_tag in chunk.metadata.get("industry", []):
|
|
industry_boost = 0.2
|
|
|
|
role_boost = 0.0
|
|
if lead_features.role_category in chunk.metadata.get("roles_relevant", []):
|
|
role_boost = 0.1
|
|
|
|
metrics_boost = 0.0
|
|
if chunk.metadata.get("metrics"):
|
|
metrics_boost = 0.1
|
|
|
|
final_score = 0.6 * score + industry_boost + role_boost + metrics_boost
|
|
chunk.similarity_score = final_score
|
|
|
|
chunks.sort(key=lambda x: x.similarity_score or 0.0, reverse=True)
|
|
|
|
seen_content = set()
|
|
unique_chunks = []
|
|
for chunk in chunks:
|
|
content_hash = hash(chunk.content_text[:100])
|
|
if content_hash not in seen_content:
|
|
seen_content.add(content_hash)
|
|
unique_chunks.append(chunk)
|
|
|
|
return unique_chunks[: settings.top_n_context]
|
|
|
|
def build_context(
|
|
self, chunks: List[DocChunk], max_tokens: int = 3000
|
|
) -> RankedContext:
|
|
selected_chunks = []
|
|
total_tokens = 0
|
|
|
|
for chunk in chunks:
|
|
chunk_tokens = len(self.encoding.encode(chunk.content_text))
|
|
if total_tokens + chunk_tokens <= max_tokens:
|
|
selected_chunks.append(chunk)
|
|
total_tokens += chunk_tokens
|
|
else:
|
|
break
|
|
|
|
summary_bullets = []
|
|
for chunk in selected_chunks:
|
|
content = chunk.content_text.strip()
|
|
if len(content) > 200:
|
|
content = content[:200] + "..."
|
|
|
|
metrics_info = ""
|
|
if chunk.metadata.get("metrics"):
|
|
metrics = chunk.metadata["metrics"]
|
|
metrics_parts = []
|
|
if isinstance(metrics, dict):
|
|
for key, value in metrics.items():
|
|
if isinstance(value, (int, float)):
|
|
metrics_parts.append(f"{key}: {value}")
|
|
elif isinstance(metrics, str):
|
|
metrics_parts.append(metrics)
|
|
if metrics_parts:
|
|
metrics_info = f" ({', '.join(metrics_parts)})"
|
|
|
|
summary_bullets.append(f"• {content}{metrics_info}")
|
|
|
|
return RankedContext(
|
|
chunks=selected_chunks,
|
|
total_tokens=total_tokens,
|
|
summary_bullets=summary_bullets,
|
|
)
|