brief-rags-bench/app/services/rag_service.py

320 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG Service for interacting with RAG backends.
Поддерживает два режима:
1. Bench mode - batch запросы (все вопросы сразу)
2. Backend mode - вопросы по одному с возможностью сброса сессии
"""
import logging
import httpx
import uuid
from typing import List, Dict, Optional, Any
from datetime import datetime
from app.config import settings
from app.models.query import QuestionRequest, RagResponseBenchList
logger = logging.getLogger(__name__)
class RagService:
"""
Сервис для взаимодействия с RAG backend для трех окружений (IFT, PSI, PROD).
Поддерживает mTLS, настраиваемые headers и два режима работы:
- bench: batch запросы
- backend: последовательные запросы с reset session
"""
def __init__(self):
"""Инициализация клиентов для всех трех окружений."""
self.clients = {
'ift': self._create_client('ift'),
'psi': self._create_client('psi'),
'prod': self._create_client('prod')
}
logger.info("RagService initialized for all environments")
def _create_client(self, environment: str) -> httpx.AsyncClient:
"""
Создать HTTP клиент для указанного окружения с mTLS поддержкой.
Args:
environment: Окружение (ift/psi/prod)
Returns:
Настроенный httpx.AsyncClient
"""
env_upper = environment.upper()
cert_cert_path = getattr(settings, f"{env_upper}_RAG_CERT_CERT", "")
cert_key_path = getattr(settings, f"{env_upper}_RAG_CERT_KEY", "")
cert_ca_path = getattr(settings, f"{env_upper}_RAG_CERT_CA", "")
cert = None
verify = True
if cert_cert_path and cert_key_path:
cert = (cert_cert_path, cert_key_path)
logger.info(f"mTLS enabled for {environment} with cert: {cert_cert_path}")
if cert_ca_path:
verify = cert_ca_path
logger.info(f"Custom CA for {environment}: {cert_ca_path}")
return httpx.AsyncClient(
timeout=httpx.Timeout(1800.0),
cert=cert,
verify=verify,
follow_redirects=True
)
async def close(self):
"""Закрыть все HTTP клиенты."""
for env, client in self.clients.items():
await client.aclose()
logger.info(f"Client closed for {env}")
async def __aenter__(self):
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.close()
def _get_base_url(self, environment: str) -> str:
"""
Получить базовый URL для окружения.
Args:
environment: Окружение (ift/psi/prod)
Returns:
Базовый URL (https://host:port)
"""
env_upper = environment.upper()
host = getattr(settings, f"{env_upper}_RAG_HOST")
port = getattr(settings, f"{env_upper}_RAG_PORT")
return f"https://{host}:{port}"
def _get_bench_endpoint(self, environment: str) -> str:
"""
Получить endpoint для bench mode.
Args:
environment: Окружение (ift/psi/prod)
Returns:
Полный URL для bench запросов
"""
env_upper = environment.upper()
base_url = self._get_base_url(environment)
endpoint = getattr(settings, f"{env_upper}_RAG_ENDPOINT")
return f"{base_url}/{endpoint.lstrip('/')}"
def _get_backend_endpoints(self, environment: str, user_settings: Dict) -> Dict[str, str]:
"""
Получить endpoints для backend mode (ask + reset).
Args:
environment: Окружение (ift/psi/prod)
user_settings: Настройки пользователя для окружения
Returns:
Dict с ask_endpoint и reset_endpoint
"""
base_url = self._get_base_url(environment)
ask_endpoint = user_settings.get('backendAskEndpoint', 'ask')
reset_endpoint = user_settings.get('backendResetEndpoint', 'reset')
return {
'ask': f"{base_url}/{ask_endpoint.lstrip('/')}",
'reset': f"{base_url}/{reset_endpoint.lstrip('/')}"
}
def _build_bench_headers(
self,
environment: str,
user_settings: Dict,
request_id: Optional[str] = None
) -> Dict[str, str]:
"""
Построить headers для bench mode запроса.
Args:
environment: Окружение (ift/psi/prod)
user_settings: Настройки пользователя
request_id: Request ID (генерируется если не задан)
Returns:
Dict с headers
"""
headers = {
"Content-Type": "application/json",
"Request-Id": request_id or str(uuid.uuid4()),
"System-Id": f"brief-bench-{environment}"
}
if user_settings.get('bearerToken'):
headers["Authorization"] = f"Bearer {user_settings['bearerToken']}"
if user_settings.get('systemPlatform'):
headers["System-Platform"] = user_settings['systemPlatform']
return headers
def _build_backend_headers(self, user_settings: Dict) -> Dict[str, str]:
"""
Построить headers для backend mode запроса.
Args:
user_settings: Настройки пользователя
Returns:
Dict с headers
"""
headers = {
"Content-Type": "application/json"
}
if user_settings.get('bearerToken'):
headers["Authorization"] = f"Bearer {user_settings['bearerToken']}"
if user_settings.get('platformUserId'):
headers["Platform-User-Id"] = user_settings['platformUserId']
if user_settings.get('platformId'):
headers["Platform-Id"] = user_settings['platformId']
return headers
async def send_bench_query(
self,
environment: str,
questions: List[QuestionRequest],
user_settings: Dict,
request_id: Optional[str] = None
) -> RagResponseBenchList:
"""
Отправить batch запрос к RAG backend (bench mode).
Args:
environment: Окружение (ift/psi/prod)
questions: Список вопросов
user_settings: Настройки пользователя для окружения
request_id: Request ID (опционально)
Returns:
RagResponseBenchList с ответом от RAG backend
Raises:
httpx.HTTPStatusError: При HTTP ошибках
"""
client = self.clients[environment.lower()]
url = self._get_bench_endpoint(environment)
headers = self._build_bench_headers(environment, user_settings, request_id)
body = [q.model_dump() for q in questions]
logger.info(f"Sending bench query to {environment}: {len(questions)} questions")
logger.debug(f"URL: {url}, Headers: {headers}")
try:
response = await client.post(url, json=body, headers=headers)
response.raise_for_status()
response_data = response.json()
# Валидация ответа через Pydantic модель
return RagResponseBenchList(**response_data)
except httpx.HTTPStatusError as e:
logger.error(f"Bench query failed for {environment}: {e.response.status_code} - {e.response.text}")
raise
except Exception as e:
logger.error(f"Unexpected error in bench query for {environment}: {e}")
raise
async def send_backend_query(
self,
environment: str,
questions: List[QuestionRequest],
user_settings: Dict,
reset_session: bool = True
) -> List[Dict[str, Any]]:
"""
Отправить вопросы по одному к RAG backend (backend mode).
После каждого вопроса может сбросить сессию (если resetSessionMode=true).
Args:
environment: Окружение (ift/psi/prod)
questions: Список вопросов
user_settings: Настройки пользователя для окружения
reset_session: Сбрасывать ли сессию после каждого вопроса
Returns:
List[Dict] с ответами от RAG backend
Raises:
httpx.HTTPStatusError: При HTTP ошибках
"""
client = self.clients[environment.lower()]
endpoints = self._get_backend_endpoints(environment, user_settings)
headers = self._build_backend_headers(user_settings)
logger.info(
f"Sending backend query to {environment}: {len(questions)} questions "
f"(reset_session={reset_session})"
)
responses = []
for idx, question in enumerate(questions, start=1):
now = datetime.utcnow().isoformat() + "Z"
body = {
"question": question.body,
"user_message_id": idx,
"user_message_datetime": now,
"with_classify": user_settings.get('withClassify', False)
}
logger.debug(f"Sending question {idx}/{len(questions)}: {question.body[:50]}...")
try:
response = await client.post(endpoints['ask'], json=body, headers=headers)
response.raise_for_status()
response_data = response.json()
responses.append(response_data)
if reset_session and user_settings.get('resetSessionMode', True):
reset_body = {"user_message_datetime": now}
logger.debug(f"Resetting session after question {idx}")
reset_response = await client.post(
endpoints['reset'],
json=reset_body,
headers=headers
)
reset_response.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(
f"Backend query failed for {environment} (question {idx}): "
f"{e.response.status_code} - {e.response.text}"
)
raise
except Exception as e:
logger.error(f"Unexpected error in backend query for {environment} (question {idx}): {e}")
raise
logger.info(f"Backend query completed for {environment}: {len(responses)} responses")
return responses