from typing import List, Optional import tiktoken import hashlib import time from src.models.email import DocChunk, RetrievalQuery, RankedContext from src.models.lead import LeadFeatures from src.services.chroma_store import ChromaStore from src.services.embeddings import EmbeddingService from src.app.config import settings class RetrievalService: def __init__(self, chroma_store: ChromaStore, embedding_service: EmbeddingService): self.chroma_store = chroma_store self.embedding_service = embedding_service self.encoding = tiktoken.get_encoding("cl100k_base") self._embedding_cache = {} self._search_cache = {} self._cache_ttl = 3600 def _get_query_hash(self, text_query: str) -> str: return hashlib.md5(text_query.encode()).hexdigest() def _get_cached_embedding(self, text_query: str) -> Optional[List[float]]: query_hash = self._get_query_hash(text_query) cached = self._embedding_cache.get(query_hash) if cached and time.time() - cached["timestamp"] < self._cache_ttl: return cached["embedding"] return None def _cache_embedding(self, text_query: str, embedding: List[float]): query_hash = self._get_query_hash(text_query) self._embedding_cache[query_hash] = { "embedding": embedding, "timestamp": time.time(), } if len(self._embedding_cache) > 1000: oldest_key = min( self._embedding_cache.keys(), key=lambda k: self._embedding_cache[k]["timestamp"], ) del self._embedding_cache[oldest_key] def build_retrieval_query(self, lead_features: LeadFeatures) -> RetrievalQuery: search_terms = [ "самозанятые", "автоматизация", "исполнители", lead_features.industry_tag, lead_features.role_category, ] pain_points_map = { "marketing_agency": ["онбординг", "выплаты", "подрядчики", "отчетность"], "logistics": ["документооборот", "исполнители", "чеки", "география"], "software": ["фриланс", "ИП", "API", "интеграции"], "retail": ["сезонность", "временные", "массовые", "выплаты"], "consulting": ["проекты", "экспертиза", "временные", "специалисты"], "construction": ["подрядчики", "документы", "безопасность", "сроки"], "other": ["онбординг", "выплаты", "документооборот"], } industry_terms = pain_points_map.get( lead_features.industry_tag, pain_points_map["other"] ) search_terms.extend(industry_terms) text_query = " ".join(search_terms) metadata_filters = None return RetrievalQuery( text_query=text_query, metadata_filters=metadata_filters if metadata_filters else None, k=settings.top_k, ) def search_relevant_chunks(self, query: RetrievalQuery) -> List[DocChunk]: cached_embedding = self._get_cached_embedding(query.text_query) if cached_embedding: query_embedding = cached_embedding else: query_embedding = self.embedding_service.embed_text(query.text_query) self._cache_embedding(query.text_query, query_embedding) chunks = self.chroma_store.search_similar( query_embedding=query_embedding, k=query.k, where=query.metadata_filters ) return chunks def rank_chunks( self, chunks: List[DocChunk], lead_features: LeadFeatures ) -> List[DocChunk]: for chunk in chunks: score = chunk.similarity_score or 0.0 industry_boost = 0.0 if lead_features.industry_tag in chunk.metadata.get("industry", []): industry_boost = 0.2 role_boost = 0.0 if lead_features.role_category in chunk.metadata.get("roles_relevant", []): role_boost = 0.1 metrics_boost = 0.0 if chunk.metadata.get("metrics"): metrics_boost = 0.1 final_score = 0.6 * score + industry_boost + role_boost + metrics_boost chunk.similarity_score = final_score chunks.sort(key=lambda x: x.similarity_score or 0.0, reverse=True) seen_content = set() unique_chunks = [] for chunk in chunks: content_hash = hash(chunk.content_text[:100]) if content_hash not in seen_content: seen_content.add(content_hash) unique_chunks.append(chunk) return unique_chunks[: settings.top_n_context] def build_context( self, chunks: List[DocChunk], max_tokens: int = 3000 ) -> RankedContext: selected_chunks = [] total_tokens = 0 for chunk in chunks: chunk_tokens = len(self.encoding.encode(chunk.content_text)) if total_tokens + chunk_tokens <= max_tokens: selected_chunks.append(chunk) total_tokens += chunk_tokens else: break summary_bullets = [] for chunk in selected_chunks: content = chunk.content_text.strip() if len(content) > 200: content = content[:200] + "..." metrics_info = "" if chunk.metadata.get("metrics"): metrics = chunk.metadata["metrics"] metrics_parts = [] if isinstance(metrics, dict): for key, value in metrics.items(): if isinstance(value, (int, float)): metrics_parts.append(f"{key}: {value}") elif isinstance(metrics, str): metrics_parts.append(metrics) if metrics_parts: metrics_info = f" ({', '.join(metrics_parts)})" summary_bullets.append(f"• {content}{metrics_info}") return RankedContext( chunks=selected_chunks, total_tokens=total_tokens, summary_bullets=summary_bullets, )