207 lines
6.8 KiB
Python
207 lines
6.8 KiB
Python
import asyncio
|
|
import time
|
|
|
|
import httpx
|
|
import openai
|
|
import structlog
|
|
import tiktoken
|
|
from openai import AsyncOpenAI
|
|
from openai.types.chat import ChatCompletion
|
|
|
|
from src.models.constants import LLM_MAX_INPUT_TOKENS, MAX_TOKEN_LIMIT_WITH_BUFFER
|
|
|
|
from ..models import AppConfig
|
|
from .base import BaseAdapter, CircuitBreaker, RateLimiter, with_retry
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class LLMError(Exception):
|
|
pass
|
|
|
|
|
|
class LLMTokenLimitError(LLMError):
|
|
pass
|
|
|
|
|
|
class LLMRateLimitError(LLMError):
|
|
pass
|
|
|
|
|
|
class LLMProviderAdapter(BaseAdapter):
|
|
|
|
def __init__(self, config: AppConfig) -> None:
|
|
super().__init__("llm_adapter")
|
|
self.config = config
|
|
|
|
self.client = AsyncOpenAI(
|
|
api_key=config.openai_api_key, http_client=self._build_http_client()
|
|
)
|
|
|
|
try:
|
|
self.tokenizer = tiktoken.encoding_for_model(config.openai_model)
|
|
except KeyError:
|
|
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
|
|
self.rate_limiter = RateLimiter(
|
|
max_concurrent=config.max_concurrent_llm,
|
|
name="llm_limiter",
|
|
)
|
|
self.circuit_breaker = CircuitBreaker(
|
|
failure_threshold=config.circuit_failure_threshold,
|
|
recovery_timeout=config.circuit_recovery_timeout,
|
|
name="llm_circuit",
|
|
)
|
|
|
|
self.request_times: list[float] = []
|
|
self.rpm_lock = asyncio.Lock()
|
|
|
|
def count_tokens(self, text: str) -> int:
|
|
try:
|
|
return len(self.tokenizer.encode(text))
|
|
except Exception as e:
|
|
self.logger.warning("Ошибка подсчёта токенов", error=str(e))
|
|
return len(text) // 4
|
|
|
|
async def _check_rpm_limit(self) -> None:
|
|
async with self.rpm_lock:
|
|
current_time = time.time()
|
|
self.request_times = [
|
|
req_time for req_time in self.request_times if current_time - req_time < 60
|
|
]
|
|
|
|
if len(self.request_times) >= self.config.openai_rpm:
|
|
oldest_request = min(self.request_times)
|
|
wait_time = 60 - (current_time - oldest_request)
|
|
if wait_time > 0:
|
|
self.logger.info(
|
|
"Ожидание из-за RPM лимита",
|
|
wait_seconds=wait_time,
|
|
current_rpm=len(self.request_times),
|
|
)
|
|
await asyncio.sleep(wait_time)
|
|
|
|
self.request_times.append(current_time)
|
|
|
|
async def _make_completion_request(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
) -> ChatCompletion:
|
|
try:
|
|
response = await self.client.chat.completions.create(
|
|
model=self.config.openai_model,
|
|
messages=messages,
|
|
temperature=self.config.openai_temperature,
|
|
max_tokens=MAX_TOKEN_LIMIT_WITH_BUFFER,
|
|
)
|
|
return response
|
|
except openai.RateLimitError as e:
|
|
raise LLMRateLimitError(f"Rate limit exceeded: {e}") from e
|
|
except openai.APIError as e:
|
|
raise LLMError(f"OpenAI API error: {e}") from e
|
|
|
|
async def simplify_text(
|
|
self,
|
|
title: str,
|
|
wiki_text: str,
|
|
prompt_template: str,
|
|
) -> tuple[str, int, int]:
|
|
input_tokens = self.count_tokens(wiki_text)
|
|
if input_tokens > LLM_MAX_INPUT_TOKENS:
|
|
raise LLMTokenLimitError(f"Текст слишком длинный: {input_tokens} токенов")
|
|
|
|
try:
|
|
prompt_text = prompt_template.format(
|
|
title=title,
|
|
wiki_source_text=wiki_text,
|
|
)
|
|
except KeyError as e:
|
|
raise LLMError(f"Ошибка в шаблоне промпта: отсутствует ключ {e}") from e
|
|
|
|
messages = self._parse_prompt_template(prompt_text)
|
|
|
|
total_input_tokens = sum(self.count_tokens(msg["content"]) for msg in messages)
|
|
|
|
async with self.rate_limiter:
|
|
await self._check_rpm_limit()
|
|
|
|
response = await self.circuit_breaker.call(
|
|
lambda: with_retry(
|
|
lambda: self._make_completion_request(messages),
|
|
max_attempts=self.config.max_retries,
|
|
min_wait=self.config.retry_delay,
|
|
max_wait=self.config.retry_delay * 4,
|
|
retry_exceptions=(LLMRateLimitError, ConnectionError, TimeoutError),
|
|
name=f"simplify_{title}",
|
|
)
|
|
)
|
|
|
|
if not response.choices:
|
|
raise LLMError("Пустой ответ от OpenAI")
|
|
|
|
simplified_text = response.choices[0].message.content
|
|
if not simplified_text:
|
|
raise LLMError("OpenAI вернул пустой текст")
|
|
|
|
simplified_text = simplified_text.replace("###END###", "").strip()
|
|
|
|
output_tokens = self.count_tokens(simplified_text)
|
|
|
|
if output_tokens > MAX_TOKEN_LIMIT_WITH_BUFFER:
|
|
self.logger.warning(
|
|
"Упрощённый текст превышает лимит",
|
|
output_tokens=output_tokens,
|
|
title=title,
|
|
)
|
|
|
|
self.logger.info(
|
|
"Текст успешно упрощён",
|
|
title=title,
|
|
input_tokens=total_input_tokens,
|
|
output_tokens=output_tokens,
|
|
)
|
|
|
|
return simplified_text, total_input_tokens, output_tokens
|
|
|
|
def _parse_prompt_template(self, prompt_text: str) -> list[dict[str, str]]:
|
|
messages: list[dict[str, str]] = []
|
|
|
|
parts = prompt_text.split("### role:")
|
|
|
|
for part in parts[1:]:
|
|
lines = part.strip().split("\n", 1)
|
|
if len(lines) < 2:
|
|
continue
|
|
|
|
role = lines[0].strip()
|
|
content = lines[1].strip()
|
|
|
|
if role in ("system", "user", "assistant"):
|
|
messages.append({"role": role, "content": content})
|
|
|
|
if not messages:
|
|
messages = [{"role": "user", "content": prompt_text}]
|
|
|
|
return messages
|
|
|
|
def _build_http_client(self) -> httpx.AsyncClient:
|
|
if self.config.openai_proxy_url:
|
|
return httpx.AsyncClient(proxy=self.config.openai_proxy_url, timeout=60.0)
|
|
return httpx.AsyncClient(timeout=60.0)
|
|
|
|
async def health_check(self) -> bool:
|
|
try:
|
|
test_messages = [{"role": "user", "content": "Ответь 'OK' если всё работает."}]
|
|
|
|
response = await self.client.chat.completions.create(
|
|
model=self.config.openai_model,
|
|
messages=test_messages,
|
|
temperature=0,
|
|
max_tokens=10,
|
|
)
|
|
|
|
return bool(response.choices and response.choices[0].message.content)
|
|
except Exception as e:
|
|
self.logger.error("LLM health check failed", error=str(e))
|
|
return False
|