Add proxy support for OpenAI
This commit is contained in:
parent
c5306bb56e
commit
bf830fd330
|
@ -1,6 +1,7 @@
|
|||
OPENAI_API_KEY=your_openai_api_key_here
|
||||
OPENAI_MODEL=gpt-4o-mini
|
||||
OPENAI_MODEL=gpt-4o
|
||||
OPENAI_TEMPERATURE=0.0
|
||||
OPENAI_PROXY_URL='socks5h://37.18.73.60:5566' # socks5 recommended
|
||||
|
||||
DB_PATH=./data/wiki.db
|
||||
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
import asyncio
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
import structlog
|
||||
import tiktoken
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
from src.models.constants import LLM_MAX_INPUT_TOKENS, MAX_TOKEN_LIMIT_WITH_BUFFER
|
||||
|
||||
from ..models import AppConfig
|
||||
from .base import BaseAdapter, CircuitBreaker, RateLimiter, with_retry
|
||||
|
||||
|
@ -31,7 +34,10 @@ class LLMProviderAdapter(BaseAdapter):
|
|||
super().__init__("llm_adapter")
|
||||
self.config = config
|
||||
|
||||
self.client = AsyncOpenAI(api_key=config.openai_api_key)
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=config.openai_api_key,
|
||||
http_client=self._build_http_client()
|
||||
)
|
||||
|
||||
try:
|
||||
self.tokenizer = tiktoken.encoding_for_model(config.openai_model)
|
||||
|
@ -87,7 +93,7 @@ class LLMProviderAdapter(BaseAdapter):
|
|||
model=self.config.openai_model,
|
||||
messages=messages,
|
||||
temperature=self.config.openai_temperature,
|
||||
max_tokens=1500,
|
||||
max_tokens=MAX_TOKEN_LIMIT_WITH_BUFFER,
|
||||
)
|
||||
return response
|
||||
except openai.RateLimitError as e:
|
||||
|
@ -102,8 +108,8 @@ class LLMProviderAdapter(BaseAdapter):
|
|||
prompt_template: str,
|
||||
) -> tuple[str, int, int]:
|
||||
input_tokens = self.count_tokens(wiki_text)
|
||||
if input_tokens > 6000:
|
||||
raise LLMTokenLimitError(f"Текст слишком длинный: {input_tokens} токенов (лимит 6000)")
|
||||
if input_tokens > LLM_MAX_INPUT_TOKENS:
|
||||
raise LLMTokenLimitError(f"Текст слишком длинный: {input_tokens} токенов")
|
||||
|
||||
try:
|
||||
prompt_text = prompt_template.format(
|
||||
|
@ -142,7 +148,7 @@ class LLMProviderAdapter(BaseAdapter):
|
|||
|
||||
output_tokens = self.count_tokens(simplified_text)
|
||||
|
||||
if output_tokens > 1200:
|
||||
if output_tokens > MAX_TOKEN_LIMIT_WITH_BUFFER:
|
||||
self.logger.warning(
|
||||
"Упрощённый текст превышает лимит",
|
||||
output_tokens=output_tokens,
|
||||
|
@ -179,6 +185,16 @@ class LLMProviderAdapter(BaseAdapter):
|
|||
|
||||
return messages
|
||||
|
||||
def _build_http_client(self) -> httpx.AsyncClient:
|
||||
if self.config.openai_proxy_url:
|
||||
return httpx.AsyncClient(
|
||||
proxy=self.config.openai_proxy_url,
|
||||
timeout=60.0
|
||||
)
|
||||
return httpx.AsyncClient(timeout=60.0)
|
||||
|
||||
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
try:
|
||||
test_messages = [{"role": "user", "content": "Ответь 'OK' если всё работает."}]
|
||||
|
|
|
@ -19,6 +19,7 @@ class AppConfig(BaseSettings):
|
|||
openai_temperature: float = Field(
|
||||
default=0.0, ge=0.0, le=2.0, description="Температура для LLM"
|
||||
)
|
||||
openai_proxy_url: str | None = Field(description="Proxy URL для OpenAI")
|
||||
|
||||
db_path: str = Field(default="./data/wiki.db", description="Путь к файлу SQLite")
|
||||
|
||||
|
@ -33,7 +34,7 @@ class AppConfig(BaseSettings):
|
|||
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(default="INFO")
|
||||
log_format: Literal["json", "text"] = Field(default="json")
|
||||
|
||||
chunk_size: int = Field(default=2000, ge=500, le=8000, description="Размер чанка для текста")
|
||||
chunk_size: int = Field(default=2000, ge=500, le=122000, description="Размер чанка для текста")
|
||||
chunk_overlap: int = Field(default=200, ge=0, le=1000, description="Перекрытие между чанками")
|
||||
max_retries: int = Field(default=3, ge=1, le=10, description="Максимум попыток повтора")
|
||||
retry_delay: float = Field(
|
||||
|
|
|
@ -65,6 +65,8 @@ class AsyncWriteQueue:
|
|||
await self._queue.put(operation)
|
||||
|
||||
async def update_from_result(self, result: ProcessingResult) -> ArticleDTO:
|
||||
self.logger.info("Получен результат для записи", url=result.url, success=result.success)
|
||||
|
||||
future: asyncio.Future[ArticleDTO] = asyncio.Future()
|
||||
|
||||
operation = WriteOperation(
|
||||
|
@ -73,15 +75,20 @@ class AsyncWriteQueue:
|
|||
)
|
||||
operation.future = future
|
||||
|
||||
self.logger.info("Добавляем операцию в очередь", url=result.url)
|
||||
await self._queue.put(operation)
|
||||
return await future
|
||||
self.logger.info("Операция добавлена в очередь, ожидаем результат", url=result.url)
|
||||
|
||||
result_article = await future
|
||||
self.logger.info("Получен результат из очереди", url=result.url)
|
||||
return result_article
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
batch: list[WriteOperation] = []
|
||||
|
||||
while not self._shutdown_event.is_set():
|
||||
batch = await self._collect_batch(batch)
|
||||
if batch and (len(batch) >= self.max_batch_size or self._shutdown_event.is_set()):
|
||||
if batch:
|
||||
await self._process_batch(batch)
|
||||
batch.clear()
|
||||
|
||||
|
@ -90,7 +97,7 @@ class AsyncWriteQueue:
|
|||
|
||||
async def _collect_batch(self, batch: list[WriteOperation]) -> list[WriteOperation]:
|
||||
try:
|
||||
timeout = 0.1 if batch else 1.0
|
||||
timeout = 1.0 if not batch else 0.1
|
||||
operation = await asyncio.wait_for(self._queue.get(), timeout=timeout)
|
||||
batch.append(operation)
|
||||
return batch
|
||||
|
@ -117,12 +124,18 @@ class AsyncWriteQueue:
|
|||
|
||||
async def _process_operation_safely(self, operation: WriteOperation) -> None:
|
||||
try:
|
||||
self.logger.info("Начинаем обработку операции",
|
||||
operation_type=operation.operation_type,
|
||||
url=operation.result.url if operation.result else "N/A")
|
||||
|
||||
await self._process_single_operation(operation)
|
||||
self._total_operations += 1
|
||||
|
||||
if operation.future and not operation.future.done():
|
||||
if operation.operation_type == "update_from_result" and operation.result:
|
||||
self.logger.info("Получаем статью из репозитория", url=operation.result.url)
|
||||
article = await self.repository.get_by_url(operation.result.url)
|
||||
self.logger.info("Статья получена, устанавливаем результат", url=operation.result.url)
|
||||
operation.future.set_result(article)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -145,11 +158,15 @@ class AsyncWriteQueue:
|
|||
raise ValueError(msg)
|
||||
|
||||
async def _update_article_from_result(self, result: ProcessingResult) -> ArticleDTO:
|
||||
self.logger.info("Начинаем обновление статьи из результата", url=result.url)
|
||||
|
||||
article = await self.repository.get_by_url(result.url)
|
||||
if not article:
|
||||
msg = f"Статья с URL {result.url} не найдена"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.logger.info("Статья найдена, обновляем поля", url=result.url, success=result.success)
|
||||
|
||||
if result.success:
|
||||
if not (result.title and result.raw_text and result.simplified_text):
|
||||
msg = "Неполные данные в успешном результате"
|
||||
|
@ -162,7 +179,11 @@ class AsyncWriteQueue:
|
|||
from src.models.article_dto import ArticleStatus
|
||||
article.status = ArticleStatus.FAILED
|
||||
|
||||
return await self.repository.update_article(article)
|
||||
self.logger.info("Сохраняем обновлённую статью", url=result.url)
|
||||
updated_article = await self.repository.update_article(article)
|
||||
self.logger.info("Статья успешно обновлена", url=result.url)
|
||||
|
||||
return updated_article
|
||||
|
||||
@property
|
||||
def queue_size(self) -> int:
|
||||
|
|
Loading…
Reference in New Issue