369 lines
12 KiB
Python
369 lines
12 KiB
Python
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from src.adapters import LLMProviderAdapter, RuWikiAdapter
|
|
from src.adapters.ruwiki import WikiPageInfo
|
|
from src.models import ProcessingResult, SimplifyCommand
|
|
from src.services import (
|
|
AsyncWriteQueue,
|
|
DatabaseService,
|
|
RecursiveCharacterTextSplitter,
|
|
SimplifyService,
|
|
)
|
|
|
|
|
|
class TestRecursiveCharacterTextSplitter:
|
|
def test_split_short_text(self):
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
|
|
|
|
short_text = "Это короткий текст."
|
|
chunks = splitter.split_text(short_text)
|
|
|
|
assert len(chunks) == 1
|
|
assert chunks[0] == short_text
|
|
|
|
def test_split_long_text(self):
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=50, chunk_overlap=10)
|
|
|
|
long_text = "Это очень длинный текст. " * 10
|
|
chunks = splitter.split_text(long_text)
|
|
|
|
assert len(chunks) > 1
|
|
|
|
for chunk in chunks:
|
|
assert len(chunk) <= 60
|
|
|
|
def test_split_by_paragraphs(self):
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10)
|
|
|
|
text = "Первый абзац.\n\nВторой абзац.\n\nТретий абзац."
|
|
chunks = splitter.split_text(text)
|
|
|
|
assert len(chunks) >= 2
|
|
|
|
def test_split_empty_text(self):
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
|
|
|
|
chunks = splitter.split_text("")
|
|
assert chunks == []
|
|
|
|
def test_custom_length_function(self):
|
|
def word_count(text: str) -> int:
|
|
return len(text.split())
|
|
|
|
splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=5,
|
|
chunk_overlap=2,
|
|
length_function=word_count,
|
|
)
|
|
|
|
text = "Один два три четыре пять шесть семь восемь девять десять"
|
|
chunks = splitter.split_text(text)
|
|
|
|
assert len(chunks) > 1
|
|
|
|
for chunk in chunks:
|
|
word_count_in_chunk = len(chunk.split())
|
|
assert word_count_in_chunk <= 7
|
|
|
|
def test_create_chunks_with_metadata(self):
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=50, chunk_overlap=10)
|
|
|
|
text = "Это тестовый текст. " * 10
|
|
title = "Тестовая статья"
|
|
|
|
chunks_with_metadata = splitter.create_chunks_with_metadata(text, title)
|
|
|
|
assert len(chunks_with_metadata) > 1
|
|
|
|
for i, chunk_data in enumerate(chunks_with_metadata):
|
|
assert "text" in chunk_data
|
|
assert chunk_data["title"] == title
|
|
assert chunk_data["chunk_index"] == i
|
|
assert chunk_data["total_chunks"] == len(chunks_with_metadata)
|
|
assert "chunk_size" in chunk_data
|
|
|
|
|
|
class TestDatabaseService:
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialize_database(self, test_config):
|
|
db_service = DatabaseService(test_config)
|
|
|
|
await db_service.initialize_database()
|
|
|
|
assert Path(test_config.db_path).exists()
|
|
|
|
assert await db_service.health_check() is True
|
|
|
|
db_service.close()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_connection(self, test_config):
|
|
db_service = DatabaseService(test_config)
|
|
await db_service.initialize_database()
|
|
|
|
async with await db_service.get_connection() as conn:
|
|
cursor = await conn.execute("SELECT 1")
|
|
result = await cursor.fetchone()
|
|
assert result[0] == 1
|
|
|
|
db_service.close()
|
|
|
|
|
|
class TestAsyncWriteQueue:
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_stop(self):
|
|
mock_repository = AsyncMock()
|
|
queue = AsyncWriteQueue(mock_repository, max_batch_size=5)
|
|
|
|
await queue.start()
|
|
assert queue._worker_task is not None
|
|
|
|
await queue.stop(timeout=1.0)
|
|
assert queue._worker_task.done()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_from_result_success(self, sample_article, simplified_text):
|
|
mock_repository = AsyncMock()
|
|
mock_repository.get_by_url.return_value = sample_article
|
|
mock_repository.update_article.return_value = sample_article
|
|
|
|
queue = AsyncWriteQueue(mock_repository, max_batch_size=1)
|
|
await queue.start()
|
|
|
|
try:
|
|
result = ProcessingResult.success_result(
|
|
url=sample_article.url,
|
|
title=sample_article.title,
|
|
raw_text=sample_article.raw_text,
|
|
simplified_text=simplified_text,
|
|
token_count_raw=100,
|
|
token_count_simplified=50,
|
|
processing_time_seconds=2.0,
|
|
)
|
|
|
|
updated_article = await queue.update_from_result(result)
|
|
|
|
assert updated_article.simplified_text == simplified_text
|
|
mock_repository.get_by_url.assert_called_once_with(sample_article.url)
|
|
mock_repository.update_article.assert_called_once()
|
|
|
|
finally:
|
|
await queue.stop(timeout=1.0)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_from_result_failure(self, sample_article):
|
|
mock_repository = AsyncMock()
|
|
mock_repository.get_by_url.return_value = sample_article
|
|
mock_repository.update_article.return_value = sample_article
|
|
|
|
queue = AsyncWriteQueue(mock_repository, max_batch_size=1)
|
|
await queue.start()
|
|
|
|
try:
|
|
result = ProcessingResult.failure_result(
|
|
url=sample_article.url,
|
|
error_message="Тестовая ошибка",
|
|
)
|
|
|
|
updated_article = await queue.update_from_result(result)
|
|
|
|
assert updated_article.error_message == "Тестовая ошибка"
|
|
mock_repository.update_article.assert_called_once()
|
|
|
|
finally:
|
|
await queue.stop(timeout=1.0)
|
|
|
|
def test_stats(self):
|
|
mock_repository = AsyncMock()
|
|
queue = AsyncWriteQueue(mock_repository)
|
|
|
|
stats = queue.stats
|
|
|
|
assert "total_operations" in stats
|
|
assert "failed_operations" in stats
|
|
assert "queue_size" in stats
|
|
assert stats["total_operations"] == 0
|
|
|
|
|
|
class TestSimplifyService:
|
|
|
|
@pytest.fixture
|
|
def mock_adapters_and_queue(self, test_config):
|
|
mock_ruwiki = AsyncMock(spec=RuWikiAdapter)
|
|
mock_llm = AsyncMock(spec=LLMProviderAdapter)
|
|
mock_repository = AsyncMock()
|
|
mock_write_queue = AsyncMock()
|
|
|
|
return mock_ruwiki, mock_llm, mock_repository, mock_write_queue
|
|
|
|
def test_service_initialization(self, test_config, mock_adapters_and_queue):
|
|
mock_ruwiki, mock_llm, mock_repository, mock_write_queue = mock_adapters_and_queue
|
|
|
|
service = SimplifyService(
|
|
config=test_config,
|
|
ruwiki_adapter=mock_ruwiki,
|
|
llm_adapter=mock_llm,
|
|
repository=mock_repository,
|
|
write_queue=mock_write_queue,
|
|
)
|
|
|
|
assert service.config == test_config
|
|
assert service.ruwiki_adapter == mock_ruwiki
|
|
assert service.llm_adapter == mock_llm
|
|
assert isinstance(service.text_splitter, RecursiveCharacterTextSplitter)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_prompt_template(self, test_config, mock_adapters_and_queue):
|
|
mock_ruwiki, mock_llm, mock_repository, mock_write_queue = mock_adapters_and_queue
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
|
f.write("### role: system\nТы помощник")
|
|
temp_prompt_path = f.name
|
|
|
|
test_config.prompt_template_path = temp_prompt_path
|
|
|
|
service = SimplifyService(
|
|
config=test_config,
|
|
ruwiki_adapter=mock_ruwiki,
|
|
llm_adapter=mock_llm,
|
|
repository=mock_repository,
|
|
write_queue=mock_write_queue,
|
|
)
|
|
|
|
try:
|
|
template = await service.get_prompt_template()
|
|
assert "### role: system" in template
|
|
assert "Ты помощник" in template
|
|
|
|
template2 = await service.get_prompt_template()
|
|
assert template == template2
|
|
|
|
finally:
|
|
Path(temp_prompt_path).unlink(missing_ok=True)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_prompt_template_not_found(self, test_config, mock_adapters_and_queue):
|
|
mock_ruwiki, mock_llm, mock_repository, mock_write_queue = mock_adapters_and_queue
|
|
|
|
test_config.prompt_template_path = "nonexistent.txt"
|
|
|
|
service = SimplifyService(
|
|
config=test_config,
|
|
ruwiki_adapter=mock_ruwiki,
|
|
llm_adapter=mock_llm,
|
|
repository=mock_repository,
|
|
write_queue=mock_write_queue,
|
|
)
|
|
|
|
with pytest.raises(FileNotFoundError):
|
|
await service.get_prompt_template()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_command_success(
|
|
self, test_config, mock_adapters_and_queue, sample_wikitext, simplified_text
|
|
):
|
|
mock_ruwiki, mock_llm, mock_repository, mock_write_queue = mock_adapters_and_queue
|
|
|
|
wiki_page_info = WikiPageInfo(
|
|
title="Тест",
|
|
content=sample_wikitext,
|
|
)
|
|
mock_ruwiki.fetch_page_cleaned.return_value = wiki_page_info
|
|
mock_llm.simplify_text.return_value = (simplified_text, 100, 50)
|
|
mock_llm.count_tokens.return_value = 100
|
|
|
|
mock_repository.get_by_url.return_value = None
|
|
mock_repository.create_article.return_value = MagicMock(id=1)
|
|
mock_repository.update_article.return_value = MagicMock()
|
|
|
|
mock_write_queue.update_from_result.return_value = MagicMock()
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
|
f.write("### role: user\n{wiki_source_text}")
|
|
test_config.prompt_template_path = f.name
|
|
|
|
service = SimplifyService(
|
|
config=test_config,
|
|
ruwiki_adapter=mock_ruwiki,
|
|
llm_adapter=mock_llm,
|
|
repository=mock_repository,
|
|
write_queue=mock_write_queue,
|
|
)
|
|
|
|
try:
|
|
command = SimplifyCommand(url="https://ru.wikipedia.org/wiki/Тест")
|
|
result = await service.process_command(command)
|
|
|
|
assert result.success is True
|
|
assert result.title == "Тест"
|
|
assert result.simplified_text == simplified_text
|
|
assert result.token_count_raw == 100
|
|
assert result.token_count_simplified == 50
|
|
|
|
mock_ruwiki.fetch_page_cleaned.assert_called_once()
|
|
mock_llm.simplify_text.assert_called_once()
|
|
mock_write_queue.update_from_result.assert_called_once()
|
|
|
|
finally:
|
|
Path(test_config.prompt_template_path).unlink(missing_ok=True)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_process_command_skip_existing(
|
|
self, test_config, mock_adapters_and_queue, completed_article
|
|
):
|
|
mock_ruwiki, mock_llm, mock_repository, mock_write_queue = mock_adapters_and_queue
|
|
|
|
mock_repository.get_by_url.return_value = completed_article
|
|
|
|
service = SimplifyService(
|
|
config=test_config,
|
|
ruwiki_adapter=mock_ruwiki,
|
|
llm_adapter=mock_llm,
|
|
repository=mock_repository,
|
|
write_queue=mock_write_queue,
|
|
)
|
|
|
|
command = SimplifyCommand(url=completed_article.url, force_reprocess=False)
|
|
result = await service.process_command(command)
|
|
|
|
assert result.success is True
|
|
assert result.title == completed_article.title
|
|
|
|
mock_ruwiki.fetch_page_cleaned.assert_not_called()
|
|
mock_llm.simplify_text.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check(self, test_config, mock_adapters_and_queue):
|
|
mock_ruwiki, mock_llm, mock_repository, mock_write_queue = mock_adapters_and_queue
|
|
|
|
mock_ruwiki.health_check.return_value = True
|
|
mock_llm.health_check.return_value = True
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
|
f.write("test prompt")
|
|
test_config.prompt_template_path = f.name
|
|
|
|
service = SimplifyService(
|
|
config=test_config,
|
|
ruwiki_adapter=mock_ruwiki,
|
|
llm_adapter=mock_llm,
|
|
repository=mock_repository,
|
|
write_queue=mock_write_queue,
|
|
)
|
|
|
|
try:
|
|
checks = await service.health_check()
|
|
|
|
assert checks["ruwiki"] is True
|
|
assert checks["llm"] is True
|
|
assert checks["prompt_template"] is True
|
|
|
|
finally:
|
|
Path(test_config.prompt_template_path).unlink(missing_ok=True)
|