diff --git a/pyproject.toml b/pyproject.toml index 74ed051..6c67d3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,3 +34,37 @@ dataloader = "dataloader.__main__:main" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + + +[tool.black] +line-length = 88 +target-version = ["py311"] +skip-string-normalization = false +preview = true + +[tool.isort] +profile = "black" +line_length = 88 +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +combine_as_imports = true +known_first_party = ["dataloader"] +src_paths = ["src"] + +[tool.ruff] +line-length = 88 +target-version = "py311" +fix = true +lint.select = ["E", "F", "W", "I", "N", "B"] +lint.ignore = [ + "E501", +] +exclude = [ + ".git", + "__pycache__", + "build", + "dist", + ".venv", + "venv", +] \ No newline at end of file diff --git a/src/dataloader/__init__.py b/src/dataloader/__init__.py index b64b871..a1be0f2 100644 --- a/src/dataloader/__init__.py +++ b/src/dataloader/__init__.py @@ -5,7 +5,6 @@ from . import api - __all__ = [ "api", ] diff --git a/src/dataloader/__main__.py b/src/dataloader/__main__.py index fae0177..b84b9e9 100644 --- a/src/dataloader/__main__.py +++ b/src/dataloader/__main__.py @@ -1,15 +1,17 @@ import uvicorn - from dataloader.api import app_main from dataloader.config import APP_CONFIG -from dataloader.logger.uvicorn_logging_config import LOGGING_CONFIG, setup_uvicorn_logging +from dataloader.logger.uvicorn_logging_config import ( + LOGGING_CONFIG, + setup_uvicorn_logging, +) def main() -> None: - # Инициализируем логирование uvicorn перед запуском + setup_uvicorn_logging() - + uvicorn.run( app_main, host=APP_CONFIG.app.app_host, diff --git a/src/dataloader/api/__init__.py b/src/dataloader/api/__init__.py index dd00155..8a4dba2 100644 --- a/src/dataloader/api/__init__.py +++ b/src/dataloader/api/__init__.py @@ -1,20 +1,19 @@ -# src/dataloader/api/__init__.py from __future__ import annotations -from collections.abc import AsyncGenerator import contextlib import typing as tp +from collections.abc import AsyncGenerator from fastapi import FastAPI +from dataloader.context import APP_CTX +from dataloader.workers.manager import WorkerManager, build_manager_from_env +from dataloader.workers.pipelines import load_all as load_pipelines + from .metric_router import router as metric_router from .middleware import log_requests from .os_router import router as service_router from .v1 import router as v1_router -from dataloader.context import APP_CTX -from dataloader.workers.manager import build_manager_from_env, WorkerManager -from dataloader.workers.pipelines import load_all as load_pipelines - _manager: WorkerManager | None = None diff --git a/src/dataloader/api/metric_router.py b/src/dataloader/api/metric_router.py index c9fa14a..4f28bcc 100644 --- a/src/dataloader/api/metric_router.py +++ b/src/dataloader/api/metric_router.py @@ -1,10 +1,11 @@ -""" 🚨 НЕ РЕДАКТИРОВАТЬ !!!!!! -""" +"""🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!""" import uuid from fastapi import APIRouter, Header, status + from dataloader.context import APP_CTX + from . import schemas router = APIRouter() @@ -17,8 +18,7 @@ logger = APP_CTX.get_logger() response_model=schemas.RateResponse, ) async def like( - # pylint: disable=C0103,W0613 - header_Request_Id: str = Header(uuid.uuid4(), alias="Request-Id") + header_request_id: str = Header(uuid.uuid4(), alias="Request-Id") ) -> dict[str, str]: logger.metric( metric_name="dataloader_likes_total", @@ -33,8 +33,7 @@ async def like( response_model=schemas.RateResponse, ) async def dislike( - # pylint: disable=C0103,W0613 - header_Request_Id: str = Header(uuid.uuid4(), alias="Request-Id") + header_request_id: str = Header(uuid.uuid4(), alias="Request-Id") ) -> dict[str, str]: logger.metric( metric_name="dataloader_dislikes_total", diff --git a/src/dataloader/api/middleware.py b/src/dataloader/api/middleware.py index f72358c..4ad1fe6 100644 --- a/src/dataloader/api/middleware.py +++ b/src/dataloader/api/middleware.py @@ -38,8 +38,14 @@ async def log_requests(request: Request, call_next) -> any: logger = APP_CTX.get_logger() request_path = request.url.path - allowed_headers_to_log = ((k, request.headers.get(k)) for k in HEADERS_WHITE_LIST_TO_LOG) - headers_to_log = {header_name: header_value for header_name, header_value in allowed_headers_to_log if header_value} + allowed_headers_to_log = ( + (k, request.headers.get(k)) for k in HEADERS_WHITE_LIST_TO_LOG + ) + headers_to_log = { + header_name: header_value + for header_name, header_value in allowed_headers_to_log + if header_value + } APP_CTX.get_context_vars_container().set_context_vars( request_id=headers_to_log.get("Request-Id", ""), @@ -50,7 +56,9 @@ async def log_requests(request: Request, call_next) -> any: if request_path in NON_LOGGED_ENDPOINTS: response = await call_next(request) - logger.debug(f"Processed request for {request_path} with code {response.status_code}") + logger.debug( + f"Processed request for {request_path} with code {response.status_code}" + ) elif headers_to_log.get("Request-Id", None): raw_request_body = await request.body() request_body_decoded = _get_decoded_body(raw_request_body, "request", logger) @@ -80,7 +88,7 @@ async def log_requests(request: Request, call_next) -> any: event_params=json.dumps( request_body_decoded, ensure_ascii=False, - ) + ), ) response = await call_next(request) @@ -88,12 +96,16 @@ async def log_requests(request: Request, call_next) -> any: response_body = [chunk async for chunk in response.body_iterator] response.body_iterator = iterate_in_threadpool(iter(response_body)) - headers_to_log["Response-Time"] = datetime.now(APP_CTX.get_pytz_timezone()).isoformat() + headers_to_log["Response-Time"] = datetime.now( + APP_CTX.get_pytz_timezone() + ).isoformat() for header in headers_to_log: response.headers[header] = headers_to_log[header] response_body_extracted = response_body[0] if len(response_body) > 0 else b"" - decoded_response_body = _get_decoded_body(response_body_extracted, "response", logger) + decoded_response_body = _get_decoded_body( + response_body_extracted, "response", logger + ) logger.info( "Outgoing response to client system", @@ -115,11 +127,13 @@ async def log_requests(request: Request, call_next) -> any: event_params=json.dumps( decoded_response_body, ensure_ascii=False, - ) + ), ) processing_time_ms = int(round((time.time() - start_time), 3) * 1000) - logger.info(f"Request processing time for {request_path}: {processing_time_ms} ms") + logger.info( + f"Request processing time for {request_path}: {processing_time_ms} ms" + ) logger.metric( metric_name="dataloader_process_duration_ms", metric_value=processing_time_ms, @@ -138,7 +152,9 @@ async def log_requests(request: Request, call_next) -> any: else: logger.info(f"Incoming {request.method}-request with no id for {request_path}") response = await call_next(request) - logger.info(f"Request with no id for {request_path} processing time: {time.time() - start_time:.3f} s") + logger.info( + f"Request with no id for {request_path} processing time: {time.time() - start_time:.3f} s" + ) return response diff --git a/src/dataloader/api/os_router.py b/src/dataloader/api/os_router.py index 1dc9210..96b1e2f 100644 --- a/src/dataloader/api/os_router.py +++ b/src/dataloader/api/os_router.py @@ -1,6 +1,4 @@ -# Инфраструктурные endpoint'ы (/health, /status) -""" 🚨 НЕ РЕДАКТИРОВАТЬ !!!!!! -""" +"""🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!""" from importlib.metadata import distribution diff --git a/src/dataloader/api/schemas.py b/src/dataloader/api/schemas.py index 8ee2035..480bc45 100644 --- a/src/dataloader/api/schemas.py +++ b/src/dataloader/api/schemas.py @@ -3,22 +3,26 @@ from pydantic import BaseModel, ConfigDict, Field class HealthResponse(BaseModel): """Ответ для ручки /health""" - model_config = ConfigDict( - json_schema_extra={"example": {"status": "running"}} - ) - status: str = Field(default="running", description="Service health check", max_length=7) + model_config = ConfigDict(json_schema_extra={"example": {"status": "running"}}) + + status: str = Field( + default="running", description="Service health check", max_length=7 + ) class InfoResponse(BaseModel): """Ответ для ручки /info""" + model_config = ConfigDict( json_schema_extra={ "example": { "name": "rest-template", - "description": "Python 'AI gateway' template for developing REST microservices", + "description": ( + "Python 'AI gateway' template for developing REST microservices" + ), "type": "REST API", - "version": "0.1.0" + "version": "0.1.0", } } ) @@ -26,11 +30,14 @@ class InfoResponse(BaseModel): name: str = Field(description="Service name", max_length=50) description: str = Field(description="Service description", max_length=200) type: str = Field(default="REST API", description="Service type", max_length=20) - version: str = Field(description="Service version", max_length=20, pattern=r"^\d+\.\d+\.\d+") + version: str = Field( + description="Service version", max_length=20, pattern=r"^\d+\.\d+\.\d+" + ) class RateResponse(BaseModel): """Ответ для записи рейтинга""" + model_config = ConfigDict(str_strip_whitespace=True) rating_result: str = Field(description="Rating that was recorded", max_length=50) diff --git a/src/dataloader/api/v1/exceptions.py b/src/dataloader/api/v1/exceptions.py index d2fd45b..743cb3b 100644 --- a/src/dataloader/api/v1/exceptions.py +++ b/src/dataloader/api/v1/exceptions.py @@ -5,12 +5,8 @@ from fastapi import HTTPException, status class JobNotFoundError(HTTPException): """Задача не найдена.""" - + def __init__(self, job_id: str): super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Job {job_id} not found" + status_code=status.HTTP_404_NOT_FOUND, detail=f"Job {job_id} not found" ) - - - diff --git a/src/dataloader/api/v1/models.py b/src/dataloader/api/v1/models.py index d5c7bd5..4e3433f 100644 --- a/src/dataloader/api/v1/models.py +++ b/src/dataloader/api/v1/models.py @@ -1,2 +1 @@ """Модели данных для API v1.""" - diff --git a/src/dataloader/api/v1/router.py b/src/dataloader/api/v1/router.py index 62f7faa..18a0ca7 100644 --- a/src/dataloader/api/v1/router.py +++ b/src/dataloader/api/v1/router.py @@ -1,12 +1,11 @@ -# src/dataloader/api/v1/router.py from __future__ import annotations -from collections.abc import AsyncGenerator from http import HTTPStatus from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession from dataloader.api.v1.exceptions import JobNotFoundError from dataloader.api.v1.schemas import ( @@ -16,8 +15,6 @@ from dataloader.api.v1.schemas import ( ) from dataloader.api.v1.service import JobsService from dataloader.context import get_session -from sqlalchemy.ext.asyncio import AsyncSession - router = APIRouter(prefix="/jobs", tags=["jobs"]) @@ -40,7 +37,9 @@ async def trigger_job( return await svc.trigger(payload) -@router.get("/{job_id}/status", response_model=JobStatusResponse, status_code=HTTPStatus.OK) +@router.get( + "/{job_id}/status", response_model=JobStatusResponse, status_code=HTTPStatus.OK +) async def get_status( job_id: UUID, svc: Annotated[JobsService, Depends(get_service)], @@ -54,7 +53,9 @@ async def get_status( return st -@router.post("/{job_id}/cancel", response_model=JobStatusResponse, status_code=HTTPStatus.OK) +@router.post( + "/{job_id}/cancel", response_model=JobStatusResponse, status_code=HTTPStatus.OK +) async def cancel_job( job_id: UUID, svc: Annotated[JobsService, Depends(get_service)], diff --git a/src/dataloader/api/v1/schemas.py b/src/dataloader/api/v1/schemas.py index bd52ccb..212bbea 100644 --- a/src/dataloader/api/v1/schemas.py +++ b/src/dataloader/api/v1/schemas.py @@ -1,4 +1,3 @@ -# src/dataloader/api/v1/schemas.py from __future__ import annotations from datetime import datetime, timezone @@ -12,6 +11,7 @@ class TriggerJobRequest(BaseModel): """ Запрос на постановку задачи в очередь. """ + model_config = ConfigDict(str_strip_whitespace=True) queue: str = Field(...) @@ -39,6 +39,7 @@ class TriggerJobResponse(BaseModel): """ Ответ на постановку задачи. """ + model_config = ConfigDict(str_strip_whitespace=True) job_id: UUID = Field(...) @@ -49,6 +50,7 @@ class JobStatusResponse(BaseModel): """ Текущий статус задачи. """ + model_config = ConfigDict(str_strip_whitespace=True) job_id: UUID = Field(...) @@ -65,6 +67,7 @@ class CancelJobResponse(BaseModel): """ Ответ на запрос отмены задачи. """ + model_config = ConfigDict(str_strip_whitespace=True) job_id: UUID = Field(...) diff --git a/src/dataloader/api/v1/service.py b/src/dataloader/api/v1/service.py index 3151653..7469778 100644 --- a/src/dataloader/api/v1/service.py +++ b/src/dataloader/api/v1/service.py @@ -1,4 +1,3 @@ -# src/dataloader/api/v1/service.py from __future__ import annotations from datetime import datetime, timezone @@ -13,15 +12,16 @@ from dataloader.api.v1.schemas import ( TriggerJobResponse, ) from dataloader.api.v1.utils import new_job_id -from dataloader.storage.schemas import CreateJobRequest -from dataloader.storage.repositories import QueueRepository from dataloader.logger.logger import get_logger +from dataloader.storage.repositories import QueueRepository +from dataloader.storage.schemas import CreateJobRequest class JobsService: """ Бизнес-логика работы с очередью задач. """ + def __init__(self, session: AsyncSession): self._s = session self._repo = QueueRepository(self._s) diff --git a/src/dataloader/api/v1/utils.py b/src/dataloader/api/v1/utils.py index 8c6731d..ed8a09e 100644 --- a/src/dataloader/api/v1/utils.py +++ b/src/dataloader/api/v1/utils.py @@ -1,4 +1,3 @@ -# src/dataloader/api/v1/utils.py from __future__ import annotations from uuid import UUID, uuid4 diff --git a/src/dataloader/config.py b/src/dataloader/config.py index ffbd674..8831cd9 100644 --- a/src/dataloader/config.py +++ b/src/dataloader/config.py @@ -1,6 +1,5 @@ -# src/dataloader/config.py -import os import json +import os from logging import DEBUG, INFO from typing import Annotated, Any @@ -32,6 +31,7 @@ class BaseAppSettings(BaseSettings): """ Базовый класс для настроек. """ + local: bool = Field(validation_alias="LOCAL", default=False) debug: bool = Field(validation_alias="DEBUG", default=False) @@ -44,6 +44,7 @@ class AppSettings(BaseAppSettings): """ Настройки приложения. """ + app_host: str = Field(validation_alias="APP_HOST", default="0.0.0.0") app_port: int = Field(validation_alias="APP_PORT", default=8081) kube_net_name: str = Field(validation_alias="PROJECT_NAME", default="AIGATEWAY") @@ -54,15 +55,28 @@ class LogSettings(BaseAppSettings): """ Настройки логирования. """ + private_log_file_path: str = Field(validation_alias="LOG_PATH", default=os.getcwd()) - private_log_file_name: str = Field(validation_alias="LOG_FILE_NAME", default="app.log") + private_log_file_name: str = Field( + validation_alias="LOG_FILE_NAME", default="app.log" + ) log_rotation: str = Field(validation_alias="LOG_ROTATION", default="10 MB") - private_metric_file_path: str = Field(validation_alias="METRIC_PATH", default=os.getcwd()) - private_metric_file_name: str = Field(validation_alias="METRIC_FILE_NAME", default="app-metric.log") - private_audit_file_path: str = Field(validation_alias="AUDIT_LOG_PATH", default=os.getcwd()) - private_audit_file_name: str = Field(validation_alias="AUDIT_LOG_FILE_NAME", default="events.log") + private_metric_file_path: str = Field( + validation_alias="METRIC_PATH", default=os.getcwd() + ) + private_metric_file_name: str = Field( + validation_alias="METRIC_FILE_NAME", default="app-metric.log" + ) + private_audit_file_path: str = Field( + validation_alias="AUDIT_LOG_PATH", default=os.getcwd() + ) + private_audit_file_name: str = Field( + validation_alias="AUDIT_LOG_FILE_NAME", default="events.log" + ) audit_host_ip: str = Field(validation_alias="HOST_IP", default="127.0.0.1") - audit_host_uid: str = Field(validation_alias="HOST_UID", default="63b6dcee-170b-49bf-a65c-3ec967398ccd") + audit_host_uid: str = Field( + validation_alias="HOST_UID", default="63b6dcee-170b-49bf-a65c-3ec967398ccd" + ) @staticmethod def get_file_abs_path(path_name: str, file_name: str) -> str: @@ -70,15 +84,21 @@ class LogSettings(BaseAppSettings): @property def log_file_abs_path(self) -> str: - return self.get_file_abs_path(self.private_log_file_path, self.private_log_file_name) + return self.get_file_abs_path( + self.private_log_file_path, self.private_log_file_name + ) @property def metric_file_abs_path(self) -> str: - return self.get_file_abs_path(self.private_metric_file_path, self.private_metric_file_name) + return self.get_file_abs_path( + self.private_metric_file_path, self.private_metric_file_name + ) @property def audit_file_abs_path(self) -> str: - return self.get_file_abs_path(self.private_audit_file_path, self.private_audit_file_name) + return self.get_file_abs_path( + self.private_audit_file_path, self.private_audit_file_name + ) @property def log_lvl(self) -> int: @@ -89,6 +109,7 @@ class PGSettings(BaseSettings): """ Настройки подключения к Postgres. """ + host: str = Field(validation_alias="PG_HOST", default="localhost") port: int = Field(validation_alias="PG_PORT", default=5432) user: str = Field(validation_alias="PG_USER", default="postgres") @@ -116,9 +137,12 @@ class WorkerSettings(BaseSettings): """ Настройки очереди и воркеров. """ + workers_json: str = Field(validation_alias="WORKERS_JSON", default="[]") heartbeat_sec: int = Field(validation_alias="DL_HEARTBEAT_SEC", default=10) - default_lease_ttl_sec: int = Field(validation_alias="DL_DEFAULT_LEASE_TTL_SEC", default=60) + default_lease_ttl_sec: int = Field( + validation_alias="DL_DEFAULT_LEASE_TTL_SEC", default=60 + ) reaper_period_sec: int = Field(validation_alias="DL_REAPER_PERIOD_SEC", default=10) claim_backoff_sec: int = Field(validation_alias="DL_CLAIM_BACKOFF_SEC", default=15) @@ -137,6 +161,7 @@ class CertsSettings(BaseSettings): """ Настройки SSL сертификатов для локальной разработки. """ + ca_bundle_file: str = Field(validation_alias="CA_BUNDLE_FILE", default="") cert_file: str = Field(validation_alias="CERT_FILE", default="") key_file: str = Field(validation_alias="KEY_FILE", default="") @@ -146,21 +171,24 @@ class SuperTeneraSettings(BaseAppSettings): """ Настройки интеграции с SuperTenera. """ + host: Annotated[str, BeforeValidator(strip_slashes)] = Field( validation_alias="SUPERTENERA_HOST", - default="ci03801737-ift-tenera-giga.delta.sbrf.ru/atlant360bc/" + default="ci03801737-ift-tenera-giga.delta.sbrf.ru/atlant360bc/", ) port: str = Field(validation_alias="SUPERTENERA_PORT", default="443") quotes_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field( validation_alias="SUPERTENERA_QUOTES_ENDPOINT", - default="/get_gigaparser_quotes/" + default="/get_gigaparser_quotes/", ) timeout: int = Field(validation_alias="SUPERTENERA_TIMEOUT", default=20) @property def base_url(self) -> str: """Возвращает абсолютный URL для SuperTenera.""" - domain, raw_path = self.host.split("/", 1) if "/" in self.host else (self.host, "") + domain, raw_path = ( + self.host.split("/", 1) if "/" in self.host else (self.host, "") + ) return build_url(self.protocol, domain, self.port, raw_path) @@ -168,22 +196,21 @@ class Gmap2BriefSettings(BaseAppSettings): """ Настройки интеграции с Gmap2Brief (OPU API). """ + host: Annotated[str, BeforeValidator(strip_slashes)] = Field( validation_alias="GMAP2BRIEF_HOST", - default="ci02533826-tib-brief.apps.ift-terra000024-edm.ocp.delta.sbrf.ru" + default="ci02533826-tib-brief.apps.ift-terra000024-edm.ocp.delta.sbrf.ru", ) port: str = Field(validation_alias="GMAP2BRIEF_PORT", default="443") start_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field( - validation_alias="GMAP2BRIEF_START_ENDPOINT", - default="/export/opu/start" + validation_alias="GMAP2BRIEF_START_ENDPOINT", default="/export/opu/start" ) status_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field( - validation_alias="GMAP2BRIEF_STATUS_ENDPOINT", - default="/export/{job_id}/status" + validation_alias="GMAP2BRIEF_STATUS_ENDPOINT", default="/export/{job_id}/status" ) download_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field( validation_alias="GMAP2BRIEF_DOWNLOAD_ENDPOINT", - default="/export/{job_id}/download" + default="/export/{job_id}/download", ) poll_interval: int = Field(validation_alias="GMAP2BRIEF_POLL_INTERVAL", default=2) timeout: int = Field(validation_alias="GMAP2BRIEF_TIMEOUT", default=3600) @@ -198,6 +225,7 @@ class Secrets: """ Агрегатор настроек приложения. """ + app: AppSettings = AppSettings() log: LogSettings = LogSettings() pg: PGSettings = PGSettings() diff --git a/src/dataloader/context.py b/src/dataloader/context.py index 328767c..e0da114 100644 --- a/src/dataloader/context.py +++ b/src/dataloader/context.py @@ -1,8 +1,7 @@ -# src/dataloader/context.py from __future__ import annotations -from typing import AsyncGenerator from logging import Logger +from typing import AsyncGenerator from zoneinfo import ZoneInfo from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker @@ -15,6 +14,7 @@ class AppContext: """ Контекст приложения, хранящий глобальные зависимости (Singleton pattern). """ + def __init__(self) -> None: self._engine: AsyncEngine | None = None self._sessionmaker: async_sessionmaker[AsyncSession] | None = None @@ -75,6 +75,7 @@ class AppContext: Экземпляр Logger """ from .logger.logger import get_logger as get_app_logger + return get_app_logger(name) def get_context_vars_container(self) -> ContextVarsContainer: diff --git a/src/dataloader/exceptions.py b/src/dataloader/exceptions.py index d259d99..8d8e0d2 100644 --- a/src/dataloader/exceptions.py +++ b/src/dataloader/exceptions.py @@ -1,2 +1 @@ """Исключения уровня приложения.""" - diff --git a/src/dataloader/interfaces/gmap2_brief/interface.py b/src/dataloader/interfaces/gmap2_brief/interface.py index d949931..c15db93 100644 --- a/src/dataloader/interfaces/gmap2_brief/interface.py +++ b/src/dataloader/interfaces/gmap2_brief/interface.py @@ -10,7 +10,10 @@ from typing import TYPE_CHECKING import httpx from dataloader.config import APP_CONFIG -from dataloader.interfaces.gmap2_brief.schemas import ExportJobStatus, StartExportResponse +from dataloader.interfaces.gmap2_brief.schemas import ( + ExportJobStatus, + StartExportResponse, +) if TYPE_CHECKING: from logging import Logger @@ -37,10 +40,11 @@ class Gmap2BriefInterface: self._ssl_context = None if APP_CONFIG.app.local and APP_CONFIG.certs.cert_file: - self._ssl_context = ssl.create_default_context(cafile=APP_CONFIG.certs.ca_bundle_file) + self._ssl_context = ssl.create_default_context( + cafile=APP_CONFIG.certs.ca_bundle_file + ) self._ssl_context.load_cert_chain( - certfile=APP_CONFIG.certs.cert_file, - keyfile=APP_CONFIG.certs.key_file + certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file ) async def start_export(self) -> str: @@ -56,9 +60,13 @@ class Gmap2BriefInterface: url = self.base_url + APP_CONFIG.gmap2brief.start_endpoint async with httpx.AsyncClient( - cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None, + cert=( + (APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) + if APP_CONFIG.app.local + else None + ), verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True, - timeout=self.timeout + timeout=self.timeout, ) as client: try: self.logger.info(f"Starting OPU export: POST {url}") @@ -87,12 +95,18 @@ class Gmap2BriefInterface: Raises: Gmap2BriefConnectionError: При ошибке запроса """ - url = self.base_url + APP_CONFIG.gmap2brief.status_endpoint.format(job_id=job_id) + url = self.base_url + APP_CONFIG.gmap2brief.status_endpoint.format( + job_id=job_id + ) async with httpx.AsyncClient( - cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None, + cert=( + (APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) + if APP_CONFIG.app.local + else None + ), verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True, - timeout=self.timeout + timeout=self.timeout, ) as client: try: response = await client.get(url) @@ -106,7 +120,9 @@ class Gmap2BriefInterface: except httpx.RequestError as e: raise Gmap2BriefConnectionError(f"Request error: {e}") from e - async def wait_for_completion(self, job_id: str, max_wait: int | None = None) -> ExportJobStatus: + async def wait_for_completion( + self, job_id: str, max_wait: int | None = None + ) -> ExportJobStatus: """ Ждет завершения задачи экспорта с периодическим polling. @@ -127,14 +143,18 @@ class Gmap2BriefInterface: status = await self.get_status(job_id) if status.status == "completed": - self.logger.info(f"Job {job_id} completed, total_rows={status.total_rows}") + self.logger.info( + f"Job {job_id} completed, total_rows={status.total_rows}" + ) return status elif status.status == "failed": raise Gmap2BriefConnectionError(f"Job {job_id} failed: {status.error}") elapsed = asyncio.get_event_loop().time() - start_time if max_wait and elapsed > max_wait: - raise Gmap2BriefConnectionError(f"Job {job_id} timeout after {elapsed:.1f}s") + raise Gmap2BriefConnectionError( + f"Job {job_id} timeout after {elapsed:.1f}s" + ) self.logger.debug( f"Job {job_id} status={status.status}, rows={status.total_rows}, elapsed={elapsed:.1f}s" @@ -155,12 +175,18 @@ class Gmap2BriefInterface: Raises: Gmap2BriefConnectionError: При ошибке скачивания """ - url = self.base_url + APP_CONFIG.gmap2brief.download_endpoint.format(job_id=job_id) + url = self.base_url + APP_CONFIG.gmap2brief.download_endpoint.format( + job_id=job_id + ) async with httpx.AsyncClient( - cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None, + cert=( + (APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) + if APP_CONFIG.app.local + else None + ), verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True, - timeout=self.timeout + timeout=self.timeout, ) as client: try: self.logger.info(f"Downloading export: GET {url}") diff --git a/src/dataloader/interfaces/gmap2_brief/schemas.py b/src/dataloader/interfaces/gmap2_brief/schemas.py index bb33281..b5a63c4 100644 --- a/src/dataloader/interfaces/gmap2_brief/schemas.py +++ b/src/dataloader/interfaces/gmap2_brief/schemas.py @@ -17,7 +17,11 @@ class ExportJobStatus(BaseModel): """Статус задачи экспорта.""" job_id: str = Field(..., description="Идентификатор задачи") - status: Literal["pending", "running", "completed", "failed"] = Field(..., description="Статус задачи") + status: Literal["pending", "running", "completed", "failed"] = Field( + ..., description="Статус задачи" + ) total_rows: int = Field(default=0, description="Количество обработанных строк") error: str | None = Field(default=None, description="Текст ошибки (если есть)") - temp_file_path: str | None = Field(default=None, description="Путь к временному файлу (для completed)") + temp_file_path: str | None = Field( + default=None, description="Путь к временному файлу (для completed)" + ) diff --git a/src/dataloader/interfaces/tenera/interface.py b/src/dataloader/interfaces/tenera/interface.py index 4eb42d7..6899395 100644 --- a/src/dataloader/interfaces/tenera/interface.py +++ b/src/dataloader/interfaces/tenera/interface.py @@ -47,8 +47,12 @@ class SuperTeneraInterface: self._ssl_context = None if APP_CONFIG.app.local: - self._ssl_context = ssl.create_default_context(cafile=APP_CONFIG.certs.ca_bundle_file) - self._ssl_context.load_cert_chain(certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file) + self._ssl_context = ssl.create_default_context( + cafile=APP_CONFIG.certs.ca_bundle_file + ) + self._ssl_context.load_cert_chain( + certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file + ) def form_base_headers(self) -> dict: """Формирует базовые заголовки для запроса.""" @@ -57,7 +61,11 @@ class SuperTeneraInterface: "request-time": str(datetime.now(tz=self.timezone).isoformat()), "system-id": APP_CONFIG.app.kube_net_name, } - return {metakey: metavalue for metakey, metavalue in metadata_pairs.items() if metavalue} + return { + metakey: metavalue + for metakey, metavalue in metadata_pairs.items() + if metavalue + } async def __aenter__(self) -> Self: """Async context manager enter.""" @@ -86,7 +94,9 @@ class SuperTeneraInterface: async with self._session.get(url, **kwargs) as response: if APP_CONFIG.app.debug: - self.logger.debug(f"Response: {(await response.text(errors='ignore'))[:100]}") + self.logger.debug( + f"Response: {(await response.text(errors='ignore'))[:100]}" + ) return await response.json(encoding=encoding, content_type=content_type) async def get_quotes_data(self) -> MainData: @@ -103,7 +113,9 @@ class SuperTeneraInterface: kwargs["ssl"] = self._ssl_context try: async with self._session.get( - APP_CONFIG.supertenera.quotes_endpoint, timeout=APP_CONFIG.supertenera.timeout, **kwargs + APP_CONFIG.supertenera.quotes_endpoint, + timeout=APP_CONFIG.supertenera.timeout, + **kwargs, ) as resp: resp.raise_for_status() return True @@ -112,7 +124,9 @@ class SuperTeneraInterface: f"Ошибка подключения к SuperTenera API при проверке системы - {e.status}." ) from e except TimeoutError as e: - raise SuperTeneraConnectionError("Ошибка Timeout подключения к SuperTenera API при проверке системы.") from e + raise SuperTeneraConnectionError( + "Ошибка Timeout подключения к SuperTenera API при проверке системы." + ) from e def get_async_tenera_interface() -> SuperTeneraInterface: diff --git a/src/dataloader/interfaces/tenera/schemas.py b/src/dataloader/interfaces/tenera/schemas.py index ea82c64..06cf6cb 100644 --- a/src/dataloader/interfaces/tenera/schemas.py +++ b/src/dataloader/interfaces/tenera/schemas.py @@ -77,7 +77,9 @@ class InvestingCandlestick(TeneraBaseModel): value: str = Field(alias="V") -class InvestingTimePoint(RootModel[EmptyTimePoint | InvestingNumeric | InvestingCandlestick]): +class InvestingTimePoint( + RootModel[EmptyTimePoint | InvestingNumeric | InvestingCandlestick] +): """ Union-модель точки времени для источника Investing.com. @@ -340,5 +342,9 @@ class MainData(TeneraBaseModel): :return: Отфильтрованный объект """ if isinstance(v, dict): - return {key: value for key, value in v.items() if value is not None and not str(key).isdigit()} + return { + key: value + for key, value in v.items() + if value is not None and not str(key).isdigit() + } return v diff --git a/src/dataloader/logger/__init__.py b/src/dataloader/logger/__init__.py index da2dbbf..059238a 100644 --- a/src/dataloader/logger/__init__.py +++ b/src/dataloader/logger/__init__.py @@ -1,4 +1,3 @@ -# src/dataloader/logger/__init__.py -from .logger import setup_logging, get_logger +from .logger import get_logger, setup_logging __all__ = ["setup_logging", "get_logger"] diff --git a/src/dataloader/logger/context_vars.py b/src/dataloader/logger/context_vars.py index 91d98da..cb97169 100644 --- a/src/dataloader/logger/context_vars.py +++ b/src/dataloader/logger/context_vars.py @@ -1,9 +1,7 @@ -# Управление контекстом запросов для логирования import uuid from contextvars import ContextVar from typing import Final - REQUEST_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("request_id", default="") DEVICE_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("device_id", default="") SESSION_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("session_id", default="") @@ -66,7 +64,13 @@ class ContextVarsContainer: def gw_session_id(self, value: str) -> None: GW_SESSION_ID_CTX_VAR.set(value) - def set_context_vars(self, request_id: str = "", request_time: str = "", system_id: str = "", gw_session_id: str = "") -> None: + def set_context_vars( + self, + request_id: str = "", + request_time: str = "", + system_id: str = "", + gw_session_id: str = "", + ) -> None: if request_id: self.set_request_id(request_id) if request_time: diff --git a/src/dataloader/logger/logger.py b/src/dataloader/logger/logger.py index 3c94ca7..fc7913e 100644 --- a/src/dataloader/logger/logger.py +++ b/src/dataloader/logger/logger.py @@ -1,17 +1,12 @@ -# src/dataloader/logger/logger.py -import sys -import typing -from datetime import tzinfo -from logging import Logger import logging +import sys +from logging import Logger from loguru import logger -from .context_vars import ContextVarsContainer from dataloader.config import APP_CONFIG -# Определяем фильтры для разных типов логов def metric_only_filter(record: dict) -> bool: return "metric" in record["extra"] @@ -26,13 +21,12 @@ def regular_log_filter(record: dict) -> bool: class InterceptHandler(logging.Handler): def emit(self, record: logging.LogRecord) -> None: - # Get corresponding Loguru level if it exists + try: level = logger.level(record.levelname).name except ValueError: level = record.levelno - # Find caller from where originated the logged message frame, depth = logging.currentframe(), 2 while frame.f_code.co_filename == logging.__file__: frame = frame.f_back diff --git a/src/dataloader/logger/models.py b/src/dataloader/logger/models.py index 23d0453..8b13789 100644 --- a/src/dataloader/logger/models.py +++ b/src/dataloader/logger/models.py @@ -1 +1 @@ -# Модели логов, метрик, событий аудита + diff --git a/src/dataloader/logger/utils.py b/src/dataloader/logger/utils.py index 3481061..8b13789 100644 --- a/src/dataloader/logger/utils.py +++ b/src/dataloader/logger/utils.py @@ -1 +1 @@ -# Функции + маскирование args + diff --git a/src/dataloader/logger/uvicorn_logging_config.py b/src/dataloader/logger/uvicorn_logging_config.py index 3df67ab..d83017c 100644 --- a/src/dataloader/logger/uvicorn_logging_config.py +++ b/src/dataloader/logger/uvicorn_logging_config.py @@ -1,34 +1,28 @@ -# Конфигурация логирования uvicorn import logging -import sys - from loguru import logger class InterceptHandler(logging.Handler): def emit(self, record: logging.LogRecord) -> None: - # Get corresponding Loguru level if it exists + try: level = logger.level(record.levelname).name except ValueError: level = record.levelno - - # Find caller from where originated the logged message frame, depth = logging.currentframe(), 2 while frame.f_code.co_filename == logging.__file__: frame = frame.f_back depth += 1 - logger.opt(depth=depth, exception=record.exc_info).log( level, record.getMessage() ) def setup_uvicorn_logging() -> None: - # Set all uvicorn loggers to use InterceptHandler + for logger_name in ["uvicorn", "uvicorn.error", "uvicorn.access"]: log = logging.getLogger(logger_name) log.handlers = [InterceptHandler()] @@ -36,7 +30,6 @@ def setup_uvicorn_logging() -> None: log.propagate = False -# uvicorn logging config LOGGING_CONFIG = { "version": 1, "disable_existing_loggers": False, diff --git a/src/dataloader/storage/__init__.py b/src/dataloader/storage/__init__.py index f23621f..2a09814 100644 --- a/src/dataloader/storage/__init__.py +++ b/src/dataloader/storage/__init__.py @@ -1,2 +1 @@ """Модуль для работы с хранилищем данных.""" - diff --git a/src/dataloader/storage/engine.py b/src/dataloader/storage/engine.py index 7d376fc..4cee2ac 100644 --- a/src/dataloader/storage/engine.py +++ b/src/dataloader/storage/engine.py @@ -1,7 +1,11 @@ -# src/dataloader/storage/engine.py from __future__ import annotations -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) from dataloader.config import APP_CONFIG diff --git a/src/dataloader/storage/models/__init__.py b/src/dataloader/storage/models/__init__.py index 6e74dc2..400f713 100644 --- a/src/dataloader/storage/models/__init__.py +++ b/src/dataloader/storage/models/__init__.py @@ -1,8 +1,8 @@ -# src/dataloader/storage/models/__init__.py """ ORM модели для работы с базой данных. Организованы по доменам для масштабируемости. """ + from __future__ import annotations from .base import Base diff --git a/src/dataloader/storage/models/base.py b/src/dataloader/storage/models/base.py index 1581e31..8ac10a9 100644 --- a/src/dataloader/storage/models/base.py +++ b/src/dataloader/storage/models/base.py @@ -1,4 +1,3 @@ -# src/dataloader/storage/models/base.py from __future__ import annotations from sqlalchemy.orm import DeclarativeBase @@ -9,4 +8,5 @@ class Base(DeclarativeBase): Базовый класс для всех ORM моделей приложения. Используется SQLAlchemy 2.0+ declarative style. """ + pass diff --git a/src/dataloader/storage/models/opu.py b/src/dataloader/storage/models/opu.py index 3da8e96..96e732d 100644 --- a/src/dataloader/storage/models/opu.py +++ b/src/dataloader/storage/models/opu.py @@ -4,7 +4,7 @@ from __future__ import annotations from datetime import date, datetime -from sqlalchemy import BigInteger, Date, Integer, Numeric, String, TIMESTAMP, text +from sqlalchemy import TIMESTAMP, BigInteger, Date, Integer, Numeric, String, text from sqlalchemy.orm import Mapped, mapped_column from dataloader.storage.models.base import Base @@ -16,29 +16,47 @@ class BriefDigitalCertificateOpu(Base): __tablename__ = "brief_digital_certificate_opu" __table_args__ = ({"schema": "opu"},) - object_id: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'")) + object_id: Mapped[str] = mapped_column( + String, primary_key=True, server_default=text("'-'") + ) object_nm: Mapped[str | None] = mapped_column(String) - desk_nm: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'")) - actdate: Mapped[date] = mapped_column(Date, primary_key=True, server_default=text("CURRENT_DATE")) - layer_cd: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'")) + desk_nm: Mapped[str] = mapped_column( + String, primary_key=True, server_default=text("'-'") + ) + actdate: Mapped[date] = mapped_column( + Date, primary_key=True, server_default=text("CURRENT_DATE") + ) + layer_cd: Mapped[str] = mapped_column( + String, primary_key=True, server_default=text("'-'") + ) layer_nm: Mapped[str | None] = mapped_column(String) opu_cd: Mapped[str] = mapped_column(String, primary_key=True) opu_nm_sh: Mapped[str | None] = mapped_column(String) opu_nm: Mapped[str | None] = mapped_column(String) - opu_lvl: Mapped[int] = mapped_column(Integer, primary_key=True, server_default=text("'-1'")) - opu_prnt_cd: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'")) + opu_lvl: Mapped[int] = mapped_column( + Integer, primary_key=True, server_default=text("'-1'") + ) + opu_prnt_cd: Mapped[str] = mapped_column( + String, primary_key=True, server_default=text("'-'") + ) opu_prnt_nm_sh: Mapped[str | None] = mapped_column(String) opu_prnt_nm: Mapped[str | None] = mapped_column(String) sum_amountrub_p_usd: Mapped[float | None] = mapped_column(Numeric) - wf_load_id: Mapped[int] = mapped_column(BigInteger, nullable=False, server_default=text("'-1'")) + wf_load_id: Mapped[int] = mapped_column( + BigInteger, nullable=False, server_default=text("'-1'") + ) wf_load_dttm: Mapped[datetime] = mapped_column( TIMESTAMP(timezone=False), nullable=False, server_default=text("CURRENT_TIMESTAMP"), ) - wf_row_id: Mapped[int] = mapped_column(BigInteger, nullable=False, server_default=text("'-1'")) + wf_row_id: Mapped[int] = mapped_column( + BigInteger, nullable=False, server_default=text("'-1'") + ) object_tp: Mapped[str | None] = mapped_column(String) - object_unit: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'")) + object_unit: Mapped[str] = mapped_column( + String, primary_key=True, server_default=text("'-'") + ) measure: Mapped[str | None] = mapped_column(String) product_nm: Mapped[str | None] = mapped_column(String) product_prnt_nm: Mapped[str | None] = mapped_column(String) diff --git a/src/dataloader/storage/models/queue.py b/src/dataloader/storage/models/queue.py index 8da17f9..3ec52ac 100644 --- a/src/dataloader/storage/models/queue.py +++ b/src/dataloader/storage/models/queue.py @@ -1,4 +1,3 @@ -# src/dataloader/storage/models/queue.py from __future__ import annotations from datetime import datetime @@ -10,7 +9,6 @@ from sqlalchemy.orm import Mapped, mapped_column from .base import Base - dl_status_enum = ENUM( "queued", "running", @@ -29,6 +27,7 @@ class DLJob(Base): Модель таблицы очереди задач dl_jobs. Использует логическое имя схемы 'queue' для поддержки schema_translate_map. """ + __tablename__ = "dl_jobs" __table_args__ = {"schema": "queue"} @@ -40,15 +39,23 @@ class DLJob(Base): lock_key: Mapped[str] = mapped_column(Text, nullable=False) partition_key: Mapped[str] = mapped_column(Text, default="", nullable=False) priority: Mapped[int] = mapped_column(nullable=False, default=100) - available_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) - status: Mapped[str] = mapped_column(dl_status_enum, nullable=False, default="queued") + available_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + status: Mapped[str] = mapped_column( + dl_status_enum, nullable=False, default="queued" + ) attempt: Mapped[int] = mapped_column(nullable=False, default=0) max_attempts: Mapped[int] = mapped_column(nullable=False, default=5) lease_ttl_sec: Mapped[int] = mapped_column(nullable=False, default=60) - lease_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True)) + lease_expires_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True) + ) heartbeat_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True)) cancel_requested: Mapped[bool] = mapped_column(nullable=False, default=False) - progress: Mapped[dict[str, Any]] = mapped_column(JSONB, default=dict, nullable=False) + progress: Mapped[dict[str, Any]] = mapped_column( + JSONB, default=dict, nullable=False + ) error: Mapped[Optional[str]] = mapped_column(Text) producer: Mapped[Optional[str]] = mapped_column(Text) consumer_group: Mapped[Optional[str]] = mapped_column(Text) @@ -62,10 +69,13 @@ class DLJobEvent(Base): Модель таблицы журнала событий dl_job_events. Использует логическое имя схемы 'queue' для поддержки schema_translate_map. """ + __tablename__ = "dl_job_events" __table_args__ = {"schema": "queue"} - event_id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True) + event_id: Mapped[int] = mapped_column( + BigInteger, primary_key=True, autoincrement=True + ) job_id: Mapped[str] = mapped_column(UUID(as_uuid=False), nullable=False) queue: Mapped[str] = mapped_column(Text, nullable=False) load_dttm: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) diff --git a/src/dataloader/storage/models/quote.py b/src/dataloader/storage/models/quote.py index 9a0473c..93c4eaa 100644 --- a/src/dataloader/storage/models/quote.py +++ b/src/dataloader/storage/models/quote.py @@ -5,7 +5,15 @@ from __future__ import annotations from datetime import datetime from typing import TYPE_CHECKING -from sqlalchemy import JSON, TIMESTAMP, BigInteger, ForeignKey, String, UniqueConstraint, func +from sqlalchemy import ( + JSON, + TIMESTAMP, + BigInteger, + ForeignKey, + String, + UniqueConstraint, + func, +) from sqlalchemy.orm import Mapped, mapped_column, relationship from dataloader.storage.models.base import Base @@ -30,7 +38,9 @@ class Quote(Base): srce: Mapped[str | None] = mapped_column(String) ticker: Mapped[str | None] = mapped_column(String) quote_sect_id: Mapped[int] = mapped_column( - ForeignKey("quotes.quotes_sect.quote_sect_id", ondelete="CASCADE", onupdate="CASCADE"), + ForeignKey( + "quotes.quotes_sect.quote_sect_id", ondelete="CASCADE", onupdate="CASCADE" + ), nullable=False, ) last_update_dttm: Mapped[datetime | None] = mapped_column(TIMESTAMP(timezone=True)) diff --git a/src/dataloader/storage/models/quote_value.py b/src/dataloader/storage/models/quote_value.py index eedd174..ed45151 100644 --- a/src/dataloader/storage/models/quote_value.py +++ b/src/dataloader/storage/models/quote_value.py @@ -5,7 +5,17 @@ from __future__ import annotations from datetime import datetime from typing import TYPE_CHECKING -from sqlalchemy import TIMESTAMP, BigInteger, Boolean, DateTime, Float, ForeignKey, String, UniqueConstraint, func +from sqlalchemy import ( + TIMESTAMP, + BigInteger, + Boolean, + DateTime, + Float, + ForeignKey, + String, + UniqueConstraint, + func, +) from sqlalchemy.orm import Mapped, mapped_column, relationship from dataloader.storage.models.base import Base diff --git a/src/dataloader/storage/notify_listener.py b/src/dataloader/storage/notify_listener.py index 248eec9..9e8a6ed 100644 --- a/src/dataloader/storage/notify_listener.py +++ b/src/dataloader/storage/notify_listener.py @@ -1,16 +1,23 @@ -# src/dataloader/storage/notify_listener.py from __future__ import annotations import asyncio -import asyncpg from typing import Callable, Optional +import asyncpg + class PGNotifyListener: """ Прослушиватель PostgreSQL NOTIFY для канала 'dl_jobs'. """ - def __init__(self, dsn: str, queue: str, callback: Callable[[], None], stop_event: asyncio.Event): + + def __init__( + self, + dsn: str, + queue: str, + callback: Callable[[], None], + stop_event: asyncio.Event, + ): self._dsn = dsn self._queue = queue self._callback = callback diff --git a/src/dataloader/storage/repositories/__init__.py b/src/dataloader/storage/repositories/__init__.py index 314bfb6..f5ef597 100644 --- a/src/dataloader/storage/repositories/__init__.py +++ b/src/dataloader/storage/repositories/__init__.py @@ -1,8 +1,8 @@ -# src/dataloader/storage/repositories/__init__.py """ Репозитории для работы с базой данных. Организованы по доменам для масштабируемости. """ + from __future__ import annotations from .opu import OpuRepository diff --git a/src/dataloader/storage/repositories/opu.py b/src/dataloader/storage/repositories/opu.py index ec2dfd3..189a093 100644 --- a/src/dataloader/storage/repositories/opu.py +++ b/src/dataloader/storage/repositories/opu.py @@ -69,7 +69,6 @@ class OpuRepository: if not records: return 0 - # Получаем колонки для обновления (все кроме PK и технических) update_columns = { c.name for c in BriefDigitalCertificateOpu.__table__.columns diff --git a/src/dataloader/storage/repositories/queue.py b/src/dataloader/storage/repositories/queue.py index 2b3e0da..4508ac9 100644 --- a/src/dataloader/storage/repositories/queue.py +++ b/src/dataloader/storage/repositories/queue.py @@ -1,4 +1,3 @@ -# src/dataloader/storage/repositories/queue.py from __future__ import annotations from datetime import datetime, timedelta, timezone @@ -15,6 +14,7 @@ class QueueRepository: """ Репозиторий для работы с очередью задач и журналом событий. """ + def __init__(self, session: AsyncSession): self.s = session @@ -62,7 +62,9 @@ class QueueRepository: finished_at=None, ) self.s.add(row) - await self._append_event(req.job_id, req.queue, "queued", {"task": req.task}) + await self._append_event( + req.job_id, req.queue, "queued", {"task": req.task} + ) return req.job_id, "queued" async def get_status(self, job_id: str) -> Optional[JobStatus]: @@ -118,7 +120,9 @@ class QueueRepository: await self._append_event(job_id, job.queue, "cancel_requested", None) return True - async def claim_one(self, queue: str, claim_backoff_sec: int) -> Optional[dict[str, Any]]: + async def claim_one( + self, queue: str, claim_backoff_sec: int + ) -> Optional[dict[str, Any]]: """ Захватывает одну задачу из очереди с учётом блокировок и выставляет running. @@ -150,15 +154,21 @@ class QueueRepository: job.started_at = job.started_at or datetime.now(timezone.utc) job.attempt = int(job.attempt) + 1 job.heartbeat_at = datetime.now(timezone.utc) - job.lease_expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(job.lease_ttl_sec)) + job.lease_expires_at = datetime.now(timezone.utc) + timedelta( + seconds=int(job.lease_ttl_sec) + ) ok = await self._try_advisory_lock(job.lock_key) if not ok: job.status = "queued" - job.available_at = datetime.now(timezone.utc) + timedelta(seconds=claim_backoff_sec) + job.available_at = datetime.now(timezone.utc) + timedelta( + seconds=claim_backoff_sec + ) return None - await self._append_event(job.job_id, job.queue, "picked", {"attempt": job.attempt}) + await self._append_event( + job.job_id, job.queue, "picked", {"attempt": job.attempt} + ) return { "job_id": job.job_id, @@ -192,10 +202,15 @@ class QueueRepository: q = ( update(DLJob) .where(DLJob.job_id == job_id, DLJob.status == "running") - .values(heartbeat_at=now, lease_expires_at=now + timedelta(seconds=int(ttl_sec))) + .values( + heartbeat_at=now, + lease_expires_at=now + timedelta(seconds=int(ttl_sec)), + ) ) await self.s.execute(q) - await self._append_event(job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec}) + await self._append_event( + job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec} + ) return True, cancel_requested async def finish_ok(self, job_id: str) -> None: @@ -215,7 +230,9 @@ class QueueRepository: await self._append_event(job_id, job.queue, "succeeded", None) await self._advisory_unlock(job.lock_key) - async def finish_fail_or_retry(self, job_id: str, err: str, is_canceled: bool = False) -> None: + async def finish_fail_or_retry( + self, job_id: str, err: str, is_canceled: bool = False + ) -> None: """ Помечает задачу как failed, canceled или возвращает в очередь с задержкой. @@ -239,16 +256,25 @@ class QueueRepository: can_retry = int(job.attempt) < int(job.max_attempts) if can_retry: job.status = "queued" - job.available_at = datetime.now(timezone.utc) + timedelta(seconds=30 * int(job.attempt)) + job.available_at = datetime.now(timezone.utc) + timedelta( + seconds=30 * int(job.attempt) + ) job.error = err job.lease_expires_at = None - await self._append_event(job_id, job.queue, "requeue", {"attempt": job.attempt, "error": err}) + await self._append_event( + job_id, + job.queue, + "requeue", + {"attempt": job.attempt, "error": err}, + ) else: job.status = "failed" job.error = err job.finished_at = datetime.now(timezone.utc) job.lease_expires_at = None - await self._append_event(job_id, job.queue, "failed", {"error": err}) + await self._append_event( + job_id, job.queue, "failed", {"error": err} + ) await self._advisory_unlock(job.lock_key) async def requeue_lost(self, now: Optional[datetime] = None) -> list[str]: @@ -293,7 +319,11 @@ class QueueRepository: Возвращает: ORM модель DLJob или None """ - r = await self.s.execute(select(DLJob).where(DLJob.job_id == job_id).with_for_update(skip_locked=True)) + r = await self.s.execute( + select(DLJob) + .where(DLJob.job_id == job_id) + .with_for_update(skip_locked=True) + ) return r.scalar_one_or_none() async def _resolve_queue(self, job_id: str) -> str: @@ -310,7 +340,9 @@ class QueueRepository: v = r.scalar_one_or_none() return v or "" - async def _append_event(self, job_id: str, queue: str, kind: str, payload: Optional[dict[str, Any]]) -> None: + async def _append_event( + self, job_id: str, queue: str, kind: str, payload: Optional[dict[str, Any]] + ) -> None: """ Добавляет запись в журнал событий. @@ -339,7 +371,9 @@ class QueueRepository: Возвращает: True, если блокировка получена """ - r = await self.s.execute(select(func.pg_try_advisory_lock(func.hashtext(lock_key)))) + r = await self.s.execute( + select(func.pg_try_advisory_lock(func.hashtext(lock_key))) + ) return bool(r.scalar()) async def _advisory_unlock(self, lock_key: str) -> None: diff --git a/src/dataloader/storage/repositories/quotes.py b/src/dataloader/storage/repositories/quotes.py index 308ad42..c9a8599 100644 --- a/src/dataloader/storage/repositories/quotes.py +++ b/src/dataloader/storage/repositories/quotes.py @@ -39,7 +39,9 @@ class QuotesRepository: result = await self.s.execute(stmt) return result.scalar_one_or_none() - async def get_or_create_section(self, name: str, params: dict | None = None) -> QuoteSection: + async def get_or_create_section( + self, name: str, params: dict | None = None + ) -> QuoteSection: """ Получить существующую секцию или создать новую. diff --git a/src/dataloader/storage/schemas/__init__.py b/src/dataloader/storage/schemas/__init__.py index c0d536d..0f0d421 100644 --- a/src/dataloader/storage/schemas/__init__.py +++ b/src/dataloader/storage/schemas/__init__.py @@ -1,8 +1,8 @@ -# src/dataloader/storage/schemas/__init__.py """ DTO (Data Transfer Objects) для слоя хранилища. Организованы по доменам для масштабируемости. """ + from __future__ import annotations from .queue import CreateJobRequest, JobStatus diff --git a/src/dataloader/storage/schemas/queue.py b/src/dataloader/storage/schemas/queue.py index be07545..4758566 100644 --- a/src/dataloader/storage/schemas/queue.py +++ b/src/dataloader/storage/schemas/queue.py @@ -1,4 +1,3 @@ -# src/dataloader/storage/schemas/queue.py from __future__ import annotations from dataclasses import dataclass @@ -11,6 +10,7 @@ class CreateJobRequest: """ DTO для создания задачи в очереди. """ + job_id: str queue: str task: str @@ -31,6 +31,7 @@ class JobStatus: """ DTO для статуса задачи. """ + job_id: str status: str attempt: int diff --git a/src/dataloader/workers/__init__.py b/src/dataloader/workers/__init__.py index 20cd673..4aeed8a 100644 --- a/src/dataloader/workers/__init__.py +++ b/src/dataloader/workers/__init__.py @@ -1,2 +1 @@ """Модуль воркеров для обработки задач.""" - diff --git a/src/dataloader/workers/base.py b/src/dataloader/workers/base.py index 3c809d6..d3b50f5 100644 --- a/src/dataloader/workers/base.py +++ b/src/dataloader/workers/base.py @@ -1,4 +1,3 @@ -# src/dataloader/workers/base.py from __future__ import annotations import asyncio @@ -9,8 +8,8 @@ from typing import AsyncIterator, Callable, Optional from dataloader.config import APP_CONFIG from dataloader.context import APP_CTX -from dataloader.storage.repositories import QueueRepository from dataloader.storage.notify_listener import PGNotifyListener +from dataloader.storage.repositories import QueueRepository from dataloader.workers.pipelines.registry import resolve as resolve_pipeline @@ -19,6 +18,7 @@ class WorkerConfig: """ Конфигурация воркера. """ + queue: str heartbeat_sec: int claim_backoff_sec: int @@ -28,6 +28,7 @@ class PGWorker: """ Базовый асинхронный воркер очереди Postgres. """ + def __init__(self, cfg: WorkerConfig, stop_event: asyncio.Event) -> None: self._cfg = cfg self._stop = stop_event @@ -46,14 +47,16 @@ class PGWorker: dsn=APP_CONFIG.pg.url, queue=self._cfg.queue, callback=lambda: self._notify_wakeup.set(), - stop_event=self._stop + stop_event=self._stop, ) try: await self._listener.start() except Exception as e: - self._log.warning(f"Failed to start LISTEN/NOTIFY, falling back to polling: {e}") + self._log.warning( + f"Failed to start LISTEN/NOTIFY, falling back to polling: {e}" + ) self._listener = None - + try: while not self._stop.is_set(): claimed = await self._claim_and_execute_once() @@ -69,24 +72,27 @@ class PGWorker: Ожидание появления задач через LISTEN/NOTIFY или с тайм-аутом. """ if self._listener: - # Используем LISTEN/NOTIFY с fallback на таймаут + done, pending = await asyncio.wait( - [asyncio.create_task(self._notify_wakeup.wait()), asyncio.create_task(self._stop.wait())], + [ + asyncio.create_task(self._notify_wakeup.wait()), + asyncio.create_task(self._stop.wait()), + ], return_when=asyncio.FIRST_COMPLETED, - timeout=timeout_sec + timeout=timeout_sec, ) - # Отменяем оставшиеся задачи + for task in pending: task.cancel() try: await task except asyncio.CancelledError: pass - # Очищаем событие, если оно было установлено + if self._notify_wakeup.is_set(): self._notify_wakeup.clear() else: - # Fallback на простой таймаут + try: await asyncio.wait_for(self._stop.wait(), timeout=timeout_sec) except asyncio.TimeoutError: @@ -110,30 +116,42 @@ class PGWorker: args = row["args"] try: - canceled = await self._execute_with_heartbeat(job_id, ttl, self._pipeline(task, args)) + canceled = await self._execute_with_heartbeat( + job_id, ttl, self._pipeline(task, args) + ) if canceled: - await repo.finish_fail_or_retry(job_id, "canceled by user", is_canceled=True) + await repo.finish_fail_or_retry( + job_id, "canceled by user", is_canceled=True + ) else: await repo.finish_ok(job_id) return True except asyncio.CancelledError: - await repo.finish_fail_or_retry(job_id, "cancelled by shutdown", is_canceled=True) + await repo.finish_fail_or_retry( + job_id, "cancelled by shutdown", is_canceled=True + ) raise except Exception as e: await repo.finish_fail_or_retry(job_id, str(e)) return True - async def _execute_with_heartbeat(self, job_id: str, ttl: int, it: AsyncIterator[None]) -> bool: + async def _execute_with_heartbeat( + self, job_id: str, ttl: int, it: AsyncIterator[None] + ) -> bool: """ Исполняет конвейер с поддержкой heartbeat. Возвращает True, если задача была отменена (cancel_requested). """ - next_hb = datetime.now(timezone.utc) + timedelta(seconds=self._cfg.heartbeat_sec) + next_hb = datetime.now(timezone.utc) + timedelta( + seconds=self._cfg.heartbeat_sec + ) async for _ in it: now = datetime.now(timezone.utc) if now >= next_hb: async with self._sm() as s_hb: - success, cancel_requested = await QueueRepository(s_hb).heartbeat(job_id, ttl) + success, cancel_requested = await QueueRepository(s_hb).heartbeat( + job_id, ttl + ) if cancel_requested: return True next_hb = now + timedelta(seconds=self._cfg.heartbeat_sec) diff --git a/src/dataloader/workers/manager.py b/src/dataloader/workers/manager.py index 6d4e722..37bf61f 100644 --- a/src/dataloader/workers/manager.py +++ b/src/dataloader/workers/manager.py @@ -1,4 +1,3 @@ -# src/dataloader/workers/manager.py from __future__ import annotations import asyncio @@ -16,6 +15,7 @@ class WorkerSpec: """ Конфигурация набора воркеров для очереди. """ + queue: str concurrency: int @@ -24,6 +24,7 @@ class WorkerManager: """ Управляет жизненным циклом асинхронных воркеров. """ + def __init__(self, specs: list[WorkerSpec]) -> None: self._log = APP_CTX.get_logger() self._specs = specs @@ -40,15 +41,22 @@ class WorkerManager: for spec in self._specs: for i in range(max(1, spec.concurrency)): - cfg = WorkerConfig(queue=spec.queue, heartbeat_sec=hb, claim_backoff_sec=backoff) - t = asyncio.create_task(PGWorker(cfg, self._stop).run(), name=f"worker:{spec.queue}:{i}") + cfg = WorkerConfig( + queue=spec.queue, heartbeat_sec=hb, claim_backoff_sec=backoff + ) + t = asyncio.create_task( + PGWorker(cfg, self._stop).run(), name=f"worker:{spec.queue}:{i}" + ) self._tasks.append(t) self._reaper_task = asyncio.create_task(self._reaper_loop(), name="reaper") self._log.info( "worker_manager.started", - extra={"specs": [spec.__dict__ for spec in self._specs], "total_tasks": len(self._tasks)}, + extra={ + "specs": [spec.__dict__ for spec in self._specs], + "total_tasks": len(self._tasks), + }, ) async def stop(self) -> None: diff --git a/src/dataloader/workers/pipelines/__init__.py b/src/dataloader/workers/pipelines/__init__.py index 9dd033c..bc6ed0e 100644 --- a/src/dataloader/workers/pipelines/__init__.py +++ b/src/dataloader/workers/pipelines/__init__.py @@ -1,6 +1,5 @@ """Модуль пайплайнов обработки задач.""" -# src/dataloader/workers/pipelines/__init__.py from __future__ import annotations import importlib diff --git a/src/dataloader/workers/pipelines/load_opu.py b/src/dataloader/workers/pipelines/load_opu.py index b4c5872..288512a 100644 --- a/src/dataloader/workers/pipelines/load_opu.py +++ b/src/dataloader/workers/pipelines/load_opu.py @@ -59,7 +59,6 @@ def _parse_jsonl_from_zst(file_path: Path, chunk_size: int = 10000): APP_CTX.logger.warning(f"Failed to parse JSON line: {e}") continue - # Process remaining buffer if buffer.strip(): try: record = orjson.loads(buffer) @@ -85,11 +84,9 @@ def _convert_record(raw: dict[str, Any]) -> dict[str, Any]: """ result = raw.copy() - # Преобразуем actdate из ISO строки в date if "actdate" in result and isinstance(result["actdate"], str): result["actdate"] = datetime.fromisoformat(result["actdate"]).date() - # Преобразуем wf_load_dttm из ISO строки в datetime if "wf_load_dttm" in result and isinstance(result["wf_load_dttm"], str): result["wf_load_dttm"] = datetime.fromisoformat(result["wf_load_dttm"]) @@ -118,18 +115,15 @@ async def load_opu(args: dict) -> AsyncIterator[None]: logger = APP_CTX.logger logger.info("Starting OPU ETL pipeline") - # Шаг 1: Запуск экспорта interface = get_gmap2brief_interface() job_id = await interface.start_export() logger.info(f"OPU export job started: {job_id}") yield - # Шаг 2: Ожидание завершения status = await interface.wait_for_completion(job_id) logger.info(f"OPU export completed: {status.total_rows} rows") yield - # Шаг 3: Скачивание архива with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) archive_path = temp_path / f"opu_export_{job_id}.jsonl.zst" @@ -138,7 +132,6 @@ async def load_opu(args: dict) -> AsyncIterator[None]: logger.info(f"OPU archive downloaded: {archive_path.stat().st_size:,} bytes") yield - # Шаг 4: Truncate таблицы async with APP_CTX.sessionmaker() as session: repo = OpuRepository(session) await repo.truncate() @@ -146,7 +139,6 @@ async def load_opu(args: dict) -> AsyncIterator[None]: logger.info("OPU table truncated") yield - # Шаг 5: Загрузка данных стримингово total_inserted = 0 batch_num = 0 @@ -154,17 +146,16 @@ async def load_opu(args: dict) -> AsyncIterator[None]: for batch in _parse_jsonl_from_zst(archive_path, chunk_size=5000): batch_num += 1 - # Конвертируем записи converted = [_convert_record(rec) for rec in batch] - # Вставляем батч inserted = await repo.bulk_insert(converted) await session.commit() total_inserted += inserted - logger.debug(f"Batch {batch_num}: inserted {inserted} rows (total: {total_inserted})") + logger.debug( + f"Batch {batch_num}: inserted {inserted} rows (total: {total_inserted})" + ) - # Heartbeat после каждого батча yield logger.info(f"OPU ETL completed: {total_inserted} rows inserted") diff --git a/src/dataloader/workers/pipelines/load_tenera.py b/src/dataloader/workers/pipelines/load_tenera.py index 77df2cb..9340bdc 100644 --- a/src/dataloader/workers/pipelines/load_tenera.py +++ b/src/dataloader/workers/pipelines/load_tenera.py @@ -59,7 +59,9 @@ def _parse_ts_to_datetime(ts: str) -> datetime | None: return None -def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] | None: # noqa: C901 +def _build_value_row( + source: str, dt: datetime, point: Any +) -> dict[str, Any] | None: # noqa: C901 """Строит строку для `quotes_values` по источнику и типу точки.""" if isinstance(point, int): return {"dt": dt, "key": point} @@ -82,7 +84,10 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] | if isinstance(deep_inner, InvestingCandlestick): return { "dt": dt, - "price_o": _to_float(getattr(deep_inner, "open_", None) or getattr(deep_inner, "open", None)), + "price_o": _to_float( + getattr(deep_inner, "open_", None) + or getattr(deep_inner, "open", None) + ), "price_h": _to_float(deep_inner.high), "price_l": _to_float(deep_inner.low), "price_c": _to_float(deep_inner.close), @@ -92,12 +97,16 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] | if isinstance(inner, TradingViewTimePoint | SgxTimePoint): return { "dt": dt, - "price_o": _to_float(getattr(inner, "open_", None) or getattr(inner, "open", None)), + "price_o": _to_float( + getattr(inner, "open_", None) or getattr(inner, "open", None) + ), "price_h": _to_float(inner.high), "price_l": _to_float(inner.low), "price_c": _to_float(inner.close), "volume": _to_float( - getattr(inner, "volume", None) or getattr(inner, "interest", None) or getattr(inner, "value", None) + getattr(inner, "volume", None) + or getattr(inner, "interest", None) + or getattr(inner, "value", None) ), } @@ -132,7 +141,9 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] | "dt": dt, "value_last": _to_float(deep_inner.last), "value_previous": _to_float(deep_inner.previous), - "unit": str(deep_inner.unit) if deep_inner.unit is not None else None, + "unit": ( + str(deep_inner.unit) if deep_inner.unit is not None else None + ), } if isinstance(deep_inner, TradingEconomicsStringPercent): @@ -218,7 +229,14 @@ async def load_tenera(args: dict) -> AsyncIterator[None]: async with APP_CTX.sessionmaker() as session: repo = QuotesRepository(session) - for source_name in ("cbr", "investing", "sgx", "tradingeconomics", "bloomberg", "trading_view"): + for source_name in ( + "cbr", + "investing", + "sgx", + "tradingeconomics", + "bloomberg", + "trading_view", + ): source_data = getattr(data, source_name) if not source_data: continue diff --git a/src/dataloader/workers/pipelines/noop.py b/src/dataloader/workers/pipelines/noop.py index 4b20c00..4df8be9 100644 --- a/src/dataloader/workers/pipelines/noop.py +++ b/src/dataloader/workers/pipelines/noop.py @@ -1,4 +1,3 @@ -# src/dataloader/workers/pipelines/noop.py from __future__ import annotations import asyncio diff --git a/src/dataloader/workers/pipelines/registry.py b/src/dataloader/workers/pipelines/registry.py index e7e0fc5..a92b2e5 100644 --- a/src/dataloader/workers/pipelines/registry.py +++ b/src/dataloader/workers/pipelines/registry.py @@ -1,4 +1,3 @@ -# src/dataloader/workers/pipelines/registry.py from __future__ import annotations from typing import Any, Callable, Dict, Iterable @@ -6,13 +5,17 @@ from typing import Any, Callable, Dict, Iterable _Registry: Dict[str, Callable[[dict[str, Any]], Any]] = {} -def register(task: str) -> Callable[[Callable[[dict[str, Any]], Any]], Callable[[dict[str, Any]], Any]]: +def register( + task: str, +) -> Callable[[Callable[[dict[str, Any]], Any]], Callable[[dict[str, Any]], Any]]: """ Регистрирует обработчик пайплайна под именем задачи. """ + def _wrap(fn: Callable[[dict[str, Any]], Any]) -> Callable[[dict[str, Any]], Any]: _Registry[task] = fn return fn + return _wrap @@ -22,8 +25,8 @@ def resolve(task: str) -> Callable[[dict[str, Any]], Any]: """ try: return _Registry[task] - except KeyError: - raise KeyError(f"pipeline not found: {task}") + except KeyError as err: + raise KeyError(f"pipeline not found: {task}") from err def tasks() -> Iterable[str]: diff --git a/src/dataloader/workers/reaper.py b/src/dataloader/workers/reaper.py index 0b72ddd..ca9f670 100644 --- a/src/dataloader/workers/reaper.py +++ b/src/dataloader/workers/reaper.py @@ -1,7 +1,7 @@ -# src/dataloader/workers/reaper.py from __future__ import annotations from typing import Sequence + from sqlalchemy.ext.asyncio import AsyncSession from dataloader.storage.repositories import QueueRepository diff --git a/tests/conftest.py b/tests/conftest.py index 212645d..411ccd1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ -# tests/conftest.py from __future__ import annotations import asyncio @@ -9,23 +8,23 @@ from uuid import uuid4 import pytest import pytest_asyncio from dotenv import load_dotenv -from httpx import AsyncClient, ASGITransport +from httpx import ASGITransport, AsyncClient from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker -load_dotenv() - from dataloader.api import app_main from dataloader.config import APP_CONFIG -from dataloader.context import APP_CTX, get_session -from dataloader.storage.models import Base -from dataloader.storage.engine import create_engine, create_sessionmaker +from dataloader.context import get_session +from dataloader.storage.engine import create_engine + +load_dotenv() if sys.platform == "win32": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) pytestmark = pytest.mark.asyncio + @pytest_asyncio.fixture(scope="function") async def db_engine() -> AsyncGenerator[AsyncEngine, None]: """ @@ -39,15 +38,15 @@ async def db_engine() -> AsyncGenerator[AsyncEngine, None]: await engine.dispose() - - @pytest_asyncio.fixture(scope="function") async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]: """ Предоставляет сессию БД для каждого теста. НЕ использует транзакцию, чтобы работали advisory locks. """ - sessionmaker = async_sessionmaker(bind=db_engine, expire_on_commit=False, class_=AsyncSession) + sessionmaker = async_sessionmaker( + bind=db_engine, expire_on_commit=False, class_=AsyncSession + ) async with sessionmaker() as session: yield session await session.rollback() @@ -69,6 +68,7 @@ async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: """ HTTP клиент для тестирования API. """ + async def override_get_session() -> AsyncGenerator[AsyncSession, None]: yield db_session diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py index d6c43f2..8b13789 100644 --- a/tests/integration_tests/__init__.py +++ b/tests/integration_tests/__init__.py @@ -1 +1 @@ -# tests/integration_tests/__init__.py + diff --git a/tests/integration_tests/test_api_endpoints.py b/tests/integration_tests/test_api_endpoints.py index 216cccc..bfd1375 100644 --- a/tests/integration_tests/test_api_endpoints.py +++ b/tests/integration_tests/test_api_endpoints.py @@ -1,4 +1,3 @@ -# tests/integration_tests/test_api_endpoints.py from __future__ import annotations import pytest diff --git a/tests/integration_tests/test_api_router_not_found.py b/tests/integration_tests/test_api_router_not_found.py index ba96dec..a90375b 100644 --- a/tests/integration_tests/test_api_router_not_found.py +++ b/tests/integration_tests/test_api_router_not_found.py @@ -1,11 +1,11 @@ -# tests/unit/test_api_router_not_found.py from __future__ import annotations -import pytest -from uuid import uuid4, UUID +from uuid import UUID, uuid4 + +import pytest -from dataloader.api.v1.router import get_status, cancel_job from dataloader.api.v1.exceptions import JobNotFoundError +from dataloader.api.v1.router import cancel_job, get_status from dataloader.api.v1.schemas import JobStatusResponse diff --git a/tests/integration_tests/test_api_router_success.py b/tests/integration_tests/test_api_router_success.py index a19246a..fb18827 100644 --- a/tests/integration_tests/test_api_router_success.py +++ b/tests/integration_tests/test_api_router_success.py @@ -1,11 +1,11 @@ -# tests/unit/test_api_router_success.py from __future__ import annotations -import pytest -from uuid import uuid4, UUID from datetime import datetime, timezone +from uuid import UUID, uuid4 -from dataloader.api.v1.router import get_status, cancel_job +import pytest + +from dataloader.api.v1.router import cancel_job, get_status from dataloader.api.v1.schemas import JobStatusResponse diff --git a/tests/integration_tests/test_queue_repository.py b/tests/integration_tests/test_queue_repository.py index 028755d..ebda277 100644 --- a/tests/integration_tests/test_queue_repository.py +++ b/tests/integration_tests/test_queue_repository.py @@ -1,7 +1,6 @@ -# tests/integration_tests/test_queue_repository.py from __future__ import annotations -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from uuid import uuid4 import pytest @@ -10,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from dataloader.storage.models import DLJob from dataloader.storage.repositories import QueueRepository -from dataloader.storage.schemas import CreateJobRequest, JobStatus +from dataloader.storage.schemas import CreateJobRequest @pytest.mark.integration @@ -445,6 +444,7 @@ class TestQueueRepository: await repo.claim_one(queue_name, claim_backoff_sec=15) import asyncio + await asyncio.sleep(2) requeued = await repo.requeue_lost() @@ -500,7 +500,9 @@ class TestQueueRepository: assert st is not None assert st.status == "queued" - row = (await db_session.execute(select(DLJob).where(DLJob.job_id == job_id))).scalar_one() + row = ( + await db_session.execute(select(DLJob).where(DLJob.job_id == job_id)) + ).scalar_one() assert row.available_at >= before + timedelta(seconds=15) assert row.available_at <= after + timedelta(seconds=60) @@ -569,7 +571,9 @@ class TestQueueRepository: await repo.create_or_get(req) await repo.claim_one(queue_name, claim_backoff_sec=5) - await repo.finish_fail_or_retry(job_id, err="Canceled by test", is_canceled=True) + await repo.finish_fail_or_retry( + job_id, err="Canceled by test", is_canceled=True + ) st = await repo.get_status(job_id) assert st is not None diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index f28a885..8b13789 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -1 +1 @@ -# tests/unit/__init__.py + diff --git a/tests/unit/test_api_service.py b/tests/unit/test_api_service.py index 31ae26a..e670915 100644 --- a/tests/unit/test_api_service.py +++ b/tests/unit/test_api_service.py @@ -1,13 +1,13 @@ -# tests/unit/test_api_service.py from __future__ import annotations from datetime import datetime, timezone -from uuid import UUID from unittest.mock import AsyncMock, Mock, patch +from uuid import UUID + import pytest -from dataloader.api.v1.service import JobsService from dataloader.api.v1.schemas import TriggerJobRequest +from dataloader.api.v1.service import JobsService from dataloader.storage.schemas import JobStatus @@ -38,15 +38,19 @@ class TestJobsService: """ mock_session = AsyncMock() - with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ - patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \ - patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id: + with ( + patch("dataloader.api.v1.service.get_logger") as mock_get_logger, + patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, + patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id, + ): mock_get_logger.return_value = Mock() mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678") mock_repo = Mock() - mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued")) + mock_repo.create_or_get = AsyncMock( + return_value=("12345678-1234-5678-1234-567812345678", "queued") + ) mock_repo_cls.return_value = mock_repo service = JobsService(mock_session) @@ -58,7 +62,7 @@ class TestJobsService: lock_key="lock_1", priority=100, max_attempts=5, - lease_ttl_sec=60 + lease_ttl_sec=60, ) response = await service.trigger(req) @@ -74,15 +78,19 @@ class TestJobsService: """ mock_session = AsyncMock() - with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ - patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \ - patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id: + with ( + patch("dataloader.api.v1.service.get_logger") as mock_get_logger, + patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, + patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id, + ): mock_get_logger.return_value = Mock() mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678") mock_repo = Mock() - mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued")) + mock_repo.create_or_get = AsyncMock( + return_value=("12345678-1234-5678-1234-567812345678", "queued") + ) mock_repo_cls.return_value = mock_repo service = JobsService(mock_session) @@ -95,7 +103,7 @@ class TestJobsService: lock_key="lock_1", priority=100, max_attempts=5, - lease_ttl_sec=60 + lease_ttl_sec=60, ) response = await service.trigger(req) @@ -112,15 +120,19 @@ class TestJobsService: """ mock_session = AsyncMock() - with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ - patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \ - patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id: + with ( + patch("dataloader.api.v1.service.get_logger") as mock_get_logger, + patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, + patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id, + ): mock_get_logger.return_value = Mock() mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678") mock_repo = Mock() - mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued")) + mock_repo.create_or_get = AsyncMock( + return_value=("12345678-1234-5678-1234-567812345678", "queued") + ) mock_repo_cls.return_value = mock_repo service = JobsService(mock_session) @@ -135,10 +147,10 @@ class TestJobsService: available_at=future_time, priority=100, max_attempts=5, - lease_ttl_sec=60 + lease_ttl_sec=60, ) - response = await service.trigger(req) + await service.trigger(req) call_args = mock_repo.create_or_get.call_args[0][0] assert call_args.available_at == future_time @@ -150,15 +162,19 @@ class TestJobsService: """ mock_session = AsyncMock() - with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ - patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \ - patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id: + with ( + patch("dataloader.api.v1.service.get_logger") as mock_get_logger, + patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, + patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id, + ): mock_get_logger.return_value = Mock() mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678") mock_repo = Mock() - mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued")) + mock_repo.create_or_get = AsyncMock( + return_value=("12345678-1234-5678-1234-567812345678", "queued") + ) mock_repo_cls.return_value = mock_repo service = JobsService(mock_session) @@ -173,10 +189,10 @@ class TestJobsService: consumer_group="test_group", priority=100, max_attempts=5, - lease_ttl_sec=60 + lease_ttl_sec=60, ) - response = await service.trigger(req) + await service.trigger(req) call_args = mock_repo.create_or_get.call_args[0][0] assert call_args.partition_key == "partition_1" @@ -190,8 +206,10 @@ class TestJobsService: """ mock_session = AsyncMock() - with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ - patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls: + with ( + patch("dataloader.api.v1.service.get_logger") as mock_get_logger, + patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, + ): mock_get_logger.return_value = Mock() @@ -204,7 +222,7 @@ class TestJobsService: finished_at=None, heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc), error=None, - progress={"step": 1} + progress={"step": 1}, ) mock_repo.get_status = AsyncMock(return_value=mock_status) mock_repo_cls.return_value = mock_repo @@ -227,8 +245,10 @@ class TestJobsService: """ mock_session = AsyncMock() - with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ - patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls: + with ( + patch("dataloader.api.v1.service.get_logger") as mock_get_logger, + patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, + ): mock_get_logger.return_value = Mock() @@ -250,8 +270,10 @@ class TestJobsService: """ mock_session = AsyncMock() - with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ - patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls: + with ( + patch("dataloader.api.v1.service.get_logger") as mock_get_logger, + patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, + ): mock_get_logger.return_value = Mock() @@ -265,7 +287,7 @@ class TestJobsService: finished_at=None, heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc), error=None, - progress={} + progress={}, ) mock_repo.get_status = AsyncMock(return_value=mock_status) mock_repo_cls.return_value = mock_repo @@ -287,8 +309,10 @@ class TestJobsService: """ mock_session = AsyncMock() - with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ - patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls: + with ( + patch("dataloader.api.v1.service.get_logger") as mock_get_logger, + patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, + ): mock_get_logger.return_value = Mock() @@ -312,8 +336,10 @@ class TestJobsService: """ mock_session = AsyncMock() - with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ - patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls: + with ( + patch("dataloader.api.v1.service.get_logger") as mock_get_logger, + patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, + ): mock_get_logger.return_value = Mock() @@ -326,7 +352,7 @@ class TestJobsService: finished_at=None, heartbeat_at=None, error=None, - progress=None + progress=None, ) mock_repo.get_status = AsyncMock(return_value=mock_status) mock_repo_cls.return_value = mock_repo diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 0c82df8..57b2e6d 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,18 +1,18 @@ -# tests/unit/test_config.py from __future__ import annotations import json from logging import DEBUG, INFO from unittest.mock import patch + import pytest from dataloader.config import ( - BaseAppSettings, AppSettings, + BaseAppSettings, LogSettings, PGSettings, - WorkerSettings, Secrets, + WorkerSettings, ) @@ -81,12 +81,15 @@ class TestAppSettings: """ Тест загрузки из переменных окружения. """ - with patch.dict("os.environ", { - "APP_HOST": "127.0.0.1", - "APP_PORT": "9000", - "PROJECT_NAME": "TestProject", - "TIMEZONE": "UTC" - }): + with patch.dict( + "os.environ", + { + "APP_HOST": "127.0.0.1", + "APP_PORT": "9000", + "PROJECT_NAME": "TestProject", + "TIMEZONE": "UTC", + }, + ): settings = AppSettings() assert settings.app_host == "127.0.0.1" @@ -136,7 +139,9 @@ class TestLogSettings: """ Тест свойства log_file_abs_path. """ - with patch.dict("os.environ", {"LOG_PATH": "/var/log", "LOG_FILE_NAME": "test.log"}): + with patch.dict( + "os.environ", {"LOG_PATH": "/var/log", "LOG_FILE_NAME": "test.log"} + ): settings = LogSettings() assert "test.log" in settings.log_file_abs_path @@ -146,7 +151,10 @@ class TestLogSettings: """ Тест свойства metric_file_abs_path. """ - with patch.dict("os.environ", {"METRIC_PATH": "/var/metrics", "METRIC_FILE_NAME": "metrics.log"}): + with patch.dict( + "os.environ", + {"METRIC_PATH": "/var/metrics", "METRIC_FILE_NAME": "metrics.log"}, + ): settings = LogSettings() assert "metrics.log" in settings.metric_file_abs_path @@ -156,7 +164,10 @@ class TestLogSettings: """ Тест свойства audit_file_abs_path. """ - with patch.dict("os.environ", {"AUDIT_LOG_PATH": "/var/audit", "AUDIT_LOG_FILE_NAME": "audit.log"}): + with patch.dict( + "os.environ", + {"AUDIT_LOG_PATH": "/var/audit", "AUDIT_LOG_FILE_NAME": "audit.log"}, + ): settings = LogSettings() assert "audit.log" in settings.audit_file_abs_path @@ -208,29 +219,37 @@ class TestPGSettings: """ Тест формирования строки подключения. """ - with patch.dict("os.environ", { - "PG_HOST": "db.example.com", - "PG_PORT": "5433", - "PG_USER": "testuser", - "PG_PASSWORD": "testpass", - "PG_DATABASE": "testdb" - }): + with patch.dict( + "os.environ", + { + "PG_HOST": "db.example.com", + "PG_PORT": "5433", + "PG_USER": "testuser", + "PG_PASSWORD": "testpass", + "PG_DATABASE": "testdb", + }, + ): settings = PGSettings() - expected = "postgresql+asyncpg://testuser:testpass@db.example.com:5433/testdb" + expected = ( + "postgresql+asyncpg://testuser:testpass@db.example.com:5433/testdb" + ) assert settings.url == expected def test_url_property_with_empty_password(self): """ Тест строки подключения с пустым паролем. """ - with patch.dict("os.environ", { - "PG_HOST": "localhost", - "PG_PORT": "5432", - "PG_USER": "postgres", - "PG_PASSWORD": "", - "PG_DATABASE": "testdb" - }): + with patch.dict( + "os.environ", + { + "PG_HOST": "localhost", + "PG_PORT": "5432", + "PG_USER": "postgres", + "PG_PASSWORD": "", + "PG_DATABASE": "testdb", + }, + ): settings = PGSettings() expected = "postgresql+asyncpg://postgres:@localhost:5432/testdb" @@ -240,15 +259,18 @@ class TestPGSettings: """ Тест загрузки из переменных окружения. """ - with patch.dict("os.environ", { - "PG_HOST": "testhost", - "PG_PORT": "5433", - "PG_USER": "testuser", - "PG_PASSWORD": "testpass", - "PG_DATABASE": "testdb", - "PG_SCHEMA_QUEUE": "queue_schema", - "PG_POOL_SIZE": "20" - }): + with patch.dict( + "os.environ", + { + "PG_HOST": "testhost", + "PG_PORT": "5433", + "PG_USER": "testuser", + "PG_PASSWORD": "testpass", + "PG_DATABASE": "testdb", + "PG_SCHEMA_QUEUE": "queue_schema", + "PG_POOL_SIZE": "20", + }, + ): settings = PGSettings() assert settings.host == "testhost" @@ -292,10 +314,12 @@ class TestWorkerSettings: """ Тест парсинга валидного JSON. """ - workers_json = json.dumps([ - {"queue": "queue1", "concurrency": 2}, - {"queue": "queue2", "concurrency": 3} - ]) + workers_json = json.dumps( + [ + {"queue": "queue1", "concurrency": 2}, + {"queue": "queue2", "concurrency": 3}, + ] + ) with patch.dict("os.environ", {"WORKERS_JSON": workers_json}): settings = WorkerSettings() @@ -311,12 +335,14 @@ class TestWorkerSettings: """ Тест фильтрации не-словарей из JSON. """ - workers_json = json.dumps([ - {"queue": "queue1", "concurrency": 2}, - "invalid_item", - 123, - {"queue": "queue2", "concurrency": 3} - ]) + workers_json = json.dumps( + [ + {"queue": "queue1", "concurrency": 2}, + "invalid_item", + 123, + {"queue": "queue2", "concurrency": 3}, + ] + ) with patch.dict("os.environ", {"WORKERS_JSON": workers_json}): settings = WorkerSettings() diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 11aa5bd..d12f3b5 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -1,7 +1,7 @@ -# tests/unit/test_context.py from __future__ import annotations from unittest.mock import AsyncMock, Mock, patch + import pytest from dataloader.context import AppContext, get_session @@ -71,10 +71,16 @@ class TestAppContext: mock_engine = Mock() mock_sm = Mock() - with patch("dataloader.logger.logger.setup_logging") as mock_setup_logging, \ - patch("dataloader.storage.engine.create_engine", return_value=mock_engine) as mock_create_engine, \ - patch("dataloader.storage.engine.create_sessionmaker", return_value=mock_sm) as mock_create_sm, \ - patch("dataloader.context.APP_CONFIG") as mock_config: + with ( + patch("dataloader.logger.logger.setup_logging") as mock_setup_logging, + patch( + "dataloader.storage.engine.create_engine", return_value=mock_engine + ) as mock_create_engine, + patch( + "dataloader.storage.engine.create_sessionmaker", return_value=mock_sm + ) as mock_create_sm, + patch("dataloader.context.APP_CONFIG") as mock_config, + ): mock_config.pg.url = "postgresql://test" @@ -194,7 +200,7 @@ class TestGetSession: with patch("dataloader.context.APP_CTX") as mock_ctx: mock_ctx.sessionmaker = mock_sm - async for session in get_session(): + async for _session in get_session(): pass assert mock_exit.call_count == 1 diff --git a/tests/unit/test_notify_listener.py b/tests/unit/test_notify_listener.py index 2fdc6d2..d70eb60 100644 --- a/tests/unit/test_notify_listener.py +++ b/tests/unit/test_notify_listener.py @@ -1,8 +1,8 @@ -# tests/unit/test_notify_listener.py from __future__ import annotations import asyncio from unittest.mock import AsyncMock, Mock, patch + import pytest from dataloader.storage.notify_listener import PGNotifyListener @@ -25,7 +25,7 @@ class TestPGNotifyListener: dsn="postgresql://test", queue="test_queue", callback=callback, - stop_event=stop_event + stop_event=stop_event, ) assert listener._dsn == "postgresql://test" @@ -47,14 +47,16 @@ class TestPGNotifyListener: dsn="postgresql://test", queue="test_queue", callback=callback, - stop_event=stop_event + stop_event=stop_event, ) mock_conn = AsyncMock() mock_conn.execute = AsyncMock() mock_conn.add_listener = AsyncMock() - with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): + with patch( + "dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn + ): await listener.start() assert listener._conn == mock_conn @@ -76,14 +78,16 @@ class TestPGNotifyListener: dsn="postgresql+asyncpg://test", queue="test_queue", callback=callback, - stop_event=stop_event + stop_event=stop_event, ) mock_conn = AsyncMock() mock_conn.execute = AsyncMock() mock_conn.add_listener = AsyncMock() - with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn) as mock_connect: + with patch( + "dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn + ) as mock_connect: await listener.start() mock_connect.assert_called_once_with("postgresql://test") @@ -102,14 +106,16 @@ class TestPGNotifyListener: dsn="postgresql://test", queue="test_queue", callback=callback, - stop_event=stop_event + stop_event=stop_event, ) mock_conn = AsyncMock() mock_conn.execute = AsyncMock() mock_conn.add_listener = AsyncMock() - with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): + with patch( + "dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn + ): await listener.start() handler = listener._on_notify_handler @@ -131,14 +137,16 @@ class TestPGNotifyListener: dsn="postgresql://test", queue="test_queue", callback=callback, - stop_event=stop_event + stop_event=stop_event, ) mock_conn = AsyncMock() mock_conn.execute = AsyncMock() mock_conn.add_listener = AsyncMock() - with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): + with patch( + "dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn + ): await listener.start() handler = listener._on_notify_handler @@ -160,14 +168,16 @@ class TestPGNotifyListener: dsn="postgresql://test", queue="test_queue", callback=callback, - stop_event=stop_event + stop_event=stop_event, ) mock_conn = AsyncMock() mock_conn.execute = AsyncMock() mock_conn.add_listener = AsyncMock() - with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): + with patch( + "dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn + ): await listener.start() handler = listener._on_notify_handler @@ -189,14 +199,16 @@ class TestPGNotifyListener: dsn="postgresql://test", queue="test_queue", callback=callback, - stop_event=stop_event + stop_event=stop_event, ) mock_conn = AsyncMock() mock_conn.execute = AsyncMock() mock_conn.add_listener = AsyncMock() - with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): + with patch( + "dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn + ): await listener.start() handler = listener._on_notify_handler @@ -218,7 +230,7 @@ class TestPGNotifyListener: dsn="postgresql://test", queue="test_queue", callback=callback, - stop_event=stop_event + stop_event=stop_event, ) mock_conn = AsyncMock() @@ -227,7 +239,9 @@ class TestPGNotifyListener: mock_conn.remove_listener = AsyncMock() mock_conn.close = AsyncMock() - with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): + with patch( + "dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn + ): await listener.start() assert listener._task is not None @@ -250,7 +264,7 @@ class TestPGNotifyListener: dsn="postgresql://test", queue="test_queue", callback=callback, - stop_event=stop_event + stop_event=stop_event, ) mock_conn = AsyncMock() @@ -259,7 +273,9 @@ class TestPGNotifyListener: mock_conn.remove_listener = AsyncMock() mock_conn.close = AsyncMock() - with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): + with patch( + "dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn + ): await listener.start() await listener.stop() @@ -280,7 +296,7 @@ class TestPGNotifyListener: dsn="postgresql://test", queue="test_queue", callback=callback, - stop_event=stop_event + stop_event=stop_event, ) mock_conn = AsyncMock() @@ -289,7 +305,9 @@ class TestPGNotifyListener: mock_conn.remove_listener = AsyncMock() mock_conn.close = AsyncMock() - with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): + with patch( + "dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn + ): await listener.start() stop_event.set() @@ -311,7 +329,7 @@ class TestPGNotifyListener: dsn="postgresql://test", queue="test_queue", callback=callback, - stop_event=stop_event + stop_event=stop_event, ) mock_conn = AsyncMock() @@ -320,7 +338,9 @@ class TestPGNotifyListener: mock_conn.remove_listener = AsyncMock(side_effect=Exception("Remove error")) mock_conn.close = AsyncMock() - with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): + with patch( + "dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn + ): await listener.start() await listener.stop() @@ -340,7 +360,7 @@ class TestPGNotifyListener: dsn="postgresql://test", queue="test_queue", callback=callback, - stop_event=stop_event + stop_event=stop_event, ) mock_conn = AsyncMock() @@ -349,7 +369,9 @@ class TestPGNotifyListener: mock_conn.remove_listener = AsyncMock() mock_conn.close = AsyncMock(side_effect=Exception("Close error")) - with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): + with patch( + "dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn + ): await listener.start() await listener.stop() @@ -368,7 +390,7 @@ class TestPGNotifyListener: dsn="postgresql://test", queue="test_queue", callback=callback, - stop_event=stop_event + stop_event=stop_event, ) await listener.stop() diff --git a/tests/unit/test_pipeline_registry.py b/tests/unit/test_pipeline_registry.py index 50c8a67..764aff9 100644 --- a/tests/unit/test_pipeline_registry.py +++ b/tests/unit/test_pipeline_registry.py @@ -1,9 +1,8 @@ -# tests/unit/test_pipeline_registry.py from __future__ import annotations import pytest -from dataloader.workers.pipelines.registry import register, resolve, tasks, _Registry +from dataloader.workers.pipelines.registry import _Registry, register, resolve, tasks @pytest.mark.unit @@ -22,6 +21,7 @@ class TestPipelineRegistry: """ Тест регистрации пайплайна. """ + @register("test.task") def test_pipeline(args: dict): return "result" @@ -33,6 +33,7 @@ class TestPipelineRegistry: """ Тест получения зарегистрированного пайплайна. """ + @register("test.resolve") def test_pipeline(args: dict): return "resolved" @@ -54,6 +55,7 @@ class TestPipelineRegistry: """ Тест получения списка зарегистрированных задач. """ + @register("task1") def pipeline1(args: dict): pass @@ -70,6 +72,7 @@ class TestPipelineRegistry: """ Тест перезаписи существующего пайплайна. """ + @register("overwrite.task") def first_pipeline(args: dict): return "first" diff --git a/tests/unit/test_workers_base.py b/tests/unit/test_workers_base.py index 383c920..57c1c02 100644 --- a/tests/unit/test_workers_base.py +++ b/tests/unit/test_workers_base.py @@ -1,9 +1,8 @@ -# tests/unit/test_workers_base.py from __future__ import annotations import asyncio -from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, Mock, patch + import pytest from dataloader.workers.base import PGWorker, WorkerConfig @@ -41,9 +40,11 @@ class TestPGWorker: cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1) stop_event = asyncio.Event() - with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \ - patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls: + with ( + patch("dataloader.workers.base.APP_CTX") as mock_ctx, + patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, + patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls, + ): mock_ctx.get_logger.return_value = Mock() mock_ctx.sessionmaker = Mock() @@ -64,7 +65,9 @@ class TestPGWorker: stop_event.set() return False - with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim): + with patch.object( + worker, "_claim_and_execute_once", side_effect=mock_claim + ): await worker.run() assert mock_listener.start.call_count == 1 @@ -78,9 +81,11 @@ class TestPGWorker: cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1) stop_event = asyncio.Event() - with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \ - patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls: + with ( + patch("dataloader.workers.base.APP_CTX") as mock_ctx, + patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, + patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls, + ): mock_logger = Mock() mock_ctx.get_logger.return_value = mock_logger @@ -101,7 +106,9 @@ class TestPGWorker: stop_event.set() return False - with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim): + with patch.object( + worker, "_claim_and_execute_once", side_effect=mock_claim + ): await worker.run() assert worker._listener is None @@ -156,8 +163,10 @@ class TestPGWorker: cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) stop_event = asyncio.Event() - with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: + with ( + patch("dataloader.workers.base.APP_CTX") as mock_ctx, + patch("dataloader.workers.base.QueueRepository") as mock_repo_cls, + ): mock_session = AsyncMock() mock_session.commit = AsyncMock() @@ -185,8 +194,10 @@ class TestPGWorker: cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) stop_event = asyncio.Event() - with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: + with ( + patch("dataloader.workers.base.APP_CTX") as mock_ctx, + patch("dataloader.workers.base.QueueRepository") as mock_repo_cls, + ): mock_session = AsyncMock() mock_sm = MagicMock() @@ -196,12 +207,14 @@ class TestPGWorker: mock_ctx.sessionmaker = mock_sm mock_repo = Mock() - mock_repo.claim_one = AsyncMock(return_value={ - "job_id": "test-job-id", - "lease_ttl_sec": 60, - "task": "test.task", - "args": {"key": "value"} - }) + mock_repo.claim_one = AsyncMock( + return_value={ + "job_id": "test-job-id", + "lease_ttl_sec": 60, + "task": "test.task", + "args": {"key": "value"}, + } + ) mock_repo.finish_ok = AsyncMock() mock_repo_cls.return_value = mock_repo @@ -210,8 +223,10 @@ class TestPGWorker: async def mock_pipeline(task, args): yield - with patch.object(worker, "_pipeline", side_effect=mock_pipeline), \ - patch.object(worker, "_execute_with_heartbeat", return_value=False): + with ( + patch.object(worker, "_pipeline", side_effect=mock_pipeline), + patch.object(worker, "_execute_with_heartbeat", return_value=False), + ): result = await worker._claim_and_execute_once() assert result is True @@ -225,8 +240,10 @@ class TestPGWorker: cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) stop_event = asyncio.Event() - with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: + with ( + patch("dataloader.workers.base.APP_CTX") as mock_ctx, + patch("dataloader.workers.base.QueueRepository") as mock_repo_cls, + ): mock_session = AsyncMock() mock_sm = MagicMock() @@ -236,12 +253,14 @@ class TestPGWorker: mock_ctx.sessionmaker = mock_sm mock_repo = Mock() - mock_repo.claim_one = AsyncMock(return_value={ - "job_id": "test-job-id", - "lease_ttl_sec": 60, - "task": "test.task", - "args": {} - }) + mock_repo.claim_one = AsyncMock( + return_value={ + "job_id": "test-job-id", + "lease_ttl_sec": 60, + "task": "test.task", + "args": {}, + } + ) mock_repo.finish_fail_or_retry = AsyncMock() mock_repo_cls.return_value = mock_repo @@ -263,8 +282,10 @@ class TestPGWorker: cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) stop_event = asyncio.Event() - with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: + with ( + patch("dataloader.workers.base.APP_CTX") as mock_ctx, + patch("dataloader.workers.base.QueueRepository") as mock_repo_cls, + ): mock_session = AsyncMock() mock_sm = MagicMock() @@ -274,18 +295,22 @@ class TestPGWorker: mock_ctx.sessionmaker = mock_sm mock_repo = Mock() - mock_repo.claim_one = AsyncMock(return_value={ - "job_id": "test-job-id", - "lease_ttl_sec": 60, - "task": "test.task", - "args": {} - }) + mock_repo.claim_one = AsyncMock( + return_value={ + "job_id": "test-job-id", + "lease_ttl_sec": 60, + "task": "test.task", + "args": {}, + } + ) mock_repo.finish_fail_or_retry = AsyncMock() mock_repo_cls.return_value = mock_repo worker = PGWorker(cfg, stop_event) - with patch.object(worker, "_execute_with_heartbeat", side_effect=ValueError("Test error")): + with patch.object( + worker, "_execute_with_heartbeat", side_effect=ValueError("Test error") + ): result = await worker._claim_and_execute_once() assert result is True @@ -301,8 +326,10 @@ class TestPGWorker: cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5) stop_event = asyncio.Event() - with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: + with ( + patch("dataloader.workers.base.APP_CTX") as mock_ctx, + patch("dataloader.workers.base.QueueRepository") as mock_repo_cls, + ): mock_session = AsyncMock() mock_sm = MagicMock() @@ -323,7 +350,9 @@ class TestPGWorker: await asyncio.sleep(0.6) yield - canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline()) + canceled = await worker._execute_with_heartbeat( + "job-id", 60, slow_pipeline() + ) assert canceled is False assert mock_repo.heartbeat.call_count >= 1 @@ -336,8 +365,10 @@ class TestPGWorker: cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5) stop_event = asyncio.Event() - with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: + with ( + patch("dataloader.workers.base.APP_CTX") as mock_ctx, + patch("dataloader.workers.base.QueueRepository") as mock_repo_cls, + ): mock_session = AsyncMock() mock_sm = MagicMock() @@ -358,7 +389,9 @@ class TestPGWorker: await asyncio.sleep(0.6) yield - canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline()) + canceled = await worker._execute_with_heartbeat( + "job-id", 60, slow_pipeline() + ) assert canceled is True @@ -370,8 +403,10 @@ class TestPGWorker: cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) stop_event = asyncio.Event() - with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.base.resolve_pipeline") as mock_resolve: + with ( + patch("dataloader.workers.base.APP_CTX") as mock_ctx, + patch("dataloader.workers.base.resolve_pipeline") as mock_resolve, + ): mock_ctx.get_logger.return_value = Mock() mock_ctx.sessionmaker = Mock() @@ -397,8 +432,10 @@ class TestPGWorker: cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) stop_event = asyncio.Event() - with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.base.resolve_pipeline") as mock_resolve: + with ( + patch("dataloader.workers.base.APP_CTX") as mock_ctx, + patch("dataloader.workers.base.resolve_pipeline") as mock_resolve, + ): mock_ctx.get_logger.return_value = Mock() mock_ctx.sessionmaker = Mock() @@ -424,8 +461,10 @@ class TestPGWorker: cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) stop_event = asyncio.Event() - with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.base.resolve_pipeline") as mock_resolve: + with ( + patch("dataloader.workers.base.APP_CTX") as mock_ctx, + patch("dataloader.workers.base.resolve_pipeline") as mock_resolve, + ): mock_ctx.get_logger.return_value = Mock() mock_ctx.sessionmaker = Mock() @@ -450,8 +489,10 @@ class TestPGWorker: cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) stop_event = asyncio.Event() - with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: + with ( + patch("dataloader.workers.base.APP_CTX") as mock_ctx, + patch("dataloader.workers.base.QueueRepository") as mock_repo_cls, + ): mock_session = AsyncMock() mock_sm = MagicMock() @@ -461,12 +502,14 @@ class TestPGWorker: mock_ctx.sessionmaker = mock_sm mock_repo = Mock() - mock_repo.claim_one = AsyncMock(return_value={ - "job_id": "test-job-id", - "lease_ttl_sec": 60, - "task": "test.task", - "args": {} - }) + mock_repo.claim_one = AsyncMock( + return_value={ + "job_id": "test-job-id", + "lease_ttl_sec": 60, + "task": "test.task", + "args": {}, + } + ) mock_repo.finish_fail_or_retry = AsyncMock() mock_repo_cls.return_value = mock_repo @@ -484,19 +527,16 @@ class TestPGWorker: assert "cancelled by shutdown" in args[1] assert kwargs.get("is_canceled") is True - - - - - @pytest.mark.asyncio async def test_execute_with_heartbeat_raises_cancelled_when_stop_set(self): cfg = WorkerConfig(queue="test", heartbeat_sec=1000, claim_backoff_sec=5) stop_event = asyncio.Event() stop_event.set() - with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: + with ( + patch("dataloader.workers.base.APP_CTX") as mock_ctx, + patch("dataloader.workers.base.QueueRepository") as mock_repo_cls, + ): mock_ctx.get_logger.return_value = Mock() mock_ctx.sessionmaker = Mock() @@ -509,4 +549,3 @@ class TestPGWorker: with pytest.raises(asyncio.CancelledError): await worker._execute_with_heartbeat("job-id", 60, one_yield()) - diff --git a/tests/unit/test_workers_manager.py b/tests/unit/test_workers_manager.py index c8c5485..2538cbd 100644 --- a/tests/unit/test_workers_manager.py +++ b/tests/unit/test_workers_manager.py @@ -1,8 +1,8 @@ -# tests/unit/test_workers_manager.py from __future__ import annotations import asyncio from unittest.mock import AsyncMock, MagicMock, Mock, patch + import pytest from dataloader.workers.manager import WorkerManager, WorkerSpec, build_manager_from_env @@ -39,9 +39,11 @@ class TestWorkerManager: WorkerSpec(queue="queue2", concurrency=1), ] - with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \ - patch("dataloader.workers.manager.PGWorker") as mock_worker_cls: + with ( + patch("dataloader.workers.manager.APP_CTX") as mock_ctx, + patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, + patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, + ): mock_ctx.get_logger.return_value = Mock() mock_cfg.worker.heartbeat_sec = 10 @@ -67,9 +69,11 @@ class TestWorkerManager: """ specs = [WorkerSpec(queue="test", concurrency=0)] - with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \ - patch("dataloader.workers.manager.PGWorker") as mock_worker_cls: + with ( + patch("dataloader.workers.manager.APP_CTX") as mock_ctx, + patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, + patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, + ): mock_ctx.get_logger.return_value = Mock() mock_cfg.worker.heartbeat_sec = 10 @@ -93,9 +97,11 @@ class TestWorkerManager: """ specs = [WorkerSpec(queue="test", concurrency=2)] - with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \ - patch("dataloader.workers.manager.PGWorker") as mock_worker_cls: + with ( + patch("dataloader.workers.manager.APP_CTX") as mock_ctx, + patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, + patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, + ): mock_ctx.get_logger.return_value = Mock() mock_cfg.worker.heartbeat_sec = 10 @@ -108,8 +114,6 @@ class TestWorkerManager: manager = WorkerManager(specs) await manager.start() - initial_task_count = len(manager._tasks) - await manager.stop() assert manager._stop.is_set() @@ -123,10 +127,12 @@ class TestWorkerManager: """ specs = [WorkerSpec(queue="test", concurrency=1)] - with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \ - patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, \ - patch("dataloader.workers.manager.requeue_lost") as mock_requeue: + with ( + patch("dataloader.workers.manager.APP_CTX") as mock_ctx, + patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, + patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, + patch("dataloader.workers.manager.requeue_lost") as mock_requeue, + ): mock_logger = Mock() mock_ctx.get_logger.return_value = mock_logger @@ -161,10 +167,12 @@ class TestWorkerManager: """ specs = [WorkerSpec(queue="test", concurrency=1)] - with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \ - patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, \ - patch("dataloader.workers.manager.requeue_lost") as mock_requeue: + with ( + patch("dataloader.workers.manager.APP_CTX") as mock_ctx, + patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, + patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, + patch("dataloader.workers.manager.requeue_lost") as mock_requeue, + ): mock_logger = Mock() mock_ctx.get_logger.return_value = mock_logger @@ -203,8 +211,10 @@ class TestBuildManagerFromEnv: """ Тест создания менеджера из конфигурации. """ - with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg: + with ( + patch("dataloader.workers.manager.APP_CTX") as mock_ctx, + patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, + ): mock_ctx.get_logger.return_value = Mock() mock_cfg.worker.parsed_workers.return_value = [ @@ -224,8 +234,10 @@ class TestBuildManagerFromEnv: """ Тест, что пустые имена очередей пропускаются. """ - with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg: + with ( + patch("dataloader.workers.manager.APP_CTX") as mock_ctx, + patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, + ): mock_ctx.get_logger.return_value = Mock() mock_cfg.worker.parsed_workers.return_value = [ @@ -243,8 +255,10 @@ class TestBuildManagerFromEnv: """ Тест обработки отсутствующих полей с дефолтными значениями. """ - with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg: + with ( + patch("dataloader.workers.manager.APP_CTX") as mock_ctx, + patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, + ): mock_ctx.get_logger.return_value = Mock() mock_cfg.worker.parsed_workers.return_value = [ @@ -262,8 +276,10 @@ class TestBuildManagerFromEnv: """ Тест, что concurrency всегда минимум 1. """ - with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ - patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg: + with ( + patch("dataloader.workers.manager.APP_CTX") as mock_ctx, + patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, + ): mock_ctx.get_logger.return_value = Mock() mock_cfg.worker.parsed_workers.return_value = [ diff --git a/tests/unit/test_workers_reaper.py b/tests/unit/test_workers_reaper.py index 43669db..cf6c51d 100644 --- a/tests/unit/test_workers_reaper.py +++ b/tests/unit/test_workers_reaper.py @@ -1,8 +1,8 @@ -# tests/unit/test_workers_reaper.py from __future__ import annotations +from unittest.mock import AsyncMock, Mock, patch + import pytest -from unittest.mock import AsyncMock, patch, Mock from dataloader.workers.reaper import requeue_lost