roocode: try fixes
This commit is contained in:
parent
7a76dc1d84
commit
7152f4b61e
|
|
@ -14,18 +14,18 @@ from dataloader.api.v1.schemas import (
|
|||
TriggerJobResponse,
|
||||
)
|
||||
from dataloader.api.v1.service import JobsService
|
||||
from dataloader.storage.db import session_scope
|
||||
from dataloader.context import get_session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
|
||||
|
||||
async def get_service() -> AsyncGenerator[JobsService, None]:
|
||||
def get_service(session: Annotated[AsyncSession, Depends(get_session)]) -> JobsService:
|
||||
"""
|
||||
Создаёт JobsService с новой сессией и корректно закрывает её после запроса.
|
||||
FastAPI dependency to create a JobsService instance with a database session.
|
||||
"""
|
||||
async for s in session_scope():
|
||||
yield JobsService(s)
|
||||
return JobsService(session)
|
||||
|
||||
|
||||
@router.post("/trigger", response_model=TriggerJobResponse, status_code=HTTPStatus.OK)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from dataloader.storage.repositories import (
|
|||
CreateJobRequest,
|
||||
QueueRepository,
|
||||
)
|
||||
from dataloader.context import APP_CTX
|
||||
from dataloader.logger.logger import get_logger
|
||||
|
||||
|
||||
class JobsService:
|
||||
|
|
@ -27,7 +27,7 @@ class JobsService:
|
|||
def __init__(self, session: AsyncSession):
|
||||
self._s = session
|
||||
self._repo = QueueRepository(self._s)
|
||||
self._log = APP_CTX.get_logger()
|
||||
self._log = get_logger(__name__)
|
||||
|
||||
async def trigger(self, req: TriggerJobRequest) -> TriggerJobResponse:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,124 +1,59 @@
|
|||
# src/dataloader/context.py
|
||||
from __future__ import annotations
|
||||
from typing import AsyncGenerator
|
||||
from logging import Logger
|
||||
|
||||
import typing
|
||||
import pytz
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
from dataloader.base import Singleton
|
||||
from dataloader.config import APP_CONFIG, Secrets
|
||||
from dataloader.logger import ContextVarsContainer, LoggerConfigurator
|
||||
from sqlalchemy import event, select, func, text
|
||||
from .config import APP_CONFIG
|
||||
from .logger.context_vars import ContextVarsContainer
|
||||
|
||||
|
||||
class AppContext(metaclass=Singleton):
|
||||
"""
|
||||
Контекст приложения: логгер, таймзона, подключение к БД и фабрика сессий.
|
||||
"""
|
||||
def __init__(self, secrets: Secrets) -> None:
|
||||
self.timezone = pytz.timezone(secrets.app.timezone)
|
||||
self.context_vars_container = ContextVarsContainer()
|
||||
self._logger_manager = LoggerConfigurator(
|
||||
log_lvl=secrets.log.log_lvl,
|
||||
log_file_path=secrets.log.log_file_abs_path,
|
||||
metric_file_path=secrets.log.metric_file_abs_path,
|
||||
audit_file_path=secrets.log.audit_file_abs_path,
|
||||
audit_host_ip=secrets.log.audit_host_ip,
|
||||
audit_host_uid=secrets.log.audit_host_uid,
|
||||
context_vars_container=self.context_vars_container,
|
||||
timezone=self.timezone,
|
||||
)
|
||||
self.pg = secrets.pg
|
||||
self.dl = secrets.dl
|
||||
class AppContext:
|
||||
def __init__(self) -> None:
|
||||
self._engine: AsyncEngine | None = None
|
||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = None
|
||||
self.logger.info("App context initialized.")
|
||||
|
||||
@property
|
||||
def logger(self) -> "typing.Any":
|
||||
"""
|
||||
Возвращает асинхронный логгер.
|
||||
"""
|
||||
return self._logger_manager.async_logger
|
||||
self._context_vars_container = ContextVarsContainer()
|
||||
|
||||
@property
|
||||
def engine(self) -> AsyncEngine:
|
||||
"""
|
||||
Возвращает текущий AsyncEngine.
|
||||
"""
|
||||
assert self._engine is not None, "Engine is not initialized"
|
||||
if self._engine is None:
|
||||
raise RuntimeError("Database engine is not initialized.")
|
||||
return self._engine
|
||||
|
||||
@property
|
||||
def sessionmaker(self) -> async_sessionmaker[AsyncSession]:
|
||||
"""
|
||||
Возвращает фабрику асинхронных сессий.
|
||||
"""
|
||||
assert self._sessionmaker is not None, "Sessionmaker is not initialized"
|
||||
if self._sessionmaker is None:
|
||||
raise RuntimeError("Sessionmaker is not initialized.")
|
||||
return self._sessionmaker
|
||||
|
||||
def get_logger(self) -> "typing.Any":
|
||||
"""
|
||||
Возвращает логгер.
|
||||
"""
|
||||
return self.logger
|
||||
def startup(self) -> None:
|
||||
from .storage.db import create_engine, create_sessionmaker
|
||||
from .logger.logger import setup_logging
|
||||
|
||||
setup_logging()
|
||||
self._engine = create_engine(APP_CONFIG.pg.url)
|
||||
self._sessionmaker = create_sessionmaker(self._engine)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._engine:
|
||||
await self._engine.dispose()
|
||||
|
||||
def get_logger(self, name: str | None = None) -> Logger:
|
||||
"""Returns a configured logger instance."""
|
||||
from .logger.logger import get_logger as get_app_logger
|
||||
return get_app_logger(name)
|
||||
|
||||
def get_context_vars_container(self) -> ContextVarsContainer:
|
||||
return self._context_vars_container
|
||||
|
||||
|
||||
APP_CTX = AppContext()
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Возвращает контейнер контекстных переменных логгера.
|
||||
FastAPI dependency to get a database session.
|
||||
Yields a session from the global sessionmaker and ensures it's closed.
|
||||
"""
|
||||
return self.context_vars_container
|
||||
|
||||
def get_pytz_timezone(self):
|
||||
"""
|
||||
Возвращает таймзону приложения.
|
||||
"""
|
||||
return self.timezone
|
||||
|
||||
async def on_startup(self) -> None:
|
||||
"""
|
||||
Инициализирует подключение к БД и готовит фабрику сессий.
|
||||
"""
|
||||
self.logger.info("Application is starting up.")
|
||||
dsn = APP_CONFIG.resolved_dsn
|
||||
self._engine = create_async_engine(
|
||||
dsn,
|
||||
pool_size=self.pg.pool_size if self.pg.use_pool else None,
|
||||
max_overflow=self.pg.max_overflow if self.pg.use_pool else 0,
|
||||
pool_recycle=self.pg.pool_recycle if self.pg.use_pool else -1,
|
||||
)
|
||||
|
||||
schema = self.pg.schema_.strip()
|
||||
|
||||
@event.listens_for(self._engine.sync_engine, "connect")
|
||||
def _set_search_path(dbapi_conn, _):
|
||||
cur = dbapi_conn.cursor()
|
||||
try:
|
||||
if schema:
|
||||
cur.execute(f'SET SESSION search_path TO "{schema}"')
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
self._sessionmaker = async_sessionmaker(self._engine, expire_on_commit=False, class_=AsyncSession)
|
||||
|
||||
async with self._sessionmaker() as s:
|
||||
await s.execute(select(func.count()).select_from(text("dl_jobs")))
|
||||
await s.execute(select(func.count()).select_from(text("dl_job_events")))
|
||||
|
||||
self.logger.info("All connections checked. Application is up and ready.")
|
||||
|
||||
async def on_shutdown(self) -> None:
|
||||
"""
|
||||
Останавливает подсистемы и освобождает ресурсы.
|
||||
"""
|
||||
self.logger.info("Application is shutting down.")
|
||||
if self._engine is not None:
|
||||
await self._engine.dispose()
|
||||
self._engine = None
|
||||
self._sessionmaker = None
|
||||
self._logger_manager.remove_logger_handlers()
|
||||
|
||||
|
||||
APP_CTX = AppContext(APP_CONFIG)
|
||||
|
||||
__all__ = ["APP_CTX"]
|
||||
async with APP_CTX.sessionmaker() as session:
|
||||
yield session
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# logger package
|
||||
from .context_vars import ContextVarsContainer
|
||||
from .logger import LoggerConfigurator
|
||||
# src/dataloader/logger/__init__.py
|
||||
from .logger import setup_logging, get_logger
|
||||
|
||||
__all__ = ["ContextVarsContainer", "LoggerConfigurator"]
|
||||
__all__ = ["setup_logging", "get_logger"]
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
# Основной логгер приложения
|
||||
# src/dataloader/logger/logger.py
|
||||
import sys
|
||||
import typing
|
||||
from datetime import tzinfo
|
||||
|
||||
from logging import Logger
|
||||
import logging
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
from .context_vars import ContextVarsContainer
|
||||
|
||||
from dataloader.config import APP_CONFIG
|
||||
|
||||
|
||||
# Определяем фильтры для разных типов логов
|
||||
|
|
@ -16,131 +16,44 @@ def metric_only_filter(record: dict) -> bool:
|
|||
return "metric" in record["extra"]
|
||||
|
||||
|
||||
|
||||
def audit_only_filter(record: dict) -> bool:
|
||||
return "audit" in record["extra"]
|
||||
|
||||
|
||||
|
||||
def regular_log_filter(record: dict) -> bool:
|
||||
return "metric" not in record["extra"] and "audit" not in record["extra"]
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# Get corresponding Loguru level if it exists
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
class LoggerConfigurator:
|
||||
def __init__(
|
||||
self,
|
||||
log_lvl: str,
|
||||
log_file_path: str,
|
||||
metric_file_path: str,
|
||||
audit_file_path: str,
|
||||
audit_host_ip: str,
|
||||
audit_host_uid: str,
|
||||
context_vars_container: ContextVarsContainer,
|
||||
timezone: tzinfo,
|
||||
) -> None:
|
||||
self.context_vars_container = context_vars_container
|
||||
self.timezone = timezone
|
||||
self.log_lvl = log_lvl
|
||||
self.log_file_path = log_file_path
|
||||
self.metric_file_path = metric_file_path
|
||||
self.audit_file_path = audit_file_path
|
||||
self.audit_host_ip = audit_host_ip
|
||||
self.audit_host_uid = audit_host_uid
|
||||
self._handler_ids = []
|
||||
self.configure_logger()
|
||||
# Find caller from where originated the logged message
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def async_logger(self) -> "typing.Any":
|
||||
return self._async_logger
|
||||
|
||||
|
||||
def patch_record_with_context(self, record: dict) -> None:
|
||||
context_data = self.context_vars_container.as_dict()
|
||||
record["extra"].update(context_data)
|
||||
if not record["extra"].get("request_id"):
|
||||
record["extra"]["request_id"] = "system_event"
|
||||
|
||||
|
||||
def configure_logger(self) -> None:
|
||||
def setup_logging():
|
||||
"""Настройка логгера `loguru` с необходимыми обработчиками."""
|
||||
logger.remove()
|
||||
logger.patch(self.patch_record_with_context)
|
||||
|
||||
# Функция для безопасного форматирования консольных логов
|
||||
def console_format(record):
|
||||
request_id = record["extra"].get("request_id", "system_event")
|
||||
elapsed = record["elapsed"]
|
||||
level = record["level"].name
|
||||
name = record["name"]
|
||||
function = record["function"]
|
||||
line = record["line"]
|
||||
message = record["message"]
|
||||
time_str = record["time"].strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
||||
|
||||
return (
|
||||
f"<green>{time_str} ({elapsed})</green> | "
|
||||
f"<cyan>{request_id}</cyan> | "
|
||||
f"<level>{level: <8}</level> | "
|
||||
f"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
||||
f"<level>{message}</level>\n"
|
||||
)
|
||||
|
||||
# Обработчик для обычных логов (консоль)
|
||||
handler_id = logger.add(
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=self.log_lvl,
|
||||
level=APP_CONFIG.log.log_level.upper(),
|
||||
filter=regular_log_filter,
|
||||
format=console_format,
|
||||
colorize=True,
|
||||
)
|
||||
self._handler_ids.append(handler_id)
|
||||
logging.basicConfig(handlers=[InterceptHandler()], level=0)
|
||||
|
||||
|
||||
# Обработчик для обычных логов (файл)
|
||||
handler_id = logger.add(
|
||||
self.log_file_path,
|
||||
level=self.log_lvl,
|
||||
filter=regular_log_filter,
|
||||
rotation="10 MB",
|
||||
compression="zip",
|
||||
enqueue=True,
|
||||
serialize=True,
|
||||
)
|
||||
self._handler_ids.append(handler_id)
|
||||
|
||||
|
||||
# Обработчик для метрик
|
||||
handler_id = logger.add(
|
||||
self.metric_file_path,
|
||||
level="INFO",
|
||||
filter=metric_only_filter,
|
||||
rotation="10 MB",
|
||||
compression="zip",
|
||||
enqueue=True,
|
||||
serialize=True,
|
||||
)
|
||||
self._handler_ids.append(handler_id)
|
||||
|
||||
|
||||
# Обработчик для аудита
|
||||
handler_id = logger.add(
|
||||
self.audit_file_path,
|
||||
level="INFO",
|
||||
filter=audit_only_filter,
|
||||
rotation="10 MB",
|
||||
compression="zip",
|
||||
enqueue=True,
|
||||
serialize=True,
|
||||
)
|
||||
self._handler_ids.append(handler_id)
|
||||
|
||||
|
||||
self._async_logger = logger
|
||||
|
||||
|
||||
def remove_logger_handlers(self) -> None:
|
||||
"""Удаление всех обработчиков логгера."""
|
||||
for handler_id in self._handler_ids:
|
||||
self._async_logger.remove(handler_id)
|
||||
def get_logger(name: str | None = None) -> Logger:
|
||||
return logging.getLogger(name or "dataloader")
|
||||
|
|
|
|||
|
|
@ -117,6 +117,7 @@ class QueueRepository:
|
|||
"""
|
||||
Идемпотентно создаёт запись в очереди и возвращает (job_id, status).
|
||||
"""
|
||||
async with self.s.begin():
|
||||
if req.idempotency_key:
|
||||
q = select(DLJob).where(DLJob.idempotency_key == req.idempotency_key)
|
||||
r = await self.s.execute(q)
|
||||
|
|
@ -151,7 +152,6 @@ class QueueRepository:
|
|||
)
|
||||
self.s.add(row)
|
||||
await self._append_event(req.job_id, req.queue, "queued", {"task": req.task})
|
||||
await self.s.commit()
|
||||
return req.job_id, "queued"
|
||||
|
||||
async def get_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
|
|
@ -187,12 +187,12 @@ class QueueRepository:
|
|||
"""
|
||||
Устанавливает флаг отмены для задачи.
|
||||
"""
|
||||
async with self.s.begin():
|
||||
job = await self._get(job_id)
|
||||
if not job:
|
||||
return False
|
||||
job.cancel_requested = True
|
||||
await self._append_event(job_id, job.queue, "cancel_requested", None)
|
||||
await self.s.commit()
|
||||
return True
|
||||
|
||||
async def claim_one(self, queue: str, claim_backoff_sec: int) -> Optional[dict[str, Any]]:
|
||||
|
|
@ -246,6 +246,7 @@ class QueueRepository:
|
|||
Обновляет heartbeat и продлевает lease.
|
||||
Возвращает (success, cancel_requested).
|
||||
"""
|
||||
async with self.s.begin():
|
||||
job = await self._get(job_id)
|
||||
if not job or job.status != "running":
|
||||
return False, False
|
||||
|
|
@ -259,13 +260,13 @@ class QueueRepository:
|
|||
)
|
||||
await self.s.execute(q)
|
||||
await self._append_event(job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec})
|
||||
await self.s.commit()
|
||||
return True, cancel_requested
|
||||
|
||||
async def finish_ok(self, job_id: str) -> None:
|
||||
"""
|
||||
Помечает задачу как выполненную успешно и снимает advisory-lock.
|
||||
"""
|
||||
async with self.s.begin():
|
||||
job = await self._get(job_id)
|
||||
if not job:
|
||||
return
|
||||
|
|
@ -274,12 +275,12 @@ class QueueRepository:
|
|||
job.lease_expires_at = None
|
||||
await self._append_event(job_id, job.queue, "succeeded", None)
|
||||
await self._advisory_unlock(job.lock_key)
|
||||
await self.s.commit()
|
||||
|
||||
async def finish_fail_or_retry(self, job_id: str, err: str, is_canceled: bool = False) -> None:
|
||||
"""
|
||||
Помечает задачу как failed, canceled или возвращает в очередь с задержкой.
|
||||
"""
|
||||
async with self.s.begin():
|
||||
job = await self._get(job_id)
|
||||
if not job:
|
||||
return
|
||||
|
|
@ -305,7 +306,6 @@ class QueueRepository:
|
|||
job.lease_expires_at = None
|
||||
await self._append_event(job_id, job.queue, "failed", {"error": err})
|
||||
await self._advisory_unlock(job.lock_key)
|
||||
await self.s.commit()
|
||||
|
||||
async def requeue_lost(self, now: Optional[datetime] = None) -> list[str]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,179 @@
|
|||
# tests/conftest.py
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
from asyncpg.exceptions import InvalidCatalogNameError
|
||||
|
||||
# Load .env before other imports to ensure config is available
|
||||
load_dotenv()
|
||||
|
||||
from dataloader.api import app_main
|
||||
from dataloader.api.v1.router import get_service
|
||||
from dataloader.api.v1.service import JobsService
|
||||
from dataloader.context import APP_CTX
|
||||
from dataloader.storage.db import Base
|
||||
from dataloader.config import APP_CONFIG
|
||||
|
||||
# For Windows: use SelectorEventLoop which is more stable
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def db_engine() -> AsyncGenerator[AsyncEngine, None]:
|
||||
"""
|
||||
Creates a temporary, isolated test database for the entire test session.
|
||||
- Connects to the default 'postgres' database to create a new test DB.
|
||||
- Yields an engine connected to the new test DB.
|
||||
- Drops the test DB after the session is complete.
|
||||
"""
|
||||
pg_settings = APP_CONFIG.pg
|
||||
test_db_name = f"{pg_settings.database}_test"
|
||||
|
||||
# DSN for connecting to 'postgres' DB to manage the test DB
|
||||
system_dsn = pg_settings.url.replace(f"/{pg_settings.database}", "/postgres")
|
||||
system_engine = create_async_engine(system_dsn, isolation_level="AUTOCOMMIT")
|
||||
|
||||
# DSN for the new test database
|
||||
test_dsn = pg_settings.url.replace(f"/{pg_settings.database}", f"/{test_db_name}")
|
||||
|
||||
async with system_engine.connect() as conn:
|
||||
# Drop the test DB if it exists from a previous failed run
|
||||
await conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name} WITH (FORCE)"))
|
||||
await conn.execute(text(f"CREATE DATABASE {test_db_name}"))
|
||||
|
||||
# Create an engine connected to the new test database
|
||||
engine = create_async_engine(test_dsn)
|
||||
|
||||
# Create all tables and DDL objects
|
||||
async with engine.begin() as conn:
|
||||
# Execute DDL statements one by one
|
||||
await conn.execute(text("CREATE TYPE dl_status AS ENUM ('queued','running','succeeded','failed','canceled','lost')"))
|
||||
|
||||
# Create tables
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Create indexes
|
||||
await conn.execute(text("CREATE INDEX ix_dl_jobs_claim ON dl_jobs(queue, available_at, priority, created_at) WHERE status = 'queued'"))
|
||||
await conn.execute(text("CREATE INDEX ix_dl_jobs_running_lease ON dl_jobs(lease_expires_at) WHERE status = 'running'"))
|
||||
await conn.execute(text("CREATE INDEX ix_dl_jobs_status_queue ON dl_jobs(status, queue)"))
|
||||
|
||||
# Create function and triggers
|
||||
await conn.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION notify_job_ready() RETURNS trigger AS $$
|
||||
BEGIN
|
||||
IF (TG_OP = 'INSERT') THEN
|
||||
PERFORM pg_notify('dl_jobs', NEW.queue);
|
||||
RETURN NEW;
|
||||
ELSIF (TG_OP = 'UPDATE') THEN
|
||||
IF NEW.status = 'queued' AND NEW.available_at <= now()
|
||||
AND (OLD.status IS DISTINCT FROM NEW.status OR OLD.available_at IS DISTINCT FROM NEW.available_at) THEN
|
||||
PERFORM pg_notify('dl_jobs', NEW.queue);
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END $$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
)
|
||||
await conn.execute(text("CREATE TRIGGER dl_jobs_notify_ins AFTER INSERT ON dl_jobs FOR EACH ROW EXECUTE FUNCTION notify_job_ready()"))
|
||||
await conn.execute(text("CREATE TRIGGER dl_jobs_notify_upd AFTER UPDATE OF status, available_at ON dl_jobs FOR EACH ROW EXECUTE FUNCTION notify_job_ready()"))
|
||||
|
||||
yield engine
|
||||
|
||||
# --- Teardown ---
|
||||
await engine.dispose()
|
||||
|
||||
# Drop the test database
|
||||
async with system_engine.connect() as conn:
|
||||
await conn.execute(text(f"DROP DATABASE {test_db_name} WITH (FORCE)"))
|
||||
await system_engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def app_context(db_engine: AsyncEngine) -> AsyncGenerator[None, None]:
|
||||
"""
|
||||
Overrides the global APP_CTX with the test database engine and sessionmaker.
|
||||
This ensures that the app, when tested, uses the isolated test DB.
|
||||
"""
|
||||
original_engine = APP_CTX._engine
|
||||
original_sessionmaker = APP_CTX._sessionmaker
|
||||
|
||||
APP_CTX._engine = db_engine
|
||||
APP_CTX._sessionmaker = async_sessionmaker(db_engine, expire_on_commit=False, class_=AsyncSession)
|
||||
|
||||
yield
|
||||
|
||||
# Restore original context
|
||||
APP_CTX._engine = original_engine
|
||||
APP_CTX._sessionmaker = original_sessionmaker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Provides a transactional session for tests.
|
||||
A single connection is acquired from the engine, a transaction is started,
|
||||
and a new session is created bound to that connection. At the end of the test,
|
||||
the transaction is rolled back and the connection is closed, ensuring
|
||||
perfect test isolation.
|
||||
"""
|
||||
connection = await db_engine.connect()
|
||||
trans = await connection.begin()
|
||||
|
||||
# Create a sessionmaker bound to the single connection
|
||||
TestSession = async_sessionmaker(
|
||||
bind=connection,
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession
|
||||
)
|
||||
session = TestSession()
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Clean up the session, rollback the transaction, and close the connection
|
||||
await session.close()
|
||||
await trans.rollback()
|
||||
await connection.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""
|
||||
Provides an HTTP client for the FastAPI application.
|
||||
It overrides the 'get_session' dependency to use the test-scoped,
|
||||
transactional session provided by the 'db_session' fixture. This ensures
|
||||
that API calls operate within the same transaction as the test function,
|
||||
allowing for consistent state checks.
|
||||
"""
|
||||
async def override_get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""This override simply yields the session created by the db_session fixture."""
|
||||
yield db_session
|
||||
|
||||
# The service depends on the session, so we override the session provider
|
||||
# that is used by the service via Depends(get_session)
|
||||
from dataloader.context import get_session
|
||||
app_main.dependency_overrides[get_session] = override_get_session
|
||||
|
||||
transport = ASGITransport(app=app_main)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
# Clean up the dependency override
|
||||
app_main.dependency_overrides.clear()
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
# tests/integrations/test_api_v1.py
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_trigger_and_get_status_ok(client: AsyncClient):
|
||||
"""
|
||||
Тест проверяет успешное создание задачи и получение её статуса.
|
||||
"""
|
||||
# 1. Триггер задачи
|
||||
trigger_payload = {
|
||||
"queue": "test_queue",
|
||||
"task": "test_task",
|
||||
"lock_key": "lock_123",
|
||||
}
|
||||
response = await client.post("/api/v1/jobs/trigger", json=trigger_payload)
|
||||
|
||||
# Проверки ответа на триггер
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert "job_id" in response_data
|
||||
assert response_data["status"] == "queued"
|
||||
job_id = response_data["job_id"]
|
||||
|
||||
# 2. Получение статуса
|
||||
response = await client.get(f"/api/v1/jobs/{job_id}/status")
|
||||
|
||||
# Проверки ответа на статус
|
||||
assert response.status_code == 200
|
||||
status_data = response.json()
|
||||
assert status_data["job_id"] == job_id
|
||||
assert status_data["status"] == "queued"
|
||||
assert status_data["attempt"] == 0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cancel_job_ok(client: AsyncClient):
|
||||
"""
|
||||
Тест проверяет успешную отмену задачи.
|
||||
"""
|
||||
# 1. Триггер задачи
|
||||
trigger_payload = {
|
||||
"queue": "cancel_queue",
|
||||
"task": "cancel_task",
|
||||
"lock_key": "lock_cancel",
|
||||
}
|
||||
response = await client.post("/api/v1/jobs/trigger", json=trigger_payload)
|
||||
assert response.status_code == 200
|
||||
job_id = response.json()["job_id"]
|
||||
|
||||
# 2. Запрос на отмену
|
||||
response = await client.post(f"/api/v1/jobs/{job_id}/cancel")
|
||||
assert response.status_code == 200
|
||||
cancel_data = response.json()
|
||||
assert cancel_data["job_id"] == job_id
|
||||
# Воркер еще не взял задачу, поэтому статус queued, но cancel_requested уже true
|
||||
assert cancel_data["status"] == "queued"
|
||||
|
||||
# 3. Проверка статуса после отмены
|
||||
response = await client.get(f"/api/v1/jobs/{job_id}/status")
|
||||
assert response.status_code == 200
|
||||
status_data = response.json()
|
||||
assert status_data["job_id"] == job_id
|
||||
# Задача еще не отменена, а только запрошена к отмене
|
||||
assert status_data["status"] == "queued"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_trigger_duplicate_idempotency_key(client: AsyncClient):
|
||||
"""
|
||||
Тест проверяет, что повторная отправка с тем же ключом идемпотентности
|
||||
возвращает ту же самую задачу.
|
||||
"""
|
||||
idempotency_key = "idem_key_123"
|
||||
trigger_payload = {
|
||||
"queue": "idempotent_queue",
|
||||
"task": "idempotent_task",
|
||||
"lock_key": "lock_idem",
|
||||
"idempotency_key": idempotency_key,
|
||||
}
|
||||
|
||||
# 1. Первый запрос
|
||||
response1 = await client.post("/api/v1/jobs/trigger", json=trigger_payload)
|
||||
assert response1.status_code == 200
|
||||
job_id1 = response1.json()["job_id"]
|
||||
|
||||
# 2. Второй запрос с тем же ключом
|
||||
response2 = await client.post("/api/v1/jobs/trigger", json=trigger_payload)
|
||||
assert response2.status_code == 200
|
||||
job_id2 = response2.json()["job_id"]
|
||||
|
||||
# ID задач должны совпадать
|
||||
assert job_id1 == job_id2
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
# tests/integrations/test_worker_protocol.py
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from dataloader.storage.repositories import QueueRepository, CreateJobRequest
|
||||
from dataloader.context import APP_CTX
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_e2e_worker_protocol_ok(db_session: AsyncSession):
|
||||
"""
|
||||
Проверяет полный E2E-сценарий жизненного цикла задачи:
|
||||
1. Постановка (create)
|
||||
2. Захват (claim)
|
||||
3. Пульс (heartbeat)
|
||||
4. Успешное завершение (finish_ok)
|
||||
5. Проверка статуса
|
||||
"""
|
||||
repo = QueueRepository(db_session)
|
||||
job_id = str(uuid4())
|
||||
queue_name = "e2e_ok_queue"
|
||||
lock_key = f"lock_{job_id}"
|
||||
|
||||
# 1. Постановка задачи
|
||||
create_req = CreateJobRequest(
|
||||
job_id=job_id,
|
||||
queue=queue_name,
|
||||
task="test_e2e_task",
|
||||
args={},
|
||||
idempotency_key=None,
|
||||
lock_key=lock_key,
|
||||
partition_key="",
|
||||
priority=100,
|
||||
available_at=datetime.now(timezone.utc),
|
||||
max_attempts=3,
|
||||
lease_ttl_sec=30,
|
||||
producer=None,
|
||||
consumer_group=None,
|
||||
)
|
||||
await repo.create_or_get(create_req)
|
||||
|
||||
# 2. Захват задачи
|
||||
claimed_job = await repo.claim_one(queue_name, claim_backoff_sec=10)
|
||||
assert claimed_job is not None
|
||||
assert claimed_job["job_id"] == job_id
|
||||
assert claimed_job["lock_key"] == lock_key
|
||||
|
||||
# 3. Пульс
|
||||
success, cancel_requested = await repo.heartbeat(job_id, ttl_sec=60)
|
||||
assert success
|
||||
assert not cancel_requested
|
||||
|
||||
# 4. Успешное завершение
|
||||
await repo.finish_ok(job_id)
|
||||
|
||||
# 5. Проверка статуса
|
||||
status = await repo.get_status(job_id)
|
||||
assert status is not None
|
||||
assert status.status == "succeeded"
|
||||
assert status.finished_at is not None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_concurrency_claim_one_locks(db_session: AsyncSession):
|
||||
"""
|
||||
Проверяет, что при конкурентном доступе к задачам с одинаковым
|
||||
lock_key только один воркер может захватить задачу.
|
||||
"""
|
||||
repo = QueueRepository(db_session)
|
||||
queue_name = "concurrency_queue"
|
||||
lock_key = "concurrent_lock_123"
|
||||
job_ids = [str(uuid4()), str(uuid4())]
|
||||
|
||||
# 1. Создание двух задач с одинаковым lock_key
|
||||
for i, job_id in enumerate(job_ids):
|
||||
create_req = CreateJobRequest(
|
||||
job_id=job_id,
|
||||
queue=queue_name,
|
||||
task=f"task_{i}",
|
||||
args={},
|
||||
idempotency_key=f"idem_con_{i}",
|
||||
lock_key=lock_key,
|
||||
partition_key="",
|
||||
priority=100 + i,
|
||||
available_at=datetime.now(timezone.utc),
|
||||
max_attempts=1,
|
||||
lease_ttl_sec=30,
|
||||
producer="test",
|
||||
consumer_group="test_group",
|
||||
)
|
||||
await repo.create_or_get(create_req)
|
||||
|
||||
# 2. Первый воркер захватывает задачу
|
||||
claimed_job_1 = await repo.claim_one(queue_name, claim_backoff_sec=1)
|
||||
assert claimed_job_1 is not None
|
||||
assert claimed_job_1["job_id"] == job_ids[0]
|
||||
|
||||
# 3. Второй воркер пытается захватить задачу, но не может (из-за advisory lock)
|
||||
claimed_job_2 = await repo.claim_one(queue_name, claim_backoff_sec=1)
|
||||
assert claimed_job_2 is None
|
||||
|
||||
# 4. Первый воркер освобождает advisory lock (как будто завершил работу)
|
||||
await repo._advisory_unlock(lock_key)
|
||||
|
||||
# 5. Второй воркер теперь может захватить вторую задачу
|
||||
claimed_job_3 = await repo.claim_one(queue_name, claim_backoff_sec=1)
|
||||
assert claimed_job_3 is not None
|
||||
assert claimed_job_3["job_id"] == job_ids[1]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_reaper_requeues_lost_jobs(db_session: AsyncSession):
|
||||
"""
|
||||
Проверяет, что reaper корректно возвращает "потерянные" задачи в очередь.
|
||||
"""
|
||||
repo = QueueRepository(db_session)
|
||||
job_id = str(uuid4())
|
||||
queue_name = "reaper_queue"
|
||||
|
||||
# 1. Создаем и захватываем задачу
|
||||
create_req = CreateJobRequest(
|
||||
job_id=job_id,
|
||||
queue=queue_name,
|
||||
task="reaper_test_task",
|
||||
args={},
|
||||
idempotency_key="idem_reaper_1",
|
||||
lock_key="reaper_lock_1",
|
||||
partition_key="",
|
||||
priority=100,
|
||||
available_at=datetime.now(timezone.utc),
|
||||
max_attempts=3,
|
||||
lease_ttl_sec=1, # Очень короткий lease
|
||||
producer=None,
|
||||
consumer_group=None,
|
||||
)
|
||||
await repo.create_or_get(create_req)
|
||||
|
||||
claimed_job = await repo.claim_one(queue_name, claim_backoff_sec=1)
|
||||
assert claimed_job is not None
|
||||
assert claimed_job["job_id"] == job_id
|
||||
|
||||
# 2. Ждем истечения lease
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# 3. Запускаем reaper
|
||||
requeued_ids = await repo.requeue_lost()
|
||||
assert requeued_ids == [job_id]
|
||||
|
||||
# 4. Проверяем статус
|
||||
status = await repo.get_status(job_id)
|
||||
assert status is not None
|
||||
assert status.status == "queued"
|
||||
Loading…
Reference in New Issue