dataloader/tests/unit/test_pipeline_registry.py

83 lines
2.4 KiB
Python

# tests/unit/test_pipeline_registry.py
from __future__ import annotations
import pytest
from dataloader.workers.pipelines.registry import register, resolve, tasks, _Registry
@pytest.mark.unit
class TestPipelineRegistry:
"""
Unit тесты для системы регистрации пайплайнов.
"""
def setup_method(self):
"""
Очищаем реестр перед каждым тестом.
"""
_Registry.clear()
def test_register_adds_pipeline_to_registry(self):
"""
Тест регистрации пайплайна.
"""
@register("test.task")
def test_pipeline(args: dict):
return "result"
assert "test.task" in _Registry
assert _Registry["test.task"] == test_pipeline
def test_resolve_returns_registered_pipeline(self):
"""
Тест получения зарегистрированного пайплайна.
"""
@register("test.resolve")
def test_pipeline(args: dict):
return "resolved"
resolved = resolve("test.resolve")
assert resolved == test_pipeline
assert resolved({}) == "resolved"
def test_resolve_raises_keyerror_for_unknown_task(self):
"""
Тест ошибки при запросе незарегистрированного пайплайна.
"""
with pytest.raises(KeyError) as exc_info:
resolve("unknown.task")
assert "pipeline not found: unknown.task" in str(exc_info.value)
def test_tasks_returns_registered_task_names(self):
"""
Тест получения списка зарегистрированных задач.
"""
@register("task1")
def pipeline1(args: dict):
pass
@register("task2")
def pipeline2(args: dict):
pass
task_list = list(tasks())
assert "task1" in task_list
assert "task2" in task_list
def test_register_overwrites_existing_pipeline(self):
"""
Тест перезаписи существующего пайплайна.
"""
@register("overwrite.task")
def first_pipeline(args: dict):
return "first"
@register("overwrite.task")
def second_pipeline(args: dict):
return "second"
resolved = resolve("overwrite.task")
assert resolved({}) == "second"