Add proxy support for OpenAI
This commit is contained in:
		
							parent
							
								
									c5306bb56e
								
							
						
					
					
						commit
						bf830fd330
					
				| 
						 | 
					@ -1,6 +1,7 @@
 | 
				
			||||||
OPENAI_API_KEY=your_openai_api_key_here
 | 
					OPENAI_API_KEY=your_openai_api_key_here
 | 
				
			||||||
OPENAI_MODEL=gpt-4o-mini
 | 
					OPENAI_MODEL=gpt-4o
 | 
				
			||||||
OPENAI_TEMPERATURE=0.0
 | 
					OPENAI_TEMPERATURE=0.0
 | 
				
			||||||
 | 
					OPENAI_PROXY_URL='socks5h://37.18.73.60:5566' # socks5 recommended
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DB_PATH=./data/wiki.db
 | 
					DB_PATH=./data/wiki.db
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,12 +1,15 @@
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import httpx
 | 
				
			||||||
import openai
 | 
					import openai
 | 
				
			||||||
import structlog
 | 
					import structlog
 | 
				
			||||||
import tiktoken
 | 
					import tiktoken
 | 
				
			||||||
from openai import AsyncOpenAI
 | 
					from openai import AsyncOpenAI
 | 
				
			||||||
from openai.types.chat import ChatCompletion
 | 
					from openai.types.chat import ChatCompletion
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from src.models.constants import LLM_MAX_INPUT_TOKENS, MAX_TOKEN_LIMIT_WITH_BUFFER
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..models import AppConfig
 | 
					from ..models import AppConfig
 | 
				
			||||||
from .base import BaseAdapter, CircuitBreaker, RateLimiter, with_retry
 | 
					from .base import BaseAdapter, CircuitBreaker, RateLimiter, with_retry
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -31,7 +34,10 @@ class LLMProviderAdapter(BaseAdapter):
 | 
				
			||||||
        super().__init__("llm_adapter")
 | 
					        super().__init__("llm_adapter")
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.client = AsyncOpenAI(api_key=config.openai_api_key)
 | 
					        self.client = AsyncOpenAI(
 | 
				
			||||||
 | 
					            api_key=config.openai_api_key,
 | 
				
			||||||
 | 
					            http_client=self._build_http_client()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            self.tokenizer = tiktoken.encoding_for_model(config.openai_model)
 | 
					            self.tokenizer = tiktoken.encoding_for_model(config.openai_model)
 | 
				
			||||||
| 
						 | 
					@ -87,7 +93,7 @@ class LLMProviderAdapter(BaseAdapter):
 | 
				
			||||||
                model=self.config.openai_model,
 | 
					                model=self.config.openai_model,
 | 
				
			||||||
                messages=messages,
 | 
					                messages=messages,
 | 
				
			||||||
                temperature=self.config.openai_temperature,
 | 
					                temperature=self.config.openai_temperature,
 | 
				
			||||||
                max_tokens=1500,
 | 
					                max_tokens=MAX_TOKEN_LIMIT_WITH_BUFFER,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            return response
 | 
					            return response
 | 
				
			||||||
        except openai.RateLimitError as e:
 | 
					        except openai.RateLimitError as e:
 | 
				
			||||||
| 
						 | 
					@ -102,8 +108,8 @@ class LLMProviderAdapter(BaseAdapter):
 | 
				
			||||||
        prompt_template: str,
 | 
					        prompt_template: str,
 | 
				
			||||||
    ) -> tuple[str, int, int]:
 | 
					    ) -> tuple[str, int, int]:
 | 
				
			||||||
        input_tokens = self.count_tokens(wiki_text)
 | 
					        input_tokens = self.count_tokens(wiki_text)
 | 
				
			||||||
        if input_tokens > 6000:
 | 
					        if input_tokens > LLM_MAX_INPUT_TOKENS:
 | 
				
			||||||
            raise LLMTokenLimitError(f"Текст слишком длинный: {input_tokens} токенов (лимит 6000)")
 | 
					            raise LLMTokenLimitError(f"Текст слишком длинный: {input_tokens} токенов")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            prompt_text = prompt_template.format(
 | 
					            prompt_text = prompt_template.format(
 | 
				
			||||||
| 
						 | 
					@ -142,7 +148,7 @@ class LLMProviderAdapter(BaseAdapter):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        output_tokens = self.count_tokens(simplified_text)
 | 
					        output_tokens = self.count_tokens(simplified_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if output_tokens > 1200:
 | 
					        if output_tokens > MAX_TOKEN_LIMIT_WITH_BUFFER:
 | 
				
			||||||
            self.logger.warning(
 | 
					            self.logger.warning(
 | 
				
			||||||
                "Упрощённый текст превышает лимит",
 | 
					                "Упрощённый текст превышает лимит",
 | 
				
			||||||
                output_tokens=output_tokens,
 | 
					                output_tokens=output_tokens,
 | 
				
			||||||
| 
						 | 
					@ -179,6 +185,16 @@ class LLMProviderAdapter(BaseAdapter):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return messages
 | 
					        return messages
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _build_http_client(self) -> httpx.AsyncClient:
 | 
				
			||||||
 | 
					        if self.config.openai_proxy_url:
 | 
				
			||||||
 | 
					            return httpx.AsyncClient(
 | 
				
			||||||
 | 
					                proxy=self.config.openai_proxy_url,
 | 
				
			||||||
 | 
					                timeout=60.0
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        return httpx.AsyncClient(timeout=60.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def health_check(self) -> bool:
 | 
					    async def health_check(self) -> bool:
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            test_messages = [{"role": "user", "content": "Ответь 'OK' если всё работает."}]
 | 
					            test_messages = [{"role": "user", "content": "Ответь 'OK' если всё работает."}]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,6 +19,7 @@ class AppConfig(BaseSettings):
 | 
				
			||||||
    openai_temperature: float = Field(
 | 
					    openai_temperature: float = Field(
 | 
				
			||||||
        default=0.0, ge=0.0, le=2.0, description="Температура для LLM"
 | 
					        default=0.0, ge=0.0, le=2.0, description="Температура для LLM"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    openai_proxy_url: str | None = Field(description="Proxy URL для OpenAI")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    db_path: str = Field(default="./data/wiki.db", description="Путь к файлу SQLite")
 | 
					    db_path: str = Field(default="./data/wiki.db", description="Путь к файлу SQLite")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -33,7 +34,7 @@ class AppConfig(BaseSettings):
 | 
				
			||||||
    log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(default="INFO")
 | 
					    log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(default="INFO")
 | 
				
			||||||
    log_format: Literal["json", "text"] = Field(default="json")
 | 
					    log_format: Literal["json", "text"] = Field(default="json")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    chunk_size: int = Field(default=2000, ge=500, le=8000, description="Размер чанка для текста")
 | 
					    chunk_size: int = Field(default=2000, ge=500, le=122000, description="Размер чанка для текста")
 | 
				
			||||||
    chunk_overlap: int = Field(default=200, ge=0, le=1000, description="Перекрытие между чанками")
 | 
					    chunk_overlap: int = Field(default=200, ge=0, le=1000, description="Перекрытие между чанками")
 | 
				
			||||||
    max_retries: int = Field(default=3, ge=1, le=10, description="Максимум попыток повтора")
 | 
					    max_retries: int = Field(default=3, ge=1, le=10, description="Максимум попыток повтора")
 | 
				
			||||||
    retry_delay: float = Field(
 | 
					    retry_delay: float = Field(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -65,6 +65,8 @@ class AsyncWriteQueue:
 | 
				
			||||||
        await self._queue.put(operation)
 | 
					        await self._queue.put(operation)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def update_from_result(self, result: ProcessingResult) -> ArticleDTO:
 | 
					    async def update_from_result(self, result: ProcessingResult) -> ArticleDTO:
 | 
				
			||||||
 | 
					        self.logger.info("Получен результат для записи", url=result.url, success=result.success)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        future: asyncio.Future[ArticleDTO] = asyncio.Future()
 | 
					        future: asyncio.Future[ArticleDTO] = asyncio.Future()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        operation = WriteOperation(
 | 
					        operation = WriteOperation(
 | 
				
			||||||
| 
						 | 
					@ -73,15 +75,20 @@ class AsyncWriteQueue:
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        operation.future = future
 | 
					        operation.future = future
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.logger.info("Добавляем операцию в очередь", url=result.url)
 | 
				
			||||||
        await self._queue.put(operation)
 | 
					        await self._queue.put(operation)
 | 
				
			||||||
        return await future
 | 
					        self.logger.info("Операция добавлена в очередь, ожидаем результат", url=result.url)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        result_article = await future
 | 
				
			||||||
 | 
					        self.logger.info("Получен результат из очереди", url=result.url)
 | 
				
			||||||
 | 
					        return result_article
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def _worker_loop(self) -> None:
 | 
					    async def _worker_loop(self) -> None:
 | 
				
			||||||
        batch: list[WriteOperation] = []
 | 
					        batch: list[WriteOperation] = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        while not self._shutdown_event.is_set():
 | 
					        while not self._shutdown_event.is_set():
 | 
				
			||||||
            batch = await self._collect_batch(batch)
 | 
					            batch = await self._collect_batch(batch)
 | 
				
			||||||
            if batch and (len(batch) >= self.max_batch_size or self._shutdown_event.is_set()):
 | 
					            if batch:
 | 
				
			||||||
                await self._process_batch(batch)
 | 
					                await self._process_batch(batch)
 | 
				
			||||||
                batch.clear()
 | 
					                batch.clear()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -90,7 +97,7 @@ class AsyncWriteQueue:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def _collect_batch(self, batch: list[WriteOperation]) -> list[WriteOperation]:
 | 
					    async def _collect_batch(self, batch: list[WriteOperation]) -> list[WriteOperation]:
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            timeout = 0.1 if batch else 1.0
 | 
					            timeout = 1.0 if not batch else 0.1
 | 
				
			||||||
            operation = await asyncio.wait_for(self._queue.get(), timeout=timeout)
 | 
					            operation = await asyncio.wait_for(self._queue.get(), timeout=timeout)
 | 
				
			||||||
            batch.append(operation)
 | 
					            batch.append(operation)
 | 
				
			||||||
            return batch
 | 
					            return batch
 | 
				
			||||||
| 
						 | 
					@ -117,12 +124,18 @@ class AsyncWriteQueue:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def _process_operation_safely(self, operation: WriteOperation) -> None:
 | 
					    async def _process_operation_safely(self, operation: WriteOperation) -> None:
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
 | 
					            self.logger.info("Начинаем обработку операции", 
 | 
				
			||||||
 | 
					                           operation_type=operation.operation_type,
 | 
				
			||||||
 | 
					                           url=operation.result.url if operation.result else "N/A")
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
            await self._process_single_operation(operation)
 | 
					            await self._process_single_operation(operation)
 | 
				
			||||||
            self._total_operations += 1
 | 
					            self._total_operations += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if operation.future and not operation.future.done():
 | 
					            if operation.future and not operation.future.done():
 | 
				
			||||||
                if operation.operation_type == "update_from_result" and operation.result:
 | 
					                if operation.operation_type == "update_from_result" and operation.result:
 | 
				
			||||||
 | 
					                    self.logger.info("Получаем статью из репозитория", url=operation.result.url)
 | 
				
			||||||
                    article = await self.repository.get_by_url(operation.result.url)
 | 
					                    article = await self.repository.get_by_url(operation.result.url)
 | 
				
			||||||
 | 
					                    self.logger.info("Статья получена, устанавливаем результат", url=operation.result.url)
 | 
				
			||||||
                    operation.future.set_result(article)
 | 
					                    operation.future.set_result(article)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
| 
						 | 
					@ -145,11 +158,15 @@ class AsyncWriteQueue:
 | 
				
			||||||
            raise ValueError(msg)
 | 
					            raise ValueError(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def _update_article_from_result(self, result: ProcessingResult) -> ArticleDTO:
 | 
					    async def _update_article_from_result(self, result: ProcessingResult) -> ArticleDTO:
 | 
				
			||||||
 | 
					        self.logger.info("Начинаем обновление статьи из результата", url=result.url)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        article = await self.repository.get_by_url(result.url)
 | 
					        article = await self.repository.get_by_url(result.url)
 | 
				
			||||||
        if not article:
 | 
					        if not article:
 | 
				
			||||||
            msg = f"Статья с URL {result.url} не найдена"
 | 
					            msg = f"Статья с URL {result.url} не найдена"
 | 
				
			||||||
            raise ValueError(msg)
 | 
					            raise ValueError(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.logger.info("Статья найдена, обновляем поля", url=result.url, success=result.success)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if result.success:
 | 
					        if result.success:
 | 
				
			||||||
            if not (result.title and result.raw_text and result.simplified_text):
 | 
					            if not (result.title and result.raw_text and result.simplified_text):
 | 
				
			||||||
                msg = "Неполные данные в успешном результате"
 | 
					                msg = "Неполные данные в успешном результате"
 | 
				
			||||||
| 
						 | 
					@ -162,7 +179,11 @@ class AsyncWriteQueue:
 | 
				
			||||||
            from src.models.article_dto import ArticleStatus
 | 
					            from src.models.article_dto import ArticleStatus
 | 
				
			||||||
            article.status = ArticleStatus.FAILED
 | 
					            article.status = ArticleStatus.FAILED
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return await self.repository.update_article(article)
 | 
					        self.logger.info("Сохраняем обновлённую статью", url=result.url)
 | 
				
			||||||
 | 
					        updated_article = await self.repository.update_article(article)
 | 
				
			||||||
 | 
					        self.logger.info("Статья успешно обновлена", url=result.url)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return updated_article
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def queue_size(self) -> int:
 | 
					    def queue_size(self) -> int:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue