303 lines
10 KiB
Python
303 lines
10 KiB
Python
import asyncio
|
||
import time
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
from openai import APIError, RateLimitError
|
||
|
||
from src.adapters import (
|
||
CircuitBreaker,
|
||
CircuitBreakerError,
|
||
LLMProviderAdapter,
|
||
LLMRateLimitError,
|
||
LLMTokenLimitError,
|
||
RateLimiter,
|
||
RuWikiAdapter,
|
||
)
|
||
|
||
|
||
class TestCircuitBreaker:
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_successful_call(self):
|
||
cb = CircuitBreaker(failure_threshold=3, recovery_timeout=1)
|
||
|
||
async def test_func():
|
||
return "success"
|
||
|
||
result = await cb.call(test_func)
|
||
assert result == "success"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_failure_accumulation(self):
|
||
cb = CircuitBreaker(failure_threshold=2, recovery_timeout=1)
|
||
|
||
async def failing_func():
|
||
raise ValueError("Test error")
|
||
|
||
with pytest.raises(ValueError):
|
||
await cb.call(failing_func)
|
||
|
||
with pytest.raises(ValueError):
|
||
await cb.call(failing_func)
|
||
|
||
with pytest.raises(CircuitBreakerError):
|
||
await cb.call(failing_func)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_recovery(self):
|
||
cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1)
|
||
|
||
async def failing_func():
|
||
raise ValueError("Test error")
|
||
|
||
async def success_func():
|
||
return "recovered"
|
||
|
||
with pytest.raises(ValueError):
|
||
await cb.call(failing_func)
|
||
|
||
with pytest.raises(CircuitBreakerError):
|
||
await cb.call(failing_func)
|
||
|
||
await asyncio.sleep(0.2)
|
||
|
||
result = await cb.call(success_func)
|
||
assert result == "recovered"
|
||
|
||
|
||
class TestRateLimiter:
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_concurrency_limit(self):
|
||
limiter = RateLimiter(max_concurrent=2)
|
||
results = []
|
||
|
||
async def test_task(task_id: int):
|
||
async with limiter:
|
||
results.append(f"start_{task_id}")
|
||
await asyncio.sleep(0.1)
|
||
results.append(f"end_{task_id}")
|
||
|
||
tasks = [test_task(i) for i in range(3)]
|
||
await asyncio.gather(*tasks)
|
||
|
||
start_count = 0
|
||
max_concurrent = 0
|
||
|
||
for result in results:
|
||
if result.startswith("start_"):
|
||
start_count += 1
|
||
max_concurrent = max(max_concurrent, start_count)
|
||
elif result.startswith("end_"):
|
||
start_count -= 1
|
||
|
||
assert max_concurrent <= 2
|
||
|
||
|
||
class TestRuWikiAdapter:
|
||
|
||
def test_extract_title_from_url(self):
|
||
adapter = RuWikiAdapter
|
||
|
||
title = adapter.extract_title_from_url("https://ru.ruwiki.ru/wiki/Тест")
|
||
assert title == "Тест"
|
||
|
||
title = adapter.extract_title_from_url("https://ru.ruwiki.ru/wiki/Тест_статья")
|
||
assert title == "Тест статья"
|
||
|
||
title = adapter.extract_title_from_url("https://ru.ruwiki.ru/wiki/%D0%A2%D0%B5%D1%81%D1%82")
|
||
assert title == "Тест"
|
||
|
||
def test_extract_title_invalid_url(self):
|
||
adapter = RuWikiAdapter
|
||
|
||
with pytest.raises(ValueError):
|
||
adapter.extract_title_from_url("https://example.com/invalid")
|
||
|
||
with pytest.raises(ValueError):
|
||
adapter.extract_title_from_url("https://ru.ruwiki.ru/invalid")
|
||
|
||
def test_clean_wikitext(self, test_config, sample_wikitext):
|
||
adapter = RuWikiAdapter(test_config)
|
||
|
||
cleaned = adapter._clean_wikitext(sample_wikitext)
|
||
|
||
assert "{{навигация" not in cleaned
|
||
assert "[[Категория:" not in cleaned
|
||
|
||
assert "'''Тест'''" in cleaned
|
||
assert "== Определение ==" in cleaned
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_success(self, test_config):
|
||
adapter = RuWikiAdapter(test_config)
|
||
|
||
with patch.object(adapter, "_get_client") as mock_get_client:
|
||
mock_client = AsyncMock()
|
||
mock_get_client.return_value = mock_client
|
||
|
||
with patch("asyncio.to_thread") as mock_to_thread:
|
||
mock_to_thread.return_value = {"query": {"general": {}}}
|
||
|
||
result = await adapter.health_check()
|
||
assert result is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_failure(self, test_config):
|
||
adapter = RuWikiAdapter(test_config)
|
||
|
||
with patch.object(adapter, "_get_client") as mock_get_client:
|
||
mock_get_client.side_effect = ConnectionError("Network error")
|
||
|
||
result = await adapter.health_check()
|
||
assert result is False
|
||
|
||
|
||
class TestLLMProviderAdapter:
|
||
|
||
def test_count_tokens(self, test_config):
|
||
adapter = LLMProviderAdapter(test_config)
|
||
|
||
count = adapter.count_tokens("Hello world")
|
||
assert count > 0
|
||
|
||
count = adapter.count_tokens("")
|
||
assert count == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_rpm_limiting(self, test_config):
|
||
test_config.openai_rpm = 2
|
||
adapter = LLMProviderAdapter(test_config)
|
||
|
||
current_time = time.time()
|
||
adapter.request_times = [current_time - 10, current_time - 5]
|
||
|
||
start_time = time.time()
|
||
await adapter._check_rpm_limit()
|
||
elapsed = time.time() - start_time
|
||
|
||
assert elapsed > 0.01
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_simplify_text_token_limit_error(self, test_config):
|
||
adapter = LLMProviderAdapter(test_config)
|
||
long_text = "word " * 2000
|
||
with patch.object(adapter, "_check_rpm_limit"):
|
||
with patch.object(adapter, "count_tokens", return_value=50000):
|
||
with patch.object(
|
||
adapter,
|
||
"_make_completion_request",
|
||
side_effect=LLMTokenLimitError("Token limit exceeded"),
|
||
):
|
||
with pytest.raises(LLMTokenLimitError):
|
||
await adapter.simplify_text("Test", long_text, "template")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_simplify_text_success(self, test_config, mock_openai_response):
|
||
adapter = LLMProviderAdapter(test_config)
|
||
|
||
with patch.object(
|
||
adapter.client.chat.completions, "create", new_callable=AsyncMock
|
||
) as mock_create:
|
||
mock_create.return_value = mock_openai_response
|
||
|
||
with patch.object(adapter, "_check_rpm_limit"):
|
||
result = await adapter.simplify_text(
|
||
title="Тест",
|
||
wiki_text="Тестовый текст",
|
||
prompt_template="### role: user\n{wiki_source_text}",
|
||
)
|
||
|
||
simplified_text, input_tokens, output_tokens = result
|
||
|
||
assert "Упрощённый текст для школьников" in simplified_text
|
||
assert "###END###" not in simplified_text
|
||
assert input_tokens > 0
|
||
assert output_tokens > 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_simplify_text_openai_error(self, test_config):
|
||
from tenacity import AsyncRetrying, before_sleep_log
|
||
import structlog
|
||
import logging
|
||
|
||
adapter = LLMProviderAdapter(test_config)
|
||
|
||
good_logger = structlog.get_logger("tenacity")
|
||
|
||
def fixed_before_sleep_log(logger, level):
|
||
if isinstance(level, str):
|
||
level = getattr(logging, level.upper(), logging.WARNING)
|
||
return before_sleep_log(logger, level)
|
||
|
||
with patch("src.adapters.base.AsyncRetrying") as mock_retrying:
|
||
mock_retrying.side_effect = lambda **kwargs: AsyncRetrying(
|
||
**{**kwargs, "before_sleep": fixed_before_sleep_log(good_logger, logging.WARNING)}
|
||
)
|
||
|
||
with patch.object(
|
||
adapter.client.chat.completions, "create", new_callable=AsyncMock
|
||
) as mock_create:
|
||
mock_response = MagicMock()
|
||
mock_create.side_effect = RateLimitError(
|
||
"Rate limit exceeded", response=mock_response, body=None
|
||
)
|
||
with patch.object(adapter, "_check_rpm_limit"):
|
||
with pytest.raises(LLMRateLimitError):
|
||
await adapter.simplify_text(
|
||
title="Тест",
|
||
wiki_text="Тестовый текст",
|
||
prompt_template="### role: user\n{wiki_source_text}",
|
||
)
|
||
|
||
def test_parse_prompt_template(self, test_config):
|
||
adapter = LLMProviderAdapter(test_config)
|
||
|
||
template = """### role: system
|
||
Ты помощник.
|
||
|
||
### role: user
|
||
Задание: {task}"""
|
||
|
||
messages = adapter._parse_prompt_template(template)
|
||
|
||
assert len(messages) == 2
|
||
assert messages[0]["role"] == "system"
|
||
assert messages[0]["content"] == "Ты помощник."
|
||
assert messages[1]["role"] == "user"
|
||
assert messages[1]["content"] == "Задание: {task}"
|
||
|
||
def test_parse_prompt_template_fallback(self, test_config):
|
||
adapter = LLMProviderAdapter(test_config)
|
||
|
||
template = "Обычный текст без ролей"
|
||
messages = adapter._parse_prompt_template(template)
|
||
|
||
assert len(messages) == 1
|
||
assert messages[0]["role"] == "user"
|
||
assert messages[0]["content"] == template
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_success(self, test_config, mock_openai_response):
|
||
adapter = LLMProviderAdapter(test_config)
|
||
|
||
with patch.object(
|
||
adapter.client.chat.completions, "create", new_callable=AsyncMock
|
||
) as mock_create:
|
||
mock_create.return_value = mock_openai_response
|
||
|
||
result = await adapter.health_check()
|
||
assert result is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check_failure(self, test_config):
|
||
adapter = LLMProviderAdapter(test_config)
|
||
with patch.object(
|
||
adapter.client.chat.completions, "create", new_callable=AsyncMock
|
||
) as mock_create:
|
||
mock_request = MagicMock()
|
||
mock_create.side_effect = APIError("API Error", body=None, request=mock_request)
|
||
result = await adapter.health_check()
|
||
assert result is False
|