Add proxy support for OpenAI

This commit is contained in:
itqop 2025-07-12 10:37:00 +03:00
parent c5306bb56e
commit bf830fd330
4 changed files with 50 additions and 11 deletions

View File

@ -1,6 +1,7 @@
OPENAI_API_KEY=your_openai_api_key_here OPENAI_API_KEY=your_openai_api_key_here
OPENAI_MODEL=gpt-4o-mini OPENAI_MODEL=gpt-4o
OPENAI_TEMPERATURE=0.0 OPENAI_TEMPERATURE=0.0
OPENAI_PROXY_URL='socks5h://37.18.73.60:5566' # socks5 recommended
DB_PATH=./data/wiki.db DB_PATH=./data/wiki.db

View File

@ -1,12 +1,15 @@
import asyncio import asyncio
import time import time
import httpx
import openai import openai
import structlog import structlog
import tiktoken import tiktoken
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion from openai.types.chat import ChatCompletion
from src.models.constants import LLM_MAX_INPUT_TOKENS, MAX_TOKEN_LIMIT_WITH_BUFFER
from ..models import AppConfig from ..models import AppConfig
from .base import BaseAdapter, CircuitBreaker, RateLimiter, with_retry from .base import BaseAdapter, CircuitBreaker, RateLimiter, with_retry
@ -31,7 +34,10 @@ class LLMProviderAdapter(BaseAdapter):
super().__init__("llm_adapter") super().__init__("llm_adapter")
self.config = config self.config = config
self.client = AsyncOpenAI(api_key=config.openai_api_key) self.client = AsyncOpenAI(
api_key=config.openai_api_key,
http_client=self._build_http_client()
)
try: try:
self.tokenizer = tiktoken.encoding_for_model(config.openai_model) self.tokenizer = tiktoken.encoding_for_model(config.openai_model)
@ -87,7 +93,7 @@ class LLMProviderAdapter(BaseAdapter):
model=self.config.openai_model, model=self.config.openai_model,
messages=messages, messages=messages,
temperature=self.config.openai_temperature, temperature=self.config.openai_temperature,
max_tokens=1500, max_tokens=MAX_TOKEN_LIMIT_WITH_BUFFER,
) )
return response return response
except openai.RateLimitError as e: except openai.RateLimitError as e:
@ -102,8 +108,8 @@ class LLMProviderAdapter(BaseAdapter):
prompt_template: str, prompt_template: str,
) -> tuple[str, int, int]: ) -> tuple[str, int, int]:
input_tokens = self.count_tokens(wiki_text) input_tokens = self.count_tokens(wiki_text)
if input_tokens > 6000: if input_tokens > LLM_MAX_INPUT_TOKENS:
raise LLMTokenLimitError(f"Текст слишком длинный: {input_tokens} токенов (лимит 6000)") raise LLMTokenLimitError(f"Текст слишком длинный: {input_tokens} токенов")
try: try:
prompt_text = prompt_template.format( prompt_text = prompt_template.format(
@ -142,7 +148,7 @@ class LLMProviderAdapter(BaseAdapter):
output_tokens = self.count_tokens(simplified_text) output_tokens = self.count_tokens(simplified_text)
if output_tokens > 1200: if output_tokens > MAX_TOKEN_LIMIT_WITH_BUFFER:
self.logger.warning( self.logger.warning(
"Упрощённый текст превышает лимит", "Упрощённый текст превышает лимит",
output_tokens=output_tokens, output_tokens=output_tokens,
@ -179,6 +185,16 @@ class LLMProviderAdapter(BaseAdapter):
return messages return messages
def _build_http_client(self) -> httpx.AsyncClient:
if self.config.openai_proxy_url:
return httpx.AsyncClient(
proxy=self.config.openai_proxy_url,
timeout=60.0
)
return httpx.AsyncClient(timeout=60.0)
async def health_check(self) -> bool: async def health_check(self) -> bool:
try: try:
test_messages = [{"role": "user", "content": "Ответь 'OK' если всё работает."}] test_messages = [{"role": "user", "content": "Ответь 'OK' если всё работает."}]

View File

@ -19,6 +19,7 @@ class AppConfig(BaseSettings):
openai_temperature: float = Field( openai_temperature: float = Field(
default=0.0, ge=0.0, le=2.0, description="Температура для LLM" default=0.0, ge=0.0, le=2.0, description="Температура для LLM"
) )
openai_proxy_url: str | None = Field(description="Proxy URL для OpenAI")
db_path: str = Field(default="./data/wiki.db", description="Путь к файлу SQLite") db_path: str = Field(default="./data/wiki.db", description="Путь к файлу SQLite")
@ -33,7 +34,7 @@ class AppConfig(BaseSettings):
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(default="INFO") log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(default="INFO")
log_format: Literal["json", "text"] = Field(default="json") log_format: Literal["json", "text"] = Field(default="json")
chunk_size: int = Field(default=2000, ge=500, le=8000, description="Размер чанка для текста") chunk_size: int = Field(default=2000, ge=500, le=122000, description="Размер чанка для текста")
chunk_overlap: int = Field(default=200, ge=0, le=1000, description="Перекрытие между чанками") chunk_overlap: int = Field(default=200, ge=0, le=1000, description="Перекрытие между чанками")
max_retries: int = Field(default=3, ge=1, le=10, description="Максимум попыток повтора") max_retries: int = Field(default=3, ge=1, le=10, description="Максимум попыток повтора")
retry_delay: float = Field( retry_delay: float = Field(

View File

@ -65,6 +65,8 @@ class AsyncWriteQueue:
await self._queue.put(operation) await self._queue.put(operation)
async def update_from_result(self, result: ProcessingResult) -> ArticleDTO: async def update_from_result(self, result: ProcessingResult) -> ArticleDTO:
self.logger.info("Получен результат для записи", url=result.url, success=result.success)
future: asyncio.Future[ArticleDTO] = asyncio.Future() future: asyncio.Future[ArticleDTO] = asyncio.Future()
operation = WriteOperation( operation = WriteOperation(
@ -73,15 +75,20 @@ class AsyncWriteQueue:
) )
operation.future = future operation.future = future
self.logger.info("Добавляем операцию в очередь", url=result.url)
await self._queue.put(operation) await self._queue.put(operation)
return await future self.logger.info("Операция добавлена в очередь, ожидаем результат", url=result.url)
result_article = await future
self.logger.info("Получен результат из очереди", url=result.url)
return result_article
async def _worker_loop(self) -> None: async def _worker_loop(self) -> None:
batch: list[WriteOperation] = [] batch: list[WriteOperation] = []
while not self._shutdown_event.is_set(): while not self._shutdown_event.is_set():
batch = await self._collect_batch(batch) batch = await self._collect_batch(batch)
if batch and (len(batch) >= self.max_batch_size or self._shutdown_event.is_set()): if batch:
await self._process_batch(batch) await self._process_batch(batch)
batch.clear() batch.clear()
@ -90,7 +97,7 @@ class AsyncWriteQueue:
async def _collect_batch(self, batch: list[WriteOperation]) -> list[WriteOperation]: async def _collect_batch(self, batch: list[WriteOperation]) -> list[WriteOperation]:
try: try:
timeout = 0.1 if batch else 1.0 timeout = 1.0 if not batch else 0.1
operation = await asyncio.wait_for(self._queue.get(), timeout=timeout) operation = await asyncio.wait_for(self._queue.get(), timeout=timeout)
batch.append(operation) batch.append(operation)
return batch return batch
@ -117,12 +124,18 @@ class AsyncWriteQueue:
async def _process_operation_safely(self, operation: WriteOperation) -> None: async def _process_operation_safely(self, operation: WriteOperation) -> None:
try: try:
self.logger.info("Начинаем обработку операции",
operation_type=operation.operation_type,
url=operation.result.url if operation.result else "N/A")
await self._process_single_operation(operation) await self._process_single_operation(operation)
self._total_operations += 1 self._total_operations += 1
if operation.future and not operation.future.done(): if operation.future and not operation.future.done():
if operation.operation_type == "update_from_result" and operation.result: if operation.operation_type == "update_from_result" and operation.result:
self.logger.info("Получаем статью из репозитория", url=operation.result.url)
article = await self.repository.get_by_url(operation.result.url) article = await self.repository.get_by_url(operation.result.url)
self.logger.info("Статья получена, устанавливаем результат", url=operation.result.url)
operation.future.set_result(article) operation.future.set_result(article)
except Exception as e: except Exception as e:
@ -145,11 +158,15 @@ class AsyncWriteQueue:
raise ValueError(msg) raise ValueError(msg)
async def _update_article_from_result(self, result: ProcessingResult) -> ArticleDTO: async def _update_article_from_result(self, result: ProcessingResult) -> ArticleDTO:
self.logger.info("Начинаем обновление статьи из результата", url=result.url)
article = await self.repository.get_by_url(result.url) article = await self.repository.get_by_url(result.url)
if not article: if not article:
msg = f"Статья с URL {result.url} не найдена" msg = f"Статья с URL {result.url} не найдена"
raise ValueError(msg) raise ValueError(msg)
self.logger.info("Статья найдена, обновляем поля", url=result.url, success=result.success)
if result.success: if result.success:
if not (result.title and result.raw_text and result.simplified_text): if not (result.title and result.raw_text and result.simplified_text):
msg = "Неполные данные в успешном результате" msg = "Неполные данные в успешном результате"
@ -162,7 +179,11 @@ class AsyncWriteQueue:
from src.models.article_dto import ArticleStatus from src.models.article_dto import ArticleStatus
article.status = ArticleStatus.FAILED article.status = ArticleStatus.FAILED
return await self.repository.update_article(article) self.logger.info("Сохраняем обновлённую статью", url=result.url)
updated_article = await self.repository.update_article(article)
self.logger.info("Статья успешно обновлена", url=result.url)
return updated_article
@property @property
def queue_size(self) -> int: def queue_size(self) -> int: