roocode: try fixes

This commit is contained in:
itqop 2025-11-05 13:31:19 +03:00
parent 7a76dc1d84
commit 7152f4b61e
9 changed files with 592 additions and 318 deletions

View File

@ -14,18 +14,18 @@ from dataloader.api.v1.schemas import (
TriggerJobResponse, TriggerJobResponse,
) )
from dataloader.api.v1.service import JobsService from dataloader.api.v1.service import JobsService
from dataloader.storage.db import session_scope from dataloader.context import get_session
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter(prefix="/jobs", tags=["jobs"]) router = APIRouter(prefix="/jobs", tags=["jobs"])
async def get_service() -> AsyncGenerator[JobsService, None]: def get_service(session: Annotated[AsyncSession, Depends(get_session)]) -> JobsService:
""" """
Создаёт JobsService с новой сессией и корректно закрывает её после запроса. FastAPI dependency to create a JobsService instance with a database session.
""" """
async for s in session_scope(): return JobsService(session)
yield JobsService(s)
@router.post("/trigger", response_model=TriggerJobResponse, status_code=HTTPStatus.OK) @router.post("/trigger", response_model=TriggerJobResponse, status_code=HTTPStatus.OK)

View File

@ -17,7 +17,7 @@ from dataloader.storage.repositories import (
CreateJobRequest, CreateJobRequest,
QueueRepository, QueueRepository,
) )
from dataloader.context import APP_CTX from dataloader.logger.logger import get_logger
class JobsService: class JobsService:
@ -27,7 +27,7 @@ class JobsService:
def __init__(self, session: AsyncSession): def __init__(self, session: AsyncSession):
self._s = session self._s = session
self._repo = QueueRepository(self._s) self._repo = QueueRepository(self._s)
self._log = APP_CTX.get_logger() self._log = get_logger(__name__)
async def trigger(self, req: TriggerJobRequest) -> TriggerJobResponse: async def trigger(self, req: TriggerJobRequest) -> TriggerJobResponse:
""" """

View File

@ -1,124 +1,59 @@
# src/dataloader/context.py # src/dataloader/context.py
from __future__ import annotations from typing import AsyncGenerator
from logging import Logger
import typing from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
import pytz
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from dataloader.base import Singleton from .config import APP_CONFIG
from dataloader.config import APP_CONFIG, Secrets from .logger.context_vars import ContextVarsContainer
from dataloader.logger import ContextVarsContainer, LoggerConfigurator
from sqlalchemy import event, select, func, text
class AppContext(metaclass=Singleton): class AppContext:
""" def __init__(self) -> None:
Контекст приложения: логгер, таймзона, подключение к БД и фабрика сессий.
"""
def __init__(self, secrets: Secrets) -> None:
self.timezone = pytz.timezone(secrets.app.timezone)
self.context_vars_container = ContextVarsContainer()
self._logger_manager = LoggerConfigurator(
log_lvl=secrets.log.log_lvl,
log_file_path=secrets.log.log_file_abs_path,
metric_file_path=secrets.log.metric_file_abs_path,
audit_file_path=secrets.log.audit_file_abs_path,
audit_host_ip=secrets.log.audit_host_ip,
audit_host_uid=secrets.log.audit_host_uid,
context_vars_container=self.context_vars_container,
timezone=self.timezone,
)
self.pg = secrets.pg
self.dl = secrets.dl
self._engine: AsyncEngine | None = None self._engine: AsyncEngine | None = None
self._sessionmaker: async_sessionmaker[AsyncSession] | None = None self._sessionmaker: async_sessionmaker[AsyncSession] | None = None
self.logger.info("App context initialized.") self._context_vars_container = ContextVarsContainer()
@property
def logger(self) -> "typing.Any":
"""
Возвращает асинхронный логгер.
"""
return self._logger_manager.async_logger
@property @property
def engine(self) -> AsyncEngine: def engine(self) -> AsyncEngine:
""" if self._engine is None:
Возвращает текущий AsyncEngine. raise RuntimeError("Database engine is not initialized.")
"""
assert self._engine is not None, "Engine is not initialized"
return self._engine return self._engine
@property @property
def sessionmaker(self) -> async_sessionmaker[AsyncSession]: def sessionmaker(self) -> async_sessionmaker[AsyncSession]:
""" if self._sessionmaker is None:
Возвращает фабрику асинхронных сессий. raise RuntimeError("Sessionmaker is not initialized.")
"""
assert self._sessionmaker is not None, "Sessionmaker is not initialized"
return self._sessionmaker return self._sessionmaker
def get_logger(self) -> "typing.Any": def startup(self) -> None:
""" from .storage.db import create_engine, create_sessionmaker
Возвращает логгер. from .logger.logger import setup_logging
"""
return self.logger setup_logging()
self._engine = create_engine(APP_CONFIG.pg.url)
self._sessionmaker = create_sessionmaker(self._engine)
async def shutdown(self) -> None:
if self._engine:
await self._engine.dispose()
def get_logger(self, name: str | None = None) -> Logger:
"""Returns a configured logger instance."""
from .logger.logger import get_logger as get_app_logger
return get_app_logger(name)
def get_context_vars_container(self) -> ContextVarsContainer: def get_context_vars_container(self) -> ContextVarsContainer:
""" return self._context_vars_container
Возвращает контейнер контекстных переменных логгера.
"""
return self.context_vars_container
def get_pytz_timezone(self):
"""
Возвращает таймзону приложения.
"""
return self.timezone
async def on_startup(self) -> None:
"""
Инициализирует подключение к БД и готовит фабрику сессий.
"""
self.logger.info("Application is starting up.")
dsn = APP_CONFIG.resolved_dsn
self._engine = create_async_engine(
dsn,
pool_size=self.pg.pool_size if self.pg.use_pool else None,
max_overflow=self.pg.max_overflow if self.pg.use_pool else 0,
pool_recycle=self.pg.pool_recycle if self.pg.use_pool else -1,
)
schema = self.pg.schema_.strip()
@event.listens_for(self._engine.sync_engine, "connect")
def _set_search_path(dbapi_conn, _):
cur = dbapi_conn.cursor()
try:
if schema:
cur.execute(f'SET SESSION search_path TO "{schema}"')
finally:
cur.close()
self._sessionmaker = async_sessionmaker(self._engine, expire_on_commit=False, class_=AsyncSession)
async with self._sessionmaker() as s:
await s.execute(select(func.count()).select_from(text("dl_jobs")))
await s.execute(select(func.count()).select_from(text("dl_job_events")))
self.logger.info("All connections checked. Application is up and ready.")
async def on_shutdown(self) -> None:
"""
Останавливает подсистемы и освобождает ресурсы.
"""
self.logger.info("Application is shutting down.")
if self._engine is not None:
await self._engine.dispose()
self._engine = None
self._sessionmaker = None
self._logger_manager.remove_logger_handlers()
APP_CTX = AppContext(APP_CONFIG) APP_CTX = AppContext()
__all__ = ["APP_CTX"]
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""
FastAPI dependency to get a database session.
Yields a session from the global sessionmaker and ensures it's closed.
"""
async with APP_CTX.sessionmaker() as session:
yield session

View File

@ -1,5 +1,4 @@
# logger package # src/dataloader/logger/__init__.py
from .context_vars import ContextVarsContainer from .logger import setup_logging, get_logger
from .logger import LoggerConfigurator
__all__ = ["ContextVarsContainer", "LoggerConfigurator"] __all__ = ["setup_logging", "get_logger"]

View File

@ -1,14 +1,14 @@
# Основной логгер приложения # src/dataloader/logger/logger.py
import sys import sys
import typing import typing
from datetime import tzinfo from datetime import tzinfo
from logging import Logger
import logging
from loguru import logger from loguru import logger
from .context_vars import ContextVarsContainer from .context_vars import ContextVarsContainer
from dataloader.config import APP_CONFIG
# Определяем фильтры для разных типов логов # Определяем фильтры для разных типов логов
@ -16,131 +16,44 @@ def metric_only_filter(record: dict) -> bool:
return "metric" in record["extra"] return "metric" in record["extra"]
def audit_only_filter(record: dict) -> bool: def audit_only_filter(record: dict) -> bool:
return "audit" in record["extra"] return "audit" in record["extra"]
def regular_log_filter(record: dict) -> bool: def regular_log_filter(record: dict) -> bool:
return "metric" not in record["extra"] and "audit" not in record["extra"] return "metric" not in record["extra"] and "audit" not in record["extra"]
class InterceptHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
# Get corresponding Loguru level if it exists
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
class LoggerConfigurator: # Find caller from where originated the logged message
def __init__( frame, depth = logging.currentframe(), 2
self, while frame.f_code.co_filename == logging.__file__:
log_lvl: str, frame = frame.f_back
log_file_path: str, depth += 1
metric_file_path: str,
audit_file_path: str,
audit_host_ip: str,
audit_host_uid: str,
context_vars_container: ContextVarsContainer,
timezone: tzinfo,
) -> None:
self.context_vars_container = context_vars_container
self.timezone = timezone
self.log_lvl = log_lvl
self.log_file_path = log_file_path
self.metric_file_path = metric_file_path
self.audit_file_path = audit_file_path
self.audit_host_ip = audit_host_ip
self.audit_host_uid = audit_host_uid
self._handler_ids = []
self.configure_logger()
logger.opt(depth=depth, exception=record.exc_info).log(
@property level, record.getMessage()
def async_logger(self) -> "typing.Any":
return self._async_logger
def patch_record_with_context(self, record: dict) -> None:
context_data = self.context_vars_container.as_dict()
record["extra"].update(context_data)
if not record["extra"].get("request_id"):
record["extra"]["request_id"] = "system_event"
def configure_logger(self) -> None:
"""Настройка логгера `loguru` с необходимыми обработчиками."""
logger.remove()
logger.patch(self.patch_record_with_context)
# Функция для безопасного форматирования консольных логов
def console_format(record):
request_id = record["extra"].get("request_id", "system_event")
elapsed = record["elapsed"]
level = record["level"].name
name = record["name"]
function = record["function"]
line = record["line"]
message = record["message"]
time_str = record["time"].strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
return (
f"<green>{time_str} ({elapsed})</green> | "
f"<cyan>{request_id}</cyan> | "
f"<level>{level: <8}</level> | "
f"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
f"<level>{message}</level>\n"
)
# Обработчик для обычных логов (консоль)
handler_id = logger.add(
sys.stdout,
level=self.log_lvl,
filter=regular_log_filter,
format=console_format,
colorize=True,
) )
self._handler_ids.append(handler_id)
# Обработчик для обычных логов (файл) def setup_logging():
handler_id = logger.add( """Настройка логгера `loguru` с необходимыми обработчиками."""
self.log_file_path, logger.remove()
level=self.log_lvl, logger.add(
filter=regular_log_filter, sys.stdout,
rotation="10 MB", level=APP_CONFIG.log.log_level.upper(),
compression="zip", filter=regular_log_filter,
enqueue=True, colorize=True,
serialize=True, )
) logging.basicConfig(handlers=[InterceptHandler()], level=0)
self._handler_ids.append(handler_id)
# Обработчик для метрик def get_logger(name: str | None = None) -> Logger:
handler_id = logger.add( return logging.getLogger(name or "dataloader")
self.metric_file_path,
level="INFO",
filter=metric_only_filter,
rotation="10 MB",
compression="zip",
enqueue=True,
serialize=True,
)
self._handler_ids.append(handler_id)
# Обработчик для аудита
handler_id = logger.add(
self.audit_file_path,
level="INFO",
filter=audit_only_filter,
rotation="10 MB",
compression="zip",
enqueue=True,
serialize=True,
)
self._handler_ids.append(handler_id)
self._async_logger = logger
def remove_logger_handlers(self) -> None:
"""Удаление всех обработчиков логгера."""
for handler_id in self._handler_ids:
self._async_logger.remove(handler_id)

View File

@ -117,41 +117,41 @@ class QueueRepository:
""" """
Идемпотентно создаёт запись в очереди и возвращает (job_id, status). Идемпотентно создаёт запись в очереди и возвращает (job_id, status).
""" """
if req.idempotency_key: async with self.s.begin():
q = select(DLJob).where(DLJob.idempotency_key == req.idempotency_key) if req.idempotency_key:
r = await self.s.execute(q) q = select(DLJob).where(DLJob.idempotency_key == req.idempotency_key)
ex = r.scalar_one_or_none() r = await self.s.execute(q)
if ex: ex = r.scalar_one_or_none()
return ex.job_id, ex.status if ex:
return ex.job_id, ex.status
row = DLJob( row = DLJob(
job_id=req.job_id, job_id=req.job_id,
queue=req.queue, queue=req.queue,
task=req.task, task=req.task,
args=req.args or {}, args=req.args or {},
idempotency_key=req.idempotency_key, idempotency_key=req.idempotency_key,
lock_key=req.lock_key, lock_key=req.lock_key,
partition_key=req.partition_key or "", partition_key=req.partition_key or "",
priority=req.priority, priority=req.priority,
available_at=req.available_at, available_at=req.available_at,
status="queued", status="queued",
attempt=0, attempt=0,
max_attempts=req.max_attempts, max_attempts=req.max_attempts,
lease_ttl_sec=req.lease_ttl_sec, lease_ttl_sec=req.lease_ttl_sec,
lease_expires_at=None, lease_expires_at=None,
heartbeat_at=None, heartbeat_at=None,
cancel_requested=False, cancel_requested=False,
progress={}, progress={},
error=None, error=None,
producer=req.producer, producer=req.producer,
consumer_group=req.consumer_group, consumer_group=req.consumer_group,
created_at=datetime.now(timezone.utc), created_at=datetime.now(timezone.utc),
started_at=None, started_at=None,
finished_at=None, finished_at=None,
) )
self.s.add(row) self.s.add(row)
await self._append_event(req.job_id, req.queue, "queued", {"task": req.task}) await self._append_event(req.job_id, req.queue, "queued", {"task": req.task})
await self.s.commit()
return req.job_id, "queued" return req.job_id, "queued"
async def get_status(self, job_id: str) -> Optional[JobStatus]: async def get_status(self, job_id: str) -> Optional[JobStatus]:
@ -187,12 +187,12 @@ class QueueRepository:
""" """
Устанавливает флаг отмены для задачи. Устанавливает флаг отмены для задачи.
""" """
job = await self._get(job_id) async with self.s.begin():
if not job: job = await self._get(job_id)
return False if not job:
job.cancel_requested = True return False
await self._append_event(job_id, job.queue, "cancel_requested", None) job.cancel_requested = True
await self.s.commit() await self._append_event(job_id, job.queue, "cancel_requested", None)
return True return True
async def claim_one(self, queue: str, claim_backoff_sec: int) -> Optional[dict[str, Any]]: async def claim_one(self, queue: str, claim_backoff_sec: int) -> Optional[dict[str, Any]]:
@ -246,66 +246,66 @@ class QueueRepository:
Обновляет heartbeat и продлевает lease. Обновляет heartbeat и продлевает lease.
Возвращает (success, cancel_requested). Возвращает (success, cancel_requested).
""" """
job = await self._get(job_id) async with self.s.begin():
if not job or job.status != "running": job = await self._get(job_id)
return False, False if not job or job.status != "running":
return False, False
cancel_requested = bool(job.cancel_requested)
now = datetime.now(timezone.utc) cancel_requested = bool(job.cancel_requested)
q = ( now = datetime.now(timezone.utc)
update(DLJob) q = (
.where(DLJob.job_id == job_id, DLJob.status == "running") update(DLJob)
.values(heartbeat_at=now, lease_expires_at=now + timedelta(seconds=int(ttl_sec))) .where(DLJob.job_id == job_id, DLJob.status == "running")
) .values(heartbeat_at=now, lease_expires_at=now + timedelta(seconds=int(ttl_sec)))
await self.s.execute(q) )
await self._append_event(job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec}) await self.s.execute(q)
await self.s.commit() await self._append_event(job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec})
return True, cancel_requested return True, cancel_requested
async def finish_ok(self, job_id: str) -> None: async def finish_ok(self, job_id: str) -> None:
""" """
Помечает задачу как выполненную успешно и снимает advisory-lock. Помечает задачу как выполненную успешно и снимает advisory-lock.
""" """
job = await self._get(job_id) async with self.s.begin():
if not job: job = await self._get(job_id)
return if not job:
job.status = "succeeded" return
job.finished_at = datetime.now(timezone.utc) job.status = "succeeded"
job.lease_expires_at = None job.finished_at = datetime.now(timezone.utc)
await self._append_event(job_id, job.queue, "succeeded", None) job.lease_expires_at = None
await self._advisory_unlock(job.lock_key) await self._append_event(job_id, job.queue, "succeeded", None)
await self.s.commit() await self._advisory_unlock(job.lock_key)
async def finish_fail_or_retry(self, job_id: str, err: str, is_canceled: bool = False) -> None: async def finish_fail_or_retry(self, job_id: str, err: str, is_canceled: bool = False) -> None:
""" """
Помечает задачу как failed, canceled или возвращает в очередь с задержкой. Помечает задачу как failed, canceled или возвращает в очередь с задержкой.
""" """
job = await self._get(job_id) async with self.s.begin():
if not job: job = await self._get(job_id)
return if not job:
return
if is_canceled:
job.status = "canceled" if is_canceled:
job.error = err job.status = "canceled"
job.finished_at = datetime.now(timezone.utc)
job.lease_expires_at = None
await self._append_event(job_id, job.queue, "canceled", {"error": err})
else:
can_retry = int(job.attempt) < int(job.max_attempts)
if can_retry:
job.status = "queued"
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=30 * int(job.attempt))
job.error = err
job.lease_expires_at = None
await self._append_event(job_id, job.queue, "requeue", {"attempt": job.attempt, "error": err})
else:
job.status = "failed"
job.error = err job.error = err
job.finished_at = datetime.now(timezone.utc) job.finished_at = datetime.now(timezone.utc)
job.lease_expires_at = None job.lease_expires_at = None
await self._append_event(job_id, job.queue, "failed", {"error": err}) await self._append_event(job_id, job.queue, "canceled", {"error": err})
await self._advisory_unlock(job.lock_key) else:
await self.s.commit() can_retry = int(job.attempt) < int(job.max_attempts)
if can_retry:
job.status = "queued"
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=30 * int(job.attempt))
job.error = err
job.lease_expires_at = None
await self._append_event(job_id, job.queue, "requeue", {"attempt": job.attempt, "error": err})
else:
job.status = "failed"
job.error = err
job.finished_at = datetime.now(timezone.utc)
job.lease_expires_at = None
await self._append_event(job_id, job.queue, "failed", {"error": err})
await self._advisory_unlock(job.lock_key)
async def requeue_lost(self, now: Optional[datetime] = None) -> list[str]: async def requeue_lost(self, now: Optional[datetime] = None) -> list[str]:
""" """

179
tests/conftest.py Normal file
View File

@ -0,0 +1,179 @@
# tests/conftest.py
import asyncio
import sys
from typing import AsyncGenerator
import pytest
from dotenv import load_dotenv
from httpx import AsyncClient, ASGITransport
from sqlalchemy import text
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.exc import ProgrammingError
from asyncpg.exceptions import InvalidCatalogNameError
# Load .env before other imports to ensure config is available
load_dotenv()
from dataloader.api import app_main
from dataloader.api.v1.router import get_service
from dataloader.api.v1.service import JobsService
from dataloader.context import APP_CTX
from dataloader.storage.db import Base
from dataloader.config import APP_CONFIG
# For Windows: use SelectorEventLoop which is more stable
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@pytest.fixture(scope="session")
async def db_engine() -> AsyncGenerator[AsyncEngine, None]:
"""
Creates a temporary, isolated test database for the entire test session.
- Connects to the default 'postgres' database to create a new test DB.
- Yields an engine connected to the new test DB.
- Drops the test DB after the session is complete.
"""
pg_settings = APP_CONFIG.pg
test_db_name = f"{pg_settings.database}_test"
# DSN for connecting to 'postgres' DB to manage the test DB
system_dsn = pg_settings.url.replace(f"/{pg_settings.database}", "/postgres")
system_engine = create_async_engine(system_dsn, isolation_level="AUTOCOMMIT")
# DSN for the new test database
test_dsn = pg_settings.url.replace(f"/{pg_settings.database}", f"/{test_db_name}")
async with system_engine.connect() as conn:
# Drop the test DB if it exists from a previous failed run
await conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name} WITH (FORCE)"))
await conn.execute(text(f"CREATE DATABASE {test_db_name}"))
# Create an engine connected to the new test database
engine = create_async_engine(test_dsn)
# Create all tables and DDL objects
async with engine.begin() as conn:
# Execute DDL statements one by one
await conn.execute(text("CREATE TYPE dl_status AS ENUM ('queued','running','succeeded','failed','canceled','lost')"))
# Create tables
await conn.run_sync(Base.metadata.create_all)
# Create indexes
await conn.execute(text("CREATE INDEX ix_dl_jobs_claim ON dl_jobs(queue, available_at, priority, created_at) WHERE status = 'queued'"))
await conn.execute(text("CREATE INDEX ix_dl_jobs_running_lease ON dl_jobs(lease_expires_at) WHERE status = 'running'"))
await conn.execute(text("CREATE INDEX ix_dl_jobs_status_queue ON dl_jobs(status, queue)"))
# Create function and triggers
await conn.execute(
text(
"""
CREATE OR REPLACE FUNCTION notify_job_ready() RETURNS trigger AS $$
BEGIN
IF (TG_OP = 'INSERT') THEN
PERFORM pg_notify('dl_jobs', NEW.queue);
RETURN NEW;
ELSIF (TG_OP = 'UPDATE') THEN
IF NEW.status = 'queued' AND NEW.available_at <= now()
AND (OLD.status IS DISTINCT FROM NEW.status OR OLD.available_at IS DISTINCT FROM NEW.available_at) THEN
PERFORM pg_notify('dl_jobs', NEW.queue);
END IF;
RETURN NEW;
END IF;
RETURN NEW;
END $$ LANGUAGE plpgsql;
"""
)
)
await conn.execute(text("CREATE TRIGGER dl_jobs_notify_ins AFTER INSERT ON dl_jobs FOR EACH ROW EXECUTE FUNCTION notify_job_ready()"))
await conn.execute(text("CREATE TRIGGER dl_jobs_notify_upd AFTER UPDATE OF status, available_at ON dl_jobs FOR EACH ROW EXECUTE FUNCTION notify_job_ready()"))
yield engine
# --- Teardown ---
await engine.dispose()
# Drop the test database
async with system_engine.connect() as conn:
await conn.execute(text(f"DROP DATABASE {test_db_name} WITH (FORCE)"))
await system_engine.dispose()
@pytest.fixture(scope="session")
async def app_context(db_engine: AsyncEngine) -> AsyncGenerator[None, None]:
"""
Overrides the global APP_CTX with the test database engine and sessionmaker.
This ensures that the app, when tested, uses the isolated test DB.
"""
original_engine = APP_CTX._engine
original_sessionmaker = APP_CTX._sessionmaker
APP_CTX._engine = db_engine
APP_CTX._sessionmaker = async_sessionmaker(db_engine, expire_on_commit=False, class_=AsyncSession)
yield
# Restore original context
APP_CTX._engine = original_engine
APP_CTX._sessionmaker = original_sessionmaker
@pytest.fixture
async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
"""
Provides a transactional session for tests.
A single connection is acquired from the engine, a transaction is started,
and a new session is created bound to that connection. At the end of the test,
the transaction is rolled back and the connection is closed, ensuring
perfect test isolation.
"""
connection = await db_engine.connect()
trans = await connection.begin()
# Create a sessionmaker bound to the single connection
TestSession = async_sessionmaker(
bind=connection,
expire_on_commit=False,
class_=AsyncSession
)
session = TestSession()
try:
yield session
finally:
# Clean up the session, rollback the transaction, and close the connection
await session.close()
await trans.rollback()
await connection.close()
@pytest.fixture
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
"""
Provides an HTTP client for the FastAPI application.
It overrides the 'get_session' dependency to use the test-scoped,
transactional session provided by the 'db_session' fixture. This ensures
that API calls operate within the same transaction as the test function,
allowing for consistent state checks.
"""
async def override_get_session() -> AsyncGenerator[AsyncSession, None]:
"""This override simply yields the session created by the db_session fixture."""
yield db_session
# The service depends on the session, so we override the session provider
# that is used by the service via Depends(get_session)
from dataloader.context import get_session
app_main.dependency_overrides[get_session] = override_get_session
transport = ASGITransport(app=app_main)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# Clean up the dependency override
app_main.dependency_overrides.clear()

View File

@ -0,0 +1,93 @@
# tests/integrations/test_api_v1.py
import pytest
from httpx import AsyncClient
@pytest.mark.anyio
async def test_trigger_and_get_status_ok(client: AsyncClient):
"""
Тест проверяет успешное создание задачи и получение её статуса.
"""
# 1. Триггер задачи
trigger_payload = {
"queue": "test_queue",
"task": "test_task",
"lock_key": "lock_123",
}
response = await client.post("/api/v1/jobs/trigger", json=trigger_payload)
# Проверки ответа на триггер
assert response.status_code == 200
response_data = response.json()
assert "job_id" in response_data
assert response_data["status"] == "queued"
job_id = response_data["job_id"]
# 2. Получение статуса
response = await client.get(f"/api/v1/jobs/{job_id}/status")
# Проверки ответа на статус
assert response.status_code == 200
status_data = response.json()
assert status_data["job_id"] == job_id
assert status_data["status"] == "queued"
assert status_data["attempt"] == 0
@pytest.mark.anyio
async def test_cancel_job_ok(client: AsyncClient):
"""
Тест проверяет успешную отмену задачи.
"""
# 1. Триггер задачи
trigger_payload = {
"queue": "cancel_queue",
"task": "cancel_task",
"lock_key": "lock_cancel",
}
response = await client.post("/api/v1/jobs/trigger", json=trigger_payload)
assert response.status_code == 200
job_id = response.json()["job_id"]
# 2. Запрос на отмену
response = await client.post(f"/api/v1/jobs/{job_id}/cancel")
assert response.status_code == 200
cancel_data = response.json()
assert cancel_data["job_id"] == job_id
# Воркер еще не взял задачу, поэтому статус queued, но cancel_requested уже true
assert cancel_data["status"] == "queued"
# 3. Проверка статуса после отмены
response = await client.get(f"/api/v1/jobs/{job_id}/status")
assert response.status_code == 200
status_data = response.json()
assert status_data["job_id"] == job_id
# Задача еще не отменена, а только запрошена к отмене
assert status_data["status"] == "queued"
@pytest.mark.anyio
async def test_trigger_duplicate_idempotency_key(client: AsyncClient):
"""
Тест проверяет, что повторная отправка с тем же ключом идемпотентности
возвращает ту же самую задачу.
"""
idempotency_key = "idem_key_123"
trigger_payload = {
"queue": "idempotent_queue",
"task": "idempotent_task",
"lock_key": "lock_idem",
"idempotency_key": idempotency_key,
}
# 1. Первый запрос
response1 = await client.post("/api/v1/jobs/trigger", json=trigger_payload)
assert response1.status_code == 200
job_id1 = response1.json()["job_id"]
# 2. Второй запрос с тем же ключом
response2 = await client.post("/api/v1/jobs/trigger", json=trigger_payload)
assert response2.status_code == 200
job_id2 = response2.json()["job_id"]
# ID задач должны совпадать
assert job_id1 == job_id2

View File

@ -0,0 +1,155 @@
# tests/integrations/test_worker_protocol.py
import asyncio
from datetime import datetime, timedelta, timezone
from uuid import uuid4
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from dataloader.storage.repositories import QueueRepository, CreateJobRequest
from dataloader.context import APP_CTX
@pytest.mark.anyio
async def test_e2e_worker_protocol_ok(db_session: AsyncSession):
"""
Проверяет полный E2E-сценарий жизненного цикла задачи:
1. Постановка (create)
2. Захват (claim)
3. Пульс (heartbeat)
4. Успешное завершение (finish_ok)
5. Проверка статуса
"""
repo = QueueRepository(db_session)
job_id = str(uuid4())
queue_name = "e2e_ok_queue"
lock_key = f"lock_{job_id}"
# 1. Постановка задачи
create_req = CreateJobRequest(
job_id=job_id,
queue=queue_name,
task="test_e2e_task",
args={},
idempotency_key=None,
lock_key=lock_key,
partition_key="",
priority=100,
available_at=datetime.now(timezone.utc),
max_attempts=3,
lease_ttl_sec=30,
producer=None,
consumer_group=None,
)
await repo.create_or_get(create_req)
# 2. Захват задачи
claimed_job = await repo.claim_one(queue_name, claim_backoff_sec=10)
assert claimed_job is not None
assert claimed_job["job_id"] == job_id
assert claimed_job["lock_key"] == lock_key
# 3. Пульс
success, cancel_requested = await repo.heartbeat(job_id, ttl_sec=60)
assert success
assert not cancel_requested
# 4. Успешное завершение
await repo.finish_ok(job_id)
# 5. Проверка статуса
status = await repo.get_status(job_id)
assert status is not None
assert status.status == "succeeded"
assert status.finished_at is not None
@pytest.mark.anyio
async def test_concurrency_claim_one_locks(db_session: AsyncSession):
"""
Проверяет, что при конкурентном доступе к задачам с одинаковым
lock_key только один воркер может захватить задачу.
"""
repo = QueueRepository(db_session)
queue_name = "concurrency_queue"
lock_key = "concurrent_lock_123"
job_ids = [str(uuid4()), str(uuid4())]
# 1. Создание двух задач с одинаковым lock_key
for i, job_id in enumerate(job_ids):
create_req = CreateJobRequest(
job_id=job_id,
queue=queue_name,
task=f"task_{i}",
args={},
idempotency_key=f"idem_con_{i}",
lock_key=lock_key,
partition_key="",
priority=100 + i,
available_at=datetime.now(timezone.utc),
max_attempts=1,
lease_ttl_sec=30,
producer="test",
consumer_group="test_group",
)
await repo.create_or_get(create_req)
# 2. Первый воркер захватывает задачу
claimed_job_1 = await repo.claim_one(queue_name, claim_backoff_sec=1)
assert claimed_job_1 is not None
assert claimed_job_1["job_id"] == job_ids[0]
# 3. Второй воркер пытается захватить задачу, но не может (из-за advisory lock)
claimed_job_2 = await repo.claim_one(queue_name, claim_backoff_sec=1)
assert claimed_job_2 is None
# 4. Первый воркер освобождает advisory lock (как будто завершил работу)
await repo._advisory_unlock(lock_key)
# 5. Второй воркер теперь может захватить вторую задачу
claimed_job_3 = await repo.claim_one(queue_name, claim_backoff_sec=1)
assert claimed_job_3 is not None
assert claimed_job_3["job_id"] == job_ids[1]
@pytest.mark.anyio
async def test_reaper_requeues_lost_jobs(db_session: AsyncSession):
"""
Проверяет, что reaper корректно возвращает "потерянные" задачи в очередь.
"""
repo = QueueRepository(db_session)
job_id = str(uuid4())
queue_name = "reaper_queue"
# 1. Создаем и захватываем задачу
create_req = CreateJobRequest(
job_id=job_id,
queue=queue_name,
task="reaper_test_task",
args={},
idempotency_key="idem_reaper_1",
lock_key="reaper_lock_1",
partition_key="",
priority=100,
available_at=datetime.now(timezone.utc),
max_attempts=3,
lease_ttl_sec=1, # Очень короткий lease
producer=None,
consumer_group=None,
)
await repo.create_or_get(create_req)
claimed_job = await repo.claim_one(queue_name, claim_backoff_sec=1)
assert claimed_job is not None
assert claimed_job["job_id"] == job_id
# 2. Ждем истечения lease
await asyncio.sleep(2)
# 3. Запускаем reaper
requeued_ids = await repo.requeue_lost()
assert requeued_ids == [job_id]
# 4. Проверяем статус
status = await repo.get_status(job_id)
assert status is not None
assert status.status == "queued"