447 lines
16 KiB
Python
447 lines
16 KiB
Python
# tests/unit/test_workers_base.py
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
from datetime import datetime, timezone
|
||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||
import pytest
|
||
|
||
from dataloader.workers.base import PGWorker, WorkerConfig
|
||
|
||
|
||
@pytest.mark.unit
|
||
class TestPGWorker:
|
||
"""
|
||
Unit тесты для PGWorker.
|
||
"""
|
||
|
||
def test_init_creates_worker_with_config(self):
|
||
"""
|
||
Тест создания воркера с конфигурацией.
|
||
"""
|
||
cfg = WorkerConfig(queue="test_queue", heartbeat_sec=10, claim_backoff_sec=5)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx:
|
||
mock_ctx.get_logger.return_value = Mock()
|
||
mock_ctx.sessionmaker = Mock()
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
|
||
assert worker._cfg == cfg
|
||
assert worker._stop == stop_event
|
||
assert worker._listener is None
|
||
assert not worker._notify_wakeup.is_set()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_run_starts_listener_and_processes_jobs(self):
|
||
"""
|
||
Тест запуска воркера с listener'ом.
|
||
"""
|
||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \
|
||
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls:
|
||
|
||
mock_ctx.get_logger.return_value = Mock()
|
||
mock_ctx.sessionmaker = Mock()
|
||
mock_cfg.pg.url = "postgresql+asyncpg://test"
|
||
|
||
mock_listener = Mock()
|
||
mock_listener.start = AsyncMock()
|
||
mock_listener.stop = AsyncMock()
|
||
mock_listener_cls.return_value = mock_listener
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
|
||
call_count = [0]
|
||
|
||
async def mock_claim():
|
||
call_count[0] += 1
|
||
if call_count[0] >= 2:
|
||
stop_event.set()
|
||
return False
|
||
|
||
with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim):
|
||
await worker.run()
|
||
|
||
assert mock_listener.start.call_count == 1
|
||
assert mock_listener.stop.call_count == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_run_falls_back_to_polling_if_listener_fails(self):
|
||
"""
|
||
Тест fallback на polling, если LISTEN/NOTIFY не запустился.
|
||
"""
|
||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \
|
||
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls:
|
||
|
||
mock_logger = Mock()
|
||
mock_ctx.get_logger.return_value = mock_logger
|
||
mock_ctx.sessionmaker = Mock()
|
||
mock_cfg.pg.url = "postgresql+asyncpg://test"
|
||
|
||
mock_listener = Mock()
|
||
mock_listener.start = AsyncMock(side_effect=Exception("Connection failed"))
|
||
mock_listener_cls.return_value = mock_listener
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
|
||
call_count = [0]
|
||
|
||
async def mock_claim():
|
||
call_count[0] += 1
|
||
if call_count[0] >= 2:
|
||
stop_event.set()
|
||
return False
|
||
|
||
with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim):
|
||
await worker.run()
|
||
|
||
assert worker._listener is None
|
||
assert mock_logger.warning.call_count == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_listen_or_sleep_with_listener_waits_for_notify(self):
|
||
"""
|
||
Тест ожидания через LISTEN/NOTIFY.
|
||
"""
|
||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx:
|
||
mock_ctx.get_logger.return_value = Mock()
|
||
mock_ctx.sessionmaker = Mock()
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
worker._listener = Mock()
|
||
|
||
worker._notify_wakeup.set()
|
||
|
||
await worker._listen_or_sleep(1)
|
||
|
||
assert not worker._notify_wakeup.is_set()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_listen_or_sleep_without_listener_uses_timeout(self):
|
||
"""
|
||
Тест fallback на таймаут без listener'а.
|
||
"""
|
||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx:
|
||
mock_ctx.get_logger.return_value = Mock()
|
||
mock_ctx.sessionmaker = Mock()
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
|
||
start_time = asyncio.get_event_loop().time()
|
||
await worker._listen_or_sleep(1)
|
||
elapsed = asyncio.get_event_loop().time() - start_time
|
||
|
||
assert elapsed >= 1.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_claim_and_execute_once_returns_false_when_no_job(self):
|
||
"""
|
||
Тест, что claim_and_execute_once возвращает False, если задач нет.
|
||
"""
|
||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||
|
||
mock_session = AsyncMock()
|
||
mock_session.commit = AsyncMock()
|
||
mock_sm = MagicMock()
|
||
mock_sm.return_value.__aenter__.return_value = mock_session
|
||
mock_sm.return_value.__aexit__.return_value = AsyncMock()
|
||
mock_ctx.get_logger.return_value = Mock()
|
||
mock_ctx.sessionmaker = mock_sm
|
||
|
||
mock_repo = Mock()
|
||
mock_repo.claim_one = AsyncMock(return_value=None)
|
||
mock_repo_cls.return_value = mock_repo
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
result = await worker._claim_and_execute_once()
|
||
|
||
assert result is False
|
||
assert mock_session.commit.call_count == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_claim_and_execute_once_executes_job_successfully(self):
|
||
"""
|
||
Тест успешного выполнения задачи.
|
||
"""
|
||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||
|
||
mock_session = AsyncMock()
|
||
mock_sm = MagicMock()
|
||
mock_sm.return_value.__aenter__.return_value = mock_session
|
||
mock_sm.return_value.__aexit__.return_value = AsyncMock()
|
||
mock_ctx.get_logger.return_value = Mock()
|
||
mock_ctx.sessionmaker = mock_sm
|
||
|
||
mock_repo = Mock()
|
||
mock_repo.claim_one = AsyncMock(return_value={
|
||
"job_id": "test-job-id",
|
||
"lease_ttl_sec": 60,
|
||
"task": "test.task",
|
||
"args": {"key": "value"}
|
||
})
|
||
mock_repo.finish_ok = AsyncMock()
|
||
mock_repo_cls.return_value = mock_repo
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
|
||
async def mock_pipeline(task, args):
|
||
yield
|
||
|
||
with patch.object(worker, "_pipeline", side_effect=mock_pipeline), \
|
||
patch.object(worker, "_execute_with_heartbeat", return_value=False):
|
||
result = await worker._claim_and_execute_once()
|
||
|
||
assert result is True
|
||
assert mock_repo.finish_ok.call_count == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_claim_and_execute_once_handles_cancellation(self):
|
||
"""
|
||
Тест обработки отмены задачи пользователем.
|
||
"""
|
||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||
|
||
mock_session = AsyncMock()
|
||
mock_sm = MagicMock()
|
||
mock_sm.return_value.__aenter__.return_value = mock_session
|
||
mock_sm.return_value.__aexit__.return_value = AsyncMock()
|
||
mock_ctx.get_logger.return_value = Mock()
|
||
mock_ctx.sessionmaker = mock_sm
|
||
|
||
mock_repo = Mock()
|
||
mock_repo.claim_one = AsyncMock(return_value={
|
||
"job_id": "test-job-id",
|
||
"lease_ttl_sec": 60,
|
||
"task": "test.task",
|
||
"args": {}
|
||
})
|
||
mock_repo.finish_fail_or_retry = AsyncMock()
|
||
mock_repo_cls.return_value = mock_repo
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
|
||
with patch.object(worker, "_execute_with_heartbeat", return_value=True):
|
||
result = await worker._claim_and_execute_once()
|
||
|
||
assert result is True
|
||
mock_repo.finish_fail_or_retry.assert_called_once()
|
||
args = mock_repo.finish_fail_or_retry.call_args
|
||
assert "canceled by user" in args[0]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_claim_and_execute_once_handles_exceptions(self):
|
||
"""
|
||
Тест обработки исключений при выполнении задачи.
|
||
"""
|
||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||
|
||
mock_session = AsyncMock()
|
||
mock_sm = MagicMock()
|
||
mock_sm.return_value.__aenter__.return_value = mock_session
|
||
mock_sm.return_value.__aexit__.return_value = AsyncMock()
|
||
mock_ctx.get_logger.return_value = Mock()
|
||
mock_ctx.sessionmaker = mock_sm
|
||
|
||
mock_repo = Mock()
|
||
mock_repo.claim_one = AsyncMock(return_value={
|
||
"job_id": "test-job-id",
|
||
"lease_ttl_sec": 60,
|
||
"task": "test.task",
|
||
"args": {}
|
||
})
|
||
mock_repo.finish_fail_or_retry = AsyncMock()
|
||
mock_repo_cls.return_value = mock_repo
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
|
||
with patch.object(worker, "_execute_with_heartbeat", side_effect=ValueError("Test error")):
|
||
result = await worker._claim_and_execute_once()
|
||
|
||
assert result is True
|
||
mock_repo.finish_fail_or_retry.assert_called_once()
|
||
args = mock_repo.finish_fail_or_retry.call_args
|
||
assert "Test error" in args[0]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_with_heartbeat_sends_heartbeats(self):
|
||
"""
|
||
Тест отправки heartbeat'ов во время выполнения задачи.
|
||
"""
|
||
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||
|
||
mock_session = AsyncMock()
|
||
mock_sm = MagicMock()
|
||
mock_sm.return_value.__aenter__.return_value = mock_session
|
||
mock_sm.return_value.__aexit__.return_value = AsyncMock()
|
||
mock_ctx.get_logger.return_value = Mock()
|
||
mock_ctx.sessionmaker = mock_sm
|
||
|
||
mock_repo = Mock()
|
||
mock_repo.heartbeat = AsyncMock(return_value=(True, False))
|
||
mock_repo_cls.return_value = mock_repo
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
|
||
async def slow_pipeline():
|
||
await asyncio.sleep(0.5)
|
||
yield
|
||
await asyncio.sleep(0.6)
|
||
yield
|
||
|
||
canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline())
|
||
|
||
assert canceled is False
|
||
assert mock_repo.heartbeat.call_count >= 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_with_heartbeat_detects_cancellation(self):
|
||
"""
|
||
Тест обнаружения отмены через heartbeat.
|
||
"""
|
||
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||
|
||
mock_session = AsyncMock()
|
||
mock_sm = MagicMock()
|
||
mock_sm.return_value.__aenter__.return_value = mock_session
|
||
mock_sm.return_value.__aexit__.return_value = AsyncMock()
|
||
mock_ctx.get_logger.return_value = Mock()
|
||
mock_ctx.sessionmaker = mock_sm
|
||
|
||
mock_repo = Mock()
|
||
mock_repo.heartbeat = AsyncMock(return_value=(True, True))
|
||
mock_repo_cls.return_value = mock_repo
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
|
||
async def slow_pipeline():
|
||
await asyncio.sleep(0.5)
|
||
yield
|
||
await asyncio.sleep(0.6)
|
||
yield
|
||
|
||
canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline())
|
||
|
||
assert canceled is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pipeline_handles_sync_function(self):
|
||
"""
|
||
Тест выполнения синхронной функции-пайплайна.
|
||
"""
|
||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
|
||
|
||
mock_ctx.get_logger.return_value = Mock()
|
||
mock_ctx.sessionmaker = Mock()
|
||
|
||
def sync_pipeline(args):
|
||
return "result"
|
||
|
||
mock_resolve.return_value = sync_pipeline
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
|
||
results = []
|
||
async for _ in worker._pipeline("test.task", {}):
|
||
results.append(_)
|
||
|
||
assert len(results) == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pipeline_handles_async_function(self):
|
||
"""
|
||
Тест выполнения асинхронной функции-пайплайна.
|
||
"""
|
||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
|
||
|
||
mock_ctx.get_logger.return_value = Mock()
|
||
mock_ctx.sessionmaker = Mock()
|
||
|
||
async def async_pipeline(args):
|
||
return "result"
|
||
|
||
mock_resolve.return_value = async_pipeline
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
|
||
results = []
|
||
async for _ in worker._pipeline("test.task", {}):
|
||
results.append(_)
|
||
|
||
assert len(results) == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_pipeline_handles_async_generator(self):
|
||
"""
|
||
Тест выполнения асинхронного генератора-пайплайна.
|
||
"""
|
||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||
stop_event = asyncio.Event()
|
||
|
||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
|
||
|
||
mock_ctx.get_logger.return_value = Mock()
|
||
mock_ctx.sessionmaker = Mock()
|
||
|
||
async def async_gen_pipeline(args):
|
||
yield
|
||
yield
|
||
yield
|
||
|
||
mock_resolve.return_value = async_gen_pipeline
|
||
|
||
worker = PGWorker(cfg, stop_event)
|
||
|
||
results = []
|
||
async for _ in worker._pipeline("test.task", {}):
|
||
results.append(_)
|
||
|
||
assert len(results) == 3
|