dataloader/tests/unit/test_workers_base.py

552 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from dataloader.workers.base import PGWorker, WorkerConfig
@pytest.mark.unit
class TestPGWorker:
"""
Unit тесты для PGWorker.
"""
def test_init_creates_worker_with_config(self):
"""
Тест создания воркера с конфигурацией.
"""
cfg = WorkerConfig(queue="test_queue", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx:
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
worker = PGWorker(cfg, stop_event)
assert worker._cfg == cfg
assert worker._stop == stop_event
assert worker._listener is None
assert not worker._notify_wakeup.is_set()
@pytest.mark.asyncio
async def test_run_starts_listener_and_processes_jobs(self):
"""
Тест запуска воркера с listener'ом.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
stop_event = asyncio.Event()
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls,
):
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
mock_cfg.pg.url = "postgresql+asyncpg://test"
mock_listener = Mock()
mock_listener.start = AsyncMock()
mock_listener.stop = AsyncMock()
mock_listener_cls.return_value = mock_listener
worker = PGWorker(cfg, stop_event)
call_count = [0]
async def mock_claim():
call_count[0] += 1
if call_count[0] >= 2:
stop_event.set()
return False
with patch.object(
worker, "_claim_and_execute_once", side_effect=mock_claim
):
await worker.run()
assert mock_listener.start.call_count == 1
assert mock_listener.stop.call_count == 1
@pytest.mark.asyncio
async def test_run_falls_back_to_polling_if_listener_fails(self):
"""
Тест fallback на polling, если LISTEN/NOTIFY не запустился.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
stop_event = asyncio.Event()
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls,
):
mock_logger = Mock()
mock_ctx.get_logger.return_value = mock_logger
mock_ctx.sessionmaker = Mock()
mock_cfg.pg.url = "postgresql+asyncpg://test"
mock_listener = Mock()
mock_listener.start = AsyncMock(side_effect=Exception("Connection failed"))
mock_listener_cls.return_value = mock_listener
worker = PGWorker(cfg, stop_event)
call_count = [0]
async def mock_claim():
call_count[0] += 1
if call_count[0] >= 2:
stop_event.set()
return False
with patch.object(
worker, "_claim_and_execute_once", side_effect=mock_claim
):
await worker.run()
assert worker._listener is None
assert mock_logger.warning.call_count == 1
@pytest.mark.asyncio
async def test_listen_or_sleep_with_listener_waits_for_notify(self):
"""
Тест ожидания через LISTEN/NOTIFY.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx:
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
worker = PGWorker(cfg, stop_event)
worker._listener = Mock()
worker._notify_wakeup.set()
await worker._listen_or_sleep(1)
assert not worker._notify_wakeup.is_set()
@pytest.mark.asyncio
async def test_listen_or_sleep_without_listener_uses_timeout(self):
"""
Тест fallback на таймаут без listener'а.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx:
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
worker = PGWorker(cfg, stop_event)
start_time = asyncio.get_event_loop().time()
await worker._listen_or_sleep(1)
elapsed = asyncio.get_event_loop().time() - start_time
assert elapsed >= 1.0
@pytest.mark.asyncio
async def test_claim_and_execute_once_returns_false_when_no_job(self):
"""
Тест, что claim_and_execute_once возвращает False, если задач нет.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_session.commit = AsyncMock()
mock_sm = MagicMock()
mock_sm.return_value.__aenter__.return_value = mock_session
mock_sm.return_value.__aexit__.return_value = AsyncMock()
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.claim_one = AsyncMock(return_value=None)
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
result = await worker._claim_and_execute_once()
assert result is False
assert mock_session.commit.call_count == 1
@pytest.mark.asyncio
async def test_claim_and_execute_once_executes_job_successfully(self):
"""
Тест успешного выполнения задачи.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_sm = MagicMock()
mock_sm.return_value.__aenter__.return_value = mock_session
mock_sm.return_value.__aexit__.return_value = AsyncMock()
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.claim_one = AsyncMock(
return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {"key": "value"},
}
)
mock_repo.finish_ok = AsyncMock()
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
async def mock_pipeline(task, args):
yield
with (
patch.object(worker, "_pipeline", side_effect=mock_pipeline),
patch.object(worker, "_execute_with_heartbeat", return_value=False),
):
result = await worker._claim_and_execute_once()
assert result is True
assert mock_repo.finish_ok.call_count == 1
@pytest.mark.asyncio
async def test_claim_and_execute_once_handles_cancellation(self):
"""
Тест обработки отмены задачи пользователем.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_sm = MagicMock()
mock_sm.return_value.__aenter__.return_value = mock_session
mock_sm.return_value.__aexit__.return_value = AsyncMock()
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.claim_one = AsyncMock(
return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {},
}
)
mock_repo.finish_fail_or_retry = AsyncMock()
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
with patch.object(worker, "_execute_with_heartbeat", return_value=True):
result = await worker._claim_and_execute_once()
assert result is True
mock_repo.finish_fail_or_retry.assert_called_once()
args = mock_repo.finish_fail_or_retry.call_args
assert "canceled by user" in args[0]
@pytest.mark.asyncio
async def test_claim_and_execute_once_handles_exceptions(self):
"""
Тест обработки исключений при выполнении задачи.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_sm = MagicMock()
mock_sm.return_value.__aenter__.return_value = mock_session
mock_sm.return_value.__aexit__.return_value = AsyncMock()
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.claim_one = AsyncMock(
return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {},
}
)
mock_repo.finish_fail_or_retry = AsyncMock()
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
with patch.object(
worker, "_execute_with_heartbeat", side_effect=ValueError("Test error")
):
result = await worker._claim_and_execute_once()
assert result is True
mock_repo.finish_fail_or_retry.assert_called_once()
args = mock_repo.finish_fail_or_retry.call_args
assert "Test error" in args[0]
@pytest.mark.asyncio
async def test_execute_with_heartbeat_sends_heartbeats(self):
"""
Тест отправки heartbeat'ов во время выполнения задачи.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
stop_event = asyncio.Event()
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_sm = MagicMock()
mock_sm.return_value.__aenter__.return_value = mock_session
mock_sm.return_value.__aexit__.return_value = AsyncMock()
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.heartbeat = AsyncMock(return_value=(True, False))
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
async def slow_pipeline():
await asyncio.sleep(0.5)
yield
await asyncio.sleep(0.6)
yield
canceled = await worker._execute_with_heartbeat(
"job-id", 60, slow_pipeline()
)
assert canceled is False
assert mock_repo.heartbeat.call_count >= 1
@pytest.mark.asyncio
async def test_execute_with_heartbeat_detects_cancellation(self):
"""
Тест обнаружения отмены через heartbeat.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
stop_event = asyncio.Event()
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_sm = MagicMock()
mock_sm.return_value.__aenter__.return_value = mock_session
mock_sm.return_value.__aexit__.return_value = AsyncMock()
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.heartbeat = AsyncMock(return_value=(True, True))
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
async def slow_pipeline():
await asyncio.sleep(0.5)
yield
await asyncio.sleep(0.6)
yield
canceled = await worker._execute_with_heartbeat(
"job-id", 60, slow_pipeline()
)
assert canceled is True
@pytest.mark.asyncio
async def test_pipeline_handles_sync_function(self):
"""
Тест выполнения синхронной функции-пайплайна.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
):
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
def sync_pipeline(args):
return "result"
mock_resolve.return_value = sync_pipeline
worker = PGWorker(cfg, stop_event)
results = []
async for _ in worker._pipeline("test.task", {}):
results.append(_)
assert len(results) == 1
@pytest.mark.asyncio
async def test_pipeline_handles_async_function(self):
"""
Тест выполнения асинхронной функции-пайплайна.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
):
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
async def async_pipeline(args):
return "result"
mock_resolve.return_value = async_pipeline
worker = PGWorker(cfg, stop_event)
results = []
async for _ in worker._pipeline("test.task", {}):
results.append(_)
assert len(results) == 1
@pytest.mark.asyncio
async def test_pipeline_handles_async_generator(self):
"""
Тест выполнения асинхронного генератора-пайплайна.
"""
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
):
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
async def async_gen_pipeline(args):
yield
yield
yield
mock_resolve.return_value = async_gen_pipeline
worker = PGWorker(cfg, stop_event)
results = []
async for _ in worker._pipeline("test.task", {}):
results.append(_)
assert len(results) == 3
@pytest.mark.asyncio
async def test_claim_and_execute_once_handles_shutdown_cancelled_error(self):
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_sm = MagicMock()
mock_sm.return_value.__aenter__.return_value = mock_session
mock_sm.return_value.__aexit__.return_value = AsyncMock(return_value=False)
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.claim_one = AsyncMock(
return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {},
}
)
mock_repo.finish_fail_or_retry = AsyncMock()
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
async def raise_cancel(*_args, **_kwargs):
raise asyncio.CancelledError()
with patch.object(worker, "_execute_with_heartbeat", new=raise_cancel):
await worker._claim_and_execute_once()
mock_repo.finish_fail_or_retry.assert_called_once()
args, kwargs = mock_repo.finish_fail_or_retry.call_args
assert args[0] == "test-job-id"
assert "cancelled by shutdown" in args[1]
assert kwargs.get("is_canceled") is True
@pytest.mark.asyncio
async def test_execute_with_heartbeat_raises_cancelled_when_stop_set(self):
cfg = WorkerConfig(queue="test", heartbeat_sec=1000, claim_backoff_sec=5)
stop_event = asyncio.Event()
stop_event.set()
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
mock_repo_cls.return_value = Mock()
worker = PGWorker(cfg, stop_event)
async def one_yield():
yield
with pytest.raises(asyncio.CancelledError):
await worker._execute_with_heartbeat("job-id", 60, one_yield())