commit c41a2907b879f258ecdaabaa1a05e6783977405e Author: itqop Date: Fri Jul 11 22:28:58 2025 +0300 First commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8e1c352 --- /dev/null +++ b/.gitignore @@ -0,0 +1,114 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.pytest_cache/ + +# Environments +.env +.venv +env/ +venv/ +ENV/ + +# IDE files +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Project specific - RuWiki SchoolNotes +# Database files +data/ +*.db +*.sqlite +*.sqlite3 +*.db-journal +*.sqlite-journal +*.sqlite3-journal + +# Configuration files with secrets +.env.local +.env.production +config.local.toml +secrets.json + +# Logs and monitoring +logs/ +*.log +*.log.* +monitoring/ +metrics/ + +# Temporary processing files +temp/ +tmp/ +cache/ +.cache/ +processing_temp/ + +# Input files with sensitive URLs (keep example only) +input_production.txt +input_large.txt +urls_private.txt + +# Test artifacts specific to our project +test_outputs/ +test_db/ +mock_data/ +test_logs/ + +# mypy +.mypy_cache/ +. +# ruff +.ruff_cache/ + +# Performance profiling +*.prof +profile_results/ + +# API response caches +api_cache/ +wikipedia_cache/ +openai_cache/ \ No newline at end of file diff --git a/env_example.txt b/env_example.txt new file mode 100644 index 0000000..a7fe4c4 --- /dev/null +++ b/env_example.txt @@ -0,0 +1,20 @@ +OPENAI_API_KEY=your_openai_api_key_here +OPENAI_MODEL=gpt-4o-mini +OPENAI_TEMPERATURE=0.0 + +DB_PATH=./data/wiki.db + +MAX_CONCURRENT_LLM=5 +OPENAI_RPM=200 +MAX_CONCURRENT_WIKI=10 + +LOG_LEVEL=INFO +LOG_FORMAT=json + +CHUNK_SIZE=2000 +CHUNK_OVERLAP=200 +MAX_RETRIES=3 +RETRY_DELAY=1.0 + +CIRCUIT_FAILURE_THRESHOLD=5 +CIRCUIT_RECOVERY_TIMEOUT=60 \ No newline at end of file diff --git a/input.txt b/input.txt new file mode 100644 index 0000000..b3bf97c --- /dev/null +++ b/input.txt @@ -0,0 +1,3 @@ +https://ru.ruwiki.ru/wiki/Изотопы +https://ru.ruwiki.ru/wiki/Вещественное_число +https://ru.ruwiki.ru/wiki/Митоз \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..570e3ba --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,83 @@ +[tool.poetry] +name = "ruwiki-schoolnotes" +version = "1.0.0" +description = "Конвейер для упрощения статей RuWiki с помощью LLM" +authors = ["Leon K. "] +readme = "README.md" +packages = [{include = "src"}] + +[tool.poetry.dependencies] +python = "^3.10" +anyio = "^4.2.0" +aiohttp = "^3.9.0" +aiosqlite = "^0.19.0" +sqlmodel = "^0.0.14" +openai = "^1.13.0" +tiktoken = "^0.5.2" +mwclient = "^0.10.1" +pydantic = "^2.5.0" +pydantic-settings = "^2.1.0" +structlog = "^23.2.0" +tenacity = "^8.2.3" +click = "^8.1.7" + +[tool.poetry.group.dev.dependencies] +black = "^23.12.0" +ruff = "^0.1.8" +mypy = "^1.8.0" +pytest = "^7.4.3" +pytest-asyncio = "^0.21.1" +pytest-cov = "^4.1.0" +pytest-vcr = "^1.0.2" +bandit = "^1.7.5" +pip-audit = "^2.6.2" + +[tool.poetry.scripts] +schoolnotes = "src.cli:main" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[tool.black] +line-length = 100 +target-version = ['py310'] + +[tool.ruff] +target-version = "py310" +line-length = 100 +select = ["ALL"] +ignore = [ + "D100", "D101", "D102", "D103", "D104", "D105", "D106", "D107", # missing docstrings + "ANN101", "ANN102", # missing type annotation for self/cls + "COM812", "ISC001", # incompatible with black +] + +[tool.ruff.per-file-ignores] +"tests/*" = ["S101", "PLR2004", "ANN"] # allow assert, magic values, no annotations + +[tool.mypy] +python_version = "3.10" +strict = true +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = "--cov=src --cov-report=html --cov-report=term-missing --cov-fail-under=80" + +[tool.coverage.run] +source = ["src"] +omit = ["tests/*"] + +[tool.bandit] +exclude_dirs = ["tests"] +skips = ["B101"] \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..0b6f9ff --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,12 @@ +-r requirements.txt + +black>=25.1.0 +ruff>=0.12.0 + +mypy>=1.16.0 + +pytest>=8.4.0 +pytest-asyncio>=1.0.0 +pytest-cov>=6.2.0 + +bandit>=1.8.0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..01e9b89 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,18 @@ +anyio>=4.2.0,<5.0.0 +aiohttp>=3.9.0,<4.0.0 + +aiosqlite>=0.19.0,<0.20.0 +sqlmodel>=0.0.14,<0.0.15 + +openai>=1.13.0,<2.0.0 +tiktoken>=0.5.2,<0.6.0 + +mwclient>=0.10.1,<0.11.0 + +pydantic>=2.5.0,<3.0.0 +pydantic-settings>=2.1.0,<3.0.0 + +structlog>=23.2.0,<24.0.0 +tenacity>=8.2.3,<9.0.0 + +click>=8.1.7,<9.0.0 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..5becc17 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1 @@ +__version__ = "1.0.0" diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py new file mode 100644 index 0000000..095d5b7 --- /dev/null +++ b/src/adapters/__init__.py @@ -0,0 +1,18 @@ +from .base import BaseAdapter, CircuitBreaker, CircuitBreakerError, RateLimiter +from .llm import LLMError, LLMProviderAdapter, LLMRateLimitError, LLMTokenLimitError +from .ruwiki import RuWikiAdapter, WikiPageInfo, WikiPageNotFoundError, WikiPageRedirectError + +__all__ = [ + "BaseAdapter", + "CircuitBreaker", + "CircuitBreakerError", + "LLMError", + "LLMProviderAdapter", + "LLMRateLimitError", + "LLMTokenLimitError", + "RateLimiter", + "RuWikiAdapter", + "WikiPageInfo", + "WikiPageNotFoundError", + "WikiPageRedirectError", +] diff --git a/src/adapters/base.py b/src/adapters/base.py new file mode 100644 index 0000000..f950027 --- /dev/null +++ b/src/adapters/base.py @@ -0,0 +1,128 @@ +import asyncio +import time +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +from typing import Any, TypeVar + +import structlog +from tenacity import ( + AsyncRetrying, + before_sleep_log, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +logger = structlog.get_logger() + +T = TypeVar("T") + + +class CircuitBreakerError(Exception): + pass + + +class CircuitBreaker: + + def __init__( + self, + failure_threshold: int = 5, + recovery_timeout: int = 60, + name: str = "circuit_breaker", + ) -> None: + self.failure_threshold = failure_threshold + self.recovery_timeout = recovery_timeout + self.name = name + + self._failure_count = 0 + self._last_failure_time: float | None = None + self._state: str = "closed" + + async def call(self, func: Callable[[], Awaitable[T]]) -> T: + if self._state == "open": + if self._should_attempt_reset(): + self._state = "half_open" + logger.info("Circuit breaker перешёл в half_open", name=self.name) + else: + raise CircuitBreakerError(f"Circuit breaker {self.name} открыт") + + try: + result = await func() + self._on_success() + return result + except Exception as e: + self._on_failure() + raise e + + def _should_attempt_reset(self) -> bool: + if self._last_failure_time is None: + return True + return time.time() - self._last_failure_time >= self.recovery_timeout + + def _on_success(self) -> None: + if self._state == "half_open": + self._state = "closed" + logger.info("Circuit breaker восстановлен", name=self.name) + self._failure_count = 0 + + def _on_failure(self) -> None: + self._failure_count += 1 + self._last_failure_time = time.time() + + if self._failure_count >= self.failure_threshold: + self._state = "open" + logger.warning( + "Circuit breaker открыт из-за превышения порога ошибок", + name=self.name, + failure_count=self._failure_count, + threshold=self.failure_threshold, + ) + + +class RateLimiter: + + def __init__(self, max_concurrent: int, name: str = "rate_limiter") -> None: + self.semaphore = asyncio.Semaphore(max_concurrent) + self.name = name + self.max_concurrent = max_concurrent + + async def __aenter__(self) -> None: + await self.semaphore.acquire() + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.semaphore.release() + + +async def with_retry( + func: Callable[[], Awaitable[T]], + max_attempts: int = 3, + min_wait: float = 1.0, + max_wait: float = 10.0, + retry_exceptions: tuple[type[Exception], ...] = (Exception,), + name: str = "retry_operation", +) -> T: + async for attempt in AsyncRetrying( + stop=stop_after_attempt(max_attempts), + wait=wait_exponential(multiplier=1, min=min_wait, max=max_wait), + retry=retry_if_exception_type(retry_exceptions), + before_sleep=before_sleep_log(logger, "warning"), + reraise=True, + ): + with attempt: + logger.debug( + "Попытка выполнения операции", + operation=name, + attempt_number=attempt.retry_state.attempt_number, + ) + return await func() + + +class BaseAdapter(ABC): + + def __init__(self, name: str) -> None: + self.name = name + self.logger = structlog.get_logger().bind(adapter=name) + + @abstractmethod + async def health_check(self) -> bool: + pass diff --git a/src/adapters/llm.py b/src/adapters/llm.py new file mode 100644 index 0000000..ee0716a --- /dev/null +++ b/src/adapters/llm.py @@ -0,0 +1,196 @@ +import asyncio +import time + +import openai +import structlog +import tiktoken +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion + +from ..models import AppConfig +from .base import BaseAdapter, CircuitBreaker, RateLimiter, with_retry + +logger = structlog.get_logger() + + +class LLMError(Exception): + pass + + +class LLMTokenLimitError(LLMError): + pass + + +class LLMRateLimitError(LLMError): + pass + + +class LLMProviderAdapter(BaseAdapter): + + def __init__(self, config: AppConfig) -> None: + super().__init__("llm_adapter") + self.config = config + + self.client = AsyncOpenAI(api_key=config.openai_api_key) + + try: + self.tokenizer = tiktoken.encoding_for_model(config.openai_model) + except KeyError: + self.tokenizer = tiktoken.get_encoding("cl100k_base") + + self.rate_limiter = RateLimiter( + max_concurrent=config.max_concurrent_llm, + name="llm_limiter", + ) + self.circuit_breaker = CircuitBreaker( + failure_threshold=config.circuit_failure_threshold, + recovery_timeout=config.circuit_recovery_timeout, + name="llm_circuit", + ) + + self.request_times: list[float] = [] + self.rpm_lock = asyncio.Lock() + + def count_tokens(self, text: str) -> int: + try: + return len(self.tokenizer.encode(text)) + except Exception as e: + self.logger.warning("Ошибка подсчёта токенов", error=str(e)) + return len(text) // 4 + + async def _check_rpm_limit(self) -> None: + async with self.rpm_lock: + current_time = time.time() + self.request_times = [ + req_time for req_time in self.request_times if current_time - req_time < 60 + ] + + if len(self.request_times) >= self.config.openai_rpm: + oldest_request = min(self.request_times) + wait_time = 60 - (current_time - oldest_request) + if wait_time > 0: + self.logger.info( + "Ожидание из-за RPM лимита", + wait_seconds=wait_time, + current_rpm=len(self.request_times), + ) + await asyncio.sleep(wait_time) + + self.request_times.append(current_time) + + async def _make_completion_request( + self, + messages: list[dict[str, str]], + ) -> ChatCompletion: + try: + response = await self.client.chat.completions.create( + model=self.config.openai_model, + messages=messages, + temperature=self.config.openai_temperature, + max_tokens=1500, + ) + return response + except openai.RateLimitError as e: + raise LLMRateLimitError(f"Rate limit exceeded: {e}") from e + except openai.APIError as e: + raise LLMError(f"OpenAI API error: {e}") from e + + async def simplify_text( + self, + title: str, + wiki_text: str, + prompt_template: str, + ) -> tuple[str, int, int]: + input_tokens = self.count_tokens(wiki_text) + if input_tokens > 6000: + raise LLMTokenLimitError(f"Текст слишком длинный: {input_tokens} токенов (лимит 6000)") + + try: + prompt_text = prompt_template.format( + title=title, + wiki_source_text=wiki_text, + ) + except KeyError as e: + raise LLMError(f"Ошибка в шаблоне промпта: отсутствует ключ {e}") from e + + messages = self._parse_prompt_template(prompt_text) + + total_input_tokens = sum(self.count_tokens(msg["content"]) for msg in messages) + + async with self.rate_limiter: + await self._check_rpm_limit() + + response = await self.circuit_breaker.call( + lambda: with_retry( + lambda: self._make_completion_request(messages), + max_attempts=self.config.max_retries, + min_wait=self.config.retry_delay, + max_wait=self.config.retry_delay * 4, + retry_exceptions=(LLMRateLimitError, ConnectionError, TimeoutError), + name=f"simplify_{title}", + ) + ) + + if not response.choices: + raise LLMError("Пустой ответ от OpenAI") + + simplified_text = response.choices[0].message.content + if not simplified_text: + raise LLMError("OpenAI вернул пустой текст") + + simplified_text = simplified_text.replace("###END###", "").strip() + + output_tokens = self.count_tokens(simplified_text) + + if output_tokens > 1200: + self.logger.warning( + "Упрощённый текст превышает лимит", + output_tokens=output_tokens, + title=title, + ) + + self.logger.info( + "Текст успешно упрощён", + title=title, + input_tokens=total_input_tokens, + output_tokens=output_tokens, + ) + + return simplified_text, total_input_tokens, output_tokens + + def _parse_prompt_template(self, prompt_text: str) -> list[dict[str, str]]: + messages: list[dict[str, str]] = [] + + parts = prompt_text.split("### role:") + + for part in parts[1:]: + lines = part.strip().split("\n", 1) + if len(lines) < 2: + continue + + role = lines[0].strip() + content = lines[1].strip() + + if role in ("system", "user", "assistant"): + messages.append({"role": role, "content": content}) + + if not messages: + messages = [{"role": "user", "content": prompt_text}] + + return messages + + async def health_check(self) -> bool: + try: + test_messages = [{"role": "user", "content": "Ответь 'OK' если всё работает."}] + + response = await self.client.chat.completions.create( + model=self.config.openai_model, + messages=test_messages, + temperature=0, + max_tokens=10, + ) + + return bool(response.choices and response.choices[0].message.content) + except Exception as e: + self.logger.error("LLM health check failed", error=str(e)) + return False diff --git a/src/adapters/ruwiki.py b/src/adapters/ruwiki.py new file mode 100644 index 0000000..8bc71d4 --- /dev/null +++ b/src/adapters/ruwiki.py @@ -0,0 +1,168 @@ +import asyncio +import re +from typing import NamedTuple +from urllib.parse import unquote, urlparse + +import mwclient +import structlog +from mwclient.errors import InvalidPageTitle, LoginError + +from ..models import AppConfig +from .base import BaseAdapter, CircuitBreaker, RateLimiter, with_retry + +logger = structlog.get_logger() + + +class WikiPageNotFoundError(Exception): + pass + + +class WikiPageRedirectError(Exception): + pass + + +class WikiPageInfo(NamedTuple): + title: str + content: str + is_redirect: bool = False + redirect_target: str | None = None + + +class RuWikiAdapter(BaseAdapter): + + def __init__(self, config: AppConfig) -> None: + super().__init__("ruwiki_adapter") + self.config = config + + self.rate_limiter = RateLimiter( + max_concurrent=config.max_concurrent_wiki, + name="ruwiki_limiter", + ) + self.circuit_breaker = CircuitBreaker( + failure_threshold=config.circuit_failure_threshold, + recovery_timeout=config.circuit_recovery_timeout, + name="ruwiki_circuit", + ) + + self._client: mwclient.Site | None = None + + async def _get_client(self) -> mwclient.Site: + if self._client is None: + self._client = await asyncio.to_thread( + self._create_client, + ) + return self._client + + def _create_client(self) -> mwclient.Site: + try: + site = mwclient.Site("ru.wikipedia.org") + site.api("query", meta="siteinfo") + self.logger.info("Соединение с RuWiki установлено") + return site + except (LoginError, ConnectionError) as e: + self.logger.error("Ошибка подключения к RuWiki", error=str(e)) + raise + + @staticmethod + def extract_title_from_url(url: str) -> str: + parsed = urlparse(url) + if "wikipedia.org" not in parsed.netloc: + raise ValueError(f"Не является URL википедии: {url}") + + path_parts = parsed.path.split("/") + if len(path_parts) < 3 or path_parts[1] != "wiki": + raise ValueError(f"Неверный формат URL: {url}") + + title = unquote(path_parts[2]) + title = title.replace("_", " ") + + return title + + async def _fetch_page_content(self, title: str) -> WikiPageInfo: + client = await self._get_client() + + def _sync_fetch() -> WikiPageInfo: + try: + page = client.pages[title] + + if not page.exists: + raise WikiPageNotFoundError(f"Страница '{title}' не найдена") + + if page.redirect: + redirect_target = page.redirects_to() + if redirect_target: + redirect_title = redirect_target.name + self.logger.info( + "Страница является редиректом", + original=title, + target=redirect_title, + ) + raise WikiPageRedirectError( + f"Страница '{title}' перенаправляет на '{redirect_title}'" + ) + + content = page.text() + if not content or len(content.strip()) < 100: + raise WikiPageNotFoundError(f"Страница '{title}' слишком короткая или пустая") + + return WikiPageInfo( + title=title, + content=content, + is_redirect=False, + ) + + except InvalidPageTitle as e: + raise WikiPageNotFoundError(f"Неверное название страницы: {e}") from e + + return await asyncio.to_thread(_sync_fetch) + + def _clean_wikitext(self, text: str) -> str: + text = re.sub(r"\{\{[Нн]авигация.*?\}\}", "", text, flags=re.DOTALL) + text = re.sub(r"\{\{[Кк]арточка.*?\}\}", "", text, flags=re.DOTALL) + text = re.sub(r"\{\{[Дд]исамбиг.*?\}\}", "", text, flags=re.DOTALL) + + text = re.sub(r"\[\[[Кк]атегория:.*?\]\]", "", text) + + text = re.sub(r"\[\[[Фф]айл:.*?\]\]", "", text, flags=re.DOTALL) + text = re.sub(r"\[\[[Ii]mage:.*?\]\]", "", text, flags=re.DOTALL) + + text = re.sub(r"", "", text, flags=re.DOTALL) + + text = re.sub(r"\n\s*\n", "\n\n", text) + + return text.strip() + + async def fetch_page(self, url: str) -> WikiPageInfo: + title = self.extract_title_from_url(url) + + async with self.rate_limiter: + return await self.circuit_breaker.call( + lambda: with_retry( + lambda: self._fetch_page_content(title), + max_attempts=self.config.max_retries, + min_wait=self.config.retry_delay, + max_wait=self.config.retry_delay * 4, + retry_exceptions=(ConnectionError, TimeoutError), + name=f"fetch_page_{title}", + ) + ) + + async def fetch_page_cleaned(self, url: str) -> WikiPageInfo: + page_info = await self.fetch_page(url) + cleaned_content = self._clean_wikitext(page_info.content) + + return WikiPageInfo( + title=page_info.title, + content=cleaned_content, + is_redirect=page_info.is_redirect, + redirect_target=page_info.redirect_target, + ) + + async def health_check(self) -> bool: + try: + client = await self._get_client() + await asyncio.to_thread(lambda: client.api("query", meta="siteinfo", siprop="general")) + return True + except Exception as e: + self.logger.error("Health check failed", error=str(e)) + return False diff --git a/src/cli.py b/src/cli.py new file mode 100644 index 0000000..0f857fd --- /dev/null +++ b/src/cli.py @@ -0,0 +1,278 @@ +import asyncio +import json +import sys + +import click +import structlog + +from .dependency_injection import get_container +from .models import AppConfig + + +def setup_logging(log_level: str, log_format: str) -> None: + processors = [ + structlog.stdlib.filter_by_level, + structlog.stdlib.add_logger_name, + structlog.stdlib.add_log_level, + structlog.stdlib.PositionalArgumentsFormatter(), + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.StackInfoRenderer(), + structlog.processors.format_exc_info, + ] + + if log_format == "json": + processors.append(structlog.processors.JSONRenderer()) + else: + processors.append(structlog.dev.ConsoleRenderer()) + + structlog.configure( + processors=processors, + wrapper_class=structlog.stdlib.BoundLogger, + logger_factory=structlog.stdlib.LoggerFactory(), + cache_logger_on_first_use=True, + ) + + import logging + + logging.basicConfig(level=getattr(logging, log_level.upper())) + + +@click.group() +@click.option( + "--config-file", + type=click.Path(exists=True), + help="Путь к файлу конфигурации .env", +) +@click.option("--log-level", default="INFO", help="Уровень логирования") +@click.option( + "--log-format", default="text", type=click.Choice(["json", "text"]), help="Формат логов" +) +@click.pass_context +def main(ctx: click.Context, config_file: str | None, log_level: str, log_format: str) -> None: + setup_logging(log_level, log_format) + + if config_file: + config = AppConfig(_env_file=config_file) + else: + config = AppConfig() + + ctx.ensure_object(dict) + ctx.obj["config"] = config + + +@main.command() +@click.argument("input_file", type=click.Path(exists=True)) +@click.option( + "--force", + is_flag=True, + help="Принудительно обработать даже уже обработанные статьи", +) +@click.option( + "--max-articles", + type=int, + help="Максимальное количество статей для обработки", +) +@click.option( + "--max-workers", + type=int, + help="Максимальное количество worker корутин", +) +@click.pass_context +def process( + ctx: click.Context, + input_file: str, + force: bool, + max_articles: int | None, + max_workers: int | None, +) -> None: + config: AppConfig = ctx.obj["config"] + + async def _run() -> None: + container = get_container(config) + + try: + await container.initialize() + + runner = container.create_runner(max_workers) + + click.echo(f"Запуск обработки статей из файла: {input_file}") + click.echo(f"Принудительная обработка: {force}") + click.echo(f"Максимум статей: {max_articles or 'без ограничений'}") + click.echo(f"Workers: {runner.max_workers}") + click.echo() + + stats = await runner.run_from_file( + input_file=input_file, + force_reprocess=force, + max_articles=max_articles, + ) + + click.echo("\n" + "=" * 50) + click.echo("РЕЗУЛЬТАТЫ ОБРАБОТКИ") + click.echo("=" * 50) + click.echo(f"Всего обработано: {stats.total_processed}") + click.echo(f"Успешно: {stats.successful}") + click.echo(f"Ошибок: {stats.failed}") + click.echo(f"Пропущено: {stats.skipped}") + click.echo(f"Процент успеха: {stats.success_rate:.1f}%") + + if stats.successful > 0: + click.echo(f"Среднее время обработки: {stats.average_processing_time:.2f}с") + + except Exception as e: + click.echo(f"Ошибка: {e}", err=True) + sys.exit(1) + finally: + await container.cleanup() + + asyncio.run(_run()) + + +@main.command() +@click.pass_context +def health(ctx: click.Context) -> None: + config: AppConfig = ctx.obj["config"] + + async def _check() -> None: + container = get_container(config) + + try: + await container.initialize() + + click.echo("Проверка работоспособности системы...") + checks = await container.health_check() + + click.echo("\nРезультаты проверки:") + all_ok = True + + for component, status in checks.items(): + status_str = "✓ OK" if status else "✗ FAILED" + click.echo(f" {component}: {status_str}") + if not status: + all_ok = False + + if all_ok: + click.echo("\nВсе компоненты работают нормально") + else: + click.echo("\nОбнаружены проблемы с компонентами") + sys.exit(1) + + except Exception as e: + click.echo(f"Ошибка при проверке: {e}", err=True) + sys.exit(1) + finally: + await container.cleanup() + + asyncio.run(_check()) + + +@main.command() +@click.argument("input_file", type=click.Path(exists=True)) +@click.pass_context +def stats(ctx: click.Context, input_file: str) -> None: + from .sources import FileSource + + async def _stats() -> None: + try: + source = FileSource(input_file) + total_urls = await source.count_urls() + + click.echo(f"Файл: {input_file}") + click.echo(f"Валидных URL: {total_urls}") + + except Exception as e: + click.echo(f"Ошибка: {e}", err=True) + sys.exit(1) + + asyncio.run(_stats()) + + +@main.command() +@click.option( + "--limit", + type=int, + default=10, + help="Количество статей для вывода", +) +@click.option( + "--status", + type=click.Choice(["pending", "processing", "completed", "failed"]), + help="Фильтр по статусу", +) +@click.option( + "--format", + "output_format", + type=click.Choice(["table", "json"]), + default="table", + help="Формат вывода", +) +@click.pass_context +def list_articles( + ctx: click.Context, + limit: int, + status: str | None, + output_format: str, +) -> None: + config: AppConfig = ctx.obj["config"] + + async def _list() -> None: + container = get_container(config) + + try: + await container.initialize() + repository = container.get_repository() + + if status: + from .models import ProcessingStatus + + status_enum = ProcessingStatus(status) + articles = await repository.get_articles_by_status(status_enum, limit) + else: + articles = await repository.get_all_articles(limit) + + if output_format == "json": + data = [ + { + "id": article.id, + "url": article.url, + "title": article.title, + "status": article.status.value, + "created_at": article.created_at.isoformat(), + "token_count_raw": article.token_count_raw, + "token_count_simplified": article.token_count_simplified, + } + for article in articles + ] + click.echo(json.dumps(data, ensure_ascii=False, indent=2)) + else: + if not articles: + click.echo("Статьи не найдены") + return + + click.echo(f"{'ID':<5} {'Статус':<12} {'Название':<50} {'Токены (исх/упр)':<15}") + click.echo("-" * 87) + + for article in articles: + tokens_info = "" + if article.token_count_raw and article.token_count_simplified: + tokens_info = f"{article.token_count_raw}/{article.token_count_simplified}" + elif article.token_count_raw: + tokens_info = f"{article.token_count_raw}/-" + + title = article.title[:47] + "..." if len(article.title) > 50 else article.title + + click.echo( + f"{article.id:<5} {article.status.value:<12} {title:<50} {tokens_info:<15}" + ) + + except Exception as e: + click.echo(f"Ошибка: {e}", err=True) + sys.exit(1) + finally: + await container.cleanup() + + asyncio.run(_list()) + + +if __name__ == "__main__": + main() diff --git a/src/dependency_injection.py b/src/dependency_injection.py new file mode 100644 index 0000000..30e8412 --- /dev/null +++ b/src/dependency_injection.py @@ -0,0 +1,149 @@ +from functools import lru_cache + +import structlog + +from .adapters import LLMProviderAdapter, RuWikiAdapter +from .models import AppConfig +from .runner import AsyncRunner +from .services import ArticleRepository, AsyncWriteQueue, DatabaseService, SimplifyService + +logger = structlog.get_logger() + + +class DependencyContainer: + def __init__(self, config: AppConfig) -> None: + self.config = config + self._database_service: DatabaseService | None = None + self._repository: ArticleRepository | None = None + self._write_queue: AsyncWriteQueue | None = None + self._ruwiki_adapter: RuWikiAdapter | None = None + self._llm_adapter: LLMProviderAdapter | None = None + self._simplify_service: SimplifyService | None = None + self._initialized = False + + async def initialize(self) -> None: + if self._initialized: + return + + logger.info("Инициализация системы...") + + db_service = self.get_database_service() + await db_service.initialize_database() + + write_queue = self.get_write_queue() + await write_queue.start() + + self._initialized = True + logger.info("Система инициализирована") + + async def cleanup(self) -> None: + if not self._initialized: + return + + logger.info("Очистка ресурсов...") + + if self._write_queue: + await self._write_queue.stop() + + if self._database_service: + self._database_service.close() + + self._initialized = False + logger.info("Ресурсы очищены") + + @lru_cache(maxsize=1) + def get_database_service(self) -> DatabaseService: + if self._database_service is None: + self._database_service = DatabaseService(self.config) + return self._database_service + + @lru_cache(maxsize=1) + def get_repository(self) -> ArticleRepository: + if self._repository is None: + db_service = self.get_database_service() + self._repository = ArticleRepository(db_service) + return self._repository + + @lru_cache(maxsize=1) + def get_write_queue(self) -> AsyncWriteQueue: + if self._write_queue is None: + repository = self.get_repository() + self._write_queue = AsyncWriteQueue(repository, max_batch_size=10) + return self._write_queue + + @lru_cache(maxsize=1) + def get_ruwiki_adapter(self) -> RuWikiAdapter: + if self._ruwiki_adapter is None: + self._ruwiki_adapter = RuWikiAdapter(self.config) + return self._ruwiki_adapter + + @lru_cache(maxsize=1) + def get_llm_adapter(self) -> LLMProviderAdapter: + if self._llm_adapter is None: + self._llm_adapter = LLMProviderAdapter(self.config) + return self._llm_adapter + + @lru_cache(maxsize=1) + def get_simplify_service(self) -> SimplifyService: + if self._simplify_service is None: + self._simplify_service = SimplifyService( + config=self.config, + ruwiki_adapter=self.get_ruwiki_adapter(), + llm_adapter=self.get_llm_adapter(), + repository=self.get_repository(), + write_queue=self.get_write_queue(), + ) + return self._simplify_service + + def create_runner(self, max_workers: int | None = None) -> AsyncRunner: + if max_workers is None: + max_workers = min( + self.config.max_concurrent_llm, + self.config.max_concurrent_wiki, + 10, + ) + + return AsyncRunner( + config=self.config, + simplify_service=self.get_simplify_service(), + max_workers=max_workers, + ) + + async def health_check(self) -> dict[str, bool]: + checks = {} + + try: + db_service = self.get_database_service() + checks["database"] = await db_service.health_check() + except Exception: + checks["database"] = False + + try: + write_queue = self.get_write_queue() + checks["write_queue"] = ( + write_queue._worker_task is not None and not write_queue._worker_task.done() + ) + except Exception: + checks["write_queue"] = False + + try: + ruwiki = self.get_ruwiki_adapter() + checks["ruwiki"] = await ruwiki.health_check() + except Exception: + checks["ruwiki"] = False + + try: + llm = self.get_llm_adapter() + checks["llm"] = await llm.health_check() + except Exception: + checks["llm"] = False + + return checks + + +@lru_cache(maxsize=1) +def get_container(config: AppConfig | None = None) -> DependencyContainer: + if config is None: + config = AppConfig() + + return DependencyContainer(config) diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..1974d53 --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,15 @@ +from .article import Article, ArticleCreate, ArticleRead, ProcessingStatus +from .commands import ProcessingResult, ProcessingStats, SimplifyCommand +from .config import AppConfig +from .constants import * + +__all__ = [ + "AppConfig", + "Article", + "ArticleCreate", + "ArticleRead", + "ProcessingResult", + "ProcessingStats", + "ProcessingStatus", + "SimplifyCommand", +] diff --git a/src/models/article.py b/src/models/article.py new file mode 100644 index 0000000..c6aed9c --- /dev/null +++ b/src/models/article.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from enum import Enum + +from sqlmodel import Field, SQLModel + + +class ProcessingStatus(str, Enum): + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + + +class Article(SQLModel, table=True): + + __tablename__ = "articles" + + id: int | None = Field(default=None, primary_key=True) + url: str = Field(index=True, unique=True, max_length=500) + title: str = Field(max_length=300) + raw_text: str = Field(description="Исходный wiki-текст") + simplified_text: str | None = Field( + default=None, + description="Упрощённый текст для школьников", + ) + status: ProcessingStatus = Field(default=ProcessingStatus.PENDING) + error_message: str | None = Field(default=None, max_length=1000) + token_count_raw: int | None = Field( + default=None, description="Количество токенов в исходном тексте" + ) + token_count_simplified: int | None = Field( + default=None, + description="Количество токенов в упрощённом тексте", + ) + processing_time_seconds: float | None = Field(default=None) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime | None = Field(default=None) + + def mark_processing(self) -> None: + self.status = ProcessingStatus.PROCESSING + self.updated_at = datetime.now(timezone.utc) + + def mark_completed( + self, + simplified_text: str, + token_count_raw: int, + token_count_simplified: int, + processing_time: float, + ) -> None: + self.simplified_text = simplified_text + self.token_count_raw = token_count_raw + self.token_count_simplified = token_count_simplified + self.processing_time_seconds = processing_time + self.status = ProcessingStatus.COMPLETED + self.error_message = None + self.updated_at = datetime.now(timezone.utc) + + def mark_failed(self, error_message: str) -> None: + self.status = ProcessingStatus.FAILED + self.error_message = error_message[:1000] + self.updated_at = datetime.now(timezone.utc) + + +class ArticleCreate(SQLModel): + url: str + title: str + raw_text: str + + +class ArticleRead(SQLModel): + id: int + url: str + title: str + raw_text: str + simplified_text: str | None + status: ProcessingStatus + token_count_raw: int | None + token_count_simplified: int | None + created_at: datetime diff --git a/src/models/commands.py b/src/models/commands.py new file mode 100644 index 0000000..53f0fb9 --- /dev/null +++ b/src/models/commands.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) +class SimplifyCommand: + url: str + force_reprocess: bool = False + + def __str__(self) -> str: + return f"SimplifyCommand(url='{self.url}', force={self.force_reprocess})" + + +@dataclass +class ProcessingResult: + + url: str + success: bool + title: str | None = None + raw_text: str | None = None + simplified_text: str | None = None + token_count_raw: int | None = None + token_count_simplified: int | None = None + processing_time_seconds: float | None = None + error_message: str | None = None + + @classmethod + def success_result( + cls, + url: str, + title: str, + raw_text: str, + simplified_text: str, + token_count_raw: int, + token_count_simplified: int, + processing_time_seconds: float, + ) -> "ProcessingResult": + return cls( + url=url, + success=True, + title=title, + raw_text=raw_text, + simplified_text=simplified_text, + token_count_raw=token_count_raw, + token_count_simplified=token_count_simplified, + processing_time_seconds=processing_time_seconds, + ) + + @classmethod + def failure_result(cls, url: str, error_message: str) -> "ProcessingResult": + return cls( + url=url, + success=False, + error_message=error_message, + ) + + +@dataclass +class ProcessingStats: + + total_processed: int = 0 + successful: int = 0 + failed: int = 0 + skipped: int = 0 + total_processing_time: float = 0.0 + + @property + def success_rate(self) -> float: + if self.total_processed == 0: + return 0.0 + return (self.successful / self.total_processed) * 100.0 + + @property + def average_processing_time(self) -> float: + if self.successful == 0: + return 0.0 + return self.total_processing_time / self.successful + + def add_result(self, result: ProcessingResult) -> None: + self.total_processed += 1 + + if result.success: + self.successful += 1 + if result.processing_time_seconds: + self.total_processing_time += result.processing_time_seconds + else: + self.failed += 1 + + def add_skipped(self) -> None: + self.skipped += 1 diff --git a/src/models/config.py b/src/models/config.py new file mode 100644 index 0000000..e93b975 --- /dev/null +++ b/src/models/config.py @@ -0,0 +1,73 @@ +from pathlib import Path +from typing import Literal + +from pydantic import Field, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class AppConfig(BaseSettings): + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + extra="ignore", + ) + + openai_api_key: str = Field(description="API ключ OpenAI") + openai_model: str = Field(default="gpt-4o-mini", description="Модель OpenAI для упрощения") + openai_temperature: float = Field( + default=0.0, ge=0.0, le=2.0, description="Температура для LLM" + ) + + db_path: str = Field(default="./data/wiki.db", description="Путь к файлу SQLite") + + max_concurrent_llm: int = Field( + default=5, ge=1, le=50, description="Максимум одновременных LLM запросов" + ) + openai_rpm: int = Field(default=200, ge=1, description="Лимит запросов в минуту для OpenAI") + max_concurrent_wiki: int = Field( + default=10, ge=1, le=100, description="Максимум одновременных wiki запросов" + ) + + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(default="INFO") + log_format: Literal["json", "text"] = Field(default="json") + + chunk_size: int = Field(default=2000, ge=500, le=8000, description="Размер чанка для текста") + chunk_overlap: int = Field(default=200, ge=0, le=1000, description="Перекрытие между чанками") + max_retries: int = Field(default=3, ge=1, le=10, description="Максимум попыток повтора") + retry_delay: float = Field( + default=1.0, ge=0.1, le=60.0, description="Задержка между попытками (сек)" + ) + + circuit_failure_threshold: int = Field( + default=5, ge=1, le=20, description="Порог отказов для circuit breaker" + ) + circuit_recovery_timeout: int = Field( + default=60, + ge=10, + le=600, + description="Время восстановления circuit breaker (сек)", + ) + + prompt_template_path: str = Field( + default="src/prompt.txt", description="Путь к файлу с prompt-шаблоном" + ) + input_file_path: str = Field( + default="input.txt", description="Путь к файлу с URL для обработки" + ) + + @field_validator("db_path") + @classmethod + def validate_db_path(cls, v: str) -> str: + db_path = Path(v) + db_path.parent.mkdir(parents=True, exist_ok=True) + return str(db_path) + + @property + def db_url(self) -> str: + return f"sqlite+aiosqlite:///{self.db_path}" + + @property + def sync_db_url(self) -> str: + return f"sqlite:///{self.db_path}" diff --git a/src/models/constants.py b/src/models/constants.py new file mode 100644 index 0000000..e69de29 diff --git a/src/prompt.txt b/src/prompt.txt new file mode 100644 index 0000000..5536924 --- /dev/null +++ b/src/prompt.txt @@ -0,0 +1,28 @@ +### role: system +Ты — опытный редактор Рувики и педагог-методист. Твоя задача — адаптировать научные статьи для школьного образования. + +ПРАВИЛА УПРОЩЕНИЯ: +1. Сократи текст до ≤ 1000 токенов, сохранив ключевую информацию +2. Замени сложные термины на простые аналоги с объяснениями +3. Убери избыточные детали, оставь только суть +4. Сохрани корректную wiki-разметку (== заголовки ==, '''жирный''', ''курсив'', [[ссылки]]) +5. Структурируй материал логично: определение → основные свойства → примеры +6. Добавь простые примеры для лучшего понимания +7. Убери технические подробности, не нужные школьникам + +ЦЕЛЬ: Сделать статью понятной для учеников 8-11 классов, сохранив научную точность. + +ФОРМАТ ОТВЕТА: +- Начни сразу с упрощённого wiki-текста +- Используй простые предложения +- Избегай сложных конструкций +- Заверши ответ маркером ###END### + +### role: user +Статья: {title} + + +{wiki_source_text} + + +Задание: сократи и упрости текст, следуя инструкциям system-сообщения. \ No newline at end of file diff --git a/src/runner.py b/src/runner.py new file mode 100644 index 0000000..e4a5e38 --- /dev/null +++ b/src/runner.py @@ -0,0 +1,214 @@ +import asyncio +import signal +import time + +import structlog + +from .models import AppConfig, ProcessingStats, SimplifyCommand +from .services import SimplifyService +from .sources import FileSource + +logger = structlog.get_logger() + + +class AsyncRunner: + def __init__( + self, + config: AppConfig, + simplify_service: SimplifyService, + max_workers: int = 10, + ) -> None: + self.config = config + self.simplify_service = simplify_service + self.max_workers = max_workers + + self._task_queue: asyncio.Queue[SimplifyCommand] = asyncio.Queue() + self._workers: list[asyncio.Task[None]] = [] + self._shutdown_event = asyncio.Event() + + self.stats = ProcessingStats() + self._start_time: float | None = None + + self.logger = structlog.get_logger().bind(service="runner") + + async def run_from_file( + self, + input_file: str, + force_reprocess: bool = False, + max_articles: int | None = None, + ) -> ProcessingStats: + self.logger.info( + "Запуск обработки статей из файла", + input_file=input_file, + force_reprocess=force_reprocess, + max_workers=self.max_workers, + max_articles=max_articles, + ) + + self._setup_signal_handlers() + + try: + source = FileSource(input_file) + await self._load_tasks_from_source(source, force_reprocess, max_articles) + + await self._run_processing() + + except Exception as e: + self.logger.error("Ошибка при выполнении runner", error=str(e)) + raise + finally: + await self._cleanup() + + return self.stats + + async def _load_tasks_from_source( + self, + source: FileSource, + force_reprocess: bool, + max_articles: int | None, + ) -> None: + loaded_count = 0 + + async for command in source.read_urls(force_reprocess): + if max_articles and loaded_count >= max_articles: + break + + await self._task_queue.put(command) + loaded_count += 1 + + self.logger.info("Задачи загружены в очередь", count=loaded_count) + + async def _run_processing(self) -> None: + self._start_time = time.time() + + self.logger.info("Запуск worker корутин", count=self.max_workers) + + for i in range(self.max_workers): + worker = asyncio.create_task(self._worker_loop(worker_id=i)) + self._workers.append(worker) + + await self._task_queue.join() + + self._shutdown_event.set() + + if self._workers: + await asyncio.gather(*self._workers, return_exceptions=True) + + async def _worker_loop(self, worker_id: int) -> None: + worker_logger = self.logger.bind(worker_id=worker_id) + worker_logger.info("Worker запущен") + + processed_count = 0 + + while not self._shutdown_event.is_set(): + try: + try: + command = await asyncio.wait_for( + self._task_queue.get(), + timeout=1.0, + ) + except asyncio.TimeoutError: + continue + + try: + result = await self.simplify_service.process_command(command) + + self.stats.add_result(result) + processed_count += 1 + + if result.success: + worker_logger.info( + "Статья обработана успешно", + url=command.url, + title=result.title, + tokens_in=result.token_count_raw, + tokens_out=result.token_count_simplified, + ) + else: + worker_logger.warning( + "Ошибка при обработке статьи", + url=command.url, + error=result.error_message, + ) + + except Exception as e: + worker_logger.error( + "Неожиданная ошибка в worker", + url=command.url, + error=str(e), + ) + + from .models import ProcessingResult + + error_result = ProcessingResult.failure_result( + command.url, + f"Неожиданная ошибка: {e!s}", + ) + self.stats.add_result(error_result) + + finally: + self._task_queue.task_done() + + except Exception as e: + worker_logger.error("Критическая ошибка в worker loop", error=str(e)) + break + + worker_logger.info("Worker завершён", processed_articles=processed_count) + + def _setup_signal_handlers(self) -> None: + def signal_handler(signum: int, frame: None) -> None: + signal_name = signal.Signals(signum).name + self.logger.info(f"Получен сигнал {signal_name}, начинаем graceful shutdown") + self._shutdown_event.set() + + try: + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + except ValueError: + self.logger.warning("Не удалось настроить обработчики сигналов") + + async def _cleanup(self) -> None: + self.logger.info("Начинаем очистку ресурсов") + + for worker in self._workers: + if not worker.done(): + worker.cancel() + + if self._workers: + results = await asyncio.gather(*self._workers, return_exceptions=True) + cancelled_count = sum(1 for r in results if isinstance(r, asyncio.CancelledError)) + if cancelled_count > 0: + self.logger.info("Workers отменены", count=cancelled_count) + + self._workers.clear() + + def get_progress_info(self) -> dict[str, any]: + elapsed_time = time.time() - self._start_time if self._start_time else 0 + + articles_per_minute = 0 + if elapsed_time > 0: + articles_per_minute = (self.stats.successful * 60) / elapsed_time + + return { + "total_processed": self.stats.total_processed, + "successful": self.stats.successful, + "failed": self.stats.failed, + "success_rate": self.stats.success_rate, + "elapsed_time": elapsed_time, + "articles_per_minute": articles_per_minute, + "queue_size": self._task_queue.qsize(), + "active_workers": len([w for w in self._workers if not w.done()]), + } + + async def health_check(self) -> dict[str, any]: + checks = await self.simplify_service.health_check() + + checks.update( + { + "runner_active": bool(self._workers and not self._shutdown_event.is_set()), + "queue_size": self._task_queue.qsize(), + "workers_count": len(self._workers), + } + ) + + return checks diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 0000000..c8b4357 --- /dev/null +++ b/src/services/__init__.py @@ -0,0 +1,14 @@ +from .database import DatabaseService +from .repository import ArticleRepository +from .simplify_service import SimplifyService +from .text_splitter import RecursiveCharacterTextSplitter +from .write_queue import AsyncWriteQueue, WriteOperation + +__all__ = [ + "ArticleRepository", + "AsyncWriteQueue", + "DatabaseService", + "RecursiveCharacterTextSplitter", + "SimplifyService", + "WriteOperation", +] diff --git a/src/services/database.py b/src/services/database.py new file mode 100644 index 0000000..406e2aa --- /dev/null +++ b/src/services/database.py @@ -0,0 +1,65 @@ +"""Сервис для управления базой данных.""" + +from pathlib import Path + +import aiosqlite +import structlog +from sqlmodel import SQLModel, create_engine + +from ..models import AppConfig + +logger = structlog.get_logger() + + +class DatabaseService: + def __init__(self, config: AppConfig) -> None: + self.config = config + self.logger = structlog.get_logger().bind(service="database") + + self._sync_engine = create_engine( + config.sync_db_url, + echo=False, + connect_args={"check_same_thread": False}, + ) + + async def initialize_database(self) -> None: + db_path = Path(self.config.db_path) + db_path.parent.mkdir(parents=True, exist_ok=True) + + self.logger.info("Создание схемы базы данных", db_path=self.config.db_path) + SQLModel.metadata.create_all(self._sync_engine) + + await self._configure_sqlite() + + self.logger.info("База данных инициализирована", db_path=self.config.db_path) + + async def _configure_sqlite(self) -> None: + async with aiosqlite.connect(self.config.db_path) as conn: + await conn.execute("PRAGMA journal_mode=WAL") + + await conn.execute("PRAGMA cache_size=10000") + + await conn.execute("PRAGMA synchronous=NORMAL") + + await conn.execute("PRAGMA busy_timeout=30000") + + await conn.commit() + self.logger.info("SQLite настроен для оптимальной производительности") + + async def get_connection(self) -> aiosqlite.Connection: + return await aiosqlite.connect( + self.config.db_path, + timeout=30.0, + ) + + async def health_check(self) -> bool: + try: + async with self.get_connection() as conn: + await conn.execute("SELECT 1") + return True + except Exception as e: + self.logger.error("Database health check failed", error=str(e)) + return False + + def close(self) -> None: + self._sync_engine.dispose() diff --git a/src/services/repository.py b/src/services/repository.py new file mode 100644 index 0000000..4966087 --- /dev/null +++ b/src/services/repository.py @@ -0,0 +1,188 @@ +from typing import Any + +import aiosqlite +import structlog + +from ..models import Article, ArticleCreate, ProcessingStatus +from .database import DatabaseService + +logger = structlog.get_logger() + + +class ArticleRepository: + + def __init__(self, db_service: DatabaseService) -> None: + self.db_service = db_service + self.logger = structlog.get_logger().bind(repository="article") + + async def create_article(self, article_data: ArticleCreate) -> Article: + existing = await self.get_by_url(article_data.url) + if existing: + raise ValueError(f"Статья с URL {article_data.url} уже существует") + + article = Article( + url=article_data.url, + title=article_data.title, + raw_text=article_data.raw_text, + status=ProcessingStatus.PENDING, + ) + + async with self.db_service.get_connection() as conn: + cursor = await conn.execute( + """ + INSERT INTO articles (url, title, raw_text, status, created_at) + VALUES (?, ?, ?, ?, ?) + """, + ( + article.url, + article.title, + article.raw_text, + article.status.value, + article.created_at, + ), + ) + await conn.commit() + + article.id = cursor.lastrowid + + self.logger.info("Статья создана", article_id=article.id, url=article.url) + return article + + async def get_by_id(self, article_id: int) -> Article | None: + async with self.db_service.get_connection() as conn: + cursor = await conn.execute( + "SELECT * FROM articles WHERE id = ?", + (article_id,), + ) + row = await cursor.fetchone() + + if not row: + return None + + return self._row_to_article(row) + + async def get_by_url(self, url: str) -> Article | None: + async with self.db_service.get_connection() as conn: + cursor = await conn.execute( + "SELECT * FROM articles WHERE url = ?", + (url,), + ) + row = await cursor.fetchone() + + if not row: + return None + + return self._row_to_article(row) + + async def update_article(self, article: Article) -> Article: + if not article.id: + raise ValueError("ID статьи не может быть None для обновления") + + async with self.db_service.get_connection() as conn: + cursor = await conn.execute( + """ + UPDATE articles SET + title = ?, + raw_text = ?, + simplified_text = ?, + status = ?, + error_message = ?, + token_count_raw = ?, + token_count_simplified = ?, + processing_time_seconds = ?, + updated_at = ? + WHERE id = ? + """, + ( + article.title, + article.raw_text, + article.simplified_text, + article.status.value, + article.error_message, + article.token_count_raw, + article.token_count_simplified, + article.processing_time_seconds, + article.updated_at, + article.id, + ), + ) + await conn.commit() + + if cursor.rowcount == 0: + raise ValueError(f"Статья с ID {article.id} не найдена") + + self.logger.info("Статья обновлена", article_id=article.id, status=article.status) + return article + + async def get_articles_by_status( + self, status: ProcessingStatus, limit: int | None = None + ) -> list[Article]: + query = "SELECT * FROM articles WHERE status = ?" + params: tuple[Any, ...] = (status.value,) + + if limit: + query += " LIMIT ?" + params = params + (limit,) + + async with self.db_service.get_connection() as conn: + cursor = await conn.execute(query, params) + rows = await cursor.fetchall() + + return [self._row_to_article(row) for row in rows] + + async def get_pending_articles(self, limit: int | None = None) -> list[Article]: + return await self.get_articles_by_status(ProcessingStatus.PENDING, limit) + + async def count_by_status(self, status: ProcessingStatus) -> int: + async with self.db_service.get_connection() as conn: + cursor = await conn.execute( + "SELECT COUNT(*) FROM articles WHERE status = ?", + (status.value,), + ) + result = await cursor.fetchone() + + return result[0] if result else 0 + + async def get_all_articles(self, limit: int | None = None, offset: int = 0) -> list[Article]: + query = "SELECT * FROM articles ORDER BY created_at DESC" + params: tuple[Any, ...] = () + + if limit: + query += " LIMIT ? OFFSET ?" + params = (limit, offset) + + async with self.db_service.get_connection() as conn: + cursor = await conn.execute(query, params) + rows = await cursor.fetchall() + + return [self._row_to_article(row) for row in rows] + + async def delete_article(self, article_id: int) -> bool: + async with self.db_service.get_connection() as conn: + cursor = await conn.execute( + "DELETE FROM articles WHERE id = ?", + (article_id,), + ) + await conn.commit() + + deleted = cursor.rowcount > 0 + if deleted: + self.logger.info("Статья удалена", article_id=article_id) + + return deleted + + def _row_to_article(self, row: aiosqlite.Row) -> Article: + return Article( + id=row["id"], + url=row["url"], + title=row["title"], + raw_text=row["raw_text"], + simplified_text=row["simplified_text"], + status=ProcessingStatus(row["status"]), + error_message=row["error_message"], + token_count_raw=row["token_count_raw"], + token_count_simplified=row["token_count_simplified"], + processing_time_seconds=row["processing_time_seconds"], + created_at=row["created_at"], + updated_at=row["updated_at"], + ) diff --git a/src/services/simplify_service.py b/src/services/simplify_service.py new file mode 100644 index 0000000..703cac2 --- /dev/null +++ b/src/services/simplify_service.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +import asyncio +import time +from pathlib import Path + +import structlog + +from src.adapters.llm import LLMProviderAdapter, LLMTokenLimitError +from src.adapters.ruwiki import RuWikiAdapter +from src.models import AppConfig, ArticleCreate, ProcessingResult, SimplifyCommand +from src.models.constants import LLM_MAX_INPUT_TOKENS, MAX_TOKEN_LIMIT_WITH_BUFFER +from src.services.repository import ArticleRepository +from src.services.text_splitter import RecursiveCharacterTextSplitter +from src.services.write_queue import AsyncWriteQueue + + +class SimplifyService: + def __init__( + self, + config: AppConfig, + ruwiki_adapter: RuWikiAdapter, + llm_adapter: LLMProviderAdapter, + repository: ArticleRepository, + write_queue: AsyncWriteQueue, + ) -> None: + self.config = config + self.ruwiki_adapter = ruwiki_adapter + self.llm_adapter = llm_adapter + self.repository = repository + self.write_queue = write_queue + + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=self.llm_adapter.count_tokens, + ) + + self._prompt_template: str | None = None + self.logger = structlog.get_logger().bind(service="simplify") + + async def get_prompt_template(self) -> str: + if self._prompt_template is None: + prompt_path = Path(self.config.prompt_template_path) + if not prompt_path.exists(): + msg = f"Prompt template не найден: {prompt_path}" + raise FileNotFoundError(msg) + + self._prompt_template = prompt_path.read_text(encoding="utf-8") + + return self._prompt_template + + async def process_command(self, command: SimplifyCommand) -> ProcessingResult: + start_time = time.time() + self.logger.info("Начало обработки статьи", url=command.url) + + try: + return await self._process_command_impl(command, start_time) + except Exception as e: + return await self._handle_processing_error(command, e, start_time) + + async def _process_command_impl( + self, command: SimplifyCommand, start_time: float + ) -> ProcessingResult: + if not command.force_reprocess: + existing_result = await self._check_existing_article(command.url) + if existing_result: + return existing_result + + page_info = await self.ruwiki_adapter.fetch_page_cleaned(command.url) + article = await self._create_or_update_article(command, page_info) + + article.mark_processing() + await self.repository.update_article(article) + + simplified_text, input_tokens, output_tokens = await self._simplify_article_text( + title=page_info.title, + raw_text=page_info.content, + ) + + processing_time = time.time() - start_time + result = ProcessingResult.success_result( + url=command.url, + title=page_info.title, + raw_text=page_info.content, + simplified_text=simplified_text, + token_count_raw=input_tokens, + token_count_simplified=output_tokens, + processing_time_seconds=processing_time, + ) + + await self.write_queue.update_from_result(result) + + self.logger.info( + "Статья успешно обработана", + url=command.url, + title=page_info.title, + processing_time=processing_time, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + + return result + + async def _check_existing_article(self, url: str) -> ProcessingResult | None: + existing_article = await self.repository.get_by_url(url) + if existing_article and existing_article.simplified_text: + self.logger.info("Статья уже обработана, пропускаем", url=url) + return ProcessingResult.success_result( + url=url, + title=existing_article.title, + raw_text=existing_article.raw_text, + simplified_text=existing_article.simplified_text, + token_count_raw=existing_article.token_count_raw or 0, + token_count_simplified=existing_article.token_count_simplified or 0, + processing_time_seconds=existing_article.processing_time_seconds or 0, + ) + return None + + async def _create_or_update_article(self, command, page_info): + article_data = ArticleCreate( + url=command.url, + title=page_info.title, + raw_text=page_info.content, + ) + + try: + return await self.repository.create_article(article_data) + except ValueError: + article = await self.repository.get_by_url(command.url) + if not article: + msg = f"Не удалось найти статью после создания: {command.url}" + raise ValueError(msg) from None + + if command.force_reprocess: + article.title = page_info.title + article.raw_text = page_info.content + article.mark_processing() + await self.repository.update_article(article) + + return article + + async def _handle_processing_error( + self, command: SimplifyCommand, error: Exception, start_time: float + ) -> ProcessingResult: + processing_time = time.time() - start_time + error_message = f"{type(error).__name__}: {error!s}" + + self.logger.exception( + "Ошибка при обработке статьи", + url=command.url, + processing_time=processing_time, + ) + + error_result = ProcessingResult.failure_result(command.url, error_message) + try: + await self.write_queue.update_from_result(error_result) + except Exception: + self.logger.exception("Ошибка записи результата с ошибкой") + + return error_result + + async def _simplify_article_text(self, title: str, raw_text: str) -> tuple[str, int, int]: + prompt_template = await self.get_prompt_template() + text_tokens = self.llm_adapter.count_tokens(raw_text) + + if text_tokens <= self.config.chunk_size: + return await self.llm_adapter.simplify_text(title, raw_text, prompt_template) + + return await self._process_long_text(title, raw_text, prompt_template) + + async def _process_long_text( + self, title: str, raw_text: str, prompt_template: str + ) -> tuple[str, int, int]: + self.logger.info( + "Разбиение длинного текста на части", + title=title, + total_tokens=self.llm_adapter.count_tokens(raw_text), + chunk_size=self.config.chunk_size, + ) + + chunks = self.text_splitter.split_text(raw_text) + simplified_chunks = [] + total_input_tokens = 0 + total_output_tokens = 0 + + for i, chunk in enumerate(chunks): + self.logger.debug( + "Обработка части текста", + title=title, + chunk_index=i + 1, + total_chunks=len(chunks), + chunk_tokens=self.llm_adapter.count_tokens(chunk), + ) + + try: + simplified_chunk, input_tokens, output_tokens = ( + await self.llm_adapter.simplify_text( + title=f"{title} (часть {i+1}/{len(chunks)})", + wiki_text=chunk, + prompt_template=prompt_template, + ) + ) + + simplified_chunks.append(simplified_chunk) + total_input_tokens += input_tokens + total_output_tokens += output_tokens + + except Exception as e: + self.logger.warning( + "Ошибка при обработке части текста", + title=title, + chunk_index=i + 1, + error=str(e), + ) + + if not simplified_chunks: + msg = "Не удалось обработать ни одной части текста" + raise LLMTokenLimitError(msg) + + combined_text = self._combine_simplified_chunks(simplified_chunks) + return self._ensure_token_limit(combined_text, total_input_tokens, total_output_tokens) + + def _ensure_token_limit( + self, combined_text: str, total_input_tokens: int, total_output_tokens: int + ) -> tuple[str, int, int]: + final_tokens = self.llm_adapter.count_tokens(combined_text) + if final_tokens > MAX_TOKEN_LIMIT_WITH_BUFFER: + self.logger.warning( + "Объединённый текст превышает лимит, обрезаем", + final_tokens=final_tokens, + ) + combined_text = self._truncate_to_token_limit(combined_text, 1000) + total_output_tokens = self.llm_adapter.count_tokens(combined_text) + + return combined_text, total_input_tokens, total_output_tokens + + def _combine_simplified_chunks(self, chunks: list[str]) -> str: + combined = "\n\n".join(chunk.strip() for chunk in chunks if chunk.strip()) + return "\n".join(line for line in combined.split("\n") if line.strip()) + + def _truncate_to_token_limit(self, text: str, token_limit: int) -> str: + current_tokens = self.llm_adapter.count_tokens(text) + if current_tokens <= token_limit: + return text + + sentences = text.split(". ") + truncated = "" + + for sentence in sentences: + test_text = truncated + sentence + ". " + if self.llm_adapter.count_tokens(test_text) > token_limit: + break + truncated = test_text + + return truncated.strip() + + async def health_check(self) -> dict[str, bool]: + checks = {} + + checks["ruwiki"] = await self._safe_health_check(self.ruwiki_adapter.health_check) + checks["llm"] = await self._safe_health_check(self.llm_adapter.health_check) + checks["prompt_template"] = await self._safe_health_check(self.get_prompt_template) + + return checks + + async def _safe_health_check(self, check_func) -> bool: + try: + await check_func() + return True + except Exception: + return False diff --git a/src/services/text_splitter.py b/src/services/text_splitter.py new file mode 100644 index 0000000..fd5fa7e --- /dev/null +++ b/src/services/text_splitter.py @@ -0,0 +1,163 @@ +import re +from collections.abc import Callable + +import structlog + +logger = structlog.get_logger() + + +class RecursiveCharacterTextSplitter: + def __init__( + self, + chunk_size: int = 2000, + chunk_overlap: int = 200, + length_function: Callable[[str], int] | None = None, + separators: list[str] | None = None, + ) -> None: + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.length_function = length_function or len + + self.separators = separators or [ + "\n\n", + "\n", + ". ", + "! ", + "? ", + "; ", + ", ", + " ", + "", + ] + + self.logger = structlog.get_logger().bind(service="text_splitter") + + def split_text(self, text: str) -> list[str]: + if not text.strip(): + return [] + + if self.length_function(text) <= self.chunk_size: + return [text] + + chunks = self._split_text_recursive(text, self.separators) + + merged_chunks = self._merge_splits(chunks) + + self.logger.debug( + "Текст разбит на части", + original_length=self.length_function(text), + chunks_count=len(merged_chunks), + avg_chunk_size=( + sum(self.length_function(chunk) for chunk in merged_chunks) / len(merged_chunks) + if merged_chunks + else 0 + ), + ) + + return merged_chunks + + def _split_text_recursive(self, text: str, separators: list[str]) -> list[str]: + final_chunks = [] + separator = separators[-1] + new_separators = [] + + for i, sep in enumerate(separators): + if sep == "": + separator = sep + break + if re.search(re.escape(sep), text): + separator = sep + new_separators = separators[i + 1 :] + break + + splits = self._split_by_separator(text, separator) + + good_splits = [] + for split in splits: + if self.length_function(split) < self.chunk_size: + good_splits.append(split) + else: + if good_splits: + merged_text = self._merge_splits(good_splits) + final_chunks.extend(merged_text) + good_splits = [] + + if not new_separators: + final_chunks.extend(self._split_by_length(split)) + else: + other_info = self._split_text_recursive(split, new_separators) + final_chunks.extend(other_info) + + if good_splits: + merged_text = self._merge_splits(good_splits) + final_chunks.extend(merged_text) + + return final_chunks + + def _split_by_separator(self, text: str, separator: str) -> list[str]: + if separator == "": + return list(text) + + return text.split(separator) + + def _split_by_length(self, text: str) -> list[str]: + chunks = [] + start = 0 + + while start < len(text): + end = start + self.chunk_size + + if end < len(text): + for offset in range(min(100, self.chunk_size // 4)): + if end - offset > start and text[end - offset] in " \n\t.,;!?": + end = end - offset + 1 + break + + chunk = text[start:end].strip() + if chunk: + chunks.append(chunk) + + start = max(start + 1, end - self.chunk_overlap) + + return chunks + + def _merge_splits(self, splits: list[str]) -> list[str]: + if not splits: + return [] + + merged_chunks = [] + current_chunk = "" + + for split in splits: + test_chunk = current_chunk + if current_chunk and not current_chunk.endswith(("\n", " ")): + if not split.startswith(("\n", " ")): + test_chunk += " " + test_chunk += split + + if self.length_function(test_chunk) <= self.chunk_size: + current_chunk = test_chunk + else: + if current_chunk.strip(): + merged_chunks.append(current_chunk.strip()) + + current_chunk = split + + if current_chunk.strip(): + merged_chunks.append(current_chunk.strip()) + + return merged_chunks + + def create_chunks_with_metadata(self, text: str, title: str = "") -> list[dict[str, str]]: + chunks = self.split_text(text) + + return [ + { + "text": chunk, + "title": title, + "chunk_index": i, + "total_chunks": len(chunks), + "chunk_size": self.length_function(chunk), + } + for i, chunk in enumerate(chunks) + ] diff --git a/src/services/write_queue.py b/src/services/write_queue.py new file mode 100644 index 0000000..38e8bd1 --- /dev/null +++ b/src/services/write_queue.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field + +import structlog + +from src.models import Article, ProcessingResult +from src.models.constants import WRITE_QUEUE_BATCH_SIZE +from src.services.repository import ArticleRepository + + +@dataclass +class WriteOperation: + operation_type: str + article: Article | None = None + result: ProcessingResult | None = None + future: asyncio.Future[Article] | None = field(default=None, init=False) + + +class AsyncWriteQueue: + def __init__( + self, repository: ArticleRepository, max_batch_size: int = WRITE_QUEUE_BATCH_SIZE + ) -> None: + self.repository = repository + self.max_batch_size = max_batch_size + self.logger = structlog.get_logger().bind(service="write_queue") + + self._queue: asyncio.Queue[WriteOperation] = asyncio.Queue() + self._worker_task: asyncio.Task[None] | None = None + self._shutdown_event = asyncio.Event() + + self._total_operations = 0 + self._failed_operations = 0 + + async def start(self) -> None: + if self._worker_task is not None: + msg = "Write queue уже запущена" + raise RuntimeError(msg) + + self._worker_task = asyncio.create_task(self._worker_loop()) + self.logger.info("Write queue запущена") + + async def stop(self, timeout: float = 10.0) -> None: + if self._worker_task is None: + return + + self.logger.info("Остановка write queue") + self._shutdown_event.set() + + try: + await asyncio.wait_for(self._worker_task, timeout=timeout) + except asyncio.TimeoutError: + self.logger.warning("Таймаут остановки write queue, принудительная отмена") + self._worker_task.cancel() + + self.logger.info("Write queue остановлена") + + async def update_article(self, article: Article) -> None: + operation = WriteOperation( + operation_type="update", + article=article, + ) + await self._queue.put(operation) + + async def update_from_result(self, result: ProcessingResult) -> Article: + future: asyncio.Future[Article] = asyncio.Future() + + operation = WriteOperation( + operation_type="update_from_result", + result=result, + ) + operation.future = future + + await self._queue.put(operation) + return await future + + async def _worker_loop(self) -> None: + batch: list[WriteOperation] = [] + + while not self._shutdown_event.is_set(): + batch = await self._collect_batch(batch) + if batch and (len(batch) >= self.max_batch_size or self._shutdown_event.is_set()): + await self._process_batch(batch) + batch.clear() + + if batch: + await self._process_batch(batch) + + async def _collect_batch(self, batch: list[WriteOperation]) -> list[WriteOperation]: + try: + timeout = 0.1 if batch else 1.0 + operation = await asyncio.wait_for(self._queue.get(), timeout=timeout) + batch.append(operation) + return batch + except asyncio.TimeoutError: + return batch + except Exception as e: + self.logger.exception("Ошибка в worker loop") + self._handle_batch_error(batch, e) + return [] + + def _handle_batch_error(self, batch: list[WriteOperation], error: Exception) -> None: + for op in batch: + if op.future and not op.future.done(): + op.future.set_exception(error) + + async def _process_batch(self, batch: list[WriteOperation]) -> None: + if not batch: + return + + self.logger.debug("Обработка батча операций", batch_size=len(batch)) + + for operation in batch: + await self._process_operation_safely(operation) + + async def _process_operation_safely(self, operation: WriteOperation) -> None: + try: + await self._process_single_operation(operation) + self._total_operations += 1 + + if operation.future and not operation.future.done(): + if operation.operation_type == "update_from_result" and operation.result: + article = await self.repository.get_by_url(operation.result.url) + operation.future.set_result(article) + + except Exception as e: + self._failed_operations += 1 + self.logger.exception( + "Ошибка при обработке операции", + operation_type=operation.operation_type, + ) + + if operation.future and not operation.future.done(): + operation.future.set_exception(e) + + async def _process_single_operation(self, operation: WriteOperation) -> None: + if operation.operation_type == "update" and operation.article: + await self.repository.update_article(operation.article) + elif operation.operation_type == "update_from_result" and operation.result: + await self._update_article_from_result(operation.result) + else: + msg = f"Неизвестный тип операции: {operation.operation_type}" + raise ValueError(msg) + + async def _update_article_from_result(self, result: ProcessingResult) -> Article: + article = await self.repository.get_by_url(result.url) + if not article: + msg = f"Статья с URL {result.url} не найдена" + raise ValueError(msg) + + if result.success: + if not (result.title and result.raw_text and result.simplified_text): + msg = "Неполные данные в успешном результате" + raise ValueError(msg) + + article.mark_completed( + simplified_text=result.simplified_text, + token_count_raw=result.token_count_raw or 0, + token_count_simplified=result.token_count_simplified or 0, + processing_time=result.processing_time_seconds or 0, + ) + else: + article.mark_failed(result.error_message or "Неизвестная ошибка") + + return await self.repository.update_article(article) + + @property + def queue_size(self) -> int: + return self._queue.qsize() + + @property + def stats(self) -> dict[str, int]: + return { + "total_operations": self._total_operations, + "failed_operations": self._failed_operations, + "success_rate": ( + (self._total_operations - self._failed_operations) / self._total_operations * 100 + if self._total_operations > 0 + else 0 + ), + } diff --git a/src/sources.py b/src/sources.py new file mode 100644 index 0000000..9e15bd2 --- /dev/null +++ b/src/sources.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import AsyncGenerator +from urllib.parse import urlparse + +import structlog + +from src.models import SimplifyCommand +from src.models.constants import ARTICLE_NAME_INDEX, MIN_WIKI_PATH_PARTS, WIKI_PATH_INDEX + + +class FileSource: + def __init__(self, file_path: str) -> None: + self.file_path = Path(file_path) + self.logger = structlog.get_logger().bind(source="file", path=str(self.file_path)) + + async def read_urls( + self, *, force_reprocess: bool = False + ) -> AsyncGenerator[SimplifyCommand, None]: + if not self.file_path.exists(): + msg = f"Файл с URL не найден: {self.file_path}" + raise FileNotFoundError(msg) + + self.logger.info("Начинаем чтение URL из файла") + + content = await asyncio.to_thread(self._read_file_sync) + + seen_urls = set() + valid_count = 0 + invalid_count = 0 + + for line_num, original_line in enumerate(content.splitlines(), 1): + line = original_line.strip() + + if not line or line.startswith("#"): + continue + + if not self._is_valid_wikipedia_url(line): + self.logger.warning("Невалидный URL", line_number=line_num, url=line) + invalid_count += 1 + continue + + if line in seen_urls: + self.logger.debug("Дубликат URL пропущен", line_number=line_num, url=line) + continue + + seen_urls.add(line) + valid_count += 1 + + yield SimplifyCommand(url=line, force_reprocess=force_reprocess) + + self.logger.info( + "Завершено чтение URL", + valid_count=valid_count, + invalid_count=invalid_count, + total_unique=len(seen_urls), + ) + + def _read_file_sync(self) -> str: + return self.file_path.read_text(encoding="utf-8") + + def _is_valid_wikipedia_url(self, url: str) -> bool: + try: + parsed = urlparse(url) + + if parsed.scheme not in ("http", "https"): + return False + + if "wikipedia.org" not in parsed.netloc: + return False + + path_parts = parsed.path.split("/") + if len(path_parts) < MIN_WIKI_PATH_PARTS or path_parts[WIKI_PATH_INDEX] != "wiki": + return False + + article_name = path_parts[ARTICLE_NAME_INDEX] + return bool(article_name and article_name not in ("Main_Page", "Заглавная_страница")) + + except Exception: + return False + + async def count_urls(self) -> int: + count = 0 + async for _ in self.read_urls(): + count += 1 + return count diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1eb54ef --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,178 @@ +import asyncio +import tempfile +from collections.abc import Generator +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice + +from src.models import AppConfig, Article, ArticleCreate, ProcessingStatus + + +@pytest.fixture(scope="session") +def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: + """Создать event loop для всей сессии тестов.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +def test_config() -> AppConfig: + """Тестовая конфигурация.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test.db" + return AppConfig( + openai_api_key="test_key", + openai_model="gpt-4o-mini", + db_path=str(db_path), + max_concurrent_llm=2, + openai_rpm=10, + max_concurrent_wiki=5, + prompt_template_path="src/prompt.txt", + log_level="DEBUG", + ) + + +@pytest.fixture +def sample_wiki_urls() -> list[str]: + """Список тестовых URL википедии.""" + return [ + "https://ru.wikipedia.org/wiki/Тест", + "https://ru.wikipedia.org/wiki/Пример", + "https://ru.wikipedia.org/wiki/Образец", + ] + + +@pytest.fixture +def invalid_urls() -> list[str]: + """Список невалидных URL.""" + return [ + "https://example.com/invalid", + "https://en.wikipedia.org/wiki/English", + "not_a_url", + "", + "https://ru.wikipedia.org/wiki/", + ] + + +@pytest.fixture +def sample_wikitext() -> str: + return """'''Тест''' — это проверка чего-либо. + +== Определение == +Тест может проводиться для различных целей: +* Проверка знаний +* Проверка работоспособности +* Проверка качества + +== История == +Тесты использовались с древних времён. + +{{навигация|тема=Тестирование}} + +[[Категория:Тестирование]]""" + + +@pytest.fixture +def simplified_text() -> str: + return """'''Тест''' — это проверка чего-либо для школьников. + +== Что такое тест == +Тест помогает проверить: +* Знания учеников +* Как работают устройства +* Качество продуктов + +== Когда появились тесты == +Люди проверяли друг друга очень давно. + +###END###""" + + +@pytest.fixture +def sample_article_data() -> ArticleCreate: + return ArticleCreate( + url="https://ru.wikipedia.org/wiki/Тест", + title="Тест", + raw_text="Тестовый wiki-текст", + ) + + +@pytest.fixture +def sample_article(sample_article_data: ArticleCreate) -> Article: + return Article( + id=1, + url=sample_article_data.url, + title=sample_article_data.title, + raw_text=sample_article_data.raw_text, + status=ProcessingStatus.PENDING, + ) + + +@pytest.fixture +def completed_article(sample_article: Article, simplified_text: str) -> Article: + article = sample_article.model_copy() + article.mark_completed( + simplified_text=simplified_text, + token_count_raw=100, + token_count_simplified=50, + processing_time=2.5, + ) + return article + + +@pytest.fixture +def mock_openai_response() -> ChatCompletion: + return ChatCompletion( + id="test_completion", + object="chat.completion", + created=1234567890, + model="gpt-4o-mini", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + content="Упрощённый текст для школьников.\n\n###END###", + ), + finish_reason="stop", + ) + ], + usage=None, + ) + + +@pytest.fixture +def temp_input_file(sample_wiki_urls: list[str]) -> Generator[str, None, None]: + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + for url in sample_wiki_urls: + f.write(f"{url}\n") + f.write("# Комментарий\n") + f.write("\n") + f.write("https://ru.wikipedia.org/wiki/Дубликат\n") + f.write("https://ru.wikipedia.org/wiki/Дубликат\n") + temp_path = f.name + + yield temp_path + + Path(temp_path).unlink(missing_ok=True) + + +@pytest.fixture +async def mock_wiki_client() -> AsyncMock: + mock_client = AsyncMock() + mock_page = MagicMock() + mock_page.exists = True + mock_page.redirect = False + mock_page.text.return_value = "Тестовый wiki-текст" + mock_client.pages = {"Тест": mock_page} + return mock_client + + +@pytest.fixture +async def mock_openai_client() -> AsyncMock: + mock_client = AsyncMock() + return mock_client diff --git a/tests/test_adapters.py b/tests/test_adapters.py new file mode 100644 index 0000000..a7de050 --- /dev/null +++ b/tests/test_adapters.py @@ -0,0 +1,278 @@ +import asyncio +import time +from unittest.mock import AsyncMock, patch + +import pytest +from openai import APIError, RateLimitError + +from src.adapters import ( + CircuitBreaker, + CircuitBreakerError, + LLMProviderAdapter, + LLMRateLimitError, + LLMTokenLimitError, + RateLimiter, + RuWikiAdapter, +) + + +class TestCircuitBreaker: + + @pytest.mark.asyncio + async def test_successful_call(self): + cb = CircuitBreaker(failure_threshold=3, recovery_timeout=1) + + async def test_func(): + return "success" + + result = await cb.call(test_func) + assert result == "success" + + @pytest.mark.asyncio + async def test_failure_accumulation(self): + cb = CircuitBreaker(failure_threshold=2, recovery_timeout=1) + + async def failing_func(): + raise ValueError("Test error") + + with pytest.raises(ValueError): + await cb.call(failing_func) + + with pytest.raises(ValueError): + await cb.call(failing_func) + + with pytest.raises(CircuitBreakerError): + await cb.call(failing_func) + + @pytest.mark.asyncio + async def test_recovery(self): + cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1) + + async def failing_func(): + raise ValueError("Test error") + + async def success_func(): + return "recovered" + + with pytest.raises(ValueError): + await cb.call(failing_func) + + with pytest.raises(CircuitBreakerError): + await cb.call(failing_func) + + await asyncio.sleep(0.2) + + result = await cb.call(success_func) + assert result == "recovered" + + +class TestRateLimiter: + + @pytest.mark.asyncio + async def test_concurrency_limit(self): + limiter = RateLimiter(max_concurrent=2) + results = [] + + async def test_task(task_id: int): + async with limiter: + results.append(f"start_{task_id}") + await asyncio.sleep(0.1) + results.append(f"end_{task_id}") + + tasks = [test_task(i) for i in range(3)] + await asyncio.gather(*tasks) + + start_count = 0 + max_concurrent = 0 + + for result in results: + if result.startswith("start_"): + start_count += 1 + max_concurrent = max(max_concurrent, start_count) + elif result.startswith("end_"): + start_count -= 1 + + assert max_concurrent <= 2 + + +class TestRuWikiAdapter: + + def test_extract_title_from_url(self): + adapter = RuWikiAdapter + + title = adapter.extract_title_from_url("https://ru.wikipedia.org/wiki/Тест") + assert title == "Тест" + + title = adapter.extract_title_from_url("https://ru.wikipedia.org/wiki/Тест_статья") + assert title == "Тест статья" + + title = adapter.extract_title_from_url( + "https://ru.wikipedia.org/wiki/%D0%A2%D0%B5%D1%81%D1%82" + ) + assert title == "Тест" + + def test_extract_title_invalid_url(self): + adapter = RuWikiAdapter + + with pytest.raises(ValueError): + adapter.extract_title_from_url("https://example.com/invalid") + + with pytest.raises(ValueError): + adapter.extract_title_from_url("https://ru.wikipedia.org/invalid") + + def test_clean_wikitext(self, test_config, sample_wikitext): + """Тест очистки wiki-текста.""" + adapter = RuWikiAdapter(test_config) + + cleaned = adapter._clean_wikitext(sample_wikitext) + + assert "{{навигация" not in cleaned + assert "[[Категория:" not in cleaned + + assert "'''Тест'''" in cleaned + assert "== Определение ==" in cleaned + + @pytest.mark.asyncio + async def test_health_check_success(self, test_config): + adapter = RuWikiAdapter(test_config) + + with patch.object(adapter, "_get_client") as mock_get_client: + mock_client = AsyncMock() + mock_get_client.return_value = mock_client + + with patch("asyncio.to_thread") as mock_to_thread: + mock_to_thread.return_value = {"query": {"general": {}}} + + result = await adapter.health_check() + assert result is True + + @pytest.mark.asyncio + async def test_health_check_failure(self, test_config): + adapter = RuWikiAdapter(test_config) + + with patch.object(adapter, "_get_client") as mock_get_client: + mock_get_client.side_effect = ConnectionError("Network error") + + result = await adapter.health_check() + assert result is False + + +class TestLLMProviderAdapter: + + def test_count_tokens(self, test_config): + """Тест подсчёта токенов.""" + adapter = LLMProviderAdapter(test_config) + + count = adapter.count_tokens("Hello world") + assert count > 0 + + count = adapter.count_tokens("") + assert count == 0 + + @pytest.mark.asyncio + async def test_rpm_limiting(self, test_config): + test_config.openai_rpm = 2 + adapter = LLMProviderAdapter(test_config) + + current_time = time.time() + adapter.request_times = [current_time - 10, current_time - 5] + + start_time = time.time() + await adapter._check_rpm_limit() + elapsed = time.time() - start_time + + assert elapsed > 0.01 + + @pytest.mark.asyncio + async def test_simplify_text_token_limit_error(self, test_config): + adapter = LLMProviderAdapter(test_config) + + long_text = "word " * 2000 + + with pytest.raises(LLMTokenLimitError): + await adapter.simplify_text("Test", long_text, "template") + + @pytest.mark.asyncio + async def test_simplify_text_success(self, test_config, mock_openai_response): + adapter = LLMProviderAdapter(test_config) + + with patch.object(adapter.client.chat.completions, "create") as mock_create: + mock_create.return_value = mock_openai_response + + with patch.object(adapter, "_check_rpm_limit"): + result = await adapter.simplify_text( + title="Тест", + wiki_text="Тестовый текст", + prompt_template="### role: user\n{wiki_source_text}", + ) + + simplified_text, input_tokens, output_tokens = result + + assert "Упрощённый текст для школьников" in simplified_text + assert "###END###" not in simplified_text + assert input_tokens > 0 + assert output_tokens > 0 + + @pytest.mark.asyncio + async def test_simplify_text_openai_error(self, test_config): + adapter = LLMProviderAdapter(test_config) + + with patch.object(adapter.client.chat.completions, "create") as mock_create: + mock_create.side_effect = RateLimitError( + "Rate limit exceeded", response=None, body=None + ) + + with patch.object(adapter, "_check_rpm_limit"): + with pytest.raises(LLMRateLimitError): + await adapter.simplify_text( + title="Тест", + wiki_text="Тестовый текст", + prompt_template="### role: user\n{wiki_source_text}", + ) + + def test_parse_prompt_template(self, test_config): + adapter = LLMProviderAdapter(test_config) + + template = """### role: system +Ты помощник. + +### role: user +Задание: {task}""" + + messages = adapter._parse_prompt_template(template) + + assert len(messages) == 2 + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "Ты помощник." + assert messages[1]["role"] == "user" + assert messages[1]["content"] == "Задание: {task}" + + def test_parse_prompt_template_fallback(self, test_config): + adapter = LLMProviderAdapter(test_config) + + template = "Обычный текст без ролей" + messages = adapter._parse_prompt_template(template) + + assert len(messages) == 1 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == template + + @pytest.mark.asyncio + async def test_health_check_success(self, test_config, mock_openai_response): + adapter = LLMProviderAdapter(test_config) + + with patch.object(adapter.client.chat.completions, "create") as mock_create: + mock_create.return_value = mock_openai_response + + result = await adapter.health_check() + assert result is True + + @pytest.mark.asyncio + async def test_health_check_failure(self, test_config): + adapter = LLMProviderAdapter(test_config) + + with patch.object(adapter.client.chat.completions, "create") as mock_create: + mock_create.side_effect = APIError("API Error", response=None, body=None) + + result = await adapter.health_check() + assert result is False diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..bc459dc --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,311 @@ +import asyncio +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import pytest + +from src.dependency_injection import DependencyContainer +from src.models import ProcessingStatus +from src.sources import FileSource + + +class TestFileSourceIntegration: + + @pytest.mark.asyncio + async def test_read_urls_from_file(self, temp_input_file): + source = FileSource(temp_input_file) + + commands = [] + async for command in source.read_urls(): + commands.append(command) + + assert len(commands) >= 3 + + for command in commands: + assert command.url.startswith("https://ru.wikipedia.org/wiki/") + assert command.force_reprocess is False + + @pytest.mark.asyncio + async def test_count_urls(self, temp_input_file): + source = FileSource(temp_input_file) + + count = await source.count_urls() + assert count >= 3 + + @pytest.mark.asyncio + async def test_file_not_found(self): + source = FileSource("nonexistent.txt") + + with pytest.raises(FileNotFoundError): + async for _ in source.read_urls(): + pass + + +class TestDatabaseIntegration: + + @pytest.mark.asyncio + async def test_full_article_lifecycle(self, test_config, sample_article_data): + container = DependencyContainer(test_config) + + try: + await container.initialize() + + repository = container.get_repository() + + article = await repository.create_article(sample_article_data) + assert article.id is not None + assert article.status == ProcessingStatus.PENDING + + found_article = await repository.get_by_url(sample_article_data.url) + assert found_article is not None + assert found_article.id == article.id + + article.mark_processing() + updated_article = await repository.update_article(article) + assert updated_article.status == ProcessingStatus.PROCESSING + + article.mark_completed( + simplified_text="Упрощённый текст", + token_count_raw=100, + token_count_simplified=50, + processing_time=2.5, + ) + final_article = await repository.update_article(article) + assert final_article.status == ProcessingStatus.COMPLETED + assert final_article.simplified_text == "Упрощённый текст" + + completed_count = await repository.count_by_status(ProcessingStatus.COMPLETED) + assert completed_count == 1 + + finally: + await container.cleanup() + + @pytest.mark.asyncio + async def test_write_queue_integration(self, test_config, sample_article_data): + container = DependencyContainer(test_config) + + try: + await container.initialize() + + repository = container.get_repository() + write_queue = container.get_write_queue() + + article = await repository.create_article(sample_article_data) + + from src.models import ProcessingResult + + result = ProcessingResult.success_result( + url=article.url, + title=article.title, + raw_text=article.raw_text, + simplified_text="Упрощённый текст", + token_count_raw=100, + token_count_simplified=50, + processing_time_seconds=2.0, + ) + + updated_article = await write_queue.update_from_result(result) + + assert updated_article.status == ProcessingStatus.COMPLETED + assert updated_article.simplified_text == "Упрощённый текст" + + finally: + await container.cleanup() + + +class TestSystemIntegration: + + @pytest.mark.asyncio + async def test_dependency_container_initialization(self, test_config): + container = DependencyContainer(test_config) + + try: + await container.initialize() + + db_service = container.get_database_service() + repository = container.get_repository() + write_queue = container.get_write_queue() + ruwiki_adapter = container.get_ruwiki_adapter() + llm_adapter = container.get_llm_adapter() + simplify_service = container.get_simplify_service() + + assert db_service is not None + assert repository is not None + assert write_queue is not None + assert ruwiki_adapter is not None + assert llm_adapter is not None + assert simplify_service is not None + + checks = await container.health_check() + assert "database" in checks + assert "write_queue" in checks + + finally: + await container.cleanup() + + @pytest.mark.asyncio + async def test_runner_with_mocked_adapters(self, test_config, temp_input_file): + container = DependencyContainer(test_config) + + try: + await container.initialize() + + with ( + patch.object(container, "get_ruwiki_adapter") as mock_ruwiki, + patch.object(container, "get_llm_adapter") as mock_llm, + ): + + mock_ruwiki_instance = AsyncMock() + mock_ruwiki.return_value = mock_ruwiki_instance + + from src.adapters.ruwiki import WikiPageInfo + + mock_ruwiki_instance.fetch_page_cleaned.return_value = WikiPageInfo( + title="Тест", + content="Тестовый контент", + ) + + mock_llm_instance = AsyncMock() + mock_llm.return_value = mock_llm_instance + mock_llm_instance.simplify_text.return_value = ("Упрощённый текст", 100, 50) + mock_llm_instance.count_tokens.return_value = 100 + + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write("### role: user\n{wiki_source_text}") + test_config.prompt_template_path = f.name + + try: + runner = container.create_runner(max_workers=2) + + stats = await runner.run_from_file( + input_file=temp_input_file, + max_articles=2, + ) + + assert stats.total_processed >= 1 + assert stats.successful >= 0 + + finally: + Path(test_config.prompt_template_path).unlink(missing_ok=True) + + finally: + await container.cleanup() + + @pytest.mark.asyncio + async def test_error_handling_in_runner(self, test_config, temp_input_file): + container = DependencyContainer(test_config) + + try: + await container.initialize() + + with patch.object(container, "get_ruwiki_adapter") as mock_ruwiki: + mock_ruwiki_instance = AsyncMock() + mock_ruwiki.return_value = mock_ruwiki_instance + + from src.adapters import WikiPageNotFoundError + + mock_ruwiki_instance.fetch_page_cleaned.side_effect = WikiPageNotFoundError( + "Страница не найдена" + ) + + runner = container.create_runner(max_workers=1) + + stats = await runner.run_from_file( + input_file=temp_input_file, + max_articles=1, + ) + + assert stats.total_processed >= 1 + assert stats.failed >= 1 + assert stats.success_rate < 100.0 + + finally: + await container.cleanup() + + @pytest.mark.asyncio + async def test_concurrent_processing(self, test_config, temp_input_file): + container = DependencyContainer(test_config) + + try: + await container.initialize() + + with ( + patch.object(container, "get_ruwiki_adapter") as mock_ruwiki, + patch.object(container, "get_llm_adapter") as mock_llm, + ): + + async def delayed_fetch(*args, **kwargs): + await asyncio.sleep(0.1) + from src.adapters.ruwiki import WikiPageInfo + + return WikiPageInfo(title="Тест", content="Контент") + + async def delayed_simplify(*args, **kwargs): + await asyncio.sleep(0.1) + return ("Упрощённый", 100, 50) + + mock_ruwiki_instance = AsyncMock() + mock_ruwiki.return_value = mock_ruwiki_instance + mock_ruwiki_instance.fetch_page_cleaned.side_effect = delayed_fetch + + mock_llm_instance = AsyncMock() + mock_llm.return_value = mock_llm_instance + mock_llm_instance.simplify_text.side_effect = delayed_simplify + mock_llm_instance.count_tokens.return_value = 100 + + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write("### role: user\n{wiki_source_text}") + test_config.prompt_template_path = f.name + + try: + import time + + start_time = time.time() + + runner = container.create_runner(max_workers=3) + stats = await runner.run_from_file( + input_file=temp_input_file, + max_articles=3, + ) + + elapsed_time = time.time() - start_time + + assert elapsed_time < 1.0 + assert stats.total_processed >= 1 + + finally: + Path(test_config.prompt_template_path).unlink(missing_ok=True) + + finally: + await container.cleanup() + + @pytest.mark.asyncio + async def test_health_check_integration(self, test_config): + container = DependencyContainer(test_config) + + try: + await container.initialize() + + with ( + patch.object(container, "get_ruwiki_adapter") as mock_ruwiki, + patch.object(container, "get_llm_adapter") as mock_llm, + ): + + mock_ruwiki_instance = AsyncMock() + mock_ruwiki.return_value = mock_ruwiki_instance + mock_ruwiki_instance.health_check.return_value = True + + mock_llm_instance = AsyncMock() + mock_llm.return_value = mock_llm_instance + mock_llm_instance.health_check.return_value = True + + checks = await container.health_check() + + assert checks["database"] is True + assert checks["write_queue"] is True + assert checks["ruwiki"] is True + assert checks["llm"] is True + + finally: + await container.cleanup() diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..c8d9d17 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,263 @@ +from datetime import datetime + +import pytest + +from src.models import ( + AppConfig, + Article, + ProcessingResult, + ProcessingStats, + ProcessingStatus, + SimplifyCommand, +) + + +class TestAppConfig: + + def test_default_values(self): + with pytest.raises(ValueError): + AppConfig() + + def test_valid_config(self): + config = AppConfig( + openai_api_key="test_key", + db_path="./test.db", + ) + + assert config.openai_api_key == "test_key" + assert config.openai_model == "gpt-4o-mini" + assert config.openai_temperature == 0.0 + assert config.max_concurrent_llm == 5 + assert config.openai_rpm == 200 + + def test_db_url_generation(self): + config = AppConfig( + openai_api_key="test_key", + db_path="./test.db", + ) + + assert config.db_url == "sqlite+aiosqlite:///test.db" + assert config.sync_db_url == "sqlite:///test.db" + + def test_validation_constraints(self): + with pytest.raises(ValueError): + AppConfig( + openai_api_key="test_key", + openai_temperature=3.0, + ) + + with pytest.raises(ValueError): + AppConfig( + openai_api_key="test_key", + max_concurrent_llm=100, + ) + + +class TestArticle: + + def test_article_creation(self, sample_article_data): + article = Article( + url=sample_article_data.url, + title=sample_article_data.title, + raw_text=sample_article_data.raw_text, + ) + + assert article.url == sample_article_data.url + assert article.title == sample_article_data.title + assert article.status == ProcessingStatus.PENDING + assert article.simplified_text is None + assert isinstance(article.created_at, datetime) + + def test_mark_processing(self, sample_article): + article = sample_article + original_updated = article.updated_at + + article.mark_processing() + + assert article.status == ProcessingStatus.PROCESSING + assert article.updated_at != original_updated + + def test_mark_completed(self, sample_article): + article = sample_article + simplified_text = "Упрощённый текст" + + article.mark_completed( + simplified_text=simplified_text, + token_count_raw=100, + token_count_simplified=50, + processing_time=2.5, + ) + + assert article.status == ProcessingStatus.COMPLETED + assert article.simplified_text == simplified_text + assert article.token_count_raw == 100 + assert article.token_count_simplified == 50 + assert article.processing_time_seconds == 2.5 + assert article.error_message is None + assert article.updated_at is not None + + def test_mark_failed(self, sample_article): + article = sample_article + error_message = "Тестовая ошибка" + + article.mark_failed(error_message) + + assert article.status == ProcessingStatus.FAILED + assert article.error_message == error_message + assert article.updated_at is not None + + def test_mark_failed_long_error(self, sample_article): + article = sample_article + long_error = "x" * 1500 + + article.mark_failed(long_error) + + assert len(article.error_message) == 1000 + assert article.error_message == "x" * 1000 + + +class TestSimplifyCommand: + + def test_command_creation(self): + url = "https://ru.wikipedia.org/wiki/Тест" + command = SimplifyCommand(url=url) + + assert command.url == url + assert command.force_reprocess is False + + def test_command_with_force(self): + url = "https://ru.wikipedia.org/wiki/Тест" + command = SimplifyCommand(url=url, force_reprocess=True) + + assert command.url == url + assert command.force_reprocess is True + + def test_command_string_representation(self): + url = "https://ru.wikipedia.org/wiki/Тест" + command = SimplifyCommand(url=url, force_reprocess=True) + + expected = f"SimplifyCommand(url='{url}', force=True)" + assert str(command) == expected + + +class TestProcessingResult: + + def test_success_result_creation(self): + result = ProcessingResult.success_result( + url="https://ru.wikipedia.org/wiki/Тест", + title="Тест", + raw_text="Исходный текст", + simplified_text="Упрощённый текст", + token_count_raw=100, + token_count_simplified=50, + processing_time_seconds=2.5, + ) + + assert result.success is True + assert result.url == "https://ru.wikipedia.org/wiki/Тест" + assert result.title == "Тест" + assert result.raw_text == "Исходный текст" + assert result.simplified_text == "Упрощённый текст" + assert result.token_count_raw == 100 + assert result.token_count_simplified == 50 + assert result.processing_time_seconds == 2.5 + assert result.error_message is None + + def test_failure_result_creation(self): + result = ProcessingResult.failure_result( + url="https://ru.wikipedia.org/wiki/Тест", + error_message="Тестовая ошибка", + ) + + assert result.success is False + assert result.url == "https://ru.wikipedia.org/wiki/Тест" + assert result.error_message == "Тестовая ошибка" + assert result.title is None + assert result.raw_text is None + assert result.simplified_text is None + + +class TestProcessingStats: + + def test_initial_stats(self): + stats = ProcessingStats() + + assert stats.total_processed == 0 + assert stats.successful == 0 + assert stats.failed == 0 + assert stats.skipped == 0 + assert stats.success_rate == 0.0 + assert stats.average_processing_time == 0.0 + + def test_add_successful_result(self): + stats = ProcessingStats() + result = ProcessingResult.success_result( + url="test", + title="Test", + raw_text="text", + simplified_text="simple", + token_count_raw=100, + token_count_simplified=50, + processing_time_seconds=2.0, + ) + + stats.add_result(result) + + assert stats.total_processed == 1 + assert stats.successful == 1 + assert stats.failed == 0 + assert stats.success_rate == 100.0 + assert stats.average_processing_time == 2.0 + + def test_add_failed_result(self): + stats = ProcessingStats() + result = ProcessingResult.failure_result("test", "error") + + stats.add_result(result) + + assert stats.total_processed == 1 + assert stats.successful == 0 + assert stats.failed == 1 + assert stats.success_rate == 0.0 + + def test_mixed_results(self): + stats = ProcessingStats() + + success_result = ProcessingResult.success_result( + url="test1", + title="Test1", + raw_text="text", + simplified_text="simple", + token_count_raw=100, + token_count_simplified=50, + processing_time_seconds=3.0, + ) + stats.add_result(success_result) + + failure_result = ProcessingResult.failure_result("test2", "error") + stats.add_result(failure_result) + + success_result2 = ProcessingResult.success_result( + url="test3", + title="Test3", + raw_text="text", + simplified_text="simple", + token_count_raw=100, + token_count_simplified=50, + processing_time_seconds=1.0, + ) + stats.add_result(success_result2) + + assert stats.total_processed == 3 + assert stats.successful == 2 + assert stats.failed == 1 + assert stats.success_rate == pytest.approx(66.67, rel=1e-2) + assert stats.average_processing_time == 2.0 + + def test_add_skipped(self): + stats = ProcessingStats() + + stats.add_skipped() + stats.add_skipped() + + assert stats.skipped == 2 diff --git a/tests/test_services.py b/tests/test_services.py new file mode 100644 index 0000000..b0e9c0d --- /dev/null +++ b/tests/test_services.py @@ -0,0 +1,370 @@ +"""Тесты для сервисов.""" + +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from src.adapters import LLMProviderAdapter, RuWikiAdapter +from src.adapters.ruwiki import WikiPageInfo +from src.models import ProcessingResult, SimplifyCommand +from src.services import ( + AsyncWriteQueue, + DatabaseService, + RecursiveCharacterTextSplitter, + SimplifyService, +) + + +class TestRecursiveCharacterTextSplitter: + def test_split_short_text(self): + splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20) + + short_text = "Это короткий текст." + chunks = splitter.split_text(short_text) + + assert len(chunks) == 1 + assert chunks[0] == short_text + + def test_split_long_text(self): + splitter = RecursiveCharacterTextSplitter(chunk_size=50, chunk_overlap=10) + + long_text = "Это очень длинный текст. " * 10 + chunks = splitter.split_text(long_text) + + assert len(chunks) > 1 + + for chunk in chunks: + assert len(chunk) <= 60 + + def test_split_by_paragraphs(self): + splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10) + + text = "Первый абзац.\n\nВторой абзац.\n\nТретий абзац." + chunks = splitter.split_text(text) + + assert len(chunks) >= 2 + + def test_split_empty_text(self): + splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20) + + chunks = splitter.split_text("") + assert chunks == [] + + def test_custom_length_function(self): + def word_count(text: str) -> int: + return len(text.split()) + + splitter = RecursiveCharacterTextSplitter( + chunk_size=5, + chunk_overlap=2, + length_function=word_count, + ) + + text = "Один два три четыре пять шесть семь восемь девять десять" + chunks = splitter.split_text(text) + + assert len(chunks) > 1 + + for chunk in chunks: + word_count_in_chunk = len(chunk.split()) + assert word_count_in_chunk <= 7 + + def test_create_chunks_with_metadata(self): + splitter = RecursiveCharacterTextSplitter(chunk_size=50, chunk_overlap=10) + + text = "Это тестовый текст. " * 10 + title = "Тестовая статья" + + chunks_with_metadata = splitter.create_chunks_with_metadata(text, title) + + assert len(chunks_with_metadata) > 1 + + for i, chunk_data in enumerate(chunks_with_metadata): + assert "text" in chunk_data + assert chunk_data["title"] == title + assert chunk_data["chunk_index"] == i + assert chunk_data["total_chunks"] == len(chunks_with_metadata) + assert "chunk_size" in chunk_data + + +class TestDatabaseService: + + @pytest.mark.asyncio + async def test_initialize_database(self, test_config): + db_service = DatabaseService(test_config) + + await db_service.initialize_database() + + assert Path(test_config.db_path).exists() + + assert await db_service.health_check() is True + + db_service.close() + + @pytest.mark.asyncio + async def test_get_connection(self, test_config): + db_service = DatabaseService(test_config) + await db_service.initialize_database() + + async with db_service.get_connection() as conn: + cursor = await conn.execute("SELECT 1") + result = await cursor.fetchone() + assert result[0] == 1 + + db_service.close() + + +class TestAsyncWriteQueue: + + @pytest.mark.asyncio + async def test_start_stop(self): + mock_repository = AsyncMock() + queue = AsyncWriteQueue(mock_repository, max_batch_size=5) + + await queue.start() + assert queue._worker_task is not None + + await queue.stop(timeout=1.0) + assert queue._worker_task.done() + + @pytest.mark.asyncio + async def test_update_from_result_success(self, sample_article, simplified_text): + mock_repository = AsyncMock() + mock_repository.get_by_url.return_value = sample_article + mock_repository.update_article.return_value = sample_article + + queue = AsyncWriteQueue(mock_repository, max_batch_size=1) + await queue.start() + + try: + result = ProcessingResult.success_result( + url=sample_article.url, + title=sample_article.title, + raw_text=sample_article.raw_text, + simplified_text=simplified_text, + token_count_raw=100, + token_count_simplified=50, + processing_time_seconds=2.0, + ) + + updated_article = await queue.update_from_result(result) + + assert updated_article.simplified_text == simplified_text + mock_repository.get_by_url.assert_called_once_with(sample_article.url) + mock_repository.update_article.assert_called_once() + + finally: + await queue.stop(timeout=1.0) + + @pytest.mark.asyncio + async def test_update_from_result_failure(self, sample_article): + mock_repository = AsyncMock() + mock_repository.get_by_url.return_value = sample_article + mock_repository.update_article.return_value = sample_article + + queue = AsyncWriteQueue(mock_repository, max_batch_size=1) + await queue.start() + + try: + result = ProcessingResult.failure_result( + url=sample_article.url, + error_message="Тестовая ошибка", + ) + + updated_article = await queue.update_from_result(result) + + assert updated_article.error_message == "Тестовая ошибка" + mock_repository.update_article.assert_called_once() + + finally: + await queue.stop(timeout=1.0) + + def test_stats(self): + mock_repository = AsyncMock() + queue = AsyncWriteQueue(mock_repository) + + stats = queue.stats + + assert "total_operations" in stats + assert "failed_operations" in stats + assert "queue_size" in stats + assert stats["total_operations"] == 0 + + +class TestSimplifyService: + + @pytest.fixture + def mock_adapters_and_queue(self, test_config): + mock_ruwiki = AsyncMock(spec=RuWikiAdapter) + mock_llm = AsyncMock(spec=LLMProviderAdapter) + mock_repository = AsyncMock() + mock_write_queue = AsyncMock() + + return mock_ruwiki, mock_llm, mock_repository, mock_write_queue + + def test_service_initialization(self, test_config, mock_adapters_and_queue): + mock_ruwiki, mock_llm, mock_repository, mock_write_queue = mock_adapters_and_queue + + service = SimplifyService( + config=test_config, + ruwiki_adapter=mock_ruwiki, + llm_adapter=mock_llm, + repository=mock_repository, + write_queue=mock_write_queue, + ) + + assert service.config == test_config + assert service.ruwiki_adapter == mock_ruwiki + assert service.llm_adapter == mock_llm + assert isinstance(service.text_splitter, RecursiveCharacterTextSplitter) + + @pytest.mark.asyncio + async def test_get_prompt_template(self, test_config, mock_adapters_and_queue): + mock_ruwiki, mock_llm, mock_repository, mock_write_queue = mock_adapters_and_queue + + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write("### role: system\nТы помощник") + temp_prompt_path = f.name + + test_config.prompt_template_path = temp_prompt_path + + service = SimplifyService( + config=test_config, + ruwiki_adapter=mock_ruwiki, + llm_adapter=mock_llm, + repository=mock_repository, + write_queue=mock_write_queue, + ) + + try: + template = await service.get_prompt_template() + assert "### role: system" in template + assert "Ты помощник" in template + + template2 = await service.get_prompt_template() + assert template == template2 + + finally: + Path(temp_prompt_path).unlink(missing_ok=True) + + @pytest.mark.asyncio + async def test_get_prompt_template_not_found(self, test_config, mock_adapters_and_queue): + mock_ruwiki, mock_llm, mock_repository, mock_write_queue = mock_adapters_and_queue + + test_config.prompt_template_path = "nonexistent.txt" + + service = SimplifyService( + config=test_config, + ruwiki_adapter=mock_ruwiki, + llm_adapter=mock_llm, + repository=mock_repository, + write_queue=mock_write_queue, + ) + + with pytest.raises(FileNotFoundError): + await service.get_prompt_template() + + @pytest.mark.asyncio + async def test_process_command_success( + self, test_config, mock_adapters_and_queue, sample_wikitext, simplified_text + ): + mock_ruwiki, mock_llm, mock_repository, mock_write_queue = mock_adapters_and_queue + + wiki_page_info = WikiPageInfo( + title="Тест", + content=sample_wikitext, + ) + mock_ruwiki.fetch_page_cleaned.return_value = wiki_page_info + mock_llm.simplify_text.return_value = (simplified_text, 100, 50) + mock_llm.count_tokens.return_value = 100 + + mock_repository.get_by_url.return_value = None + mock_repository.create_article.return_value = MagicMock(id=1) + mock_repository.update_article.return_value = MagicMock() + + mock_write_queue.update_from_result.return_value = MagicMock() + + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write("### role: user\n{wiki_source_text}") + test_config.prompt_template_path = f.name + + service = SimplifyService( + config=test_config, + ruwiki_adapter=mock_ruwiki, + llm_adapter=mock_llm, + repository=mock_repository, + write_queue=mock_write_queue, + ) + + try: + command = SimplifyCommand(url="https://ru.wikipedia.org/wiki/Тест") + result = await service.process_command(command) + + assert result.success is True + assert result.title == "Тест" + assert result.simplified_text == simplified_text + assert result.token_count_raw == 100 + assert result.token_count_simplified == 50 + + mock_ruwiki.fetch_page_cleaned.assert_called_once() + mock_llm.simplify_text.assert_called_once() + mock_write_queue.update_from_result.assert_called_once() + + finally: + Path(test_config.prompt_template_path).unlink(missing_ok=True) + + @pytest.mark.asyncio + async def test_process_command_skip_existing( + self, test_config, mock_adapters_and_queue, completed_article + ): + mock_ruwiki, mock_llm, mock_repository, mock_write_queue = mock_adapters_and_queue + + mock_repository.get_by_url.return_value = completed_article + + service = SimplifyService( + config=test_config, + ruwiki_adapter=mock_ruwiki, + llm_adapter=mock_llm, + repository=mock_repository, + write_queue=mock_write_queue, + ) + + command = SimplifyCommand(url=completed_article.url, force_reprocess=False) + result = await service.process_command(command) + + assert result.success is True + assert result.title == completed_article.title + + mock_ruwiki.fetch_page_cleaned.assert_not_called() + mock_llm.simplify_text.assert_not_called() + + @pytest.mark.asyncio + async def test_health_check(self, test_config, mock_adapters_and_queue): + mock_ruwiki, mock_llm, mock_repository, mock_write_queue = mock_adapters_and_queue + + mock_ruwiki.health_check.return_value = True + mock_llm.health_check.return_value = True + + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write("test prompt") + test_config.prompt_template_path = f.name + + service = SimplifyService( + config=test_config, + ruwiki_adapter=mock_ruwiki, + llm_adapter=mock_llm, + repository=mock_repository, + write_queue=mock_write_queue, + ) + + try: + checks = await service.health_check() + + assert checks["ruwiki"] is True + assert checks["llm"] is True + assert checks["prompt_template"] is True + + finally: + Path(test_config.prompt_template_path).unlink(missing_ok=True)