Add proxy support for OpenAI
This commit is contained in:
parent
c5306bb56e
commit
bf830fd330
|
@ -1,6 +1,7 @@
|
||||||
OPENAI_API_KEY=your_openai_api_key_here
|
OPENAI_API_KEY=your_openai_api_key_here
|
||||||
OPENAI_MODEL=gpt-4o-mini
|
OPENAI_MODEL=gpt-4o
|
||||||
OPENAI_TEMPERATURE=0.0
|
OPENAI_TEMPERATURE=0.0
|
||||||
|
OPENAI_PROXY_URL='socks5h://37.18.73.60:5566' # socks5 recommended
|
||||||
|
|
||||||
DB_PATH=./data/wiki.db
|
DB_PATH=./data/wiki.db
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
import structlog
|
import structlog
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from openai.types.chat import ChatCompletion
|
from openai.types.chat import ChatCompletion
|
||||||
|
|
||||||
|
from src.models.constants import LLM_MAX_INPUT_TOKENS, MAX_TOKEN_LIMIT_WITH_BUFFER
|
||||||
|
|
||||||
from ..models import AppConfig
|
from ..models import AppConfig
|
||||||
from .base import BaseAdapter, CircuitBreaker, RateLimiter, with_retry
|
from .base import BaseAdapter, CircuitBreaker, RateLimiter, with_retry
|
||||||
|
|
||||||
|
@ -31,7 +34,10 @@ class LLMProviderAdapter(BaseAdapter):
|
||||||
super().__init__("llm_adapter")
|
super().__init__("llm_adapter")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.client = AsyncOpenAI(api_key=config.openai_api_key)
|
self.client = AsyncOpenAI(
|
||||||
|
api_key=config.openai_api_key,
|
||||||
|
http_client=self._build_http_client()
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.tokenizer = tiktoken.encoding_for_model(config.openai_model)
|
self.tokenizer = tiktoken.encoding_for_model(config.openai_model)
|
||||||
|
@ -87,7 +93,7 @@ class LLMProviderAdapter(BaseAdapter):
|
||||||
model=self.config.openai_model,
|
model=self.config.openai_model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=self.config.openai_temperature,
|
temperature=self.config.openai_temperature,
|
||||||
max_tokens=1500,
|
max_tokens=MAX_TOKEN_LIMIT_WITH_BUFFER,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except openai.RateLimitError as e:
|
except openai.RateLimitError as e:
|
||||||
|
@ -102,8 +108,8 @@ class LLMProviderAdapter(BaseAdapter):
|
||||||
prompt_template: str,
|
prompt_template: str,
|
||||||
) -> tuple[str, int, int]:
|
) -> tuple[str, int, int]:
|
||||||
input_tokens = self.count_tokens(wiki_text)
|
input_tokens = self.count_tokens(wiki_text)
|
||||||
if input_tokens > 6000:
|
if input_tokens > LLM_MAX_INPUT_TOKENS:
|
||||||
raise LLMTokenLimitError(f"Текст слишком длинный: {input_tokens} токенов (лимит 6000)")
|
raise LLMTokenLimitError(f"Текст слишком длинный: {input_tokens} токенов")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prompt_text = prompt_template.format(
|
prompt_text = prompt_template.format(
|
||||||
|
@ -142,7 +148,7 @@ class LLMProviderAdapter(BaseAdapter):
|
||||||
|
|
||||||
output_tokens = self.count_tokens(simplified_text)
|
output_tokens = self.count_tokens(simplified_text)
|
||||||
|
|
||||||
if output_tokens > 1200:
|
if output_tokens > MAX_TOKEN_LIMIT_WITH_BUFFER:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Упрощённый текст превышает лимит",
|
"Упрощённый текст превышает лимит",
|
||||||
output_tokens=output_tokens,
|
output_tokens=output_tokens,
|
||||||
|
@ -179,6 +185,16 @@ class LLMProviderAdapter(BaseAdapter):
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
def _build_http_client(self) -> httpx.AsyncClient:
|
||||||
|
if self.config.openai_proxy_url:
|
||||||
|
return httpx.AsyncClient(
|
||||||
|
proxy=self.config.openai_proxy_url,
|
||||||
|
timeout=60.0
|
||||||
|
)
|
||||||
|
return httpx.AsyncClient(timeout=60.0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def health_check(self) -> bool:
|
async def health_check(self) -> bool:
|
||||||
try:
|
try:
|
||||||
test_messages = [{"role": "user", "content": "Ответь 'OK' если всё работает."}]
|
test_messages = [{"role": "user", "content": "Ответь 'OK' если всё работает."}]
|
||||||
|
|
|
@ -19,6 +19,7 @@ class AppConfig(BaseSettings):
|
||||||
openai_temperature: float = Field(
|
openai_temperature: float = Field(
|
||||||
default=0.0, ge=0.0, le=2.0, description="Температура для LLM"
|
default=0.0, ge=0.0, le=2.0, description="Температура для LLM"
|
||||||
)
|
)
|
||||||
|
openai_proxy_url: str | None = Field(description="Proxy URL для OpenAI")
|
||||||
|
|
||||||
db_path: str = Field(default="./data/wiki.db", description="Путь к файлу SQLite")
|
db_path: str = Field(default="./data/wiki.db", description="Путь к файлу SQLite")
|
||||||
|
|
||||||
|
@ -33,7 +34,7 @@ class AppConfig(BaseSettings):
|
||||||
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(default="INFO")
|
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(default="INFO")
|
||||||
log_format: Literal["json", "text"] = Field(default="json")
|
log_format: Literal["json", "text"] = Field(default="json")
|
||||||
|
|
||||||
chunk_size: int = Field(default=2000, ge=500, le=8000, description="Размер чанка для текста")
|
chunk_size: int = Field(default=2000, ge=500, le=122000, description="Размер чанка для текста")
|
||||||
chunk_overlap: int = Field(default=200, ge=0, le=1000, description="Перекрытие между чанками")
|
chunk_overlap: int = Field(default=200, ge=0, le=1000, description="Перекрытие между чанками")
|
||||||
max_retries: int = Field(default=3, ge=1, le=10, description="Максимум попыток повтора")
|
max_retries: int = Field(default=3, ge=1, le=10, description="Максимум попыток повтора")
|
||||||
retry_delay: float = Field(
|
retry_delay: float = Field(
|
||||||
|
|
|
@ -65,6 +65,8 @@ class AsyncWriteQueue:
|
||||||
await self._queue.put(operation)
|
await self._queue.put(operation)
|
||||||
|
|
||||||
async def update_from_result(self, result: ProcessingResult) -> ArticleDTO:
|
async def update_from_result(self, result: ProcessingResult) -> ArticleDTO:
|
||||||
|
self.logger.info("Получен результат для записи", url=result.url, success=result.success)
|
||||||
|
|
||||||
future: asyncio.Future[ArticleDTO] = asyncio.Future()
|
future: asyncio.Future[ArticleDTO] = asyncio.Future()
|
||||||
|
|
||||||
operation = WriteOperation(
|
operation = WriteOperation(
|
||||||
|
@ -73,15 +75,20 @@ class AsyncWriteQueue:
|
||||||
)
|
)
|
||||||
operation.future = future
|
operation.future = future
|
||||||
|
|
||||||
|
self.logger.info("Добавляем операцию в очередь", url=result.url)
|
||||||
await self._queue.put(operation)
|
await self._queue.put(operation)
|
||||||
return await future
|
self.logger.info("Операция добавлена в очередь, ожидаем результат", url=result.url)
|
||||||
|
|
||||||
|
result_article = await future
|
||||||
|
self.logger.info("Получен результат из очереди", url=result.url)
|
||||||
|
return result_article
|
||||||
|
|
||||||
async def _worker_loop(self) -> None:
|
async def _worker_loop(self) -> None:
|
||||||
batch: list[WriteOperation] = []
|
batch: list[WriteOperation] = []
|
||||||
|
|
||||||
while not self._shutdown_event.is_set():
|
while not self._shutdown_event.is_set():
|
||||||
batch = await self._collect_batch(batch)
|
batch = await self._collect_batch(batch)
|
||||||
if batch and (len(batch) >= self.max_batch_size or self._shutdown_event.is_set()):
|
if batch:
|
||||||
await self._process_batch(batch)
|
await self._process_batch(batch)
|
||||||
batch.clear()
|
batch.clear()
|
||||||
|
|
||||||
|
@ -90,7 +97,7 @@ class AsyncWriteQueue:
|
||||||
|
|
||||||
async def _collect_batch(self, batch: list[WriteOperation]) -> list[WriteOperation]:
|
async def _collect_batch(self, batch: list[WriteOperation]) -> list[WriteOperation]:
|
||||||
try:
|
try:
|
||||||
timeout = 0.1 if batch else 1.0
|
timeout = 1.0 if not batch else 0.1
|
||||||
operation = await asyncio.wait_for(self._queue.get(), timeout=timeout)
|
operation = await asyncio.wait_for(self._queue.get(), timeout=timeout)
|
||||||
batch.append(operation)
|
batch.append(operation)
|
||||||
return batch
|
return batch
|
||||||
|
@ -117,12 +124,18 @@ class AsyncWriteQueue:
|
||||||
|
|
||||||
async def _process_operation_safely(self, operation: WriteOperation) -> None:
|
async def _process_operation_safely(self, operation: WriteOperation) -> None:
|
||||||
try:
|
try:
|
||||||
|
self.logger.info("Начинаем обработку операции",
|
||||||
|
operation_type=operation.operation_type,
|
||||||
|
url=operation.result.url if operation.result else "N/A")
|
||||||
|
|
||||||
await self._process_single_operation(operation)
|
await self._process_single_operation(operation)
|
||||||
self._total_operations += 1
|
self._total_operations += 1
|
||||||
|
|
||||||
if operation.future and not operation.future.done():
|
if operation.future and not operation.future.done():
|
||||||
if operation.operation_type == "update_from_result" and operation.result:
|
if operation.operation_type == "update_from_result" and operation.result:
|
||||||
|
self.logger.info("Получаем статью из репозитория", url=operation.result.url)
|
||||||
article = await self.repository.get_by_url(operation.result.url)
|
article = await self.repository.get_by_url(operation.result.url)
|
||||||
|
self.logger.info("Статья получена, устанавливаем результат", url=operation.result.url)
|
||||||
operation.future.set_result(article)
|
operation.future.set_result(article)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -145,11 +158,15 @@ class AsyncWriteQueue:
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
async def _update_article_from_result(self, result: ProcessingResult) -> ArticleDTO:
|
async def _update_article_from_result(self, result: ProcessingResult) -> ArticleDTO:
|
||||||
|
self.logger.info("Начинаем обновление статьи из результата", url=result.url)
|
||||||
|
|
||||||
article = await self.repository.get_by_url(result.url)
|
article = await self.repository.get_by_url(result.url)
|
||||||
if not article:
|
if not article:
|
||||||
msg = f"Статья с URL {result.url} не найдена"
|
msg = f"Статья с URL {result.url} не найдена"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
self.logger.info("Статья найдена, обновляем поля", url=result.url, success=result.success)
|
||||||
|
|
||||||
if result.success:
|
if result.success:
|
||||||
if not (result.title and result.raw_text and result.simplified_text):
|
if not (result.title and result.raw_text and result.simplified_text):
|
||||||
msg = "Неполные данные в успешном результате"
|
msg = "Неполные данные в успешном результате"
|
||||||
|
@ -162,7 +179,11 @@ class AsyncWriteQueue:
|
||||||
from src.models.article_dto import ArticleStatus
|
from src.models.article_dto import ArticleStatus
|
||||||
article.status = ArticleStatus.FAILED
|
article.status = ArticleStatus.FAILED
|
||||||
|
|
||||||
return await self.repository.update_article(article)
|
self.logger.info("Сохраняем обновлённую статью", url=result.url)
|
||||||
|
updated_article = await self.repository.update_article(article)
|
||||||
|
self.logger.info("Статья успешно обновлена", url=result.url)
|
||||||
|
|
||||||
|
return updated_article
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def queue_size(self) -> int:
|
def queue_size(self) -> int:
|
||||||
|
|
Loading…
Reference in New Issue