ruwiki-test/tests/conftest.py

191 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import tempfile
from collections.abc import Generator
from datetime import datetime, timezone
from pathlib import Path
from typing import AsyncGenerator
from unittest.mock import MagicMock
import pytest
import pytest_asyncio
import structlog
import logging
from src.models import AppConfig
from src.models.article_dto import ArticleDTO, ArticleStatus
from src.services import ArticleRepository, DatabaseService
def level_to_int(logger, method_name, event_dict):
if isinstance(event_dict.get("level"), str):
try:
event_dict["level"] = getattr(logging, event_dict["level"].upper())
except Exception:
pass
return event_dict
@pytest.fixture(autouse=True, scope="session")
def configure_structlog():
import tenacity
logging.basicConfig(level=logging.DEBUG)
structlog.configure(
processors=[
level_to_int,
structlog.processors.TimeStamper(fmt="iso"),
structlog.dev.ConsoleRenderer(),
],
wrapper_class=structlog.make_filtering_bound_logger(logging.DEBUG),
)
tenacity.logger = structlog.get_logger("tenacity")
@pytest.fixture(autouse=True, scope="session")
def patch_tenacity_before_sleep_log():
import logging
import tenacity
from tenacity.before_sleep import before_sleep_log
original_before_sleep_log = tenacity.before_sleep_log
def patched_before_sleep_log(logger, log_level):
if isinstance(log_level, str):
log_level = getattr(logging, log_level.upper(), logging.WARNING)
return original_before_sleep_log(logger, log_level)
tenacity.before_sleep_log = patched_before_sleep_log
@pytest.fixture(scope="session")
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
def test_config() -> AppConfig:
with tempfile.TemporaryDirectory() as temp_dir:
db_path = Path(temp_dir) / "test.db"
return AppConfig(
openai_api_key="test_key",
openai_model="gpt-4o-mini",
db_path=str(db_path),
max_concurrent_llm=2,
openai_rpm=10,
max_concurrent_wiki=5,
prompt_template_path="src/prompt.txt",
log_level="DEBUG",
)
@pytest.fixture
def mock_openai_response():
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Упрощённый текст для школьников"
mock_response.usage.prompt_tokens = 100
mock_response.usage.completion_tokens = 50
mock_response.__await__ = lambda: iter([mock_response])
return mock_response
@pytest_asyncio.fixture
async def database_service(test_config: AppConfig) -> AsyncGenerator[DatabaseService, None]:
service = DatabaseService(test_config)
await service.initialize_database()
yield service
@pytest_asyncio.fixture
async def repository(database_service: DatabaseService) -> AsyncGenerator[ArticleRepository, None]:
repo = ArticleRepository(database_service)
yield repo
@pytest.fixture
def sample_wiki_urls() -> list[str]:
return [
"https://ru.ruwiki.ru/wiki/Тест",
"https://ru.ruwiki.ru/wiki/Пример",
"https://ru.ruwiki.ru/wiki/Образец",
]
@pytest.fixture
def sample_wikitext() -> str:
return """'''Тест''' — это проверка чего-либо.
== Определение ==
Тест может проводиться для различных целей:
* Проверка знаний
* Проверка работоспособности
* Проверка качества
== История ==
Тесты использовались с древних времён."""
@pytest.fixture
def simplified_text() -> str:
return """Тест — это проверка чего-либо для школьников.
Что такое тест
Тест помогает проверить:
* Знания учеников
* Как работают устройства
* Качество продуктов
Когда появились тесты
Люди проверяли друг друга очень давно."""
@pytest.fixture
def sample_article_dto() -> ArticleDTO:
return ArticleDTO(
url="https://ru.ruwiki.ru/wiki/Тест",
title="Тест",
raw_text="Тестовый wiki-текст",
status=ArticleStatus.PENDING,
created_at=datetime.now(timezone.utc),
)
@pytest_asyncio.fixture
async def sample_article_in_db(
repository: ArticleRepository, sample_article_dto: ArticleDTO
) -> AsyncGenerator[ArticleDTO, None]:
article = await repository.create_article(
url=sample_article_dto.url,
title=sample_article_dto.title,
raw_text=sample_article_dto.raw_text,
)
yield article
@pytest.fixture
def temp_input_file(sample_wiki_urls: list[str]) -> Generator[str, None, None]:
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8") as f:
for url in sample_wiki_urls:
f.write(f"{url}\n")
temp_path = f.name
yield temp_path
Path(temp_path).unlink(missing_ok=True)
@pytest_asyncio.fixture
async def multiple_articles_in_db(
repository: ArticleRepository, sample_wiki_urls: list[str]
) -> AsyncGenerator[list[ArticleDTO], None]:
articles = []
for i, url in enumerate(sample_wiki_urls):
article = await repository.create_article(
url=url,
title=f"Test Article {i+1}",
raw_text=f"Content for article {i+1}",
)
articles.append(article)
yield articles