refactor: refactor code

This commit is contained in:
itqop 2025-11-05 20:45:13 +03:00
parent c907e1d4da
commit bde0bb0e6f
66 changed files with 853 additions and 463 deletions

View File

@ -34,3 +34,37 @@ dataloader = "dataloader.__main__:main"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.black]
line-length = 88
target-version = ["py311"]
skip-string-normalization = false
preview = true
[tool.isort]
profile = "black"
line_length = 88
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
combine_as_imports = true
known_first_party = ["dataloader"]
src_paths = ["src"]
[tool.ruff]
line-length = 88
target-version = "py311"
fix = true
lint.select = ["E", "F", "W", "I", "N", "B"]
lint.ignore = [
"E501",
]
exclude = [
".git",
"__pycache__",
"build",
"dist",
".venv",
"venv",
]

View File

@ -5,7 +5,6 @@
from . import api from . import api
__all__ = [ __all__ = [
"api", "api",
] ]

View File

@ -1,13 +1,15 @@
import uvicorn import uvicorn
from dataloader.api import app_main from dataloader.api import app_main
from dataloader.config import APP_CONFIG from dataloader.config import APP_CONFIG
from dataloader.logger.uvicorn_logging_config import LOGGING_CONFIG, setup_uvicorn_logging from dataloader.logger.uvicorn_logging_config import (
LOGGING_CONFIG,
setup_uvicorn_logging,
)
def main() -> None: def main() -> None:
# Инициализируем логирование uvicorn перед запуском
setup_uvicorn_logging() setup_uvicorn_logging()
uvicorn.run( uvicorn.run(

View File

@ -1,20 +1,19 @@
# src/dataloader/api/__init__.py
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator
import contextlib import contextlib
import typing as tp import typing as tp
from collections.abc import AsyncGenerator
from fastapi import FastAPI from fastapi import FastAPI
from dataloader.context import APP_CTX
from dataloader.workers.manager import WorkerManager, build_manager_from_env
from dataloader.workers.pipelines import load_all as load_pipelines
from .metric_router import router as metric_router from .metric_router import router as metric_router
from .middleware import log_requests from .middleware import log_requests
from .os_router import router as service_router from .os_router import router as service_router
from .v1 import router as v1_router from .v1 import router as v1_router
from dataloader.context import APP_CTX
from dataloader.workers.manager import build_manager_from_env, WorkerManager
from dataloader.workers.pipelines import load_all as load_pipelines
_manager: WorkerManager | None = None _manager: WorkerManager | None = None

View File

@ -1,10 +1,11 @@
""" 🚨 НЕ РЕДАКТИРОВАТЬ !!!!!! """🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!"""
"""
import uuid import uuid
from fastapi import APIRouter, Header, status from fastapi import APIRouter, Header, status
from dataloader.context import APP_CTX from dataloader.context import APP_CTX
from . import schemas from . import schemas
router = APIRouter() router = APIRouter()
@ -17,8 +18,7 @@ logger = APP_CTX.get_logger()
response_model=schemas.RateResponse, response_model=schemas.RateResponse,
) )
async def like( async def like(
# pylint: disable=C0103,W0613 header_request_id: str = Header(uuid.uuid4(), alias="Request-Id")
header_Request_Id: str = Header(uuid.uuid4(), alias="Request-Id")
) -> dict[str, str]: ) -> dict[str, str]:
logger.metric( logger.metric(
metric_name="dataloader_likes_total", metric_name="dataloader_likes_total",
@ -33,8 +33,7 @@ async def like(
response_model=schemas.RateResponse, response_model=schemas.RateResponse,
) )
async def dislike( async def dislike(
# pylint: disable=C0103,W0613 header_request_id: str = Header(uuid.uuid4(), alias="Request-Id")
header_Request_Id: str = Header(uuid.uuid4(), alias="Request-Id")
) -> dict[str, str]: ) -> dict[str, str]:
logger.metric( logger.metric(
metric_name="dataloader_dislikes_total", metric_name="dataloader_dislikes_total",

View File

@ -38,8 +38,14 @@ async def log_requests(request: Request, call_next) -> any:
logger = APP_CTX.get_logger() logger = APP_CTX.get_logger()
request_path = request.url.path request_path = request.url.path
allowed_headers_to_log = ((k, request.headers.get(k)) for k in HEADERS_WHITE_LIST_TO_LOG) allowed_headers_to_log = (
headers_to_log = {header_name: header_value for header_name, header_value in allowed_headers_to_log if header_value} (k, request.headers.get(k)) for k in HEADERS_WHITE_LIST_TO_LOG
)
headers_to_log = {
header_name: header_value
for header_name, header_value in allowed_headers_to_log
if header_value
}
APP_CTX.get_context_vars_container().set_context_vars( APP_CTX.get_context_vars_container().set_context_vars(
request_id=headers_to_log.get("Request-Id", ""), request_id=headers_to_log.get("Request-Id", ""),
@ -50,7 +56,9 @@ async def log_requests(request: Request, call_next) -> any:
if request_path in NON_LOGGED_ENDPOINTS: if request_path in NON_LOGGED_ENDPOINTS:
response = await call_next(request) response = await call_next(request)
logger.debug(f"Processed request for {request_path} with code {response.status_code}") logger.debug(
f"Processed request for {request_path} with code {response.status_code}"
)
elif headers_to_log.get("Request-Id", None): elif headers_to_log.get("Request-Id", None):
raw_request_body = await request.body() raw_request_body = await request.body()
request_body_decoded = _get_decoded_body(raw_request_body, "request", logger) request_body_decoded = _get_decoded_body(raw_request_body, "request", logger)
@ -80,7 +88,7 @@ async def log_requests(request: Request, call_next) -> any:
event_params=json.dumps( event_params=json.dumps(
request_body_decoded, request_body_decoded,
ensure_ascii=False, ensure_ascii=False,
) ),
) )
response = await call_next(request) response = await call_next(request)
@ -88,12 +96,16 @@ async def log_requests(request: Request, call_next) -> any:
response_body = [chunk async for chunk in response.body_iterator] response_body = [chunk async for chunk in response.body_iterator]
response.body_iterator = iterate_in_threadpool(iter(response_body)) response.body_iterator = iterate_in_threadpool(iter(response_body))
headers_to_log["Response-Time"] = datetime.now(APP_CTX.get_pytz_timezone()).isoformat() headers_to_log["Response-Time"] = datetime.now(
APP_CTX.get_pytz_timezone()
).isoformat()
for header in headers_to_log: for header in headers_to_log:
response.headers[header] = headers_to_log[header] response.headers[header] = headers_to_log[header]
response_body_extracted = response_body[0] if len(response_body) > 0 else b"" response_body_extracted = response_body[0] if len(response_body) > 0 else b""
decoded_response_body = _get_decoded_body(response_body_extracted, "response", logger) decoded_response_body = _get_decoded_body(
response_body_extracted, "response", logger
)
logger.info( logger.info(
"Outgoing response to client system", "Outgoing response to client system",
@ -115,11 +127,13 @@ async def log_requests(request: Request, call_next) -> any:
event_params=json.dumps( event_params=json.dumps(
decoded_response_body, decoded_response_body,
ensure_ascii=False, ensure_ascii=False,
) ),
) )
processing_time_ms = int(round((time.time() - start_time), 3) * 1000) processing_time_ms = int(round((time.time() - start_time), 3) * 1000)
logger.info(f"Request processing time for {request_path}: {processing_time_ms} ms") logger.info(
f"Request processing time for {request_path}: {processing_time_ms} ms"
)
logger.metric( logger.metric(
metric_name="dataloader_process_duration_ms", metric_name="dataloader_process_duration_ms",
metric_value=processing_time_ms, metric_value=processing_time_ms,
@ -138,7 +152,9 @@ async def log_requests(request: Request, call_next) -> any:
else: else:
logger.info(f"Incoming {request.method}-request with no id for {request_path}") logger.info(f"Incoming {request.method}-request with no id for {request_path}")
response = await call_next(request) response = await call_next(request)
logger.info(f"Request with no id for {request_path} processing time: {time.time() - start_time:.3f} s") logger.info(
f"Request with no id for {request_path} processing time: {time.time() - start_time:.3f} s"
)
return response return response

View File

@ -1,6 +1,4 @@
# Инфраструктурные endpoint'ы (/health, /status) """🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!"""
""" 🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!
"""
from importlib.metadata import distribution from importlib.metadata import distribution

View File

@ -3,22 +3,26 @@ from pydantic import BaseModel, ConfigDict, Field
class HealthResponse(BaseModel): class HealthResponse(BaseModel):
"""Ответ для ручки /health""" """Ответ для ручки /health"""
model_config = ConfigDict(
json_schema_extra={"example": {"status": "running"}}
)
status: str = Field(default="running", description="Service health check", max_length=7) model_config = ConfigDict(json_schema_extra={"example": {"status": "running"}})
status: str = Field(
default="running", description="Service health check", max_length=7
)
class InfoResponse(BaseModel): class InfoResponse(BaseModel):
"""Ответ для ручки /info""" """Ответ для ручки /info"""
model_config = ConfigDict( model_config = ConfigDict(
json_schema_extra={ json_schema_extra={
"example": { "example": {
"name": "rest-template", "name": "rest-template",
"description": "Python 'AI gateway' template for developing REST microservices", "description": (
"Python 'AI gateway' template for developing REST microservices"
),
"type": "REST API", "type": "REST API",
"version": "0.1.0" "version": "0.1.0",
} }
} }
) )
@ -26,11 +30,14 @@ class InfoResponse(BaseModel):
name: str = Field(description="Service name", max_length=50) name: str = Field(description="Service name", max_length=50)
description: str = Field(description="Service description", max_length=200) description: str = Field(description="Service description", max_length=200)
type: str = Field(default="REST API", description="Service type", max_length=20) type: str = Field(default="REST API", description="Service type", max_length=20)
version: str = Field(description="Service version", max_length=20, pattern=r"^\d+\.\d+\.\d+") version: str = Field(
description="Service version", max_length=20, pattern=r"^\d+\.\d+\.\d+"
)
class RateResponse(BaseModel): class RateResponse(BaseModel):
"""Ответ для записи рейтинга""" """Ответ для записи рейтинга"""
model_config = ConfigDict(str_strip_whitespace=True) model_config = ConfigDict(str_strip_whitespace=True)
rating_result: str = Field(description="Rating that was recorded", max_length=50) rating_result: str = Field(description="Rating that was recorded", max_length=50)

View File

@ -8,9 +8,5 @@ class JobNotFoundError(HTTPException):
def __init__(self, job_id: str): def __init__(self, job_id: str):
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail=f"Job {job_id} not found"
detail=f"Job {job_id} not found"
) )

View File

@ -1,2 +1 @@
"""Модели данных для API v1.""" """Модели данных для API v1."""

View File

@ -1,12 +1,11 @@
# src/dataloader/api/v1/router.py
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator
from http import HTTPStatus from http import HTTPStatus
from typing import Annotated from typing import Annotated
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from dataloader.api.v1.exceptions import JobNotFoundError from dataloader.api.v1.exceptions import JobNotFoundError
from dataloader.api.v1.schemas import ( from dataloader.api.v1.schemas import (
@ -16,8 +15,6 @@ from dataloader.api.v1.schemas import (
) )
from dataloader.api.v1.service import JobsService from dataloader.api.v1.service import JobsService
from dataloader.context import get_session from dataloader.context import get_session
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter(prefix="/jobs", tags=["jobs"]) router = APIRouter(prefix="/jobs", tags=["jobs"])
@ -40,7 +37,9 @@ async def trigger_job(
return await svc.trigger(payload) return await svc.trigger(payload)
@router.get("/{job_id}/status", response_model=JobStatusResponse, status_code=HTTPStatus.OK) @router.get(
"/{job_id}/status", response_model=JobStatusResponse, status_code=HTTPStatus.OK
)
async def get_status( async def get_status(
job_id: UUID, job_id: UUID,
svc: Annotated[JobsService, Depends(get_service)], svc: Annotated[JobsService, Depends(get_service)],
@ -54,7 +53,9 @@ async def get_status(
return st return st
@router.post("/{job_id}/cancel", response_model=JobStatusResponse, status_code=HTTPStatus.OK) @router.post(
"/{job_id}/cancel", response_model=JobStatusResponse, status_code=HTTPStatus.OK
)
async def cancel_job( async def cancel_job(
job_id: UUID, job_id: UUID,
svc: Annotated[JobsService, Depends(get_service)], svc: Annotated[JobsService, Depends(get_service)],

View File

@ -1,4 +1,3 @@
# src/dataloader/api/v1/schemas.py
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
@ -12,6 +11,7 @@ class TriggerJobRequest(BaseModel):
""" """
Запрос на постановку задачи в очередь. Запрос на постановку задачи в очередь.
""" """
model_config = ConfigDict(str_strip_whitespace=True) model_config = ConfigDict(str_strip_whitespace=True)
queue: str = Field(...) queue: str = Field(...)
@ -39,6 +39,7 @@ class TriggerJobResponse(BaseModel):
""" """
Ответ на постановку задачи. Ответ на постановку задачи.
""" """
model_config = ConfigDict(str_strip_whitespace=True) model_config = ConfigDict(str_strip_whitespace=True)
job_id: UUID = Field(...) job_id: UUID = Field(...)
@ -49,6 +50,7 @@ class JobStatusResponse(BaseModel):
""" """
Текущий статус задачи. Текущий статус задачи.
""" """
model_config = ConfigDict(str_strip_whitespace=True) model_config = ConfigDict(str_strip_whitespace=True)
job_id: UUID = Field(...) job_id: UUID = Field(...)
@ -65,6 +67,7 @@ class CancelJobResponse(BaseModel):
""" """
Ответ на запрос отмены задачи. Ответ на запрос отмены задачи.
""" """
model_config = ConfigDict(str_strip_whitespace=True) model_config = ConfigDict(str_strip_whitespace=True)
job_id: UUID = Field(...) job_id: UUID = Field(...)

View File

@ -1,4 +1,3 @@
# src/dataloader/api/v1/service.py
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
@ -13,15 +12,16 @@ from dataloader.api.v1.schemas import (
TriggerJobResponse, TriggerJobResponse,
) )
from dataloader.api.v1.utils import new_job_id from dataloader.api.v1.utils import new_job_id
from dataloader.storage.schemas import CreateJobRequest
from dataloader.storage.repositories import QueueRepository
from dataloader.logger.logger import get_logger from dataloader.logger.logger import get_logger
from dataloader.storage.repositories import QueueRepository
from dataloader.storage.schemas import CreateJobRequest
class JobsService: class JobsService:
""" """
Бизнес-логика работы с очередью задач. Бизнес-логика работы с очередью задач.
""" """
def __init__(self, session: AsyncSession): def __init__(self, session: AsyncSession):
self._s = session self._s = session
self._repo = QueueRepository(self._s) self._repo = QueueRepository(self._s)

View File

@ -1,4 +1,3 @@
# src/dataloader/api/v1/utils.py
from __future__ import annotations from __future__ import annotations
from uuid import UUID, uuid4 from uuid import UUID, uuid4

View File

@ -1,6 +1,5 @@
# src/dataloader/config.py
import os
import json import json
import os
from logging import DEBUG, INFO from logging import DEBUG, INFO
from typing import Annotated, Any from typing import Annotated, Any
@ -32,6 +31,7 @@ class BaseAppSettings(BaseSettings):
""" """
Базовый класс для настроек. Базовый класс для настроек.
""" """
local: bool = Field(validation_alias="LOCAL", default=False) local: bool = Field(validation_alias="LOCAL", default=False)
debug: bool = Field(validation_alias="DEBUG", default=False) debug: bool = Field(validation_alias="DEBUG", default=False)
@ -44,6 +44,7 @@ class AppSettings(BaseAppSettings):
""" """
Настройки приложения. Настройки приложения.
""" """
app_host: str = Field(validation_alias="APP_HOST", default="0.0.0.0") app_host: str = Field(validation_alias="APP_HOST", default="0.0.0.0")
app_port: int = Field(validation_alias="APP_PORT", default=8081) app_port: int = Field(validation_alias="APP_PORT", default=8081)
kube_net_name: str = Field(validation_alias="PROJECT_NAME", default="AIGATEWAY") kube_net_name: str = Field(validation_alias="PROJECT_NAME", default="AIGATEWAY")
@ -54,15 +55,28 @@ class LogSettings(BaseAppSettings):
""" """
Настройки логирования. Настройки логирования.
""" """
private_log_file_path: str = Field(validation_alias="LOG_PATH", default=os.getcwd()) private_log_file_path: str = Field(validation_alias="LOG_PATH", default=os.getcwd())
private_log_file_name: str = Field(validation_alias="LOG_FILE_NAME", default="app.log") private_log_file_name: str = Field(
validation_alias="LOG_FILE_NAME", default="app.log"
)
log_rotation: str = Field(validation_alias="LOG_ROTATION", default="10 MB") log_rotation: str = Field(validation_alias="LOG_ROTATION", default="10 MB")
private_metric_file_path: str = Field(validation_alias="METRIC_PATH", default=os.getcwd()) private_metric_file_path: str = Field(
private_metric_file_name: str = Field(validation_alias="METRIC_FILE_NAME", default="app-metric.log") validation_alias="METRIC_PATH", default=os.getcwd()
private_audit_file_path: str = Field(validation_alias="AUDIT_LOG_PATH", default=os.getcwd()) )
private_audit_file_name: str = Field(validation_alias="AUDIT_LOG_FILE_NAME", default="events.log") private_metric_file_name: str = Field(
validation_alias="METRIC_FILE_NAME", default="app-metric.log"
)
private_audit_file_path: str = Field(
validation_alias="AUDIT_LOG_PATH", default=os.getcwd()
)
private_audit_file_name: str = Field(
validation_alias="AUDIT_LOG_FILE_NAME", default="events.log"
)
audit_host_ip: str = Field(validation_alias="HOST_IP", default="127.0.0.1") audit_host_ip: str = Field(validation_alias="HOST_IP", default="127.0.0.1")
audit_host_uid: str = Field(validation_alias="HOST_UID", default="63b6dcee-170b-49bf-a65c-3ec967398ccd") audit_host_uid: str = Field(
validation_alias="HOST_UID", default="63b6dcee-170b-49bf-a65c-3ec967398ccd"
)
@staticmethod @staticmethod
def get_file_abs_path(path_name: str, file_name: str) -> str: def get_file_abs_path(path_name: str, file_name: str) -> str:
@ -70,15 +84,21 @@ class LogSettings(BaseAppSettings):
@property @property
def log_file_abs_path(self) -> str: def log_file_abs_path(self) -> str:
return self.get_file_abs_path(self.private_log_file_path, self.private_log_file_name) return self.get_file_abs_path(
self.private_log_file_path, self.private_log_file_name
)
@property @property
def metric_file_abs_path(self) -> str: def metric_file_abs_path(self) -> str:
return self.get_file_abs_path(self.private_metric_file_path, self.private_metric_file_name) return self.get_file_abs_path(
self.private_metric_file_path, self.private_metric_file_name
)
@property @property
def audit_file_abs_path(self) -> str: def audit_file_abs_path(self) -> str:
return self.get_file_abs_path(self.private_audit_file_path, self.private_audit_file_name) return self.get_file_abs_path(
self.private_audit_file_path, self.private_audit_file_name
)
@property @property
def log_lvl(self) -> int: def log_lvl(self) -> int:
@ -89,6 +109,7 @@ class PGSettings(BaseSettings):
""" """
Настройки подключения к Postgres. Настройки подключения к Postgres.
""" """
host: str = Field(validation_alias="PG_HOST", default="localhost") host: str = Field(validation_alias="PG_HOST", default="localhost")
port: int = Field(validation_alias="PG_PORT", default=5432) port: int = Field(validation_alias="PG_PORT", default=5432)
user: str = Field(validation_alias="PG_USER", default="postgres") user: str = Field(validation_alias="PG_USER", default="postgres")
@ -116,9 +137,12 @@ class WorkerSettings(BaseSettings):
""" """
Настройки очереди и воркеров. Настройки очереди и воркеров.
""" """
workers_json: str = Field(validation_alias="WORKERS_JSON", default="[]") workers_json: str = Field(validation_alias="WORKERS_JSON", default="[]")
heartbeat_sec: int = Field(validation_alias="DL_HEARTBEAT_SEC", default=10) heartbeat_sec: int = Field(validation_alias="DL_HEARTBEAT_SEC", default=10)
default_lease_ttl_sec: int = Field(validation_alias="DL_DEFAULT_LEASE_TTL_SEC", default=60) default_lease_ttl_sec: int = Field(
validation_alias="DL_DEFAULT_LEASE_TTL_SEC", default=60
)
reaper_period_sec: int = Field(validation_alias="DL_REAPER_PERIOD_SEC", default=10) reaper_period_sec: int = Field(validation_alias="DL_REAPER_PERIOD_SEC", default=10)
claim_backoff_sec: int = Field(validation_alias="DL_CLAIM_BACKOFF_SEC", default=15) claim_backoff_sec: int = Field(validation_alias="DL_CLAIM_BACKOFF_SEC", default=15)
@ -137,6 +161,7 @@ class CertsSettings(BaseSettings):
""" """
Настройки SSL сертификатов для локальной разработки. Настройки SSL сертификатов для локальной разработки.
""" """
ca_bundle_file: str = Field(validation_alias="CA_BUNDLE_FILE", default="") ca_bundle_file: str = Field(validation_alias="CA_BUNDLE_FILE", default="")
cert_file: str = Field(validation_alias="CERT_FILE", default="") cert_file: str = Field(validation_alias="CERT_FILE", default="")
key_file: str = Field(validation_alias="KEY_FILE", default="") key_file: str = Field(validation_alias="KEY_FILE", default="")
@ -146,21 +171,24 @@ class SuperTeneraSettings(BaseAppSettings):
""" """
Настройки интеграции с SuperTenera. Настройки интеграции с SuperTenera.
""" """
host: Annotated[str, BeforeValidator(strip_slashes)] = Field( host: Annotated[str, BeforeValidator(strip_slashes)] = Field(
validation_alias="SUPERTENERA_HOST", validation_alias="SUPERTENERA_HOST",
default="ci03801737-ift-tenera-giga.delta.sbrf.ru/atlant360bc/" default="ci03801737-ift-tenera-giga.delta.sbrf.ru/atlant360bc/",
) )
port: str = Field(validation_alias="SUPERTENERA_PORT", default="443") port: str = Field(validation_alias="SUPERTENERA_PORT", default="443")
quotes_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field( quotes_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
validation_alias="SUPERTENERA_QUOTES_ENDPOINT", validation_alias="SUPERTENERA_QUOTES_ENDPOINT",
default="/get_gigaparser_quotes/" default="/get_gigaparser_quotes/",
) )
timeout: int = Field(validation_alias="SUPERTENERA_TIMEOUT", default=20) timeout: int = Field(validation_alias="SUPERTENERA_TIMEOUT", default=20)
@property @property
def base_url(self) -> str: def base_url(self) -> str:
"""Возвращает абсолютный URL для SuperTenera.""" """Возвращает абсолютный URL для SuperTenera."""
domain, raw_path = self.host.split("/", 1) if "/" in self.host else (self.host, "") domain, raw_path = (
self.host.split("/", 1) if "/" in self.host else (self.host, "")
)
return build_url(self.protocol, domain, self.port, raw_path) return build_url(self.protocol, domain, self.port, raw_path)
@ -168,22 +196,21 @@ class Gmap2BriefSettings(BaseAppSettings):
""" """
Настройки интеграции с Gmap2Brief (OPU API). Настройки интеграции с Gmap2Brief (OPU API).
""" """
host: Annotated[str, BeforeValidator(strip_slashes)] = Field( host: Annotated[str, BeforeValidator(strip_slashes)] = Field(
validation_alias="GMAP2BRIEF_HOST", validation_alias="GMAP2BRIEF_HOST",
default="ci02533826-tib-brief.apps.ift-terra000024-edm.ocp.delta.sbrf.ru" default="ci02533826-tib-brief.apps.ift-terra000024-edm.ocp.delta.sbrf.ru",
) )
port: str = Field(validation_alias="GMAP2BRIEF_PORT", default="443") port: str = Field(validation_alias="GMAP2BRIEF_PORT", default="443")
start_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field( start_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
validation_alias="GMAP2BRIEF_START_ENDPOINT", validation_alias="GMAP2BRIEF_START_ENDPOINT", default="/export/opu/start"
default="/export/opu/start"
) )
status_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field( status_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
validation_alias="GMAP2BRIEF_STATUS_ENDPOINT", validation_alias="GMAP2BRIEF_STATUS_ENDPOINT", default="/export/{job_id}/status"
default="/export/{job_id}/status"
) )
download_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field( download_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
validation_alias="GMAP2BRIEF_DOWNLOAD_ENDPOINT", validation_alias="GMAP2BRIEF_DOWNLOAD_ENDPOINT",
default="/export/{job_id}/download" default="/export/{job_id}/download",
) )
poll_interval: int = Field(validation_alias="GMAP2BRIEF_POLL_INTERVAL", default=2) poll_interval: int = Field(validation_alias="GMAP2BRIEF_POLL_INTERVAL", default=2)
timeout: int = Field(validation_alias="GMAP2BRIEF_TIMEOUT", default=3600) timeout: int = Field(validation_alias="GMAP2BRIEF_TIMEOUT", default=3600)
@ -198,6 +225,7 @@ class Secrets:
""" """
Агрегатор настроек приложения. Агрегатор настроек приложения.
""" """
app: AppSettings = AppSettings() app: AppSettings = AppSettings()
log: LogSettings = LogSettings() log: LogSettings = LogSettings()
pg: PGSettings = PGSettings() pg: PGSettings = PGSettings()

View File

@ -1,8 +1,7 @@
# src/dataloader/context.py
from __future__ import annotations from __future__ import annotations
from typing import AsyncGenerator
from logging import Logger from logging import Logger
from typing import AsyncGenerator
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
@ -15,6 +14,7 @@ class AppContext:
""" """
Контекст приложения, хранящий глобальные зависимости (Singleton pattern). Контекст приложения, хранящий глобальные зависимости (Singleton pattern).
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._engine: AsyncEngine | None = None self._engine: AsyncEngine | None = None
self._sessionmaker: async_sessionmaker[AsyncSession] | None = None self._sessionmaker: async_sessionmaker[AsyncSession] | None = None
@ -75,6 +75,7 @@ class AppContext:
Экземпляр Logger Экземпляр Logger
""" """
from .logger.logger import get_logger as get_app_logger from .logger.logger import get_logger as get_app_logger
return get_app_logger(name) return get_app_logger(name)
def get_context_vars_container(self) -> ContextVarsContainer: def get_context_vars_container(self) -> ContextVarsContainer:

View File

@ -1,2 +1 @@
"""Исключения уровня приложения.""" """Исключения уровня приложения."""

View File

@ -10,7 +10,10 @@ from typing import TYPE_CHECKING
import httpx import httpx
from dataloader.config import APP_CONFIG from dataloader.config import APP_CONFIG
from dataloader.interfaces.gmap2_brief.schemas import ExportJobStatus, StartExportResponse from dataloader.interfaces.gmap2_brief.schemas import (
ExportJobStatus,
StartExportResponse,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from logging import Logger from logging import Logger
@ -37,10 +40,11 @@ class Gmap2BriefInterface:
self._ssl_context = None self._ssl_context = None
if APP_CONFIG.app.local and APP_CONFIG.certs.cert_file: if APP_CONFIG.app.local and APP_CONFIG.certs.cert_file:
self._ssl_context = ssl.create_default_context(cafile=APP_CONFIG.certs.ca_bundle_file) self._ssl_context = ssl.create_default_context(
cafile=APP_CONFIG.certs.ca_bundle_file
)
self._ssl_context.load_cert_chain( self._ssl_context.load_cert_chain(
certfile=APP_CONFIG.certs.cert_file, certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file
keyfile=APP_CONFIG.certs.key_file
) )
async def start_export(self) -> str: async def start_export(self) -> str:
@ -56,9 +60,13 @@ class Gmap2BriefInterface:
url = self.base_url + APP_CONFIG.gmap2brief.start_endpoint url = self.base_url + APP_CONFIG.gmap2brief.start_endpoint
async with httpx.AsyncClient( async with httpx.AsyncClient(
cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None, cert=(
(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file)
if APP_CONFIG.app.local
else None
),
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True, verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
timeout=self.timeout timeout=self.timeout,
) as client: ) as client:
try: try:
self.logger.info(f"Starting OPU export: POST {url}") self.logger.info(f"Starting OPU export: POST {url}")
@ -87,12 +95,18 @@ class Gmap2BriefInterface:
Raises: Raises:
Gmap2BriefConnectionError: При ошибке запроса Gmap2BriefConnectionError: При ошибке запроса
""" """
url = self.base_url + APP_CONFIG.gmap2brief.status_endpoint.format(job_id=job_id) url = self.base_url + APP_CONFIG.gmap2brief.status_endpoint.format(
job_id=job_id
)
async with httpx.AsyncClient( async with httpx.AsyncClient(
cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None, cert=(
(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file)
if APP_CONFIG.app.local
else None
),
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True, verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
timeout=self.timeout timeout=self.timeout,
) as client: ) as client:
try: try:
response = await client.get(url) response = await client.get(url)
@ -106,7 +120,9 @@ class Gmap2BriefInterface:
except httpx.RequestError as e: except httpx.RequestError as e:
raise Gmap2BriefConnectionError(f"Request error: {e}") from e raise Gmap2BriefConnectionError(f"Request error: {e}") from e
async def wait_for_completion(self, job_id: str, max_wait: int | None = None) -> ExportJobStatus: async def wait_for_completion(
self, job_id: str, max_wait: int | None = None
) -> ExportJobStatus:
""" """
Ждет завершения задачи экспорта с периодическим polling. Ждет завершения задачи экспорта с периодическим polling.
@ -127,14 +143,18 @@ class Gmap2BriefInterface:
status = await self.get_status(job_id) status = await self.get_status(job_id)
if status.status == "completed": if status.status == "completed":
self.logger.info(f"Job {job_id} completed, total_rows={status.total_rows}") self.logger.info(
f"Job {job_id} completed, total_rows={status.total_rows}"
)
return status return status
elif status.status == "failed": elif status.status == "failed":
raise Gmap2BriefConnectionError(f"Job {job_id} failed: {status.error}") raise Gmap2BriefConnectionError(f"Job {job_id} failed: {status.error}")
elapsed = asyncio.get_event_loop().time() - start_time elapsed = asyncio.get_event_loop().time() - start_time
if max_wait and elapsed > max_wait: if max_wait and elapsed > max_wait:
raise Gmap2BriefConnectionError(f"Job {job_id} timeout after {elapsed:.1f}s") raise Gmap2BriefConnectionError(
f"Job {job_id} timeout after {elapsed:.1f}s"
)
self.logger.debug( self.logger.debug(
f"Job {job_id} status={status.status}, rows={status.total_rows}, elapsed={elapsed:.1f}s" f"Job {job_id} status={status.status}, rows={status.total_rows}, elapsed={elapsed:.1f}s"
@ -155,12 +175,18 @@ class Gmap2BriefInterface:
Raises: Raises:
Gmap2BriefConnectionError: При ошибке скачивания Gmap2BriefConnectionError: При ошибке скачивания
""" """
url = self.base_url + APP_CONFIG.gmap2brief.download_endpoint.format(job_id=job_id) url = self.base_url + APP_CONFIG.gmap2brief.download_endpoint.format(
job_id=job_id
)
async with httpx.AsyncClient( async with httpx.AsyncClient(
cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None, cert=(
(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file)
if APP_CONFIG.app.local
else None
),
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True, verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
timeout=self.timeout timeout=self.timeout,
) as client: ) as client:
try: try:
self.logger.info(f"Downloading export: GET {url}") self.logger.info(f"Downloading export: GET {url}")

View File

@ -17,7 +17,11 @@ class ExportJobStatus(BaseModel):
"""Статус задачи экспорта.""" """Статус задачи экспорта."""
job_id: str = Field(..., description="Идентификатор задачи") job_id: str = Field(..., description="Идентификатор задачи")
status: Literal["pending", "running", "completed", "failed"] = Field(..., description="Статус задачи") status: Literal["pending", "running", "completed", "failed"] = Field(
..., description="Статус задачи"
)
total_rows: int = Field(default=0, description="Количество обработанных строк") total_rows: int = Field(default=0, description="Количество обработанных строк")
error: str | None = Field(default=None, description="Текст ошибки (если есть)") error: str | None = Field(default=None, description="Текст ошибки (если есть)")
temp_file_path: str | None = Field(default=None, description="Путь к временному файлу (для completed)") temp_file_path: str | None = Field(
default=None, description="Путь к временному файлу (для completed)"
)

View File

@ -47,8 +47,12 @@ class SuperTeneraInterface:
self._ssl_context = None self._ssl_context = None
if APP_CONFIG.app.local: if APP_CONFIG.app.local:
self._ssl_context = ssl.create_default_context(cafile=APP_CONFIG.certs.ca_bundle_file) self._ssl_context = ssl.create_default_context(
self._ssl_context.load_cert_chain(certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file) cafile=APP_CONFIG.certs.ca_bundle_file
)
self._ssl_context.load_cert_chain(
certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file
)
def form_base_headers(self) -> dict: def form_base_headers(self) -> dict:
"""Формирует базовые заголовки для запроса.""" """Формирует базовые заголовки для запроса."""
@ -57,7 +61,11 @@ class SuperTeneraInterface:
"request-time": str(datetime.now(tz=self.timezone).isoformat()), "request-time": str(datetime.now(tz=self.timezone).isoformat()),
"system-id": APP_CONFIG.app.kube_net_name, "system-id": APP_CONFIG.app.kube_net_name,
} }
return {metakey: metavalue for metakey, metavalue in metadata_pairs.items() if metavalue} return {
metakey: metavalue
for metakey, metavalue in metadata_pairs.items()
if metavalue
}
async def __aenter__(self) -> Self: async def __aenter__(self) -> Self:
"""Async context manager enter.""" """Async context manager enter."""
@ -86,7 +94,9 @@ class SuperTeneraInterface:
async with self._session.get(url, **kwargs) as response: async with self._session.get(url, **kwargs) as response:
if APP_CONFIG.app.debug: if APP_CONFIG.app.debug:
self.logger.debug(f"Response: {(await response.text(errors='ignore'))[:100]}") self.logger.debug(
f"Response: {(await response.text(errors='ignore'))[:100]}"
)
return await response.json(encoding=encoding, content_type=content_type) return await response.json(encoding=encoding, content_type=content_type)
async def get_quotes_data(self) -> MainData: async def get_quotes_data(self) -> MainData:
@ -103,7 +113,9 @@ class SuperTeneraInterface:
kwargs["ssl"] = self._ssl_context kwargs["ssl"] = self._ssl_context
try: try:
async with self._session.get( async with self._session.get(
APP_CONFIG.supertenera.quotes_endpoint, timeout=APP_CONFIG.supertenera.timeout, **kwargs APP_CONFIG.supertenera.quotes_endpoint,
timeout=APP_CONFIG.supertenera.timeout,
**kwargs,
) as resp: ) as resp:
resp.raise_for_status() resp.raise_for_status()
return True return True
@ -112,7 +124,9 @@ class SuperTeneraInterface:
f"Ошибка подключения к SuperTenera API при проверке системы - {e.status}." f"Ошибка подключения к SuperTenera API при проверке системы - {e.status}."
) from e ) from e
except TimeoutError as e: except TimeoutError as e:
raise SuperTeneraConnectionError("Ошибка Timeout подключения к SuperTenera API при проверке системы.") from e raise SuperTeneraConnectionError(
"Ошибка Timeout подключения к SuperTenera API при проверке системы."
) from e
def get_async_tenera_interface() -> SuperTeneraInterface: def get_async_tenera_interface() -> SuperTeneraInterface:

View File

@ -77,7 +77,9 @@ class InvestingCandlestick(TeneraBaseModel):
value: str = Field(alias="V") value: str = Field(alias="V")
class InvestingTimePoint(RootModel[EmptyTimePoint | InvestingNumeric | InvestingCandlestick]): class InvestingTimePoint(
RootModel[EmptyTimePoint | InvestingNumeric | InvestingCandlestick]
):
""" """
Union-модель точки времени для источника Investing.com. Union-модель точки времени для источника Investing.com.
@ -340,5 +342,9 @@ class MainData(TeneraBaseModel):
:return: Отфильтрованный объект :return: Отфильтрованный объект
""" """
if isinstance(v, dict): if isinstance(v, dict):
return {key: value for key, value in v.items() if value is not None and not str(key).isdigit()} return {
key: value
for key, value in v.items()
if value is not None and not str(key).isdigit()
}
return v return v

View File

@ -1,4 +1,3 @@
# src/dataloader/logger/__init__.py from .logger import get_logger, setup_logging
from .logger import setup_logging, get_logger
__all__ = ["setup_logging", "get_logger"] __all__ = ["setup_logging", "get_logger"]

View File

@ -1,9 +1,7 @@
# Управление контекстом запросов для логирования
import uuid import uuid
from contextvars import ContextVar from contextvars import ContextVar
from typing import Final from typing import Final
REQUEST_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("request_id", default="") REQUEST_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("request_id", default="")
DEVICE_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("device_id", default="") DEVICE_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("device_id", default="")
SESSION_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("session_id", default="") SESSION_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("session_id", default="")
@ -66,7 +64,13 @@ class ContextVarsContainer:
def gw_session_id(self, value: str) -> None: def gw_session_id(self, value: str) -> None:
GW_SESSION_ID_CTX_VAR.set(value) GW_SESSION_ID_CTX_VAR.set(value)
def set_context_vars(self, request_id: str = "", request_time: str = "", system_id: str = "", gw_session_id: str = "") -> None: def set_context_vars(
self,
request_id: str = "",
request_time: str = "",
system_id: str = "",
gw_session_id: str = "",
) -> None:
if request_id: if request_id:
self.set_request_id(request_id) self.set_request_id(request_id)
if request_time: if request_time:

View File

@ -1,17 +1,12 @@
# src/dataloader/logger/logger.py
import sys
import typing
from datetime import tzinfo
from logging import Logger
import logging import logging
import sys
from logging import Logger
from loguru import logger from loguru import logger
from .context_vars import ContextVarsContainer
from dataloader.config import APP_CONFIG from dataloader.config import APP_CONFIG
# Определяем фильтры для разных типов логов
def metric_only_filter(record: dict) -> bool: def metric_only_filter(record: dict) -> bool:
return "metric" in record["extra"] return "metric" in record["extra"]
@ -26,13 +21,12 @@ def regular_log_filter(record: dict) -> bool:
class InterceptHandler(logging.Handler): class InterceptHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None: def emit(self, record: logging.LogRecord) -> None:
# Get corresponding Loguru level if it exists
try: try:
level = logger.level(record.levelname).name level = logger.level(record.levelname).name
except ValueError: except ValueError:
level = record.levelno level = record.levelno
# Find caller from where originated the logged message
frame, depth = logging.currentframe(), 2 frame, depth = logging.currentframe(), 2
while frame.f_code.co_filename == logging.__file__: while frame.f_code.co_filename == logging.__file__:
frame = frame.f_back frame = frame.f_back

View File

@ -1 +1 @@
# Модели логов, метрик, событий аудита

View File

@ -1 +1 @@
# Функции + маскирование args

View File

@ -1,34 +1,28 @@
# Конфигурация логирования uvicorn
import logging import logging
import sys
from loguru import logger from loguru import logger
class InterceptHandler(logging.Handler): class InterceptHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None: def emit(self, record: logging.LogRecord) -> None:
# Get corresponding Loguru level if it exists
try: try:
level = logger.level(record.levelname).name level = logger.level(record.levelname).name
except ValueError: except ValueError:
level = record.levelno level = record.levelno
# Find caller from where originated the logged message
frame, depth = logging.currentframe(), 2 frame, depth = logging.currentframe(), 2
while frame.f_code.co_filename == logging.__file__: while frame.f_code.co_filename == logging.__file__:
frame = frame.f_back frame = frame.f_back
depth += 1 depth += 1
logger.opt(depth=depth, exception=record.exc_info).log( logger.opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage() level, record.getMessage()
) )
def setup_uvicorn_logging() -> None: def setup_uvicorn_logging() -> None:
# Set all uvicorn loggers to use InterceptHandler
for logger_name in ["uvicorn", "uvicorn.error", "uvicorn.access"]: for logger_name in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
log = logging.getLogger(logger_name) log = logging.getLogger(logger_name)
log.handlers = [InterceptHandler()] log.handlers = [InterceptHandler()]
@ -36,7 +30,6 @@ def setup_uvicorn_logging() -> None:
log.propagate = False log.propagate = False
# uvicorn logging config
LOGGING_CONFIG = { LOGGING_CONFIG = {
"version": 1, "version": 1,
"disable_existing_loggers": False, "disable_existing_loggers": False,

View File

@ -1,2 +1 @@
"""Модуль для работы с хранилищем данных.""" """Модуль для работы с хранилищем данных."""

View File

@ -1,7 +1,11 @@
# src/dataloader/storage/engine.py
from __future__ import annotations from __future__ import annotations
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from dataloader.config import APP_CONFIG from dataloader.config import APP_CONFIG

View File

@ -1,8 +1,8 @@
# src/dataloader/storage/models/__init__.py
""" """
ORM модели для работы с базой данных. ORM модели для работы с базой данных.
Организованы по доменам для масштабируемости. Организованы по доменам для масштабируемости.
""" """
from __future__ import annotations from __future__ import annotations
from .base import Base from .base import Base

View File

@ -1,4 +1,3 @@
# src/dataloader/storage/models/base.py
from __future__ import annotations from __future__ import annotations
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
@ -9,4 +8,5 @@ class Base(DeclarativeBase):
Базовый класс для всех ORM моделей приложения. Базовый класс для всех ORM моделей приложения.
Используется SQLAlchemy 2.0+ declarative style. Используется SQLAlchemy 2.0+ declarative style.
""" """
pass pass

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from datetime import date, datetime from datetime import date, datetime
from sqlalchemy import BigInteger, Date, Integer, Numeric, String, TIMESTAMP, text from sqlalchemy import TIMESTAMP, BigInteger, Date, Integer, Numeric, String, text
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from dataloader.storage.models.base import Base from dataloader.storage.models.base import Base
@ -16,29 +16,47 @@ class BriefDigitalCertificateOpu(Base):
__tablename__ = "brief_digital_certificate_opu" __tablename__ = "brief_digital_certificate_opu"
__table_args__ = ({"schema": "opu"},) __table_args__ = ({"schema": "opu"},)
object_id: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'")) object_id: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'-'")
)
object_nm: Mapped[str | None] = mapped_column(String) object_nm: Mapped[str | None] = mapped_column(String)
desk_nm: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'")) desk_nm: Mapped[str] = mapped_column(
actdate: Mapped[date] = mapped_column(Date, primary_key=True, server_default=text("CURRENT_DATE")) String, primary_key=True, server_default=text("'-'")
layer_cd: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'")) )
actdate: Mapped[date] = mapped_column(
Date, primary_key=True, server_default=text("CURRENT_DATE")
)
layer_cd: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'-'")
)
layer_nm: Mapped[str | None] = mapped_column(String) layer_nm: Mapped[str | None] = mapped_column(String)
opu_cd: Mapped[str] = mapped_column(String, primary_key=True) opu_cd: Mapped[str] = mapped_column(String, primary_key=True)
opu_nm_sh: Mapped[str | None] = mapped_column(String) opu_nm_sh: Mapped[str | None] = mapped_column(String)
opu_nm: Mapped[str | None] = mapped_column(String) opu_nm: Mapped[str | None] = mapped_column(String)
opu_lvl: Mapped[int] = mapped_column(Integer, primary_key=True, server_default=text("'-1'")) opu_lvl: Mapped[int] = mapped_column(
opu_prnt_cd: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'")) Integer, primary_key=True, server_default=text("'-1'")
)
opu_prnt_cd: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'-'")
)
opu_prnt_nm_sh: Mapped[str | None] = mapped_column(String) opu_prnt_nm_sh: Mapped[str | None] = mapped_column(String)
opu_prnt_nm: Mapped[str | None] = mapped_column(String) opu_prnt_nm: Mapped[str | None] = mapped_column(String)
sum_amountrub_p_usd: Mapped[float | None] = mapped_column(Numeric) sum_amountrub_p_usd: Mapped[float | None] = mapped_column(Numeric)
wf_load_id: Mapped[int] = mapped_column(BigInteger, nullable=False, server_default=text("'-1'")) wf_load_id: Mapped[int] = mapped_column(
BigInteger, nullable=False, server_default=text("'-1'")
)
wf_load_dttm: Mapped[datetime] = mapped_column( wf_load_dttm: Mapped[datetime] = mapped_column(
TIMESTAMP(timezone=False), TIMESTAMP(timezone=False),
nullable=False, nullable=False,
server_default=text("CURRENT_TIMESTAMP"), server_default=text("CURRENT_TIMESTAMP"),
) )
wf_row_id: Mapped[int] = mapped_column(BigInteger, nullable=False, server_default=text("'-1'")) wf_row_id: Mapped[int] = mapped_column(
BigInteger, nullable=False, server_default=text("'-1'")
)
object_tp: Mapped[str | None] = mapped_column(String) object_tp: Mapped[str | None] = mapped_column(String)
object_unit: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'")) object_unit: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'-'")
)
measure: Mapped[str | None] = mapped_column(String) measure: Mapped[str | None] = mapped_column(String)
product_nm: Mapped[str | None] = mapped_column(String) product_nm: Mapped[str | None] = mapped_column(String)
product_prnt_nm: Mapped[str | None] = mapped_column(String) product_prnt_nm: Mapped[str | None] = mapped_column(String)

View File

@ -1,4 +1,3 @@
# src/dataloader/storage/models/queue.py
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
@ -10,7 +9,6 @@ from sqlalchemy.orm import Mapped, mapped_column
from .base import Base from .base import Base
dl_status_enum = ENUM( dl_status_enum = ENUM(
"queued", "queued",
"running", "running",
@ -29,6 +27,7 @@ class DLJob(Base):
Модель таблицы очереди задач dl_jobs. Модель таблицы очереди задач dl_jobs.
Использует логическое имя схемы 'queue' для поддержки schema_translate_map. Использует логическое имя схемы 'queue' для поддержки schema_translate_map.
""" """
__tablename__ = "dl_jobs" __tablename__ = "dl_jobs"
__table_args__ = {"schema": "queue"} __table_args__ = {"schema": "queue"}
@ -40,15 +39,23 @@ class DLJob(Base):
lock_key: Mapped[str] = mapped_column(Text, nullable=False) lock_key: Mapped[str] = mapped_column(Text, nullable=False)
partition_key: Mapped[str] = mapped_column(Text, default="", nullable=False) partition_key: Mapped[str] = mapped_column(Text, default="", nullable=False)
priority: Mapped[int] = mapped_column(nullable=False, default=100) priority: Mapped[int] = mapped_column(nullable=False, default=100)
available_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) available_at: Mapped[datetime] = mapped_column(
status: Mapped[str] = mapped_column(dl_status_enum, nullable=False, default="queued") DateTime(timezone=True), nullable=False
)
status: Mapped[str] = mapped_column(
dl_status_enum, nullable=False, default="queued"
)
attempt: Mapped[int] = mapped_column(nullable=False, default=0) attempt: Mapped[int] = mapped_column(nullable=False, default=0)
max_attempts: Mapped[int] = mapped_column(nullable=False, default=5) max_attempts: Mapped[int] = mapped_column(nullable=False, default=5)
lease_ttl_sec: Mapped[int] = mapped_column(nullable=False, default=60) lease_ttl_sec: Mapped[int] = mapped_column(nullable=False, default=60)
lease_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True)) lease_expires_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True)
)
heartbeat_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True)) heartbeat_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
cancel_requested: Mapped[bool] = mapped_column(nullable=False, default=False) cancel_requested: Mapped[bool] = mapped_column(nullable=False, default=False)
progress: Mapped[dict[str, Any]] = mapped_column(JSONB, default=dict, nullable=False) progress: Mapped[dict[str, Any]] = mapped_column(
JSONB, default=dict, nullable=False
)
error: Mapped[Optional[str]] = mapped_column(Text) error: Mapped[Optional[str]] = mapped_column(Text)
producer: Mapped[Optional[str]] = mapped_column(Text) producer: Mapped[Optional[str]] = mapped_column(Text)
consumer_group: Mapped[Optional[str]] = mapped_column(Text) consumer_group: Mapped[Optional[str]] = mapped_column(Text)
@ -62,10 +69,13 @@ class DLJobEvent(Base):
Модель таблицы журнала событий dl_job_events. Модель таблицы журнала событий dl_job_events.
Использует логическое имя схемы 'queue' для поддержки schema_translate_map. Использует логическое имя схемы 'queue' для поддержки schema_translate_map.
""" """
__tablename__ = "dl_job_events" __tablename__ = "dl_job_events"
__table_args__ = {"schema": "queue"} __table_args__ = {"schema": "queue"}
event_id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True) event_id: Mapped[int] = mapped_column(
BigInteger, primary_key=True, autoincrement=True
)
job_id: Mapped[str] = mapped_column(UUID(as_uuid=False), nullable=False) job_id: Mapped[str] = mapped_column(UUID(as_uuid=False), nullable=False)
queue: Mapped[str] = mapped_column(Text, nullable=False) queue: Mapped[str] = mapped_column(Text, nullable=False)
load_dttm: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) load_dttm: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)

View File

@ -5,7 +5,15 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqlalchemy import JSON, TIMESTAMP, BigInteger, ForeignKey, String, UniqueConstraint, func from sqlalchemy import (
JSON,
TIMESTAMP,
BigInteger,
ForeignKey,
String,
UniqueConstraint,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from dataloader.storage.models.base import Base from dataloader.storage.models.base import Base
@ -30,7 +38,9 @@ class Quote(Base):
srce: Mapped[str | None] = mapped_column(String) srce: Mapped[str | None] = mapped_column(String)
ticker: Mapped[str | None] = mapped_column(String) ticker: Mapped[str | None] = mapped_column(String)
quote_sect_id: Mapped[int] = mapped_column( quote_sect_id: Mapped[int] = mapped_column(
ForeignKey("quotes.quotes_sect.quote_sect_id", ondelete="CASCADE", onupdate="CASCADE"), ForeignKey(
"quotes.quotes_sect.quote_sect_id", ondelete="CASCADE", onupdate="CASCADE"
),
nullable=False, nullable=False,
) )
last_update_dttm: Mapped[datetime | None] = mapped_column(TIMESTAMP(timezone=True)) last_update_dttm: Mapped[datetime | None] = mapped_column(TIMESTAMP(timezone=True))

View File

@ -5,7 +5,17 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqlalchemy import TIMESTAMP, BigInteger, Boolean, DateTime, Float, ForeignKey, String, UniqueConstraint, func from sqlalchemy import (
TIMESTAMP,
BigInteger,
Boolean,
DateTime,
Float,
ForeignKey,
String,
UniqueConstraint,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from dataloader.storage.models.base import Base from dataloader.storage.models.base import Base

View File

@ -1,16 +1,23 @@
# src/dataloader/storage/notify_listener.py
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import asyncpg
from typing import Callable, Optional from typing import Callable, Optional
import asyncpg
class PGNotifyListener: class PGNotifyListener:
""" """
Прослушиватель PostgreSQL NOTIFY для канала 'dl_jobs'. Прослушиватель PostgreSQL NOTIFY для канала 'dl_jobs'.
""" """
def __init__(self, dsn: str, queue: str, callback: Callable[[], None], stop_event: asyncio.Event):
def __init__(
self,
dsn: str,
queue: str,
callback: Callable[[], None],
stop_event: asyncio.Event,
):
self._dsn = dsn self._dsn = dsn
self._queue = queue self._queue = queue
self._callback = callback self._callback = callback

View File

@ -1,8 +1,8 @@
# src/dataloader/storage/repositories/__init__.py
""" """
Репозитории для работы с базой данных. Репозитории для работы с базой данных.
Организованы по доменам для масштабируемости. Организованы по доменам для масштабируемости.
""" """
from __future__ import annotations from __future__ import annotations
from .opu import OpuRepository from .opu import OpuRepository

View File

@ -69,7 +69,6 @@ class OpuRepository:
if not records: if not records:
return 0 return 0
# Получаем колонки для обновления (все кроме PK и технических)
update_columns = { update_columns = {
c.name c.name
for c in BriefDigitalCertificateOpu.__table__.columns for c in BriefDigitalCertificateOpu.__table__.columns

View File

@ -1,4 +1,3 @@
# src/dataloader/storage/repositories/queue.py
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@ -15,6 +14,7 @@ class QueueRepository:
""" """
Репозиторий для работы с очередью задач и журналом событий. Репозиторий для работы с очередью задач и журналом событий.
""" """
def __init__(self, session: AsyncSession): def __init__(self, session: AsyncSession):
self.s = session self.s = session
@ -62,7 +62,9 @@ class QueueRepository:
finished_at=None, finished_at=None,
) )
self.s.add(row) self.s.add(row)
await self._append_event(req.job_id, req.queue, "queued", {"task": req.task}) await self._append_event(
req.job_id, req.queue, "queued", {"task": req.task}
)
return req.job_id, "queued" return req.job_id, "queued"
async def get_status(self, job_id: str) -> Optional[JobStatus]: async def get_status(self, job_id: str) -> Optional[JobStatus]:
@ -118,7 +120,9 @@ class QueueRepository:
await self._append_event(job_id, job.queue, "cancel_requested", None) await self._append_event(job_id, job.queue, "cancel_requested", None)
return True return True
async def claim_one(self, queue: str, claim_backoff_sec: int) -> Optional[dict[str, Any]]: async def claim_one(
self, queue: str, claim_backoff_sec: int
) -> Optional[dict[str, Any]]:
""" """
Захватывает одну задачу из очереди с учётом блокировок и выставляет running. Захватывает одну задачу из очереди с учётом блокировок и выставляет running.
@ -150,15 +154,21 @@ class QueueRepository:
job.started_at = job.started_at or datetime.now(timezone.utc) job.started_at = job.started_at or datetime.now(timezone.utc)
job.attempt = int(job.attempt) + 1 job.attempt = int(job.attempt) + 1
job.heartbeat_at = datetime.now(timezone.utc) job.heartbeat_at = datetime.now(timezone.utc)
job.lease_expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(job.lease_ttl_sec)) job.lease_expires_at = datetime.now(timezone.utc) + timedelta(
seconds=int(job.lease_ttl_sec)
)
ok = await self._try_advisory_lock(job.lock_key) ok = await self._try_advisory_lock(job.lock_key)
if not ok: if not ok:
job.status = "queued" job.status = "queued"
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=claim_backoff_sec) job.available_at = datetime.now(timezone.utc) + timedelta(
seconds=claim_backoff_sec
)
return None return None
await self._append_event(job.job_id, job.queue, "picked", {"attempt": job.attempt}) await self._append_event(
job.job_id, job.queue, "picked", {"attempt": job.attempt}
)
return { return {
"job_id": job.job_id, "job_id": job.job_id,
@ -192,10 +202,15 @@ class QueueRepository:
q = ( q = (
update(DLJob) update(DLJob)
.where(DLJob.job_id == job_id, DLJob.status == "running") .where(DLJob.job_id == job_id, DLJob.status == "running")
.values(heartbeat_at=now, lease_expires_at=now + timedelta(seconds=int(ttl_sec))) .values(
heartbeat_at=now,
lease_expires_at=now + timedelta(seconds=int(ttl_sec)),
)
) )
await self.s.execute(q) await self.s.execute(q)
await self._append_event(job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec}) await self._append_event(
job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec}
)
return True, cancel_requested return True, cancel_requested
async def finish_ok(self, job_id: str) -> None: async def finish_ok(self, job_id: str) -> None:
@ -215,7 +230,9 @@ class QueueRepository:
await self._append_event(job_id, job.queue, "succeeded", None) await self._append_event(job_id, job.queue, "succeeded", None)
await self._advisory_unlock(job.lock_key) await self._advisory_unlock(job.lock_key)
async def finish_fail_or_retry(self, job_id: str, err: str, is_canceled: bool = False) -> None: async def finish_fail_or_retry(
self, job_id: str, err: str, is_canceled: bool = False
) -> None:
""" """
Помечает задачу как failed, canceled или возвращает в очередь с задержкой. Помечает задачу как failed, canceled или возвращает в очередь с задержкой.
@ -239,16 +256,25 @@ class QueueRepository:
can_retry = int(job.attempt) < int(job.max_attempts) can_retry = int(job.attempt) < int(job.max_attempts)
if can_retry: if can_retry:
job.status = "queued" job.status = "queued"
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=30 * int(job.attempt)) job.available_at = datetime.now(timezone.utc) + timedelta(
seconds=30 * int(job.attempt)
)
job.error = err job.error = err
job.lease_expires_at = None job.lease_expires_at = None
await self._append_event(job_id, job.queue, "requeue", {"attempt": job.attempt, "error": err}) await self._append_event(
job_id,
job.queue,
"requeue",
{"attempt": job.attempt, "error": err},
)
else: else:
job.status = "failed" job.status = "failed"
job.error = err job.error = err
job.finished_at = datetime.now(timezone.utc) job.finished_at = datetime.now(timezone.utc)
job.lease_expires_at = None job.lease_expires_at = None
await self._append_event(job_id, job.queue, "failed", {"error": err}) await self._append_event(
job_id, job.queue, "failed", {"error": err}
)
await self._advisory_unlock(job.lock_key) await self._advisory_unlock(job.lock_key)
async def requeue_lost(self, now: Optional[datetime] = None) -> list[str]: async def requeue_lost(self, now: Optional[datetime] = None) -> list[str]:
@ -293,7 +319,11 @@ class QueueRepository:
Возвращает: Возвращает:
ORM модель DLJob или None ORM модель DLJob или None
""" """
r = await self.s.execute(select(DLJob).where(DLJob.job_id == job_id).with_for_update(skip_locked=True)) r = await self.s.execute(
select(DLJob)
.where(DLJob.job_id == job_id)
.with_for_update(skip_locked=True)
)
return r.scalar_one_or_none() return r.scalar_one_or_none()
async def _resolve_queue(self, job_id: str) -> str: async def _resolve_queue(self, job_id: str) -> str:
@ -310,7 +340,9 @@ class QueueRepository:
v = r.scalar_one_or_none() v = r.scalar_one_or_none()
return v or "" return v or ""
async def _append_event(self, job_id: str, queue: str, kind: str, payload: Optional[dict[str, Any]]) -> None: async def _append_event(
self, job_id: str, queue: str, kind: str, payload: Optional[dict[str, Any]]
) -> None:
""" """
Добавляет запись в журнал событий. Добавляет запись в журнал событий.
@ -339,7 +371,9 @@ class QueueRepository:
Возвращает: Возвращает:
True, если блокировка получена True, если блокировка получена
""" """
r = await self.s.execute(select(func.pg_try_advisory_lock(func.hashtext(lock_key)))) r = await self.s.execute(
select(func.pg_try_advisory_lock(func.hashtext(lock_key)))
)
return bool(r.scalar()) return bool(r.scalar())
async def _advisory_unlock(self, lock_key: str) -> None: async def _advisory_unlock(self, lock_key: str) -> None:

View File

@ -39,7 +39,9 @@ class QuotesRepository:
result = await self.s.execute(stmt) result = await self.s.execute(stmt)
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def get_or_create_section(self, name: str, params: dict | None = None) -> QuoteSection: async def get_or_create_section(
self, name: str, params: dict | None = None
) -> QuoteSection:
""" """
Получить существующую секцию или создать новую. Получить существующую секцию или создать новую.

View File

@ -1,8 +1,8 @@
# src/dataloader/storage/schemas/__init__.py
""" """
DTO (Data Transfer Objects) для слоя хранилища. DTO (Data Transfer Objects) для слоя хранилища.
Организованы по доменам для масштабируемости. Организованы по доменам для масштабируемости.
""" """
from __future__ import annotations from __future__ import annotations
from .queue import CreateJobRequest, JobStatus from .queue import CreateJobRequest, JobStatus

View File

@ -1,4 +1,3 @@
# src/dataloader/storage/schemas/queue.py
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
@ -11,6 +10,7 @@ class CreateJobRequest:
""" """
DTO для создания задачи в очереди. DTO для создания задачи в очереди.
""" """
job_id: str job_id: str
queue: str queue: str
task: str task: str
@ -31,6 +31,7 @@ class JobStatus:
""" """
DTO для статуса задачи. DTO для статуса задачи.
""" """
job_id: str job_id: str
status: str status: str
attempt: int attempt: int

View File

@ -1,2 +1 @@
"""Модуль воркеров для обработки задач.""" """Модуль воркеров для обработки задач."""

View File

@ -1,4 +1,3 @@
# src/dataloader/workers/base.py
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -9,8 +8,8 @@ from typing import AsyncIterator, Callable, Optional
from dataloader.config import APP_CONFIG from dataloader.config import APP_CONFIG
from dataloader.context import APP_CTX from dataloader.context import APP_CTX
from dataloader.storage.repositories import QueueRepository
from dataloader.storage.notify_listener import PGNotifyListener from dataloader.storage.notify_listener import PGNotifyListener
from dataloader.storage.repositories import QueueRepository
from dataloader.workers.pipelines.registry import resolve as resolve_pipeline from dataloader.workers.pipelines.registry import resolve as resolve_pipeline
@ -19,6 +18,7 @@ class WorkerConfig:
""" """
Конфигурация воркера. Конфигурация воркера.
""" """
queue: str queue: str
heartbeat_sec: int heartbeat_sec: int
claim_backoff_sec: int claim_backoff_sec: int
@ -28,6 +28,7 @@ class PGWorker:
""" """
Базовый асинхронный воркер очереди Postgres. Базовый асинхронный воркер очереди Postgres.
""" """
def __init__(self, cfg: WorkerConfig, stop_event: asyncio.Event) -> None: def __init__(self, cfg: WorkerConfig, stop_event: asyncio.Event) -> None:
self._cfg = cfg self._cfg = cfg
self._stop = stop_event self._stop = stop_event
@ -46,12 +47,14 @@ class PGWorker:
dsn=APP_CONFIG.pg.url, dsn=APP_CONFIG.pg.url,
queue=self._cfg.queue, queue=self._cfg.queue,
callback=lambda: self._notify_wakeup.set(), callback=lambda: self._notify_wakeup.set(),
stop_event=self._stop stop_event=self._stop,
) )
try: try:
await self._listener.start() await self._listener.start()
except Exception as e: except Exception as e:
self._log.warning(f"Failed to start LISTEN/NOTIFY, falling back to polling: {e}") self._log.warning(
f"Failed to start LISTEN/NOTIFY, falling back to polling: {e}"
)
self._listener = None self._listener = None
try: try:
@ -69,24 +72,27 @@ class PGWorker:
Ожидание появления задач через LISTEN/NOTIFY или с тайм-аутом. Ожидание появления задач через LISTEN/NOTIFY или с тайм-аутом.
""" """
if self._listener: if self._listener:
# Используем LISTEN/NOTIFY с fallback на таймаут
done, pending = await asyncio.wait( done, pending = await asyncio.wait(
[asyncio.create_task(self._notify_wakeup.wait()), asyncio.create_task(self._stop.wait())], [
asyncio.create_task(self._notify_wakeup.wait()),
asyncio.create_task(self._stop.wait()),
],
return_when=asyncio.FIRST_COMPLETED, return_when=asyncio.FIRST_COMPLETED,
timeout=timeout_sec timeout=timeout_sec,
) )
# Отменяем оставшиеся задачи
for task in pending: for task in pending:
task.cancel() task.cancel()
try: try:
await task await task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# Очищаем событие, если оно было установлено
if self._notify_wakeup.is_set(): if self._notify_wakeup.is_set():
self._notify_wakeup.clear() self._notify_wakeup.clear()
else: else:
# Fallback на простой таймаут
try: try:
await asyncio.wait_for(self._stop.wait(), timeout=timeout_sec) await asyncio.wait_for(self._stop.wait(), timeout=timeout_sec)
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -110,30 +116,42 @@ class PGWorker:
args = row["args"] args = row["args"]
try: try:
canceled = await self._execute_with_heartbeat(job_id, ttl, self._pipeline(task, args)) canceled = await self._execute_with_heartbeat(
job_id, ttl, self._pipeline(task, args)
)
if canceled: if canceled:
await repo.finish_fail_or_retry(job_id, "canceled by user", is_canceled=True) await repo.finish_fail_or_retry(
job_id, "canceled by user", is_canceled=True
)
else: else:
await repo.finish_ok(job_id) await repo.finish_ok(job_id)
return True return True
except asyncio.CancelledError: except asyncio.CancelledError:
await repo.finish_fail_or_retry(job_id, "cancelled by shutdown", is_canceled=True) await repo.finish_fail_or_retry(
job_id, "cancelled by shutdown", is_canceled=True
)
raise raise
except Exception as e: except Exception as e:
await repo.finish_fail_or_retry(job_id, str(e)) await repo.finish_fail_or_retry(job_id, str(e))
return True return True
async def _execute_with_heartbeat(self, job_id: str, ttl: int, it: AsyncIterator[None]) -> bool: async def _execute_with_heartbeat(
self, job_id: str, ttl: int, it: AsyncIterator[None]
) -> bool:
""" """
Исполняет конвейер с поддержкой heartbeat. Исполняет конвейер с поддержкой heartbeat.
Возвращает True, если задача была отменена (cancel_requested). Возвращает True, если задача была отменена (cancel_requested).
""" """
next_hb = datetime.now(timezone.utc) + timedelta(seconds=self._cfg.heartbeat_sec) next_hb = datetime.now(timezone.utc) + timedelta(
seconds=self._cfg.heartbeat_sec
)
async for _ in it: async for _ in it:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
if now >= next_hb: if now >= next_hb:
async with self._sm() as s_hb: async with self._sm() as s_hb:
success, cancel_requested = await QueueRepository(s_hb).heartbeat(job_id, ttl) success, cancel_requested = await QueueRepository(s_hb).heartbeat(
job_id, ttl
)
if cancel_requested: if cancel_requested:
return True return True
next_hb = now + timedelta(seconds=self._cfg.heartbeat_sec) next_hb = now + timedelta(seconds=self._cfg.heartbeat_sec)

View File

@ -1,4 +1,3 @@
# src/dataloader/workers/manager.py
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -16,6 +15,7 @@ class WorkerSpec:
""" """
Конфигурация набора воркеров для очереди. Конфигурация набора воркеров для очереди.
""" """
queue: str queue: str
concurrency: int concurrency: int
@ -24,6 +24,7 @@ class WorkerManager:
""" """
Управляет жизненным циклом асинхронных воркеров. Управляет жизненным циклом асинхронных воркеров.
""" """
def __init__(self, specs: list[WorkerSpec]) -> None: def __init__(self, specs: list[WorkerSpec]) -> None:
self._log = APP_CTX.get_logger() self._log = APP_CTX.get_logger()
self._specs = specs self._specs = specs
@ -40,15 +41,22 @@ class WorkerManager:
for spec in self._specs: for spec in self._specs:
for i in range(max(1, spec.concurrency)): for i in range(max(1, spec.concurrency)):
cfg = WorkerConfig(queue=spec.queue, heartbeat_sec=hb, claim_backoff_sec=backoff) cfg = WorkerConfig(
t = asyncio.create_task(PGWorker(cfg, self._stop).run(), name=f"worker:{spec.queue}:{i}") queue=spec.queue, heartbeat_sec=hb, claim_backoff_sec=backoff
)
t = asyncio.create_task(
PGWorker(cfg, self._stop).run(), name=f"worker:{spec.queue}:{i}"
)
self._tasks.append(t) self._tasks.append(t)
self._reaper_task = asyncio.create_task(self._reaper_loop(), name="reaper") self._reaper_task = asyncio.create_task(self._reaper_loop(), name="reaper")
self._log.info( self._log.info(
"worker_manager.started", "worker_manager.started",
extra={"specs": [spec.__dict__ for spec in self._specs], "total_tasks": len(self._tasks)}, extra={
"specs": [spec.__dict__ for spec in self._specs],
"total_tasks": len(self._tasks),
},
) )
async def stop(self) -> None: async def stop(self) -> None:

View File

@ -1,6 +1,5 @@
"""Модуль пайплайнов обработки задач.""" """Модуль пайплайнов обработки задач."""
# src/dataloader/workers/pipelines/__init__.py
from __future__ import annotations from __future__ import annotations
import importlib import importlib

View File

@ -59,7 +59,6 @@ def _parse_jsonl_from_zst(file_path: Path, chunk_size: int = 10000):
APP_CTX.logger.warning(f"Failed to parse JSON line: {e}") APP_CTX.logger.warning(f"Failed to parse JSON line: {e}")
continue continue
# Process remaining buffer
if buffer.strip(): if buffer.strip():
try: try:
record = orjson.loads(buffer) record = orjson.loads(buffer)
@ -85,11 +84,9 @@ def _convert_record(raw: dict[str, Any]) -> dict[str, Any]:
""" """
result = raw.copy() result = raw.copy()
# Преобразуем actdate из ISO строки в date
if "actdate" in result and isinstance(result["actdate"], str): if "actdate" in result and isinstance(result["actdate"], str):
result["actdate"] = datetime.fromisoformat(result["actdate"]).date() result["actdate"] = datetime.fromisoformat(result["actdate"]).date()
# Преобразуем wf_load_dttm из ISO строки в datetime
if "wf_load_dttm" in result and isinstance(result["wf_load_dttm"], str): if "wf_load_dttm" in result and isinstance(result["wf_load_dttm"], str):
result["wf_load_dttm"] = datetime.fromisoformat(result["wf_load_dttm"]) result["wf_load_dttm"] = datetime.fromisoformat(result["wf_load_dttm"])
@ -118,18 +115,15 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
logger = APP_CTX.logger logger = APP_CTX.logger
logger.info("Starting OPU ETL pipeline") logger.info("Starting OPU ETL pipeline")
# Шаг 1: Запуск экспорта
interface = get_gmap2brief_interface() interface = get_gmap2brief_interface()
job_id = await interface.start_export() job_id = await interface.start_export()
logger.info(f"OPU export job started: {job_id}") logger.info(f"OPU export job started: {job_id}")
yield yield
# Шаг 2: Ожидание завершения
status = await interface.wait_for_completion(job_id) status = await interface.wait_for_completion(job_id)
logger.info(f"OPU export completed: {status.total_rows} rows") logger.info(f"OPU export completed: {status.total_rows} rows")
yield yield
# Шаг 3: Скачивание архива
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir) temp_path = Path(temp_dir)
archive_path = temp_path / f"opu_export_{job_id}.jsonl.zst" archive_path = temp_path / f"opu_export_{job_id}.jsonl.zst"
@ -138,7 +132,6 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
logger.info(f"OPU archive downloaded: {archive_path.stat().st_size:,} bytes") logger.info(f"OPU archive downloaded: {archive_path.stat().st_size:,} bytes")
yield yield
# Шаг 4: Truncate таблицы
async with APP_CTX.sessionmaker() as session: async with APP_CTX.sessionmaker() as session:
repo = OpuRepository(session) repo = OpuRepository(session)
await repo.truncate() await repo.truncate()
@ -146,7 +139,6 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
logger.info("OPU table truncated") logger.info("OPU table truncated")
yield yield
# Шаг 5: Загрузка данных стримингово
total_inserted = 0 total_inserted = 0
batch_num = 0 batch_num = 0
@ -154,17 +146,16 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
for batch in _parse_jsonl_from_zst(archive_path, chunk_size=5000): for batch in _parse_jsonl_from_zst(archive_path, chunk_size=5000):
batch_num += 1 batch_num += 1
# Конвертируем записи
converted = [_convert_record(rec) for rec in batch] converted = [_convert_record(rec) for rec in batch]
# Вставляем батч
inserted = await repo.bulk_insert(converted) inserted = await repo.bulk_insert(converted)
await session.commit() await session.commit()
total_inserted += inserted total_inserted += inserted
logger.debug(f"Batch {batch_num}: inserted {inserted} rows (total: {total_inserted})") logger.debug(
f"Batch {batch_num}: inserted {inserted} rows (total: {total_inserted})"
)
# Heartbeat после каждого батча
yield yield
logger.info(f"OPU ETL completed: {total_inserted} rows inserted") logger.info(f"OPU ETL completed: {total_inserted} rows inserted")

View File

@ -59,7 +59,9 @@ def _parse_ts_to_datetime(ts: str) -> datetime | None:
return None return None
def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] | None: # noqa: C901 def _build_value_row(
source: str, dt: datetime, point: Any
) -> dict[str, Any] | None: # noqa: C901
"""Строит строку для `quotes_values` по источнику и типу точки.""" """Строит строку для `quotes_values` по источнику и типу точки."""
if isinstance(point, int): if isinstance(point, int):
return {"dt": dt, "key": point} return {"dt": dt, "key": point}
@ -82,7 +84,10 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] |
if isinstance(deep_inner, InvestingCandlestick): if isinstance(deep_inner, InvestingCandlestick):
return { return {
"dt": dt, "dt": dt,
"price_o": _to_float(getattr(deep_inner, "open_", None) or getattr(deep_inner, "open", None)), "price_o": _to_float(
getattr(deep_inner, "open_", None)
or getattr(deep_inner, "open", None)
),
"price_h": _to_float(deep_inner.high), "price_h": _to_float(deep_inner.high),
"price_l": _to_float(deep_inner.low), "price_l": _to_float(deep_inner.low),
"price_c": _to_float(deep_inner.close), "price_c": _to_float(deep_inner.close),
@ -92,12 +97,16 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] |
if isinstance(inner, TradingViewTimePoint | SgxTimePoint): if isinstance(inner, TradingViewTimePoint | SgxTimePoint):
return { return {
"dt": dt, "dt": dt,
"price_o": _to_float(getattr(inner, "open_", None) or getattr(inner, "open", None)), "price_o": _to_float(
getattr(inner, "open_", None) or getattr(inner, "open", None)
),
"price_h": _to_float(inner.high), "price_h": _to_float(inner.high),
"price_l": _to_float(inner.low), "price_l": _to_float(inner.low),
"price_c": _to_float(inner.close), "price_c": _to_float(inner.close),
"volume": _to_float( "volume": _to_float(
getattr(inner, "volume", None) or getattr(inner, "interest", None) or getattr(inner, "value", None) getattr(inner, "volume", None)
or getattr(inner, "interest", None)
or getattr(inner, "value", None)
), ),
} }
@ -132,7 +141,9 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] |
"dt": dt, "dt": dt,
"value_last": _to_float(deep_inner.last), "value_last": _to_float(deep_inner.last),
"value_previous": _to_float(deep_inner.previous), "value_previous": _to_float(deep_inner.previous),
"unit": str(deep_inner.unit) if deep_inner.unit is not None else None, "unit": (
str(deep_inner.unit) if deep_inner.unit is not None else None
),
} }
if isinstance(deep_inner, TradingEconomicsStringPercent): if isinstance(deep_inner, TradingEconomicsStringPercent):
@ -218,7 +229,14 @@ async def load_tenera(args: dict) -> AsyncIterator[None]:
async with APP_CTX.sessionmaker() as session: async with APP_CTX.sessionmaker() as session:
repo = QuotesRepository(session) repo = QuotesRepository(session)
for source_name in ("cbr", "investing", "sgx", "tradingeconomics", "bloomberg", "trading_view"): for source_name in (
"cbr",
"investing",
"sgx",
"tradingeconomics",
"bloomberg",
"trading_view",
):
source_data = getattr(data, source_name) source_data = getattr(data, source_name)
if not source_data: if not source_data:
continue continue

View File

@ -1,4 +1,3 @@
# src/dataloader/workers/pipelines/noop.py
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio

View File

@ -1,4 +1,3 @@
# src/dataloader/workers/pipelines/registry.py
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Dict, Iterable from typing import Any, Callable, Dict, Iterable
@ -6,13 +5,17 @@ from typing import Any, Callable, Dict, Iterable
_Registry: Dict[str, Callable[[dict[str, Any]], Any]] = {} _Registry: Dict[str, Callable[[dict[str, Any]], Any]] = {}
def register(task: str) -> Callable[[Callable[[dict[str, Any]], Any]], Callable[[dict[str, Any]], Any]]: def register(
task: str,
) -> Callable[[Callable[[dict[str, Any]], Any]], Callable[[dict[str, Any]], Any]]:
""" """
Регистрирует обработчик пайплайна под именем задачи. Регистрирует обработчик пайплайна под именем задачи.
""" """
def _wrap(fn: Callable[[dict[str, Any]], Any]) -> Callable[[dict[str, Any]], Any]: def _wrap(fn: Callable[[dict[str, Any]], Any]) -> Callable[[dict[str, Any]], Any]:
_Registry[task] = fn _Registry[task] = fn
return fn return fn
return _wrap return _wrap
@ -22,8 +25,8 @@ def resolve(task: str) -> Callable[[dict[str, Any]], Any]:
""" """
try: try:
return _Registry[task] return _Registry[task]
except KeyError: except KeyError as err:
raise KeyError(f"pipeline not found: {task}") raise KeyError(f"pipeline not found: {task}") from err
def tasks() -> Iterable[str]: def tasks() -> Iterable[str]:

View File

@ -1,7 +1,7 @@
# src/dataloader/workers/reaper.py
from __future__ import annotations from __future__ import annotations
from typing import Sequence from typing import Sequence
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from dataloader.storage.repositories import QueueRepository from dataloader.storage.repositories import QueueRepository

View File

@ -1,4 +1,3 @@
# tests/conftest.py
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -9,23 +8,23 @@ from uuid import uuid4
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from dotenv import load_dotenv from dotenv import load_dotenv
from httpx import AsyncClient, ASGITransport from httpx import ASGITransport, AsyncClient
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
load_dotenv()
from dataloader.api import app_main from dataloader.api import app_main
from dataloader.config import APP_CONFIG from dataloader.config import APP_CONFIG
from dataloader.context import APP_CTX, get_session from dataloader.context import get_session
from dataloader.storage.models import Base from dataloader.storage.engine import create_engine
from dataloader.storage.engine import create_engine, create_sessionmaker
load_dotenv()
if sys.platform == "win32": if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
@pytest_asyncio.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
async def db_engine() -> AsyncGenerator[AsyncEngine, None]: async def db_engine() -> AsyncGenerator[AsyncEngine, None]:
""" """
@ -39,15 +38,15 @@ async def db_engine() -> AsyncGenerator[AsyncEngine, None]:
await engine.dispose() await engine.dispose()
@pytest_asyncio.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]: async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
""" """
Предоставляет сессию БД для каждого теста. Предоставляет сессию БД для каждого теста.
НЕ использует транзакцию, чтобы работали advisory locks. НЕ использует транзакцию, чтобы работали advisory locks.
""" """
sessionmaker = async_sessionmaker(bind=db_engine, expire_on_commit=False, class_=AsyncSession) sessionmaker = async_sessionmaker(
bind=db_engine, expire_on_commit=False, class_=AsyncSession
)
async with sessionmaker() as session: async with sessionmaker() as session:
yield session yield session
await session.rollback() await session.rollback()
@ -69,6 +68,7 @@ async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
""" """
HTTP клиент для тестирования API. HTTP клиент для тестирования API.
""" """
async def override_get_session() -> AsyncGenerator[AsyncSession, None]: async def override_get_session() -> AsyncGenerator[AsyncSession, None]:
yield db_session yield db_session

View File

@ -1 +1 @@
# tests/integration_tests/__init__.py

View File

@ -1,4 +1,3 @@
# tests/integration_tests/test_api_endpoints.py
from __future__ import annotations from __future__ import annotations
import pytest import pytest

View File

@ -1,11 +1,11 @@
# tests/unit/test_api_router_not_found.py
from __future__ import annotations from __future__ import annotations
import pytest from uuid import UUID, uuid4
from uuid import uuid4, UUID
import pytest
from dataloader.api.v1.router import get_status, cancel_job
from dataloader.api.v1.exceptions import JobNotFoundError from dataloader.api.v1.exceptions import JobNotFoundError
from dataloader.api.v1.router import cancel_job, get_status
from dataloader.api.v1.schemas import JobStatusResponse from dataloader.api.v1.schemas import JobStatusResponse

View File

@ -1,11 +1,11 @@
# tests/unit/test_api_router_success.py
from __future__ import annotations from __future__ import annotations
import pytest
from uuid import uuid4, UUID
from datetime import datetime, timezone from datetime import datetime, timezone
from uuid import UUID, uuid4
from dataloader.api.v1.router import get_status, cancel_job import pytest
from dataloader.api.v1.router import cancel_job, get_status
from dataloader.api.v1.schemas import JobStatusResponse from dataloader.api.v1.schemas import JobStatusResponse

View File

@ -1,7 +1,6 @@
# tests/integration_tests/test_queue_repository.py
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone, timedelta from datetime import datetime, timedelta, timezone
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
@ -10,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from dataloader.storage.models import DLJob from dataloader.storage.models import DLJob
from dataloader.storage.repositories import QueueRepository from dataloader.storage.repositories import QueueRepository
from dataloader.storage.schemas import CreateJobRequest, JobStatus from dataloader.storage.schemas import CreateJobRequest
@pytest.mark.integration @pytest.mark.integration
@ -445,6 +444,7 @@ class TestQueueRepository:
await repo.claim_one(queue_name, claim_backoff_sec=15) await repo.claim_one(queue_name, claim_backoff_sec=15)
import asyncio import asyncio
await asyncio.sleep(2) await asyncio.sleep(2)
requeued = await repo.requeue_lost() requeued = await repo.requeue_lost()
@ -500,7 +500,9 @@ class TestQueueRepository:
assert st is not None assert st is not None
assert st.status == "queued" assert st.status == "queued"
row = (await db_session.execute(select(DLJob).where(DLJob.job_id == job_id))).scalar_one() row = (
await db_session.execute(select(DLJob).where(DLJob.job_id == job_id))
).scalar_one()
assert row.available_at >= before + timedelta(seconds=15) assert row.available_at >= before + timedelta(seconds=15)
assert row.available_at <= after + timedelta(seconds=60) assert row.available_at <= after + timedelta(seconds=60)
@ -569,7 +571,9 @@ class TestQueueRepository:
await repo.create_or_get(req) await repo.create_or_get(req)
await repo.claim_one(queue_name, claim_backoff_sec=5) await repo.claim_one(queue_name, claim_backoff_sec=5)
await repo.finish_fail_or_retry(job_id, err="Canceled by test", is_canceled=True) await repo.finish_fail_or_retry(
job_id, err="Canceled by test", is_canceled=True
)
st = await repo.get_status(job_id) st = await repo.get_status(job_id)
assert st is not None assert st is not None

View File

@ -1 +1 @@
# tests/unit/__init__.py

View File

@ -1,13 +1,13 @@
# tests/unit/test_api_service.py
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
from uuid import UUID
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
from uuid import UUID
import pytest import pytest
from dataloader.api.v1.service import JobsService
from dataloader.api.v1.schemas import TriggerJobRequest from dataloader.api.v1.schemas import TriggerJobRequest
from dataloader.api.v1.service import JobsService
from dataloader.storage.schemas import JobStatus from dataloader.storage.schemas import JobStatus
@ -38,15 +38,19 @@ class TestJobsService:
""" """
mock_session = AsyncMock() mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ with (
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \ patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id: patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
):
mock_get_logger.return_value = Mock() mock_get_logger.return_value = Mock()
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678") mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
mock_repo = Mock() mock_repo = Mock()
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued")) mock_repo.create_or_get = AsyncMock(
return_value=("12345678-1234-5678-1234-567812345678", "queued")
)
mock_repo_cls.return_value = mock_repo mock_repo_cls.return_value = mock_repo
service = JobsService(mock_session) service = JobsService(mock_session)
@ -58,7 +62,7 @@ class TestJobsService:
lock_key="lock_1", lock_key="lock_1",
priority=100, priority=100,
max_attempts=5, max_attempts=5,
lease_ttl_sec=60 lease_ttl_sec=60,
) )
response = await service.trigger(req) response = await service.trigger(req)
@ -74,15 +78,19 @@ class TestJobsService:
""" """
mock_session = AsyncMock() mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ with (
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \ patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id: patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
):
mock_get_logger.return_value = Mock() mock_get_logger.return_value = Mock()
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678") mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
mock_repo = Mock() mock_repo = Mock()
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued")) mock_repo.create_or_get = AsyncMock(
return_value=("12345678-1234-5678-1234-567812345678", "queued")
)
mock_repo_cls.return_value = mock_repo mock_repo_cls.return_value = mock_repo
service = JobsService(mock_session) service = JobsService(mock_session)
@ -95,7 +103,7 @@ class TestJobsService:
lock_key="lock_1", lock_key="lock_1",
priority=100, priority=100,
max_attempts=5, max_attempts=5,
lease_ttl_sec=60 lease_ttl_sec=60,
) )
response = await service.trigger(req) response = await service.trigger(req)
@ -112,15 +120,19 @@ class TestJobsService:
""" """
mock_session = AsyncMock() mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ with (
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \ patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id: patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
):
mock_get_logger.return_value = Mock() mock_get_logger.return_value = Mock()
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678") mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
mock_repo = Mock() mock_repo = Mock()
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued")) mock_repo.create_or_get = AsyncMock(
return_value=("12345678-1234-5678-1234-567812345678", "queued")
)
mock_repo_cls.return_value = mock_repo mock_repo_cls.return_value = mock_repo
service = JobsService(mock_session) service = JobsService(mock_session)
@ -135,10 +147,10 @@ class TestJobsService:
available_at=future_time, available_at=future_time,
priority=100, priority=100,
max_attempts=5, max_attempts=5,
lease_ttl_sec=60 lease_ttl_sec=60,
) )
response = await service.trigger(req) await service.trigger(req)
call_args = mock_repo.create_or_get.call_args[0][0] call_args = mock_repo.create_or_get.call_args[0][0]
assert call_args.available_at == future_time assert call_args.available_at == future_time
@ -150,15 +162,19 @@ class TestJobsService:
""" """
mock_session = AsyncMock() mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ with (
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \ patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id: patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
):
mock_get_logger.return_value = Mock() mock_get_logger.return_value = Mock()
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678") mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
mock_repo = Mock() mock_repo = Mock()
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued")) mock_repo.create_or_get = AsyncMock(
return_value=("12345678-1234-5678-1234-567812345678", "queued")
)
mock_repo_cls.return_value = mock_repo mock_repo_cls.return_value = mock_repo
service = JobsService(mock_session) service = JobsService(mock_session)
@ -173,10 +189,10 @@ class TestJobsService:
consumer_group="test_group", consumer_group="test_group",
priority=100, priority=100,
max_attempts=5, max_attempts=5,
lease_ttl_sec=60 lease_ttl_sec=60,
) )
response = await service.trigger(req) await service.trigger(req)
call_args = mock_repo.create_or_get.call_args[0][0] call_args = mock_repo.create_or_get.call_args[0][0]
assert call_args.partition_key == "partition_1" assert call_args.partition_key == "partition_1"
@ -190,8 +206,10 @@ class TestJobsService:
""" """
mock_session = AsyncMock() mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ with (
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls: patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
):
mock_get_logger.return_value = Mock() mock_get_logger.return_value = Mock()
@ -204,7 +222,7 @@ class TestJobsService:
finished_at=None, finished_at=None,
heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc), heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc),
error=None, error=None,
progress={"step": 1} progress={"step": 1},
) )
mock_repo.get_status = AsyncMock(return_value=mock_status) mock_repo.get_status = AsyncMock(return_value=mock_status)
mock_repo_cls.return_value = mock_repo mock_repo_cls.return_value = mock_repo
@ -227,8 +245,10 @@ class TestJobsService:
""" """
mock_session = AsyncMock() mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ with (
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls: patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
):
mock_get_logger.return_value = Mock() mock_get_logger.return_value = Mock()
@ -250,8 +270,10 @@ class TestJobsService:
""" """
mock_session = AsyncMock() mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ with (
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls: patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
):
mock_get_logger.return_value = Mock() mock_get_logger.return_value = Mock()
@ -265,7 +287,7 @@ class TestJobsService:
finished_at=None, finished_at=None,
heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc), heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc),
error=None, error=None,
progress={} progress={},
) )
mock_repo.get_status = AsyncMock(return_value=mock_status) mock_repo.get_status = AsyncMock(return_value=mock_status)
mock_repo_cls.return_value = mock_repo mock_repo_cls.return_value = mock_repo
@ -287,8 +309,10 @@ class TestJobsService:
""" """
mock_session = AsyncMock() mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ with (
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls: patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
):
mock_get_logger.return_value = Mock() mock_get_logger.return_value = Mock()
@ -312,8 +336,10 @@ class TestJobsService:
""" """
mock_session = AsyncMock() mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \ with (
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls: patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
):
mock_get_logger.return_value = Mock() mock_get_logger.return_value = Mock()
@ -326,7 +352,7 @@ class TestJobsService:
finished_at=None, finished_at=None,
heartbeat_at=None, heartbeat_at=None,
error=None, error=None,
progress=None progress=None,
) )
mock_repo.get_status = AsyncMock(return_value=mock_status) mock_repo.get_status = AsyncMock(return_value=mock_status)
mock_repo_cls.return_value = mock_repo mock_repo_cls.return_value = mock_repo

View File

@ -1,18 +1,18 @@
# tests/unit/test_config.py
from __future__ import annotations from __future__ import annotations
import json import json
from logging import DEBUG, INFO from logging import DEBUG, INFO
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from dataloader.config import ( from dataloader.config import (
BaseAppSettings,
AppSettings, AppSettings,
BaseAppSettings,
LogSettings, LogSettings,
PGSettings, PGSettings,
WorkerSettings,
Secrets, Secrets,
WorkerSettings,
) )
@ -81,12 +81,15 @@ class TestAppSettings:
""" """
Тест загрузки из переменных окружения. Тест загрузки из переменных окружения.
""" """
with patch.dict("os.environ", { with patch.dict(
"APP_HOST": "127.0.0.1", "os.environ",
"APP_PORT": "9000", {
"PROJECT_NAME": "TestProject", "APP_HOST": "127.0.0.1",
"TIMEZONE": "UTC" "APP_PORT": "9000",
}): "PROJECT_NAME": "TestProject",
"TIMEZONE": "UTC",
},
):
settings = AppSettings() settings = AppSettings()
assert settings.app_host == "127.0.0.1" assert settings.app_host == "127.0.0.1"
@ -136,7 +139,9 @@ class TestLogSettings:
""" """
Тест свойства log_file_abs_path. Тест свойства log_file_abs_path.
""" """
with patch.dict("os.environ", {"LOG_PATH": "/var/log", "LOG_FILE_NAME": "test.log"}): with patch.dict(
"os.environ", {"LOG_PATH": "/var/log", "LOG_FILE_NAME": "test.log"}
):
settings = LogSettings() settings = LogSettings()
assert "test.log" in settings.log_file_abs_path assert "test.log" in settings.log_file_abs_path
@ -146,7 +151,10 @@ class TestLogSettings:
""" """
Тест свойства metric_file_abs_path. Тест свойства metric_file_abs_path.
""" """
with patch.dict("os.environ", {"METRIC_PATH": "/var/metrics", "METRIC_FILE_NAME": "metrics.log"}): with patch.dict(
"os.environ",
{"METRIC_PATH": "/var/metrics", "METRIC_FILE_NAME": "metrics.log"},
):
settings = LogSettings() settings = LogSettings()
assert "metrics.log" in settings.metric_file_abs_path assert "metrics.log" in settings.metric_file_abs_path
@ -156,7 +164,10 @@ class TestLogSettings:
""" """
Тест свойства audit_file_abs_path. Тест свойства audit_file_abs_path.
""" """
with patch.dict("os.environ", {"AUDIT_LOG_PATH": "/var/audit", "AUDIT_LOG_FILE_NAME": "audit.log"}): with patch.dict(
"os.environ",
{"AUDIT_LOG_PATH": "/var/audit", "AUDIT_LOG_FILE_NAME": "audit.log"},
):
settings = LogSettings() settings = LogSettings()
assert "audit.log" in settings.audit_file_abs_path assert "audit.log" in settings.audit_file_abs_path
@ -208,29 +219,37 @@ class TestPGSettings:
""" """
Тест формирования строки подключения. Тест формирования строки подключения.
""" """
with patch.dict("os.environ", { with patch.dict(
"PG_HOST": "db.example.com", "os.environ",
"PG_PORT": "5433", {
"PG_USER": "testuser", "PG_HOST": "db.example.com",
"PG_PASSWORD": "testpass", "PG_PORT": "5433",
"PG_DATABASE": "testdb" "PG_USER": "testuser",
}): "PG_PASSWORD": "testpass",
"PG_DATABASE": "testdb",
},
):
settings = PGSettings() settings = PGSettings()
expected = "postgresql+asyncpg://testuser:testpass@db.example.com:5433/testdb" expected = (
"postgresql+asyncpg://testuser:testpass@db.example.com:5433/testdb"
)
assert settings.url == expected assert settings.url == expected
def test_url_property_with_empty_password(self): def test_url_property_with_empty_password(self):
""" """
Тест строки подключения с пустым паролем. Тест строки подключения с пустым паролем.
""" """
with patch.dict("os.environ", { with patch.dict(
"PG_HOST": "localhost", "os.environ",
"PG_PORT": "5432", {
"PG_USER": "postgres", "PG_HOST": "localhost",
"PG_PASSWORD": "", "PG_PORT": "5432",
"PG_DATABASE": "testdb" "PG_USER": "postgres",
}): "PG_PASSWORD": "",
"PG_DATABASE": "testdb",
},
):
settings = PGSettings() settings = PGSettings()
expected = "postgresql+asyncpg://postgres:@localhost:5432/testdb" expected = "postgresql+asyncpg://postgres:@localhost:5432/testdb"
@ -240,15 +259,18 @@ class TestPGSettings:
""" """
Тест загрузки из переменных окружения. Тест загрузки из переменных окружения.
""" """
with patch.dict("os.environ", { with patch.dict(
"PG_HOST": "testhost", "os.environ",
"PG_PORT": "5433", {
"PG_USER": "testuser", "PG_HOST": "testhost",
"PG_PASSWORD": "testpass", "PG_PORT": "5433",
"PG_DATABASE": "testdb", "PG_USER": "testuser",
"PG_SCHEMA_QUEUE": "queue_schema", "PG_PASSWORD": "testpass",
"PG_POOL_SIZE": "20" "PG_DATABASE": "testdb",
}): "PG_SCHEMA_QUEUE": "queue_schema",
"PG_POOL_SIZE": "20",
},
):
settings = PGSettings() settings = PGSettings()
assert settings.host == "testhost" assert settings.host == "testhost"
@ -292,10 +314,12 @@ class TestWorkerSettings:
""" """
Тест парсинга валидного JSON. Тест парсинга валидного JSON.
""" """
workers_json = json.dumps([ workers_json = json.dumps(
{"queue": "queue1", "concurrency": 2}, [
{"queue": "queue2", "concurrency": 3} {"queue": "queue1", "concurrency": 2},
]) {"queue": "queue2", "concurrency": 3},
]
)
with patch.dict("os.environ", {"WORKERS_JSON": workers_json}): with patch.dict("os.environ", {"WORKERS_JSON": workers_json}):
settings = WorkerSettings() settings = WorkerSettings()
@ -311,12 +335,14 @@ class TestWorkerSettings:
""" """
Тест фильтрации не-словарей из JSON. Тест фильтрации не-словарей из JSON.
""" """
workers_json = json.dumps([ workers_json = json.dumps(
{"queue": "queue1", "concurrency": 2}, [
"invalid_item", {"queue": "queue1", "concurrency": 2},
123, "invalid_item",
{"queue": "queue2", "concurrency": 3} 123,
]) {"queue": "queue2", "concurrency": 3},
]
)
with patch.dict("os.environ", {"WORKERS_JSON": workers_json}): with patch.dict("os.environ", {"WORKERS_JSON": workers_json}):
settings = WorkerSettings() settings = WorkerSettings()

View File

@ -1,7 +1,7 @@
# tests/unit/test_context.py
from __future__ import annotations from __future__ import annotations
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from dataloader.context import AppContext, get_session from dataloader.context import AppContext, get_session
@ -71,10 +71,16 @@ class TestAppContext:
mock_engine = Mock() mock_engine = Mock()
mock_sm = Mock() mock_sm = Mock()
with patch("dataloader.logger.logger.setup_logging") as mock_setup_logging, \ with (
patch("dataloader.storage.engine.create_engine", return_value=mock_engine) as mock_create_engine, \ patch("dataloader.logger.logger.setup_logging") as mock_setup_logging,
patch("dataloader.storage.engine.create_sessionmaker", return_value=mock_sm) as mock_create_sm, \ patch(
patch("dataloader.context.APP_CONFIG") as mock_config: "dataloader.storage.engine.create_engine", return_value=mock_engine
) as mock_create_engine,
patch(
"dataloader.storage.engine.create_sessionmaker", return_value=mock_sm
) as mock_create_sm,
patch("dataloader.context.APP_CONFIG") as mock_config,
):
mock_config.pg.url = "postgresql://test" mock_config.pg.url = "postgresql://test"
@ -194,7 +200,7 @@ class TestGetSession:
with patch("dataloader.context.APP_CTX") as mock_ctx: with patch("dataloader.context.APP_CTX") as mock_ctx:
mock_ctx.sessionmaker = mock_sm mock_ctx.sessionmaker = mock_sm
async for session in get_session(): async for _session in get_session():
pass pass
assert mock_exit.call_count == 1 assert mock_exit.call_count == 1

View File

@ -1,8 +1,8 @@
# tests/unit/test_notify_listener.py
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from dataloader.storage.notify_listener import PGNotifyListener from dataloader.storage.notify_listener import PGNotifyListener
@ -25,7 +25,7 @@ class TestPGNotifyListener:
dsn="postgresql://test", dsn="postgresql://test",
queue="test_queue", queue="test_queue",
callback=callback, callback=callback,
stop_event=stop_event stop_event=stop_event,
) )
assert listener._dsn == "postgresql://test" assert listener._dsn == "postgresql://test"
@ -47,14 +47,16 @@ class TestPGNotifyListener:
dsn="postgresql://test", dsn="postgresql://test",
queue="test_queue", queue="test_queue",
callback=callback, callback=callback,
stop_event=stop_event stop_event=stop_event,
) )
mock_conn = AsyncMock() mock_conn = AsyncMock()
mock_conn.execute = AsyncMock() mock_conn.execute = AsyncMock()
mock_conn.add_listener = AsyncMock() mock_conn.add_listener = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start() await listener.start()
assert listener._conn == mock_conn assert listener._conn == mock_conn
@ -76,14 +78,16 @@ class TestPGNotifyListener:
dsn="postgresql+asyncpg://test", dsn="postgresql+asyncpg://test",
queue="test_queue", queue="test_queue",
callback=callback, callback=callback,
stop_event=stop_event stop_event=stop_event,
) )
mock_conn = AsyncMock() mock_conn = AsyncMock()
mock_conn.execute = AsyncMock() mock_conn.execute = AsyncMock()
mock_conn.add_listener = AsyncMock() mock_conn.add_listener = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn) as mock_connect: with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
) as mock_connect:
await listener.start() await listener.start()
mock_connect.assert_called_once_with("postgresql://test") mock_connect.assert_called_once_with("postgresql://test")
@ -102,14 +106,16 @@ class TestPGNotifyListener:
dsn="postgresql://test", dsn="postgresql://test",
queue="test_queue", queue="test_queue",
callback=callback, callback=callback,
stop_event=stop_event stop_event=stop_event,
) )
mock_conn = AsyncMock() mock_conn = AsyncMock()
mock_conn.execute = AsyncMock() mock_conn.execute = AsyncMock()
mock_conn.add_listener = AsyncMock() mock_conn.add_listener = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start() await listener.start()
handler = listener._on_notify_handler handler = listener._on_notify_handler
@ -131,14 +137,16 @@ class TestPGNotifyListener:
dsn="postgresql://test", dsn="postgresql://test",
queue="test_queue", queue="test_queue",
callback=callback, callback=callback,
stop_event=stop_event stop_event=stop_event,
) )
mock_conn = AsyncMock() mock_conn = AsyncMock()
mock_conn.execute = AsyncMock() mock_conn.execute = AsyncMock()
mock_conn.add_listener = AsyncMock() mock_conn.add_listener = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start() await listener.start()
handler = listener._on_notify_handler handler = listener._on_notify_handler
@ -160,14 +168,16 @@ class TestPGNotifyListener:
dsn="postgresql://test", dsn="postgresql://test",
queue="test_queue", queue="test_queue",
callback=callback, callback=callback,
stop_event=stop_event stop_event=stop_event,
) )
mock_conn = AsyncMock() mock_conn = AsyncMock()
mock_conn.execute = AsyncMock() mock_conn.execute = AsyncMock()
mock_conn.add_listener = AsyncMock() mock_conn.add_listener = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start() await listener.start()
handler = listener._on_notify_handler handler = listener._on_notify_handler
@ -189,14 +199,16 @@ class TestPGNotifyListener:
dsn="postgresql://test", dsn="postgresql://test",
queue="test_queue", queue="test_queue",
callback=callback, callback=callback,
stop_event=stop_event stop_event=stop_event,
) )
mock_conn = AsyncMock() mock_conn = AsyncMock()
mock_conn.execute = AsyncMock() mock_conn.execute = AsyncMock()
mock_conn.add_listener = AsyncMock() mock_conn.add_listener = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start() await listener.start()
handler = listener._on_notify_handler handler = listener._on_notify_handler
@ -218,7 +230,7 @@ class TestPGNotifyListener:
dsn="postgresql://test", dsn="postgresql://test",
queue="test_queue", queue="test_queue",
callback=callback, callback=callback,
stop_event=stop_event stop_event=stop_event,
) )
mock_conn = AsyncMock() mock_conn = AsyncMock()
@ -227,7 +239,9 @@ class TestPGNotifyListener:
mock_conn.remove_listener = AsyncMock() mock_conn.remove_listener = AsyncMock()
mock_conn.close = AsyncMock() mock_conn.close = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start() await listener.start()
assert listener._task is not None assert listener._task is not None
@ -250,7 +264,7 @@ class TestPGNotifyListener:
dsn="postgresql://test", dsn="postgresql://test",
queue="test_queue", queue="test_queue",
callback=callback, callback=callback,
stop_event=stop_event stop_event=stop_event,
) )
mock_conn = AsyncMock() mock_conn = AsyncMock()
@ -259,7 +273,9 @@ class TestPGNotifyListener:
mock_conn.remove_listener = AsyncMock() mock_conn.remove_listener = AsyncMock()
mock_conn.close = AsyncMock() mock_conn.close = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start() await listener.start()
await listener.stop() await listener.stop()
@ -280,7 +296,7 @@ class TestPGNotifyListener:
dsn="postgresql://test", dsn="postgresql://test",
queue="test_queue", queue="test_queue",
callback=callback, callback=callback,
stop_event=stop_event stop_event=stop_event,
) )
mock_conn = AsyncMock() mock_conn = AsyncMock()
@ -289,7 +305,9 @@ class TestPGNotifyListener:
mock_conn.remove_listener = AsyncMock() mock_conn.remove_listener = AsyncMock()
mock_conn.close = AsyncMock() mock_conn.close = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start() await listener.start()
stop_event.set() stop_event.set()
@ -311,7 +329,7 @@ class TestPGNotifyListener:
dsn="postgresql://test", dsn="postgresql://test",
queue="test_queue", queue="test_queue",
callback=callback, callback=callback,
stop_event=stop_event stop_event=stop_event,
) )
mock_conn = AsyncMock() mock_conn = AsyncMock()
@ -320,7 +338,9 @@ class TestPGNotifyListener:
mock_conn.remove_listener = AsyncMock(side_effect=Exception("Remove error")) mock_conn.remove_listener = AsyncMock(side_effect=Exception("Remove error"))
mock_conn.close = AsyncMock() mock_conn.close = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start() await listener.start()
await listener.stop() await listener.stop()
@ -340,7 +360,7 @@ class TestPGNotifyListener:
dsn="postgresql://test", dsn="postgresql://test",
queue="test_queue", queue="test_queue",
callback=callback, callback=callback,
stop_event=stop_event stop_event=stop_event,
) )
mock_conn = AsyncMock() mock_conn = AsyncMock()
@ -349,7 +369,9 @@ class TestPGNotifyListener:
mock_conn.remove_listener = AsyncMock() mock_conn.remove_listener = AsyncMock()
mock_conn.close = AsyncMock(side_effect=Exception("Close error")) mock_conn.close = AsyncMock(side_effect=Exception("Close error"))
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn): with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start() await listener.start()
await listener.stop() await listener.stop()
@ -368,7 +390,7 @@ class TestPGNotifyListener:
dsn="postgresql://test", dsn="postgresql://test",
queue="test_queue", queue="test_queue",
callback=callback, callback=callback,
stop_event=stop_event stop_event=stop_event,
) )
await listener.stop() await listener.stop()

View File

@ -1,9 +1,8 @@
# tests/unit/test_pipeline_registry.py
from __future__ import annotations from __future__ import annotations
import pytest import pytest
from dataloader.workers.pipelines.registry import register, resolve, tasks, _Registry from dataloader.workers.pipelines.registry import _Registry, register, resolve, tasks
@pytest.mark.unit @pytest.mark.unit
@ -22,6 +21,7 @@ class TestPipelineRegistry:
""" """
Тест регистрации пайплайна. Тест регистрации пайплайна.
""" """
@register("test.task") @register("test.task")
def test_pipeline(args: dict): def test_pipeline(args: dict):
return "result" return "result"
@ -33,6 +33,7 @@ class TestPipelineRegistry:
""" """
Тест получения зарегистрированного пайплайна. Тест получения зарегистрированного пайплайна.
""" """
@register("test.resolve") @register("test.resolve")
def test_pipeline(args: dict): def test_pipeline(args: dict):
return "resolved" return "resolved"
@ -54,6 +55,7 @@ class TestPipelineRegistry:
""" """
Тест получения списка зарегистрированных задач. Тест получения списка зарегистрированных задач.
""" """
@register("task1") @register("task1")
def pipeline1(args: dict): def pipeline1(args: dict):
pass pass
@ -70,6 +72,7 @@ class TestPipelineRegistry:
""" """
Тест перезаписи существующего пайплайна. Тест перезаписи существующего пайплайна.
""" """
@register("overwrite.task") @register("overwrite.task")
def first_pipeline(args: dict): def first_pipeline(args: dict):
return "first" return "first"

View File

@ -1,9 +1,8 @@
# tests/unit/test_workers_base.py
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, Mock, patch from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest import pytest
from dataloader.workers.base import PGWorker, WorkerConfig from dataloader.workers.base import PGWorker, WorkerConfig
@ -41,9 +40,11 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1) cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
stop_event = asyncio.Event() stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \ patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls: patch("dataloader.workers.base.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls,
):
mock_ctx.get_logger.return_value = Mock() mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock() mock_ctx.sessionmaker = Mock()
@ -64,7 +65,9 @@ class TestPGWorker:
stop_event.set() stop_event.set()
return False return False
with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim): with patch.object(
worker, "_claim_and_execute_once", side_effect=mock_claim
):
await worker.run() await worker.run()
assert mock_listener.start.call_count == 1 assert mock_listener.start.call_count == 1
@ -78,9 +81,11 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1) cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
stop_event = asyncio.Event() stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \ patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls: patch("dataloader.workers.base.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls,
):
mock_logger = Mock() mock_logger = Mock()
mock_ctx.get_logger.return_value = mock_logger mock_ctx.get_logger.return_value = mock_logger
@ -101,7 +106,9 @@ class TestPGWorker:
stop_event.set() stop_event.set()
return False return False
with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim): with patch.object(
worker, "_claim_and_execute_once", side_effect=mock_claim
):
await worker.run() await worker.run()
assert worker._listener is None assert worker._listener is None
@ -156,8 +163,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event() stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_session.commit = AsyncMock() mock_session.commit = AsyncMock()
@ -185,8 +194,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event() stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_sm = MagicMock() mock_sm = MagicMock()
@ -196,12 +207,14 @@ class TestPGWorker:
mock_ctx.sessionmaker = mock_sm mock_ctx.sessionmaker = mock_sm
mock_repo = Mock() mock_repo = Mock()
mock_repo.claim_one = AsyncMock(return_value={ mock_repo.claim_one = AsyncMock(
"job_id": "test-job-id", return_value={
"lease_ttl_sec": 60, "job_id": "test-job-id",
"task": "test.task", "lease_ttl_sec": 60,
"args": {"key": "value"} "task": "test.task",
}) "args": {"key": "value"},
}
)
mock_repo.finish_ok = AsyncMock() mock_repo.finish_ok = AsyncMock()
mock_repo_cls.return_value = mock_repo mock_repo_cls.return_value = mock_repo
@ -210,8 +223,10 @@ class TestPGWorker:
async def mock_pipeline(task, args): async def mock_pipeline(task, args):
yield yield
with patch.object(worker, "_pipeline", side_effect=mock_pipeline), \ with (
patch.object(worker, "_execute_with_heartbeat", return_value=False): patch.object(worker, "_pipeline", side_effect=mock_pipeline),
patch.object(worker, "_execute_with_heartbeat", return_value=False),
):
result = await worker._claim_and_execute_once() result = await worker._claim_and_execute_once()
assert result is True assert result is True
@ -225,8 +240,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event() stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_sm = MagicMock() mock_sm = MagicMock()
@ -236,12 +253,14 @@ class TestPGWorker:
mock_ctx.sessionmaker = mock_sm mock_ctx.sessionmaker = mock_sm
mock_repo = Mock() mock_repo = Mock()
mock_repo.claim_one = AsyncMock(return_value={ mock_repo.claim_one = AsyncMock(
"job_id": "test-job-id", return_value={
"lease_ttl_sec": 60, "job_id": "test-job-id",
"task": "test.task", "lease_ttl_sec": 60,
"args": {} "task": "test.task",
}) "args": {},
}
)
mock_repo.finish_fail_or_retry = AsyncMock() mock_repo.finish_fail_or_retry = AsyncMock()
mock_repo_cls.return_value = mock_repo mock_repo_cls.return_value = mock_repo
@ -263,8 +282,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event() stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_sm = MagicMock() mock_sm = MagicMock()
@ -274,18 +295,22 @@ class TestPGWorker:
mock_ctx.sessionmaker = mock_sm mock_ctx.sessionmaker = mock_sm
mock_repo = Mock() mock_repo = Mock()
mock_repo.claim_one = AsyncMock(return_value={ mock_repo.claim_one = AsyncMock(
"job_id": "test-job-id", return_value={
"lease_ttl_sec": 60, "job_id": "test-job-id",
"task": "test.task", "lease_ttl_sec": 60,
"args": {} "task": "test.task",
}) "args": {},
}
)
mock_repo.finish_fail_or_retry = AsyncMock() mock_repo.finish_fail_or_retry = AsyncMock()
mock_repo_cls.return_value = mock_repo mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event) worker = PGWorker(cfg, stop_event)
with patch.object(worker, "_execute_with_heartbeat", side_effect=ValueError("Test error")): with patch.object(
worker, "_execute_with_heartbeat", side_effect=ValueError("Test error")
):
result = await worker._claim_and_execute_once() result = await worker._claim_and_execute_once()
assert result is True assert result is True
@ -301,8 +326,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5) cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
stop_event = asyncio.Event() stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_sm = MagicMock() mock_sm = MagicMock()
@ -323,7 +350,9 @@ class TestPGWorker:
await asyncio.sleep(0.6) await asyncio.sleep(0.6)
yield yield
canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline()) canceled = await worker._execute_with_heartbeat(
"job-id", 60, slow_pipeline()
)
assert canceled is False assert canceled is False
assert mock_repo.heartbeat.call_count >= 1 assert mock_repo.heartbeat.call_count >= 1
@ -336,8 +365,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5) cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
stop_event = asyncio.Event() stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_sm = MagicMock() mock_sm = MagicMock()
@ -358,7 +389,9 @@ class TestPGWorker:
await asyncio.sleep(0.6) await asyncio.sleep(0.6)
yield yield
canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline()) canceled = await worker._execute_with_heartbeat(
"job-id", 60, slow_pipeline()
)
assert canceled is True assert canceled is True
@ -370,8 +403,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event() stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve: patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
):
mock_ctx.get_logger.return_value = Mock() mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock() mock_ctx.sessionmaker = Mock()
@ -397,8 +432,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event() stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve: patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
):
mock_ctx.get_logger.return_value = Mock() mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock() mock_ctx.sessionmaker = Mock()
@ -424,8 +461,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event() stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve: patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
):
mock_ctx.get_logger.return_value = Mock() mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock() mock_ctx.sessionmaker = Mock()
@ -450,8 +489,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5) cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event() stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_sm = MagicMock() mock_sm = MagicMock()
@ -461,12 +502,14 @@ class TestPGWorker:
mock_ctx.sessionmaker = mock_sm mock_ctx.sessionmaker = mock_sm
mock_repo = Mock() mock_repo = Mock()
mock_repo.claim_one = AsyncMock(return_value={ mock_repo.claim_one = AsyncMock(
"job_id": "test-job-id", return_value={
"lease_ttl_sec": 60, "job_id": "test-job-id",
"task": "test.task", "lease_ttl_sec": 60,
"args": {} "task": "test.task",
}) "args": {},
}
)
mock_repo.finish_fail_or_retry = AsyncMock() mock_repo.finish_fail_or_retry = AsyncMock()
mock_repo_cls.return_value = mock_repo mock_repo_cls.return_value = mock_repo
@ -484,19 +527,16 @@ class TestPGWorker:
assert "cancelled by shutdown" in args[1] assert "cancelled by shutdown" in args[1]
assert kwargs.get("is_canceled") is True assert kwargs.get("is_canceled") is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_execute_with_heartbeat_raises_cancelled_when_stop_set(self): async def test_execute_with_heartbeat_raises_cancelled_when_stop_set(self):
cfg = WorkerConfig(queue="test", heartbeat_sec=1000, claim_backoff_sec=5) cfg = WorkerConfig(queue="test", heartbeat_sec=1000, claim_backoff_sec=5)
stop_event = asyncio.Event() stop_event = asyncio.Event()
stop_event.set() stop_event.set()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls: patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_ctx.get_logger.return_value = Mock() mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock() mock_ctx.sessionmaker = Mock()
@ -509,4 +549,3 @@ class TestPGWorker:
with pytest.raises(asyncio.CancelledError): with pytest.raises(asyncio.CancelledError):
await worker._execute_with_heartbeat("job-id", 60, one_yield()) await worker._execute_with_heartbeat("job-id", 60, one_yield())

View File

@ -1,8 +1,8 @@
# tests/unit/test_workers_manager.py
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from unittest.mock import AsyncMock, MagicMock, Mock, patch from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest import pytest
from dataloader.workers.manager import WorkerManager, WorkerSpec, build_manager_from_env from dataloader.workers.manager import WorkerManager, WorkerSpec, build_manager_from_env
@ -39,9 +39,11 @@ class TestWorkerManager:
WorkerSpec(queue="queue2", concurrency=1), WorkerSpec(queue="queue2", concurrency=1),
] ]
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \ patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls: patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
):
mock_ctx.get_logger.return_value = Mock() mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.heartbeat_sec = 10 mock_cfg.worker.heartbeat_sec = 10
@ -67,9 +69,11 @@ class TestWorkerManager:
""" """
specs = [WorkerSpec(queue="test", concurrency=0)] specs = [WorkerSpec(queue="test", concurrency=0)]
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \ patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls: patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
):
mock_ctx.get_logger.return_value = Mock() mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.heartbeat_sec = 10 mock_cfg.worker.heartbeat_sec = 10
@ -93,9 +97,11 @@ class TestWorkerManager:
""" """
specs = [WorkerSpec(queue="test", concurrency=2)] specs = [WorkerSpec(queue="test", concurrency=2)]
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \ patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls: patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
):
mock_ctx.get_logger.return_value = Mock() mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.heartbeat_sec = 10 mock_cfg.worker.heartbeat_sec = 10
@ -108,8 +114,6 @@ class TestWorkerManager:
manager = WorkerManager(specs) manager = WorkerManager(specs)
await manager.start() await manager.start()
initial_task_count = len(manager._tasks)
await manager.stop() await manager.stop()
assert manager._stop.is_set() assert manager._stop.is_set()
@ -123,10 +127,12 @@ class TestWorkerManager:
""" """
specs = [WorkerSpec(queue="test", concurrency=1)] specs = [WorkerSpec(queue="test", concurrency=1)]
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \ patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, \ patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.manager.requeue_lost") as mock_requeue: patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
patch("dataloader.workers.manager.requeue_lost") as mock_requeue,
):
mock_logger = Mock() mock_logger = Mock()
mock_ctx.get_logger.return_value = mock_logger mock_ctx.get_logger.return_value = mock_logger
@ -161,10 +167,12 @@ class TestWorkerManager:
""" """
specs = [WorkerSpec(queue="test", concurrency=1)] specs = [WorkerSpec(queue="test", concurrency=1)]
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \ patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, \ patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.manager.requeue_lost") as mock_requeue: patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
patch("dataloader.workers.manager.requeue_lost") as mock_requeue,
):
mock_logger = Mock() mock_logger = Mock()
mock_ctx.get_logger.return_value = mock_logger mock_ctx.get_logger.return_value = mock_logger
@ -203,8 +211,10 @@ class TestBuildManagerFromEnv:
""" """
Тест создания менеджера из конфигурации. Тест создания менеджера из конфигурации.
""" """
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg: patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
):
mock_ctx.get_logger.return_value = Mock() mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.parsed_workers.return_value = [ mock_cfg.worker.parsed_workers.return_value = [
@ -224,8 +234,10 @@ class TestBuildManagerFromEnv:
""" """
Тест, что пустые имена очередей пропускаются. Тест, что пустые имена очередей пропускаются.
""" """
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg: patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
):
mock_ctx.get_logger.return_value = Mock() mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.parsed_workers.return_value = [ mock_cfg.worker.parsed_workers.return_value = [
@ -243,8 +255,10 @@ class TestBuildManagerFromEnv:
""" """
Тест обработки отсутствующих полей с дефолтными значениями. Тест обработки отсутствующих полей с дефолтными значениями.
""" """
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg: patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
):
mock_ctx.get_logger.return_value = Mock() mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.parsed_workers.return_value = [ mock_cfg.worker.parsed_workers.return_value = [
@ -262,8 +276,10 @@ class TestBuildManagerFromEnv:
""" """
Тест, что concurrency всегда минимум 1. Тест, что concurrency всегда минимум 1.
""" """
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \ with (
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg: patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
):
mock_ctx.get_logger.return_value = Mock() mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.parsed_workers.return_value = [ mock_cfg.worker.parsed_workers.return_value = [

View File

@ -1,8 +1,8 @@
# tests/unit/test_workers_reaper.py
from __future__ import annotations from __future__ import annotations
from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from unittest.mock import AsyncMock, patch, Mock
from dataloader.workers.reaper import requeue_lost from dataloader.workers.reaper import requeue_lost