ruwiki-test/src/services/write_queue.py

183 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
import structlog
from src.models import Article, ProcessingResult
from src.models.constants import WRITE_QUEUE_BATCH_SIZE
from src.services.repository import ArticleRepository
@dataclass
class WriteOperation:
operation_type: str
article: Article | None = None
result: ProcessingResult | None = None
future: asyncio.Future[Article] | None = field(default=None, init=False)
class AsyncWriteQueue:
def __init__(
self, repository: ArticleRepository, max_batch_size: int = WRITE_QUEUE_BATCH_SIZE
) -> None:
self.repository = repository
self.max_batch_size = max_batch_size
self.logger = structlog.get_logger().bind(service="write_queue")
self._queue: asyncio.Queue[WriteOperation] = asyncio.Queue()
self._worker_task: asyncio.Task[None] | None = None
self._shutdown_event = asyncio.Event()
self._total_operations = 0
self._failed_operations = 0
async def start(self) -> None:
if self._worker_task is not None:
msg = "Write queue уже запущена"
raise RuntimeError(msg)
self._worker_task = asyncio.create_task(self._worker_loop())
self.logger.info("Write queue запущена")
async def stop(self, timeout: float = 10.0) -> None:
if self._worker_task is None:
return
self.logger.info("Остановка write queue")
self._shutdown_event.set()
try:
await asyncio.wait_for(self._worker_task, timeout=timeout)
except asyncio.TimeoutError:
self.logger.warning("Таймаут остановки write queue, принудительная отмена")
self._worker_task.cancel()
self.logger.info("Write queue остановлена")
async def update_article(self, article: Article) -> None:
operation = WriteOperation(
operation_type="update",
article=article,
)
await self._queue.put(operation)
async def update_from_result(self, result: ProcessingResult) -> Article:
future: asyncio.Future[Article] = asyncio.Future()
operation = WriteOperation(
operation_type="update_from_result",
result=result,
)
operation.future = future
await self._queue.put(operation)
return await future
async def _worker_loop(self) -> None:
batch: list[WriteOperation] = []
while not self._shutdown_event.is_set():
batch = await self._collect_batch(batch)
if batch and (len(batch) >= self.max_batch_size or self._shutdown_event.is_set()):
await self._process_batch(batch)
batch.clear()
if batch:
await self._process_batch(batch)
async def _collect_batch(self, batch: list[WriteOperation]) -> list[WriteOperation]:
try:
timeout = 0.1 if batch else 1.0
operation = await asyncio.wait_for(self._queue.get(), timeout=timeout)
batch.append(operation)
return batch
except asyncio.TimeoutError:
return batch
except Exception as e:
self.logger.exception("Ошибка в worker loop")
self._handle_batch_error(batch, e)
return []
def _handle_batch_error(self, batch: list[WriteOperation], error: Exception) -> None:
for op in batch:
if op.future and not op.future.done():
op.future.set_exception(error)
async def _process_batch(self, batch: list[WriteOperation]) -> None:
if not batch:
return
self.logger.debug("Обработка батча операций", batch_size=len(batch))
for operation in batch:
await self._process_operation_safely(operation)
async def _process_operation_safely(self, operation: WriteOperation) -> None:
try:
await self._process_single_operation(operation)
self._total_operations += 1
if operation.future and not operation.future.done():
if operation.operation_type == "update_from_result" and operation.result:
article = await self.repository.get_by_url(operation.result.url)
operation.future.set_result(article)
except Exception as e:
self._failed_operations += 1
self.logger.exception(
"Ошибка при обработке операции",
operation_type=operation.operation_type,
)
if operation.future and not operation.future.done():
operation.future.set_exception(e)
async def _process_single_operation(self, operation: WriteOperation) -> None:
if operation.operation_type == "update" and operation.article:
await self.repository.update_article(operation.article)
elif operation.operation_type == "update_from_result" and operation.result:
await self._update_article_from_result(operation.result)
else:
msg = f"Неизвестный тип операции: {operation.operation_type}"
raise ValueError(msg)
async def _update_article_from_result(self, result: ProcessingResult) -> Article:
article = await self.repository.get_by_url(result.url)
if not article:
msg = f"Статья с URL {result.url} не найдена"
raise ValueError(msg)
if result.success:
if not (result.title and result.raw_text and result.simplified_text):
msg = "Неполные данные в успешном результате"
raise ValueError(msg)
article.mark_completed(
simplified_text=result.simplified_text,
token_count_raw=result.token_count_raw or 0,
token_count_simplified=result.token_count_simplified or 0,
processing_time=result.processing_time_seconds or 0,
)
else:
article.mark_failed(result.error_message or "Неизвестная ошибка")
return await self.repository.update_article(article)
@property
def queue_size(self) -> int:
return self._queue.qsize()
@property
def stats(self) -> dict[str, int]:
return {
"total_operations": self._total_operations,
"failed_operations": self._failed_operations,
"success_rate": (
(self._total_operations - self._failed_operations) / self._total_operations * 100
if self._total_operations > 0
else 0
),
}