ruwiki-test/tests/test_adapters.py

303 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openai import APIError, RateLimitError
from src.adapters import (
CircuitBreaker,
CircuitBreakerError,
LLMProviderAdapter,
LLMRateLimitError,
LLMTokenLimitError,
RateLimiter,
RuWikiAdapter,
)
class TestCircuitBreaker:
@pytest.mark.asyncio
async def test_successful_call(self):
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
async def test_func():
return "success"
result = await cb.call(test_func)
assert result == "success"
@pytest.mark.asyncio
async def test_failure_accumulation(self):
cb = CircuitBreaker(failure_threshold=2, recovery_timeout=1)
async def failing_func():
raise ValueError("Test error")
with pytest.raises(ValueError):
await cb.call(failing_func)
with pytest.raises(ValueError):
await cb.call(failing_func)
with pytest.raises(CircuitBreakerError):
await cb.call(failing_func)
@pytest.mark.asyncio
async def test_recovery(self):
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1)
async def failing_func():
raise ValueError("Test error")
async def success_func():
return "recovered"
with pytest.raises(ValueError):
await cb.call(failing_func)
with pytest.raises(CircuitBreakerError):
await cb.call(failing_func)
await asyncio.sleep(0.2)
result = await cb.call(success_func)
assert result == "recovered"
class TestRateLimiter:
@pytest.mark.asyncio
async def test_concurrency_limit(self):
limiter = RateLimiter(max_concurrent=2)
results = []
async def test_task(task_id: int):
async with limiter:
results.append(f"start_{task_id}")
await asyncio.sleep(0.1)
results.append(f"end_{task_id}")
tasks = [test_task(i) for i in range(3)]
await asyncio.gather(*tasks)
start_count = 0
max_concurrent = 0
for result in results:
if result.startswith("start_"):
start_count += 1
max_concurrent = max(max_concurrent, start_count)
elif result.startswith("end_"):
start_count -= 1
assert max_concurrent <= 2
class TestRuWikiAdapter:
def test_extract_title_from_url(self):
adapter = RuWikiAdapter
title = adapter.extract_title_from_url("https://ru.ruwiki.ru/wiki/Тест")
assert title == "Тест"
title = adapter.extract_title_from_url("https://ru.ruwiki.ru/wiki/Тест_статья")
assert title == "Тест статья"
title = adapter.extract_title_from_url("https://ru.ruwiki.ru/wiki/%D0%A2%D0%B5%D1%81%D1%82")
assert title == "Тест"
def test_extract_title_invalid_url(self):
adapter = RuWikiAdapter
with pytest.raises(ValueError):
adapter.extract_title_from_url("https://example.com/invalid")
with pytest.raises(ValueError):
adapter.extract_title_from_url("https://ru.ruwiki.ru/invalid")
def test_clean_wikitext(self, test_config, sample_wikitext):
adapter = RuWikiAdapter(test_config)
cleaned = adapter._clean_wikitext(sample_wikitext)
assert "{{навигация" not in cleaned
assert "[[Категория:" not in cleaned
assert "'''Тест'''" in cleaned
assert "== Определение ==" in cleaned
@pytest.mark.asyncio
async def test_health_check_success(self, test_config):
adapter = RuWikiAdapter(test_config)
with patch.object(adapter, "_get_client") as mock_get_client:
mock_client = AsyncMock()
mock_get_client.return_value = mock_client
with patch("asyncio.to_thread") as mock_to_thread:
mock_to_thread.return_value = {"query": {"general": {}}}
result = await adapter.health_check()
assert result is True
@pytest.mark.asyncio
async def test_health_check_failure(self, test_config):
adapter = RuWikiAdapter(test_config)
with patch.object(adapter, "_get_client") as mock_get_client:
mock_get_client.side_effect = ConnectionError("Network error")
result = await adapter.health_check()
assert result is False
class TestLLMProviderAdapter:
def test_count_tokens(self, test_config):
adapter = LLMProviderAdapter(test_config)
count = adapter.count_tokens("Hello world")
assert count > 0
count = adapter.count_tokens("")
assert count == 0
@pytest.mark.asyncio
async def test_rpm_limiting(self, test_config):
test_config.openai_rpm = 2
adapter = LLMProviderAdapter(test_config)
current_time = time.time()
adapter.request_times = [current_time - 10, current_time - 5]
start_time = time.time()
await adapter._check_rpm_limit()
elapsed = time.time() - start_time
assert elapsed > 0.01
@pytest.mark.asyncio
async def test_simplify_text_token_limit_error(self, test_config):
adapter = LLMProviderAdapter(test_config)
long_text = "word " * 2000
with patch.object(adapter, "_check_rpm_limit"):
with patch.object(adapter, "count_tokens", return_value=50000):
with patch.object(
adapter,
"_make_completion_request",
side_effect=LLMTokenLimitError("Token limit exceeded"),
):
with pytest.raises(LLMTokenLimitError):
await adapter.simplify_text("Test", long_text, "template")
@pytest.mark.asyncio
async def test_simplify_text_success(self, test_config, mock_openai_response):
adapter = LLMProviderAdapter(test_config)
with patch.object(
adapter.client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.return_value = mock_openai_response
with patch.object(adapter, "_check_rpm_limit"):
result = await adapter.simplify_text(
title="Тест",
wiki_text="Тестовый текст",
prompt_template="### role: user\n{wiki_source_text}",
)
simplified_text, input_tokens, output_tokens = result
assert "Упрощённый текст для школьников" in simplified_text
assert "###END###" not in simplified_text
assert input_tokens > 0
assert output_tokens > 0
@pytest.mark.asyncio
async def test_simplify_text_openai_error(self, test_config):
from tenacity import AsyncRetrying, before_sleep_log
import structlog
import logging
adapter = LLMProviderAdapter(test_config)
good_logger = structlog.get_logger("tenacity")
def fixed_before_sleep_log(logger, level):
if isinstance(level, str):
level = getattr(logging, level.upper(), logging.WARNING)
return before_sleep_log(logger, level)
with patch("src.adapters.base.AsyncRetrying") as mock_retrying:
mock_retrying.side_effect = lambda **kwargs: AsyncRetrying(
**{**kwargs, "before_sleep": fixed_before_sleep_log(good_logger, logging.WARNING)}
)
with patch.object(
adapter.client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_response = MagicMock()
mock_create.side_effect = RateLimitError(
"Rate limit exceeded", response=mock_response, body=None
)
with patch.object(adapter, "_check_rpm_limit"):
with pytest.raises(LLMRateLimitError):
await adapter.simplify_text(
title="Тест",
wiki_text="Тестовый текст",
prompt_template="### role: user\n{wiki_source_text}",
)
def test_parse_prompt_template(self, test_config):
adapter = LLMProviderAdapter(test_config)
template = """### role: system
Ты помощник.
### role: user
Задание: {task}"""
messages = adapter._parse_prompt_template(template)
assert len(messages) == 2
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "Ты помощник."
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "Задание: {task}"
def test_parse_prompt_template_fallback(self, test_config):
adapter = LLMProviderAdapter(test_config)
template = "Обычный текст без ролей"
messages = adapter._parse_prompt_template(template)
assert len(messages) == 1
assert messages[0]["role"] == "user"
assert messages[0]["content"] == template
@pytest.mark.asyncio
async def test_health_check_success(self, test_config, mock_openai_response):
adapter = LLMProviderAdapter(test_config)
with patch.object(
adapter.client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.return_value = mock_openai_response
result = await adapter.health_check()
assert result is True
@pytest.mark.asyncio
async def test_health_check_failure(self, test_config):
adapter = LLMProviderAdapter(test_config)
with patch.object(
adapter.client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_request = MagicMock()
mock_create.side_effect = APIError("API Error", body=None, request=mock_request)
result = await adapter.health_check()
assert result is False