ruwiki-test/src/adapters/llm.py

207 lines
6.8 KiB
Python

import asyncio
import time
import httpx
import openai
import structlog
import tiktoken
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion
from src.models.constants import LLM_MAX_INPUT_TOKENS, MAX_TOKEN_LIMIT_WITH_BUFFER
from ..models import AppConfig
from .base import BaseAdapter, CircuitBreaker, RateLimiter, with_retry
logger = structlog.get_logger()
class LLMError(Exception):
pass
class LLMTokenLimitError(LLMError):
pass
class LLMRateLimitError(LLMError):
pass
class LLMProviderAdapter(BaseAdapter):
def __init__(self, config: AppConfig) -> None:
super().__init__("llm_adapter")
self.config = config
self.client = AsyncOpenAI(
api_key=config.openai_api_key, http_client=self._build_http_client()
)
try:
self.tokenizer = tiktoken.encoding_for_model(config.openai_model)
except KeyError:
self.tokenizer = tiktoken.get_encoding("cl100k_base")
self.rate_limiter = RateLimiter(
max_concurrent=config.max_concurrent_llm,
name="llm_limiter",
)
self.circuit_breaker = CircuitBreaker(
failure_threshold=config.circuit_failure_threshold,
recovery_timeout=config.circuit_recovery_timeout,
name="llm_circuit",
)
self.request_times: list[float] = []
self.rpm_lock = asyncio.Lock()
def count_tokens(self, text: str) -> int:
try:
return len(self.tokenizer.encode(text))
except Exception as e:
self.logger.warning("Ошибка подсчёта токенов", error=str(e))
return len(text) // 4
async def _check_rpm_limit(self) -> None:
async with self.rpm_lock:
current_time = time.time()
self.request_times = [
req_time for req_time in self.request_times if current_time - req_time < 60
]
if len(self.request_times) >= self.config.openai_rpm:
oldest_request = min(self.request_times)
wait_time = 60 - (current_time - oldest_request)
if wait_time > 0:
self.logger.info(
"Ожидание из-за RPM лимита",
wait_seconds=wait_time,
current_rpm=len(self.request_times),
)
await asyncio.sleep(wait_time)
self.request_times.append(current_time)
async def _make_completion_request(
self,
messages: list[dict[str, str]],
) -> ChatCompletion:
try:
response = await self.client.chat.completions.create(
model=self.config.openai_model,
messages=messages,
temperature=self.config.openai_temperature,
max_tokens=MAX_TOKEN_LIMIT_WITH_BUFFER,
)
return response
except openai.RateLimitError as e:
raise LLMRateLimitError(f"Rate limit exceeded: {e}") from e
except openai.APIError as e:
raise LLMError(f"OpenAI API error: {e}") from e
async def simplify_text(
self,
title: str,
wiki_text: str,
prompt_template: str,
) -> tuple[str, int, int]:
input_tokens = self.count_tokens(wiki_text)
if input_tokens > LLM_MAX_INPUT_TOKENS:
raise LLMTokenLimitError(f"Текст слишком длинный: {input_tokens} токенов")
try:
prompt_text = prompt_template.format(
title=title,
wiki_source_text=wiki_text,
)
except KeyError as e:
raise LLMError(f"Ошибка в шаблоне промпта: отсутствует ключ {e}") from e
messages = self._parse_prompt_template(prompt_text)
total_input_tokens = sum(self.count_tokens(msg["content"]) for msg in messages)
async with self.rate_limiter:
await self._check_rpm_limit()
response = await self.circuit_breaker.call(
lambda: with_retry(
lambda: self._make_completion_request(messages),
max_attempts=self.config.max_retries,
min_wait=self.config.retry_delay,
max_wait=self.config.retry_delay * 4,
retry_exceptions=(LLMRateLimitError, ConnectionError, TimeoutError),
name=f"simplify_{title}",
)
)
if not response.choices:
raise LLMError("Пустой ответ от OpenAI")
simplified_text = response.choices[0].message.content
if not simplified_text:
raise LLMError("OpenAI вернул пустой текст")
simplified_text = simplified_text.replace("###END###", "").strip()
output_tokens = self.count_tokens(simplified_text)
if output_tokens > MAX_TOKEN_LIMIT_WITH_BUFFER:
self.logger.warning(
"Упрощённый текст превышает лимит",
output_tokens=output_tokens,
title=title,
)
self.logger.info(
"Текст успешно упрощён",
title=title,
input_tokens=total_input_tokens,
output_tokens=output_tokens,
)
return simplified_text, total_input_tokens, output_tokens
def _parse_prompt_template(self, prompt_text: str) -> list[dict[str, str]]:
messages: list[dict[str, str]] = []
parts = prompt_text.split("### role:")
for part in parts[1:]:
lines = part.strip().split("\n", 1)
if len(lines) < 2:
continue
role = lines[0].strip()
content = lines[1].strip()
if role in ("system", "user", "assistant"):
messages.append({"role": role, "content": content})
if not messages:
messages = [{"role": "user", "content": prompt_text}]
return messages
def _build_http_client(self) -> httpx.AsyncClient:
if self.config.openai_proxy_url:
return httpx.AsyncClient(proxy=self.config.openai_proxy_url, timeout=60.0)
return httpx.AsyncClient(timeout=60.0)
async def health_check(self) -> bool:
try:
test_messages = [{"role": "user", "content": "Ответь 'OK' если всё работает."}]
response = await self.client.chat.completions.create(
model=self.config.openai_model,
messages=test_messages,
temperature=0,
max_tokens=10,
)
return bool(response.choices and response.choices[0].message.content)
except Exception as e:
self.logger.error("LLM health check failed", error=str(e))
return False