diff --git a/src/dataloader/api/v1/router.py b/src/dataloader/api/v1/router.py index 3b89ee8..a911ca5 100644 --- a/src/dataloader/api/v1/router.py +++ b/src/dataloader/api/v1/router.py @@ -14,18 +14,18 @@ from dataloader.api.v1.schemas import ( TriggerJobResponse, ) from dataloader.api.v1.service import JobsService -from dataloader.storage.db import session_scope +from dataloader.context import get_session +from sqlalchemy.ext.asyncio import AsyncSession router = APIRouter(prefix="/jobs", tags=["jobs"]) -async def get_service() -> AsyncGenerator[JobsService, None]: +def get_service(session: Annotated[AsyncSession, Depends(get_session)]) -> JobsService: """ - Создаёт JobsService с новой сессией и корректно закрывает её после запроса. + FastAPI dependency to create a JobsService instance with a database session. """ - async for s in session_scope(): - yield JobsService(s) + return JobsService(session) @router.post("/trigger", response_model=TriggerJobResponse, status_code=HTTPStatus.OK) diff --git a/src/dataloader/api/v1/service.py b/src/dataloader/api/v1/service.py index ca8f5fb..622e288 100644 --- a/src/dataloader/api/v1/service.py +++ b/src/dataloader/api/v1/service.py @@ -17,7 +17,7 @@ from dataloader.storage.repositories import ( CreateJobRequest, QueueRepository, ) -from dataloader.context import APP_CTX +from dataloader.logger.logger import get_logger class JobsService: @@ -27,7 +27,7 @@ class JobsService: def __init__(self, session: AsyncSession): self._s = session self._repo = QueueRepository(self._s) - self._log = APP_CTX.get_logger() + self._log = get_logger(__name__) async def trigger(self, req: TriggerJobRequest) -> TriggerJobResponse: """ diff --git a/src/dataloader/context.py b/src/dataloader/context.py index 6f03580..8d05762 100644 --- a/src/dataloader/context.py +++ b/src/dataloader/context.py @@ -1,124 +1,59 @@ # src/dataloader/context.py -from __future__ import annotations +from typing import AsyncGenerator +from logging import Logger -import typing -import pytz -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker -from dataloader.base import Singleton -from dataloader.config import APP_CONFIG, Secrets -from dataloader.logger import ContextVarsContainer, LoggerConfigurator -from sqlalchemy import event, select, func, text +from .config import APP_CONFIG +from .logger.context_vars import ContextVarsContainer -class AppContext(metaclass=Singleton): - """ - Контекст приложения: логгер, таймзона, подключение к БД и фабрика сессий. - """ - def __init__(self, secrets: Secrets) -> None: - self.timezone = pytz.timezone(secrets.app.timezone) - self.context_vars_container = ContextVarsContainer() - self._logger_manager = LoggerConfigurator( - log_lvl=secrets.log.log_lvl, - log_file_path=secrets.log.log_file_abs_path, - metric_file_path=secrets.log.metric_file_abs_path, - audit_file_path=secrets.log.audit_file_abs_path, - audit_host_ip=secrets.log.audit_host_ip, - audit_host_uid=secrets.log.audit_host_uid, - context_vars_container=self.context_vars_container, - timezone=self.timezone, - ) - self.pg = secrets.pg - self.dl = secrets.dl +class AppContext: + def __init__(self) -> None: self._engine: AsyncEngine | None = None self._sessionmaker: async_sessionmaker[AsyncSession] | None = None - self.logger.info("App context initialized.") - - @property - def logger(self) -> "typing.Any": - """ - Возвращает асинхронный логгер. - """ - return self._logger_manager.async_logger + self._context_vars_container = ContextVarsContainer() @property def engine(self) -> AsyncEngine: - """ - Возвращает текущий AsyncEngine. - """ - assert self._engine is not None, "Engine is not initialized" + if self._engine is None: + raise RuntimeError("Database engine is not initialized.") return self._engine @property def sessionmaker(self) -> async_sessionmaker[AsyncSession]: - """ - Возвращает фабрику асинхронных сессий. - """ - assert self._sessionmaker is not None, "Sessionmaker is not initialized" + if self._sessionmaker is None: + raise RuntimeError("Sessionmaker is not initialized.") return self._sessionmaker - def get_logger(self) -> "typing.Any": - """ - Возвращает логгер. - """ - return self.logger + def startup(self) -> None: + from .storage.db import create_engine, create_sessionmaker + from .logger.logger import setup_logging + + setup_logging() + self._engine = create_engine(APP_CONFIG.pg.url) + self._sessionmaker = create_sessionmaker(self._engine) + + async def shutdown(self) -> None: + if self._engine: + await self._engine.dispose() + + def get_logger(self, name: str | None = None) -> Logger: + """Returns a configured logger instance.""" + from .logger.logger import get_logger as get_app_logger + return get_app_logger(name) def get_context_vars_container(self) -> ContextVarsContainer: - """ - Возвращает контейнер контекстных переменных логгера. - """ - return self.context_vars_container - - def get_pytz_timezone(self): - """ - Возвращает таймзону приложения. - """ - return self.timezone - - async def on_startup(self) -> None: - """ - Инициализирует подключение к БД и готовит фабрику сессий. - """ - self.logger.info("Application is starting up.") - dsn = APP_CONFIG.resolved_dsn - self._engine = create_async_engine( - dsn, - pool_size=self.pg.pool_size if self.pg.use_pool else None, - max_overflow=self.pg.max_overflow if self.pg.use_pool else 0, - pool_recycle=self.pg.pool_recycle if self.pg.use_pool else -1, - ) - - schema = self.pg.schema_.strip() - - @event.listens_for(self._engine.sync_engine, "connect") - def _set_search_path(dbapi_conn, _): - cur = dbapi_conn.cursor() - try: - if schema: - cur.execute(f'SET SESSION search_path TO "{schema}"') - finally: - cur.close() - - self._sessionmaker = async_sessionmaker(self._engine, expire_on_commit=False, class_=AsyncSession) - - async with self._sessionmaker() as s: - await s.execute(select(func.count()).select_from(text("dl_jobs"))) - await s.execute(select(func.count()).select_from(text("dl_job_events"))) - - self.logger.info("All connections checked. Application is up and ready.") - - async def on_shutdown(self) -> None: - """ - Останавливает подсистемы и освобождает ресурсы. - """ - self.logger.info("Application is shutting down.") - if self._engine is not None: - await self._engine.dispose() - self._engine = None - self._sessionmaker = None - self._logger_manager.remove_logger_handlers() + return self._context_vars_container -APP_CTX = AppContext(APP_CONFIG) +APP_CTX = AppContext() -__all__ = ["APP_CTX"] + +async def get_session() -> AsyncGenerator[AsyncSession, None]: + """ + FastAPI dependency to get a database session. + Yields a session from the global sessionmaker and ensures it's closed. + """ + async with APP_CTX.sessionmaker() as session: + yield session diff --git a/src/dataloader/logger/__init__.py b/src/dataloader/logger/__init__.py index 9d0c9d9..da2dbbf 100644 --- a/src/dataloader/logger/__init__.py +++ b/src/dataloader/logger/__init__.py @@ -1,5 +1,4 @@ -# logger package -from .context_vars import ContextVarsContainer -from .logger import LoggerConfigurator +# src/dataloader/logger/__init__.py +from .logger import setup_logging, get_logger -__all__ = ["ContextVarsContainer", "LoggerConfigurator"] +__all__ = ["setup_logging", "get_logger"] diff --git a/src/dataloader/logger/logger.py b/src/dataloader/logger/logger.py index 9c52f51..b8c06f5 100644 --- a/src/dataloader/logger/logger.py +++ b/src/dataloader/logger/logger.py @@ -1,14 +1,14 @@ -# Основной логгер приложения +# src/dataloader/logger/logger.py import sys import typing from datetime import tzinfo - +from logging import Logger +import logging from loguru import logger - from .context_vars import ContextVarsContainer - +from dataloader.config import APP_CONFIG # Определяем фильтры для разных типов логов @@ -16,131 +16,44 @@ def metric_only_filter(record: dict) -> bool: return "metric" in record["extra"] - def audit_only_filter(record: dict) -> bool: return "audit" in record["extra"] - def regular_log_filter(record: dict) -> bool: return "metric" not in record["extra"] and "audit" not in record["extra"] +class InterceptHandler(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + # Get corresponding Loguru level if it exists + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno -class LoggerConfigurator: - def __init__( - self, - log_lvl: str, - log_file_path: str, - metric_file_path: str, - audit_file_path: str, - audit_host_ip: str, - audit_host_uid: str, - context_vars_container: ContextVarsContainer, - timezone: tzinfo, - ) -> None: - self.context_vars_container = context_vars_container - self.timezone = timezone - self.log_lvl = log_lvl - self.log_file_path = log_file_path - self.metric_file_path = metric_file_path - self.audit_file_path = audit_file_path - self.audit_host_ip = audit_host_ip - self.audit_host_uid = audit_host_uid - self._handler_ids = [] - self.configure_logger() + # Find caller from where originated the logged message + frame, depth = logging.currentframe(), 2 + while frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 - - @property - def async_logger(self) -> "typing.Any": - return self._async_logger - - - def patch_record_with_context(self, record: dict) -> None: - context_data = self.context_vars_container.as_dict() - record["extra"].update(context_data) - if not record["extra"].get("request_id"): - record["extra"]["request_id"] = "system_event" - - - def configure_logger(self) -> None: - """Настройка логгера `loguru` с необходимыми обработчиками.""" - logger.remove() - logger.patch(self.patch_record_with_context) - - # Функция для безопасного форматирования консольных логов - def console_format(record): - request_id = record["extra"].get("request_id", "system_event") - elapsed = record["elapsed"] - level = record["level"].name - name = record["name"] - function = record["function"] - line = record["line"] - message = record["message"] - time_str = record["time"].strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] - - return ( - f"{time_str} ({elapsed}) | " - f"{request_id} | " - f"{level: <8} | " - f"{name}:{function}:{line} - " - f"{message}\n" - ) - - # Обработчик для обычных логов (консоль) - handler_id = logger.add( - sys.stdout, - level=self.log_lvl, - filter=regular_log_filter, - format=console_format, - colorize=True, + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() ) - self._handler_ids.append(handler_id) - # Обработчик для обычных логов (файл) - handler_id = logger.add( - self.log_file_path, - level=self.log_lvl, - filter=regular_log_filter, - rotation="10 MB", - compression="zip", - enqueue=True, - serialize=True, - ) - self._handler_ids.append(handler_id) +def setup_logging(): + """Настройка логгера `loguru` с необходимыми обработчиками.""" + logger.remove() + logger.add( + sys.stdout, + level=APP_CONFIG.log.log_level.upper(), + filter=regular_log_filter, + colorize=True, + ) + logging.basicConfig(handlers=[InterceptHandler()], level=0) - # Обработчик для метрик - handler_id = logger.add( - self.metric_file_path, - level="INFO", - filter=metric_only_filter, - rotation="10 MB", - compression="zip", - enqueue=True, - serialize=True, - ) - self._handler_ids.append(handler_id) - - - # Обработчик для аудита - handler_id = logger.add( - self.audit_file_path, - level="INFO", - filter=audit_only_filter, - rotation="10 MB", - compression="zip", - enqueue=True, - serialize=True, - ) - self._handler_ids.append(handler_id) - - - self._async_logger = logger - - - def remove_logger_handlers(self) -> None: - """Удаление всех обработчиков логгера.""" - for handler_id in self._handler_ids: - self._async_logger.remove(handler_id) +def get_logger(name: str | None = None) -> Logger: + return logging.getLogger(name or "dataloader") diff --git a/src/dataloader/storage/repositories.py b/src/dataloader/storage/repositories.py index c8e5e80..aa66fb5 100644 --- a/src/dataloader/storage/repositories.py +++ b/src/dataloader/storage/repositories.py @@ -117,41 +117,41 @@ class QueueRepository: """ Идемпотентно создаёт запись в очереди и возвращает (job_id, status). """ - if req.idempotency_key: - q = select(DLJob).where(DLJob.idempotency_key == req.idempotency_key) - r = await self.s.execute(q) - ex = r.scalar_one_or_none() - if ex: - return ex.job_id, ex.status + async with self.s.begin(): + if req.idempotency_key: + q = select(DLJob).where(DLJob.idempotency_key == req.idempotency_key) + r = await self.s.execute(q) + ex = r.scalar_one_or_none() + if ex: + return ex.job_id, ex.status - row = DLJob( - job_id=req.job_id, - queue=req.queue, - task=req.task, - args=req.args or {}, - idempotency_key=req.idempotency_key, - lock_key=req.lock_key, - partition_key=req.partition_key or "", - priority=req.priority, - available_at=req.available_at, - status="queued", - attempt=0, - max_attempts=req.max_attempts, - lease_ttl_sec=req.lease_ttl_sec, - lease_expires_at=None, - heartbeat_at=None, - cancel_requested=False, - progress={}, - error=None, - producer=req.producer, - consumer_group=req.consumer_group, - created_at=datetime.now(timezone.utc), - started_at=None, - finished_at=None, - ) - self.s.add(row) - await self._append_event(req.job_id, req.queue, "queued", {"task": req.task}) - await self.s.commit() + row = DLJob( + job_id=req.job_id, + queue=req.queue, + task=req.task, + args=req.args or {}, + idempotency_key=req.idempotency_key, + lock_key=req.lock_key, + partition_key=req.partition_key or "", + priority=req.priority, + available_at=req.available_at, + status="queued", + attempt=0, + max_attempts=req.max_attempts, + lease_ttl_sec=req.lease_ttl_sec, + lease_expires_at=None, + heartbeat_at=None, + cancel_requested=False, + progress={}, + error=None, + producer=req.producer, + consumer_group=req.consumer_group, + created_at=datetime.now(timezone.utc), + started_at=None, + finished_at=None, + ) + self.s.add(row) + await self._append_event(req.job_id, req.queue, "queued", {"task": req.task}) return req.job_id, "queued" async def get_status(self, job_id: str) -> Optional[JobStatus]: @@ -187,12 +187,12 @@ class QueueRepository: """ Устанавливает флаг отмены для задачи. """ - job = await self._get(job_id) - if not job: - return False - job.cancel_requested = True - await self._append_event(job_id, job.queue, "cancel_requested", None) - await self.s.commit() + async with self.s.begin(): + job = await self._get(job_id) + if not job: + return False + job.cancel_requested = True + await self._append_event(job_id, job.queue, "cancel_requested", None) return True async def claim_one(self, queue: str, claim_backoff_sec: int) -> Optional[dict[str, Any]]: @@ -246,66 +246,66 @@ class QueueRepository: Обновляет heartbeat и продлевает lease. Возвращает (success, cancel_requested). """ - job = await self._get(job_id) - if not job or job.status != "running": - return False, False - - cancel_requested = bool(job.cancel_requested) - now = datetime.now(timezone.utc) - q = ( - update(DLJob) - .where(DLJob.job_id == job_id, DLJob.status == "running") - .values(heartbeat_at=now, lease_expires_at=now + timedelta(seconds=int(ttl_sec))) - ) - await self.s.execute(q) - await self._append_event(job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec}) - await self.s.commit() - return True, cancel_requested + async with self.s.begin(): + job = await self._get(job_id) + if not job or job.status != "running": + return False, False + + cancel_requested = bool(job.cancel_requested) + now = datetime.now(timezone.utc) + q = ( + update(DLJob) + .where(DLJob.job_id == job_id, DLJob.status == "running") + .values(heartbeat_at=now, lease_expires_at=now + timedelta(seconds=int(ttl_sec))) + ) + await self.s.execute(q) + await self._append_event(job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec}) + return True, cancel_requested async def finish_ok(self, job_id: str) -> None: """ Помечает задачу как выполненную успешно и снимает advisory-lock. """ - job = await self._get(job_id) - if not job: - return - job.status = "succeeded" - job.finished_at = datetime.now(timezone.utc) - job.lease_expires_at = None - await self._append_event(job_id, job.queue, "succeeded", None) - await self._advisory_unlock(job.lock_key) - await self.s.commit() + async with self.s.begin(): + job = await self._get(job_id) + if not job: + return + job.status = "succeeded" + job.finished_at = datetime.now(timezone.utc) + job.lease_expires_at = None + await self._append_event(job_id, job.queue, "succeeded", None) + await self._advisory_unlock(job.lock_key) async def finish_fail_or_retry(self, job_id: str, err: str, is_canceled: bool = False) -> None: """ Помечает задачу как failed, canceled или возвращает в очередь с задержкой. """ - job = await self._get(job_id) - if not job: - return - - if is_canceled: - job.status = "canceled" - job.error = err - job.finished_at = datetime.now(timezone.utc) - job.lease_expires_at = None - await self._append_event(job_id, job.queue, "canceled", {"error": err}) - else: - can_retry = int(job.attempt) < int(job.max_attempts) - if can_retry: - job.status = "queued" - job.available_at = datetime.now(timezone.utc) + timedelta(seconds=30 * int(job.attempt)) - job.error = err - job.lease_expires_at = None - await self._append_event(job_id, job.queue, "requeue", {"attempt": job.attempt, "error": err}) - else: - job.status = "failed" + async with self.s.begin(): + job = await self._get(job_id) + if not job: + return + + if is_canceled: + job.status = "canceled" job.error = err job.finished_at = datetime.now(timezone.utc) job.lease_expires_at = None - await self._append_event(job_id, job.queue, "failed", {"error": err}) - await self._advisory_unlock(job.lock_key) - await self.s.commit() + await self._append_event(job_id, job.queue, "canceled", {"error": err}) + else: + can_retry = int(job.attempt) < int(job.max_attempts) + if can_retry: + job.status = "queued" + job.available_at = datetime.now(timezone.utc) + timedelta(seconds=30 * int(job.attempt)) + job.error = err + job.lease_expires_at = None + await self._append_event(job_id, job.queue, "requeue", {"attempt": job.attempt, "error": err}) + else: + job.status = "failed" + job.error = err + job.finished_at = datetime.now(timezone.utc) + job.lease_expires_at = None + await self._append_event(job_id, job.queue, "failed", {"error": err}) + await self._advisory_unlock(job.lock_key) async def requeue_lost(self, now: Optional[datetime] = None) -> list[str]: """ diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..fce0f1e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,179 @@ +# tests/conftest.py +import asyncio +import sys +from typing import AsyncGenerator + +import pytest +from dotenv import load_dotenv +from httpx import AsyncClient, ASGITransport +from sqlalchemy import text +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.exc import ProgrammingError +from asyncpg.exceptions import InvalidCatalogNameError + +# Load .env before other imports to ensure config is available +load_dotenv() + +from dataloader.api import app_main +from dataloader.api.v1.router import get_service +from dataloader.api.v1.service import JobsService +from dataloader.context import APP_CTX +from dataloader.storage.db import Base +from dataloader.config import APP_CONFIG + +# For Windows: use SelectorEventLoop which is more stable +if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + +@pytest.fixture(scope="session") +async def db_engine() -> AsyncGenerator[AsyncEngine, None]: + """ + Creates a temporary, isolated test database for the entire test session. + - Connects to the default 'postgres' database to create a new test DB. + - Yields an engine connected to the new test DB. + - Drops the test DB after the session is complete. + """ + pg_settings = APP_CONFIG.pg + test_db_name = f"{pg_settings.database}_test" + + # DSN for connecting to 'postgres' DB to manage the test DB + system_dsn = pg_settings.url.replace(f"/{pg_settings.database}", "/postgres") + system_engine = create_async_engine(system_dsn, isolation_level="AUTOCOMMIT") + + # DSN for the new test database + test_dsn = pg_settings.url.replace(f"/{pg_settings.database}", f"/{test_db_name}") + + async with system_engine.connect() as conn: + # Drop the test DB if it exists from a previous failed run + await conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name} WITH (FORCE)")) + await conn.execute(text(f"CREATE DATABASE {test_db_name}")) + + # Create an engine connected to the new test database + engine = create_async_engine(test_dsn) + + # Create all tables and DDL objects + async with engine.begin() as conn: + # Execute DDL statements one by one + await conn.execute(text("CREATE TYPE dl_status AS ENUM ('queued','running','succeeded','failed','canceled','lost')")) + + # Create tables + await conn.run_sync(Base.metadata.create_all) + + # Create indexes + await conn.execute(text("CREATE INDEX ix_dl_jobs_claim ON dl_jobs(queue, available_at, priority, created_at) WHERE status = 'queued'")) + await conn.execute(text("CREATE INDEX ix_dl_jobs_running_lease ON dl_jobs(lease_expires_at) WHERE status = 'running'")) + await conn.execute(text("CREATE INDEX ix_dl_jobs_status_queue ON dl_jobs(status, queue)")) + + # Create function and triggers + await conn.execute( + text( + """ + CREATE OR REPLACE FUNCTION notify_job_ready() RETURNS trigger AS $$ + BEGIN + IF (TG_OP = 'INSERT') THEN + PERFORM pg_notify('dl_jobs', NEW.queue); + RETURN NEW; + ELSIF (TG_OP = 'UPDATE') THEN + IF NEW.status = 'queued' AND NEW.available_at <= now() + AND (OLD.status IS DISTINCT FROM NEW.status OR OLD.available_at IS DISTINCT FROM NEW.available_at) THEN + PERFORM pg_notify('dl_jobs', NEW.queue); + END IF; + RETURN NEW; + END IF; + RETURN NEW; + END $$ LANGUAGE plpgsql; + """ + ) + ) + await conn.execute(text("CREATE TRIGGER dl_jobs_notify_ins AFTER INSERT ON dl_jobs FOR EACH ROW EXECUTE FUNCTION notify_job_ready()")) + await conn.execute(text("CREATE TRIGGER dl_jobs_notify_upd AFTER UPDATE OF status, available_at ON dl_jobs FOR EACH ROW EXECUTE FUNCTION notify_job_ready()")) + + yield engine + + # --- Teardown --- + await engine.dispose() + + # Drop the test database + async with system_engine.connect() as conn: + await conn.execute(text(f"DROP DATABASE {test_db_name} WITH (FORCE)")) + await system_engine.dispose() + + +@pytest.fixture(scope="session") +async def app_context(db_engine: AsyncEngine) -> AsyncGenerator[None, None]: + """ + Overrides the global APP_CTX with the test database engine and sessionmaker. + This ensures that the app, when tested, uses the isolated test DB. + """ + original_engine = APP_CTX._engine + original_sessionmaker = APP_CTX._sessionmaker + + APP_CTX._engine = db_engine + APP_CTX._sessionmaker = async_sessionmaker(db_engine, expire_on_commit=False, class_=AsyncSession) + + yield + + # Restore original context + APP_CTX._engine = original_engine + APP_CTX._sessionmaker = original_sessionmaker + + +@pytest.fixture +async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]: + """ + Provides a transactional session for tests. + A single connection is acquired from the engine, a transaction is started, + and a new session is created bound to that connection. At the end of the test, + the transaction is rolled back and the connection is closed, ensuring + perfect test isolation. + """ + connection = await db_engine.connect() + trans = await connection.begin() + + # Create a sessionmaker bound to the single connection + TestSession = async_sessionmaker( + bind=connection, + expire_on_commit=False, + class_=AsyncSession + ) + session = TestSession() + + try: + yield session + finally: + # Clean up the session, rollback the transaction, and close the connection + await session.close() + await trans.rollback() + await connection.close() + + +@pytest.fixture +async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: + """ + Provides an HTTP client for the FastAPI application. + It overrides the 'get_session' dependency to use the test-scoped, + transactional session provided by the 'db_session' fixture. This ensures + that API calls operate within the same transaction as the test function, + allowing for consistent state checks. + """ + async def override_get_session() -> AsyncGenerator[AsyncSession, None]: + """This override simply yields the session created by the db_session fixture.""" + yield db_session + + # The service depends on the session, so we override the session provider + # that is used by the service via Depends(get_session) + from dataloader.context import get_session + app_main.dependency_overrides[get_session] = override_get_session + + transport = ASGITransport(app=app_main) + async with AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + # Clean up the dependency override + app_main.dependency_overrides.clear() diff --git a/tests/integrations/test_api_v1.py b/tests/integrations/test_api_v1.py new file mode 100644 index 0000000..e474f41 --- /dev/null +++ b/tests/integrations/test_api_v1.py @@ -0,0 +1,93 @@ +# tests/integrations/test_api_v1.py +import pytest +from httpx import AsyncClient + +@pytest.mark.anyio +async def test_trigger_and_get_status_ok(client: AsyncClient): + """ + Тест проверяет успешное создание задачи и получение её статуса. + """ + # 1. Триггер задачи + trigger_payload = { + "queue": "test_queue", + "task": "test_task", + "lock_key": "lock_123", + } + response = await client.post("/api/v1/jobs/trigger", json=trigger_payload) + + # Проверки ответа на триггер + assert response.status_code == 200 + response_data = response.json() + assert "job_id" in response_data + assert response_data["status"] == "queued" + job_id = response_data["job_id"] + + # 2. Получение статуса + response = await client.get(f"/api/v1/jobs/{job_id}/status") + + # Проверки ответа на статус + assert response.status_code == 200 + status_data = response.json() + assert status_data["job_id"] == job_id + assert status_data["status"] == "queued" + assert status_data["attempt"] == 0 + + +@pytest.mark.anyio +async def test_cancel_job_ok(client: AsyncClient): + """ + Тест проверяет успешную отмену задачи. + """ + # 1. Триггер задачи + trigger_payload = { + "queue": "cancel_queue", + "task": "cancel_task", + "lock_key": "lock_cancel", + } + response = await client.post("/api/v1/jobs/trigger", json=trigger_payload) + assert response.status_code == 200 + job_id = response.json()["job_id"] + + # 2. Запрос на отмену + response = await client.post(f"/api/v1/jobs/{job_id}/cancel") + assert response.status_code == 200 + cancel_data = response.json() + assert cancel_data["job_id"] == job_id + # Воркер еще не взял задачу, поэтому статус queued, но cancel_requested уже true + assert cancel_data["status"] == "queued" + + # 3. Проверка статуса после отмены + response = await client.get(f"/api/v1/jobs/{job_id}/status") + assert response.status_code == 200 + status_data = response.json() + assert status_data["job_id"] == job_id + # Задача еще не отменена, а только запрошена к отмене + assert status_data["status"] == "queued" + + +@pytest.mark.anyio +async def test_trigger_duplicate_idempotency_key(client: AsyncClient): + """ + Тест проверяет, что повторная отправка с тем же ключом идемпотентности + возвращает ту же самую задачу. + """ + idempotency_key = "idem_key_123" + trigger_payload = { + "queue": "idempotent_queue", + "task": "idempotent_task", + "lock_key": "lock_idem", + "idempotency_key": idempotency_key, + } + + # 1. Первый запрос + response1 = await client.post("/api/v1/jobs/trigger", json=trigger_payload) + assert response1.status_code == 200 + job_id1 = response1.json()["job_id"] + + # 2. Второй запрос с тем же ключом + response2 = await client.post("/api/v1/jobs/trigger", json=trigger_payload) + assert response2.status_code == 200 + job_id2 = response2.json()["job_id"] + + # ID задач должны совпадать + assert job_id1 == job_id2 \ No newline at end of file diff --git a/tests/integrations/test_worker_protocol.py b/tests/integrations/test_worker_protocol.py new file mode 100644 index 0000000..47f0981 --- /dev/null +++ b/tests/integrations/test_worker_protocol.py @@ -0,0 +1,155 @@ +# tests/integrations/test_worker_protocol.py +import asyncio +from datetime import datetime, timedelta, timezone +from uuid import uuid4 + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dataloader.storage.repositories import QueueRepository, CreateJobRequest +from dataloader.context import APP_CTX + +@pytest.mark.anyio +async def test_e2e_worker_protocol_ok(db_session: AsyncSession): + """ + Проверяет полный E2E-сценарий жизненного цикла задачи: + 1. Постановка (create) + 2. Захват (claim) + 3. Пульс (heartbeat) + 4. Успешное завершение (finish_ok) + 5. Проверка статуса + """ + repo = QueueRepository(db_session) + job_id = str(uuid4()) + queue_name = "e2e_ok_queue" + lock_key = f"lock_{job_id}" + + # 1. Постановка задачи + create_req = CreateJobRequest( + job_id=job_id, + queue=queue_name, + task="test_e2e_task", + args={}, + idempotency_key=None, + lock_key=lock_key, + partition_key="", + priority=100, + available_at=datetime.now(timezone.utc), + max_attempts=3, + lease_ttl_sec=30, + producer=None, + consumer_group=None, + ) + await repo.create_or_get(create_req) + + # 2. Захват задачи + claimed_job = await repo.claim_one(queue_name, claim_backoff_sec=10) + assert claimed_job is not None + assert claimed_job["job_id"] == job_id + assert claimed_job["lock_key"] == lock_key + + # 3. Пульс + success, cancel_requested = await repo.heartbeat(job_id, ttl_sec=60) + assert success + assert not cancel_requested + + # 4. Успешное завершение + await repo.finish_ok(job_id) + + # 5. Проверка статуса + status = await repo.get_status(job_id) + assert status is not None + assert status.status == "succeeded" + assert status.finished_at is not None + + +@pytest.mark.anyio +async def test_concurrency_claim_one_locks(db_session: AsyncSession): + """ + Проверяет, что при конкурентном доступе к задачам с одинаковым + lock_key только один воркер может захватить задачу. + """ + repo = QueueRepository(db_session) + queue_name = "concurrency_queue" + lock_key = "concurrent_lock_123" + job_ids = [str(uuid4()), str(uuid4())] + + # 1. Создание двух задач с одинаковым lock_key + for i, job_id in enumerate(job_ids): + create_req = CreateJobRequest( + job_id=job_id, + queue=queue_name, + task=f"task_{i}", + args={}, + idempotency_key=f"idem_con_{i}", + lock_key=lock_key, + partition_key="", + priority=100 + i, + available_at=datetime.now(timezone.utc), + max_attempts=1, + lease_ttl_sec=30, + producer="test", + consumer_group="test_group", + ) + await repo.create_or_get(create_req) + + # 2. Первый воркер захватывает задачу + claimed_job_1 = await repo.claim_one(queue_name, claim_backoff_sec=1) + assert claimed_job_1 is not None + assert claimed_job_1["job_id"] == job_ids[0] + + # 3. Второй воркер пытается захватить задачу, но не может (из-за advisory lock) + claimed_job_2 = await repo.claim_one(queue_name, claim_backoff_sec=1) + assert claimed_job_2 is None + + # 4. Первый воркер освобождает advisory lock (как будто завершил работу) + await repo._advisory_unlock(lock_key) + + # 5. Второй воркер теперь может захватить вторую задачу + claimed_job_3 = await repo.claim_one(queue_name, claim_backoff_sec=1) + assert claimed_job_3 is not None + assert claimed_job_3["job_id"] == job_ids[1] + + +@pytest.mark.anyio +async def test_reaper_requeues_lost_jobs(db_session: AsyncSession): + """ + Проверяет, что reaper корректно возвращает "потерянные" задачи в очередь. + """ + repo = QueueRepository(db_session) + job_id = str(uuid4()) + queue_name = "reaper_queue" + + # 1. Создаем и захватываем задачу + create_req = CreateJobRequest( + job_id=job_id, + queue=queue_name, + task="reaper_test_task", + args={}, + idempotency_key="idem_reaper_1", + lock_key="reaper_lock_1", + partition_key="", + priority=100, + available_at=datetime.now(timezone.utc), + max_attempts=3, + lease_ttl_sec=1, # Очень короткий lease + producer=None, + consumer_group=None, + ) + await repo.create_or_get(create_req) + + claimed_job = await repo.claim_one(queue_name, claim_backoff_sec=1) + assert claimed_job is not None + assert claimed_job["job_id"] == job_id + + # 2. Ждем истечения lease + await asyncio.sleep(2) + + # 3. Запускаем reaper + requeued_ids = await repo.requeue_lost() + assert requeued_ids == [job_id] + + # 4. Проверяем статус + status = await repo.get_status(job_id) + assert status is not None + assert status.status == "queued"