refactor: refactor code

This commit is contained in:
itqop 2025-11-05 20:45:13 +03:00
parent c907e1d4da
commit bde0bb0e6f
66 changed files with 853 additions and 463 deletions

View File

@ -34,3 +34,37 @@ dataloader = "dataloader.__main__:main"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.black]
line-length = 88
target-version = ["py311"]
skip-string-normalization = false
preview = true
[tool.isort]
profile = "black"
line_length = 88
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
combine_as_imports = true
known_first_party = ["dataloader"]
src_paths = ["src"]
[tool.ruff]
line-length = 88
target-version = "py311"
fix = true
lint.select = ["E", "F", "W", "I", "N", "B"]
lint.ignore = [
"E501",
]
exclude = [
".git",
"__pycache__",
"build",
"dist",
".venv",
"venv",
]

View File

@ -5,7 +5,6 @@
from . import api
__all__ = [
"api",
]

View File

@ -1,15 +1,17 @@
import uvicorn
from dataloader.api import app_main
from dataloader.config import APP_CONFIG
from dataloader.logger.uvicorn_logging_config import LOGGING_CONFIG, setup_uvicorn_logging
from dataloader.logger.uvicorn_logging_config import (
LOGGING_CONFIG,
setup_uvicorn_logging,
)
def main() -> None:
# Инициализируем логирование uvicorn перед запуском
setup_uvicorn_logging()
uvicorn.run(
app_main,
host=APP_CONFIG.app.app_host,

View File

@ -1,20 +1,19 @@
# src/dataloader/api/__init__.py
from __future__ import annotations
from collections.abc import AsyncGenerator
import contextlib
import typing as tp
from collections.abc import AsyncGenerator
from fastapi import FastAPI
from dataloader.context import APP_CTX
from dataloader.workers.manager import WorkerManager, build_manager_from_env
from dataloader.workers.pipelines import load_all as load_pipelines
from .metric_router import router as metric_router
from .middleware import log_requests
from .os_router import router as service_router
from .v1 import router as v1_router
from dataloader.context import APP_CTX
from dataloader.workers.manager import build_manager_from_env, WorkerManager
from dataloader.workers.pipelines import load_all as load_pipelines
_manager: WorkerManager | None = None

View File

@ -1,10 +1,11 @@
""" 🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!
"""
"""🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!"""
import uuid
from fastapi import APIRouter, Header, status
from dataloader.context import APP_CTX
from . import schemas
router = APIRouter()
@ -17,8 +18,7 @@ logger = APP_CTX.get_logger()
response_model=schemas.RateResponse,
)
async def like(
# pylint: disable=C0103,W0613
header_Request_Id: str = Header(uuid.uuid4(), alias="Request-Id")
header_request_id: str = Header(uuid.uuid4(), alias="Request-Id")
) -> dict[str, str]:
logger.metric(
metric_name="dataloader_likes_total",
@ -33,8 +33,7 @@ async def like(
response_model=schemas.RateResponse,
)
async def dislike(
# pylint: disable=C0103,W0613
header_Request_Id: str = Header(uuid.uuid4(), alias="Request-Id")
header_request_id: str = Header(uuid.uuid4(), alias="Request-Id")
) -> dict[str, str]:
logger.metric(
metric_name="dataloader_dislikes_total",

View File

@ -38,8 +38,14 @@ async def log_requests(request: Request, call_next) -> any:
logger = APP_CTX.get_logger()
request_path = request.url.path
allowed_headers_to_log = ((k, request.headers.get(k)) for k in HEADERS_WHITE_LIST_TO_LOG)
headers_to_log = {header_name: header_value for header_name, header_value in allowed_headers_to_log if header_value}
allowed_headers_to_log = (
(k, request.headers.get(k)) for k in HEADERS_WHITE_LIST_TO_LOG
)
headers_to_log = {
header_name: header_value
for header_name, header_value in allowed_headers_to_log
if header_value
}
APP_CTX.get_context_vars_container().set_context_vars(
request_id=headers_to_log.get("Request-Id", ""),
@ -50,7 +56,9 @@ async def log_requests(request: Request, call_next) -> any:
if request_path in NON_LOGGED_ENDPOINTS:
response = await call_next(request)
logger.debug(f"Processed request for {request_path} with code {response.status_code}")
logger.debug(
f"Processed request for {request_path} with code {response.status_code}"
)
elif headers_to_log.get("Request-Id", None):
raw_request_body = await request.body()
request_body_decoded = _get_decoded_body(raw_request_body, "request", logger)
@ -80,7 +88,7 @@ async def log_requests(request: Request, call_next) -> any:
event_params=json.dumps(
request_body_decoded,
ensure_ascii=False,
)
),
)
response = await call_next(request)
@ -88,12 +96,16 @@ async def log_requests(request: Request, call_next) -> any:
response_body = [chunk async for chunk in response.body_iterator]
response.body_iterator = iterate_in_threadpool(iter(response_body))
headers_to_log["Response-Time"] = datetime.now(APP_CTX.get_pytz_timezone()).isoformat()
headers_to_log["Response-Time"] = datetime.now(
APP_CTX.get_pytz_timezone()
).isoformat()
for header in headers_to_log:
response.headers[header] = headers_to_log[header]
response_body_extracted = response_body[0] if len(response_body) > 0 else b""
decoded_response_body = _get_decoded_body(response_body_extracted, "response", logger)
decoded_response_body = _get_decoded_body(
response_body_extracted, "response", logger
)
logger.info(
"Outgoing response to client system",
@ -115,11 +127,13 @@ async def log_requests(request: Request, call_next) -> any:
event_params=json.dumps(
decoded_response_body,
ensure_ascii=False,
)
),
)
processing_time_ms = int(round((time.time() - start_time), 3) * 1000)
logger.info(f"Request processing time for {request_path}: {processing_time_ms} ms")
logger.info(
f"Request processing time for {request_path}: {processing_time_ms} ms"
)
logger.metric(
metric_name="dataloader_process_duration_ms",
metric_value=processing_time_ms,
@ -138,7 +152,9 @@ async def log_requests(request: Request, call_next) -> any:
else:
logger.info(f"Incoming {request.method}-request with no id for {request_path}")
response = await call_next(request)
logger.info(f"Request with no id for {request_path} processing time: {time.time() - start_time:.3f} s")
logger.info(
f"Request with no id for {request_path} processing time: {time.time() - start_time:.3f} s"
)
return response

View File

@ -1,6 +1,4 @@
# Инфраструктурные endpoint'ы (/health, /status)
""" 🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!
"""
"""🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!"""
from importlib.metadata import distribution

View File

@ -3,22 +3,26 @@ from pydantic import BaseModel, ConfigDict, Field
class HealthResponse(BaseModel):
"""Ответ для ручки /health"""
model_config = ConfigDict(
json_schema_extra={"example": {"status": "running"}}
)
status: str = Field(default="running", description="Service health check", max_length=7)
model_config = ConfigDict(json_schema_extra={"example": {"status": "running"}})
status: str = Field(
default="running", description="Service health check", max_length=7
)
class InfoResponse(BaseModel):
"""Ответ для ручки /info"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"name": "rest-template",
"description": "Python 'AI gateway' template for developing REST microservices",
"description": (
"Python 'AI gateway' template for developing REST microservices"
),
"type": "REST API",
"version": "0.1.0"
"version": "0.1.0",
}
}
)
@ -26,11 +30,14 @@ class InfoResponse(BaseModel):
name: str = Field(description="Service name", max_length=50)
description: str = Field(description="Service description", max_length=200)
type: str = Field(default="REST API", description="Service type", max_length=20)
version: str = Field(description="Service version", max_length=20, pattern=r"^\d+\.\d+\.\d+")
version: str = Field(
description="Service version", max_length=20, pattern=r"^\d+\.\d+\.\d+"
)
class RateResponse(BaseModel):
"""Ответ для записи рейтинга"""
model_config = ConfigDict(str_strip_whitespace=True)
rating_result: str = Field(description="Rating that was recorded", max_length=50)

View File

@ -5,12 +5,8 @@ from fastapi import HTTPException, status
class JobNotFoundError(HTTPException):
"""Задача не найдена."""
def __init__(self, job_id: str):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Job {job_id} not found"
status_code=status.HTTP_404_NOT_FOUND, detail=f"Job {job_id} not found"
)

View File

@ -1,2 +1 @@
"""Модели данных для API v1."""

View File

@ -1,12 +1,11 @@
# src/dataloader/api/v1/router.py
from __future__ import annotations
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from dataloader.api.v1.exceptions import JobNotFoundError
from dataloader.api.v1.schemas import (
@ -16,8 +15,6 @@ from dataloader.api.v1.schemas import (
)
from dataloader.api.v1.service import JobsService
from dataloader.context import get_session
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter(prefix="/jobs", tags=["jobs"])
@ -40,7 +37,9 @@ async def trigger_job(
return await svc.trigger(payload)
@router.get("/{job_id}/status", response_model=JobStatusResponse, status_code=HTTPStatus.OK)
@router.get(
"/{job_id}/status", response_model=JobStatusResponse, status_code=HTTPStatus.OK
)
async def get_status(
job_id: UUID,
svc: Annotated[JobsService, Depends(get_service)],
@ -54,7 +53,9 @@ async def get_status(
return st
@router.post("/{job_id}/cancel", response_model=JobStatusResponse, status_code=HTTPStatus.OK)
@router.post(
"/{job_id}/cancel", response_model=JobStatusResponse, status_code=HTTPStatus.OK
)
async def cancel_job(
job_id: UUID,
svc: Annotated[JobsService, Depends(get_service)],

View File

@ -1,4 +1,3 @@
# src/dataloader/api/v1/schemas.py
from __future__ import annotations
from datetime import datetime, timezone
@ -12,6 +11,7 @@ class TriggerJobRequest(BaseModel):
"""
Запрос на постановку задачи в очередь.
"""
model_config = ConfigDict(str_strip_whitespace=True)
queue: str = Field(...)
@ -39,6 +39,7 @@ class TriggerJobResponse(BaseModel):
"""
Ответ на постановку задачи.
"""
model_config = ConfigDict(str_strip_whitespace=True)
job_id: UUID = Field(...)
@ -49,6 +50,7 @@ class JobStatusResponse(BaseModel):
"""
Текущий статус задачи.
"""
model_config = ConfigDict(str_strip_whitespace=True)
job_id: UUID = Field(...)
@ -65,6 +67,7 @@ class CancelJobResponse(BaseModel):
"""
Ответ на запрос отмены задачи.
"""
model_config = ConfigDict(str_strip_whitespace=True)
job_id: UUID = Field(...)

View File

@ -1,4 +1,3 @@
# src/dataloader/api/v1/service.py
from __future__ import annotations
from datetime import datetime, timezone
@ -13,15 +12,16 @@ from dataloader.api.v1.schemas import (
TriggerJobResponse,
)
from dataloader.api.v1.utils import new_job_id
from dataloader.storage.schemas import CreateJobRequest
from dataloader.storage.repositories import QueueRepository
from dataloader.logger.logger import get_logger
from dataloader.storage.repositories import QueueRepository
from dataloader.storage.schemas import CreateJobRequest
class JobsService:
"""
Бизнес-логика работы с очередью задач.
"""
def __init__(self, session: AsyncSession):
self._s = session
self._repo = QueueRepository(self._s)

View File

@ -1,4 +1,3 @@
# src/dataloader/api/v1/utils.py
from __future__ import annotations
from uuid import UUID, uuid4

View File

@ -1,6 +1,5 @@
# src/dataloader/config.py
import os
import json
import os
from logging import DEBUG, INFO
from typing import Annotated, Any
@ -32,6 +31,7 @@ class BaseAppSettings(BaseSettings):
"""
Базовый класс для настроек.
"""
local: bool = Field(validation_alias="LOCAL", default=False)
debug: bool = Field(validation_alias="DEBUG", default=False)
@ -44,6 +44,7 @@ class AppSettings(BaseAppSettings):
"""
Настройки приложения.
"""
app_host: str = Field(validation_alias="APP_HOST", default="0.0.0.0")
app_port: int = Field(validation_alias="APP_PORT", default=8081)
kube_net_name: str = Field(validation_alias="PROJECT_NAME", default="AIGATEWAY")
@ -54,15 +55,28 @@ class LogSettings(BaseAppSettings):
"""
Настройки логирования.
"""
private_log_file_path: str = Field(validation_alias="LOG_PATH", default=os.getcwd())
private_log_file_name: str = Field(validation_alias="LOG_FILE_NAME", default="app.log")
private_log_file_name: str = Field(
validation_alias="LOG_FILE_NAME", default="app.log"
)
log_rotation: str = Field(validation_alias="LOG_ROTATION", default="10 MB")
private_metric_file_path: str = Field(validation_alias="METRIC_PATH", default=os.getcwd())
private_metric_file_name: str = Field(validation_alias="METRIC_FILE_NAME", default="app-metric.log")
private_audit_file_path: str = Field(validation_alias="AUDIT_LOG_PATH", default=os.getcwd())
private_audit_file_name: str = Field(validation_alias="AUDIT_LOG_FILE_NAME", default="events.log")
private_metric_file_path: str = Field(
validation_alias="METRIC_PATH", default=os.getcwd()
)
private_metric_file_name: str = Field(
validation_alias="METRIC_FILE_NAME", default="app-metric.log"
)
private_audit_file_path: str = Field(
validation_alias="AUDIT_LOG_PATH", default=os.getcwd()
)
private_audit_file_name: str = Field(
validation_alias="AUDIT_LOG_FILE_NAME", default="events.log"
)
audit_host_ip: str = Field(validation_alias="HOST_IP", default="127.0.0.1")
audit_host_uid: str = Field(validation_alias="HOST_UID", default="63b6dcee-170b-49bf-a65c-3ec967398ccd")
audit_host_uid: str = Field(
validation_alias="HOST_UID", default="63b6dcee-170b-49bf-a65c-3ec967398ccd"
)
@staticmethod
def get_file_abs_path(path_name: str, file_name: str) -> str:
@ -70,15 +84,21 @@ class LogSettings(BaseAppSettings):
@property
def log_file_abs_path(self) -> str:
return self.get_file_abs_path(self.private_log_file_path, self.private_log_file_name)
return self.get_file_abs_path(
self.private_log_file_path, self.private_log_file_name
)
@property
def metric_file_abs_path(self) -> str:
return self.get_file_abs_path(self.private_metric_file_path, self.private_metric_file_name)
return self.get_file_abs_path(
self.private_metric_file_path, self.private_metric_file_name
)
@property
def audit_file_abs_path(self) -> str:
return self.get_file_abs_path(self.private_audit_file_path, self.private_audit_file_name)
return self.get_file_abs_path(
self.private_audit_file_path, self.private_audit_file_name
)
@property
def log_lvl(self) -> int:
@ -89,6 +109,7 @@ class PGSettings(BaseSettings):
"""
Настройки подключения к Postgres.
"""
host: str = Field(validation_alias="PG_HOST", default="localhost")
port: int = Field(validation_alias="PG_PORT", default=5432)
user: str = Field(validation_alias="PG_USER", default="postgres")
@ -116,9 +137,12 @@ class WorkerSettings(BaseSettings):
"""
Настройки очереди и воркеров.
"""
workers_json: str = Field(validation_alias="WORKERS_JSON", default="[]")
heartbeat_sec: int = Field(validation_alias="DL_HEARTBEAT_SEC", default=10)
default_lease_ttl_sec: int = Field(validation_alias="DL_DEFAULT_LEASE_TTL_SEC", default=60)
default_lease_ttl_sec: int = Field(
validation_alias="DL_DEFAULT_LEASE_TTL_SEC", default=60
)
reaper_period_sec: int = Field(validation_alias="DL_REAPER_PERIOD_SEC", default=10)
claim_backoff_sec: int = Field(validation_alias="DL_CLAIM_BACKOFF_SEC", default=15)
@ -137,6 +161,7 @@ class CertsSettings(BaseSettings):
"""
Настройки SSL сертификатов для локальной разработки.
"""
ca_bundle_file: str = Field(validation_alias="CA_BUNDLE_FILE", default="")
cert_file: str = Field(validation_alias="CERT_FILE", default="")
key_file: str = Field(validation_alias="KEY_FILE", default="")
@ -146,21 +171,24 @@ class SuperTeneraSettings(BaseAppSettings):
"""
Настройки интеграции с SuperTenera.
"""
host: Annotated[str, BeforeValidator(strip_slashes)] = Field(
validation_alias="SUPERTENERA_HOST",
default="ci03801737-ift-tenera-giga.delta.sbrf.ru/atlant360bc/"
default="ci03801737-ift-tenera-giga.delta.sbrf.ru/atlant360bc/",
)
port: str = Field(validation_alias="SUPERTENERA_PORT", default="443")
quotes_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
validation_alias="SUPERTENERA_QUOTES_ENDPOINT",
default="/get_gigaparser_quotes/"
default="/get_gigaparser_quotes/",
)
timeout: int = Field(validation_alias="SUPERTENERA_TIMEOUT", default=20)
@property
def base_url(self) -> str:
"""Возвращает абсолютный URL для SuperTenera."""
domain, raw_path = self.host.split("/", 1) if "/" in self.host else (self.host, "")
domain, raw_path = (
self.host.split("/", 1) if "/" in self.host else (self.host, "")
)
return build_url(self.protocol, domain, self.port, raw_path)
@ -168,22 +196,21 @@ class Gmap2BriefSettings(BaseAppSettings):
"""
Настройки интеграции с Gmap2Brief (OPU API).
"""
host: Annotated[str, BeforeValidator(strip_slashes)] = Field(
validation_alias="GMAP2BRIEF_HOST",
default="ci02533826-tib-brief.apps.ift-terra000024-edm.ocp.delta.sbrf.ru"
default="ci02533826-tib-brief.apps.ift-terra000024-edm.ocp.delta.sbrf.ru",
)
port: str = Field(validation_alias="GMAP2BRIEF_PORT", default="443")
start_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
validation_alias="GMAP2BRIEF_START_ENDPOINT",
default="/export/opu/start"
validation_alias="GMAP2BRIEF_START_ENDPOINT", default="/export/opu/start"
)
status_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
validation_alias="GMAP2BRIEF_STATUS_ENDPOINT",
default="/export/{job_id}/status"
validation_alias="GMAP2BRIEF_STATUS_ENDPOINT", default="/export/{job_id}/status"
)
download_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
validation_alias="GMAP2BRIEF_DOWNLOAD_ENDPOINT",
default="/export/{job_id}/download"
default="/export/{job_id}/download",
)
poll_interval: int = Field(validation_alias="GMAP2BRIEF_POLL_INTERVAL", default=2)
timeout: int = Field(validation_alias="GMAP2BRIEF_TIMEOUT", default=3600)
@ -198,6 +225,7 @@ class Secrets:
"""
Агрегатор настроек приложения.
"""
app: AppSettings = AppSettings()
log: LogSettings = LogSettings()
pg: PGSettings = PGSettings()

View File

@ -1,8 +1,7 @@
# src/dataloader/context.py
from __future__ import annotations
from typing import AsyncGenerator
from logging import Logger
from typing import AsyncGenerator
from zoneinfo import ZoneInfo
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
@ -15,6 +14,7 @@ class AppContext:
"""
Контекст приложения, хранящий глобальные зависимости (Singleton pattern).
"""
def __init__(self) -> None:
self._engine: AsyncEngine | None = None
self._sessionmaker: async_sessionmaker[AsyncSession] | None = None
@ -75,6 +75,7 @@ class AppContext:
Экземпляр Logger
"""
from .logger.logger import get_logger as get_app_logger
return get_app_logger(name)
def get_context_vars_container(self) -> ContextVarsContainer:

View File

@ -1,2 +1 @@
"""Исключения уровня приложения."""

View File

@ -10,7 +10,10 @@ from typing import TYPE_CHECKING
import httpx
from dataloader.config import APP_CONFIG
from dataloader.interfaces.gmap2_brief.schemas import ExportJobStatus, StartExportResponse
from dataloader.interfaces.gmap2_brief.schemas import (
ExportJobStatus,
StartExportResponse,
)
if TYPE_CHECKING:
from logging import Logger
@ -37,10 +40,11 @@ class Gmap2BriefInterface:
self._ssl_context = None
if APP_CONFIG.app.local and APP_CONFIG.certs.cert_file:
self._ssl_context = ssl.create_default_context(cafile=APP_CONFIG.certs.ca_bundle_file)
self._ssl_context = ssl.create_default_context(
cafile=APP_CONFIG.certs.ca_bundle_file
)
self._ssl_context.load_cert_chain(
certfile=APP_CONFIG.certs.cert_file,
keyfile=APP_CONFIG.certs.key_file
certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file
)
async def start_export(self) -> str:
@ -56,9 +60,13 @@ class Gmap2BriefInterface:
url = self.base_url + APP_CONFIG.gmap2brief.start_endpoint
async with httpx.AsyncClient(
cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None,
cert=(
(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file)
if APP_CONFIG.app.local
else None
),
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
timeout=self.timeout
timeout=self.timeout,
) as client:
try:
self.logger.info(f"Starting OPU export: POST {url}")
@ -87,12 +95,18 @@ class Gmap2BriefInterface:
Raises:
Gmap2BriefConnectionError: При ошибке запроса
"""
url = self.base_url + APP_CONFIG.gmap2brief.status_endpoint.format(job_id=job_id)
url = self.base_url + APP_CONFIG.gmap2brief.status_endpoint.format(
job_id=job_id
)
async with httpx.AsyncClient(
cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None,
cert=(
(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file)
if APP_CONFIG.app.local
else None
),
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
timeout=self.timeout
timeout=self.timeout,
) as client:
try:
response = await client.get(url)
@ -106,7 +120,9 @@ class Gmap2BriefInterface:
except httpx.RequestError as e:
raise Gmap2BriefConnectionError(f"Request error: {e}") from e
async def wait_for_completion(self, job_id: str, max_wait: int | None = None) -> ExportJobStatus:
async def wait_for_completion(
self, job_id: str, max_wait: int | None = None
) -> ExportJobStatus:
"""
Ждет завершения задачи экспорта с периодическим polling.
@ -127,14 +143,18 @@ class Gmap2BriefInterface:
status = await self.get_status(job_id)
if status.status == "completed":
self.logger.info(f"Job {job_id} completed, total_rows={status.total_rows}")
self.logger.info(
f"Job {job_id} completed, total_rows={status.total_rows}"
)
return status
elif status.status == "failed":
raise Gmap2BriefConnectionError(f"Job {job_id} failed: {status.error}")
elapsed = asyncio.get_event_loop().time() - start_time
if max_wait and elapsed > max_wait:
raise Gmap2BriefConnectionError(f"Job {job_id} timeout after {elapsed:.1f}s")
raise Gmap2BriefConnectionError(
f"Job {job_id} timeout after {elapsed:.1f}s"
)
self.logger.debug(
f"Job {job_id} status={status.status}, rows={status.total_rows}, elapsed={elapsed:.1f}s"
@ -155,12 +175,18 @@ class Gmap2BriefInterface:
Raises:
Gmap2BriefConnectionError: При ошибке скачивания
"""
url = self.base_url + APP_CONFIG.gmap2brief.download_endpoint.format(job_id=job_id)
url = self.base_url + APP_CONFIG.gmap2brief.download_endpoint.format(
job_id=job_id
)
async with httpx.AsyncClient(
cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None,
cert=(
(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file)
if APP_CONFIG.app.local
else None
),
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
timeout=self.timeout
timeout=self.timeout,
) as client:
try:
self.logger.info(f"Downloading export: GET {url}")

View File

@ -17,7 +17,11 @@ class ExportJobStatus(BaseModel):
"""Статус задачи экспорта."""
job_id: str = Field(..., description="Идентификатор задачи")
status: Literal["pending", "running", "completed", "failed"] = Field(..., description="Статус задачи")
status: Literal["pending", "running", "completed", "failed"] = Field(
..., description="Статус задачи"
)
total_rows: int = Field(default=0, description="Количество обработанных строк")
error: str | None = Field(default=None, description="Текст ошибки (если есть)")
temp_file_path: str | None = Field(default=None, description="Путь к временному файлу (для completed)")
temp_file_path: str | None = Field(
default=None, description="Путь к временному файлу (для completed)"
)

View File

@ -47,8 +47,12 @@ class SuperTeneraInterface:
self._ssl_context = None
if APP_CONFIG.app.local:
self._ssl_context = ssl.create_default_context(cafile=APP_CONFIG.certs.ca_bundle_file)
self._ssl_context.load_cert_chain(certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file)
self._ssl_context = ssl.create_default_context(
cafile=APP_CONFIG.certs.ca_bundle_file
)
self._ssl_context.load_cert_chain(
certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file
)
def form_base_headers(self) -> dict:
"""Формирует базовые заголовки для запроса."""
@ -57,7 +61,11 @@ class SuperTeneraInterface:
"request-time": str(datetime.now(tz=self.timezone).isoformat()),
"system-id": APP_CONFIG.app.kube_net_name,
}
return {metakey: metavalue for metakey, metavalue in metadata_pairs.items() if metavalue}
return {
metakey: metavalue
for metakey, metavalue in metadata_pairs.items()
if metavalue
}
async def __aenter__(self) -> Self:
"""Async context manager enter."""
@ -86,7 +94,9 @@ class SuperTeneraInterface:
async with self._session.get(url, **kwargs) as response:
if APP_CONFIG.app.debug:
self.logger.debug(f"Response: {(await response.text(errors='ignore'))[:100]}")
self.logger.debug(
f"Response: {(await response.text(errors='ignore'))[:100]}"
)
return await response.json(encoding=encoding, content_type=content_type)
async def get_quotes_data(self) -> MainData:
@ -103,7 +113,9 @@ class SuperTeneraInterface:
kwargs["ssl"] = self._ssl_context
try:
async with self._session.get(
APP_CONFIG.supertenera.quotes_endpoint, timeout=APP_CONFIG.supertenera.timeout, **kwargs
APP_CONFIG.supertenera.quotes_endpoint,
timeout=APP_CONFIG.supertenera.timeout,
**kwargs,
) as resp:
resp.raise_for_status()
return True
@ -112,7 +124,9 @@ class SuperTeneraInterface:
f"Ошибка подключения к SuperTenera API при проверке системы - {e.status}."
) from e
except TimeoutError as e:
raise SuperTeneraConnectionError("Ошибка Timeout подключения к SuperTenera API при проверке системы.") from e
raise SuperTeneraConnectionError(
"Ошибка Timeout подключения к SuperTenera API при проверке системы."
) from e
def get_async_tenera_interface() -> SuperTeneraInterface:

View File

@ -77,7 +77,9 @@ class InvestingCandlestick(TeneraBaseModel):
value: str = Field(alias="V")
class InvestingTimePoint(RootModel[EmptyTimePoint | InvestingNumeric | InvestingCandlestick]):
class InvestingTimePoint(
RootModel[EmptyTimePoint | InvestingNumeric | InvestingCandlestick]
):
"""
Union-модель точки времени для источника Investing.com.
@ -340,5 +342,9 @@ class MainData(TeneraBaseModel):
:return: Отфильтрованный объект
"""
if isinstance(v, dict):
return {key: value for key, value in v.items() if value is not None and not str(key).isdigit()}
return {
key: value
for key, value in v.items()
if value is not None and not str(key).isdigit()
}
return v

View File

@ -1,4 +1,3 @@
# src/dataloader/logger/__init__.py
from .logger import setup_logging, get_logger
from .logger import get_logger, setup_logging
__all__ = ["setup_logging", "get_logger"]

View File

@ -1,9 +1,7 @@
# Управление контекстом запросов для логирования
import uuid
from contextvars import ContextVar
from typing import Final
REQUEST_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("request_id", default="")
DEVICE_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("device_id", default="")
SESSION_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("session_id", default="")
@ -66,7 +64,13 @@ class ContextVarsContainer:
def gw_session_id(self, value: str) -> None:
GW_SESSION_ID_CTX_VAR.set(value)
def set_context_vars(self, request_id: str = "", request_time: str = "", system_id: str = "", gw_session_id: str = "") -> None:
def set_context_vars(
self,
request_id: str = "",
request_time: str = "",
system_id: str = "",
gw_session_id: str = "",
) -> None:
if request_id:
self.set_request_id(request_id)
if request_time:

View File

@ -1,17 +1,12 @@
# src/dataloader/logger/logger.py
import sys
import typing
from datetime import tzinfo
from logging import Logger
import logging
import sys
from logging import Logger
from loguru import logger
from .context_vars import ContextVarsContainer
from dataloader.config import APP_CONFIG
# Определяем фильтры для разных типов логов
def metric_only_filter(record: dict) -> bool:
return "metric" in record["extra"]
@ -26,13 +21,12 @@ def regular_log_filter(record: dict) -> bool:
class InterceptHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
# Get corresponding Loguru level if it exists
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
# Find caller from where originated the logged message
frame, depth = logging.currentframe(), 2
while frame.f_code.co_filename == logging.__file__:
frame = frame.f_back

View File

@ -1 +1 @@
# Модели логов, метрик, событий аудита

View File

@ -1 +1 @@
# Функции + маскирование args

View File

@ -1,34 +1,28 @@
# Конфигурация логирования uvicorn
import logging
import sys
from loguru import logger
class InterceptHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
# Get corresponding Loguru level if it exists
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
# Find caller from where originated the logged message
frame, depth = logging.currentframe(), 2
while frame.f_code.co_filename == logging.__file__:
frame = frame.f_back
depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage()
)
def setup_uvicorn_logging() -> None:
# Set all uvicorn loggers to use InterceptHandler
for logger_name in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
log = logging.getLogger(logger_name)
log.handlers = [InterceptHandler()]
@ -36,7 +30,6 @@ def setup_uvicorn_logging() -> None:
log.propagate = False
# uvicorn logging config
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,

View File

@ -1,2 +1 @@
"""Модуль для работы с хранилищем данных."""

View File

@ -1,7 +1,11 @@
# src/dataloader/storage/engine.py
from __future__ import annotations
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from dataloader.config import APP_CONFIG

View File

@ -1,8 +1,8 @@
# src/dataloader/storage/models/__init__.py
"""
ORM модели для работы с базой данных.
Организованы по доменам для масштабируемости.
"""
from __future__ import annotations
from .base import Base

View File

@ -1,4 +1,3 @@
# src/dataloader/storage/models/base.py
from __future__ import annotations
from sqlalchemy.orm import DeclarativeBase
@ -9,4 +8,5 @@ class Base(DeclarativeBase):
Базовый класс для всех ORM моделей приложения.
Используется SQLAlchemy 2.0+ declarative style.
"""
pass

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from datetime import date, datetime
from sqlalchemy import BigInteger, Date, Integer, Numeric, String, TIMESTAMP, text
from sqlalchemy import TIMESTAMP, BigInteger, Date, Integer, Numeric, String, text
from sqlalchemy.orm import Mapped, mapped_column
from dataloader.storage.models.base import Base
@ -16,29 +16,47 @@ class BriefDigitalCertificateOpu(Base):
__tablename__ = "brief_digital_certificate_opu"
__table_args__ = ({"schema": "opu"},)
object_id: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
object_id: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'-'")
)
object_nm: Mapped[str | None] = mapped_column(String)
desk_nm: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
actdate: Mapped[date] = mapped_column(Date, primary_key=True, server_default=text("CURRENT_DATE"))
layer_cd: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
desk_nm: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'-'")
)
actdate: Mapped[date] = mapped_column(
Date, primary_key=True, server_default=text("CURRENT_DATE")
)
layer_cd: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'-'")
)
layer_nm: Mapped[str | None] = mapped_column(String)
opu_cd: Mapped[str] = mapped_column(String, primary_key=True)
opu_nm_sh: Mapped[str | None] = mapped_column(String)
opu_nm: Mapped[str | None] = mapped_column(String)
opu_lvl: Mapped[int] = mapped_column(Integer, primary_key=True, server_default=text("'-1'"))
opu_prnt_cd: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
opu_lvl: Mapped[int] = mapped_column(
Integer, primary_key=True, server_default=text("'-1'")
)
opu_prnt_cd: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'-'")
)
opu_prnt_nm_sh: Mapped[str | None] = mapped_column(String)
opu_prnt_nm: Mapped[str | None] = mapped_column(String)
sum_amountrub_p_usd: Mapped[float | None] = mapped_column(Numeric)
wf_load_id: Mapped[int] = mapped_column(BigInteger, nullable=False, server_default=text("'-1'"))
wf_load_id: Mapped[int] = mapped_column(
BigInteger, nullable=False, server_default=text("'-1'")
)
wf_load_dttm: Mapped[datetime] = mapped_column(
TIMESTAMP(timezone=False),
nullable=False,
server_default=text("CURRENT_TIMESTAMP"),
)
wf_row_id: Mapped[int] = mapped_column(BigInteger, nullable=False, server_default=text("'-1'"))
wf_row_id: Mapped[int] = mapped_column(
BigInteger, nullable=False, server_default=text("'-1'")
)
object_tp: Mapped[str | None] = mapped_column(String)
object_unit: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
object_unit: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'-'")
)
measure: Mapped[str | None] = mapped_column(String)
product_nm: Mapped[str | None] = mapped_column(String)
product_prnt_nm: Mapped[str | None] = mapped_column(String)

View File

@ -1,4 +1,3 @@
# src/dataloader/storage/models/queue.py
from __future__ import annotations
from datetime import datetime
@ -10,7 +9,6 @@ from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
dl_status_enum = ENUM(
"queued",
"running",
@ -29,6 +27,7 @@ class DLJob(Base):
Модель таблицы очереди задач dl_jobs.
Использует логическое имя схемы 'queue' для поддержки schema_translate_map.
"""
__tablename__ = "dl_jobs"
__table_args__ = {"schema": "queue"}
@ -40,15 +39,23 @@ class DLJob(Base):
lock_key: Mapped[str] = mapped_column(Text, nullable=False)
partition_key: Mapped[str] = mapped_column(Text, default="", nullable=False)
priority: Mapped[int] = mapped_column(nullable=False, default=100)
available_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
status: Mapped[str] = mapped_column(dl_status_enum, nullable=False, default="queued")
available_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False
)
status: Mapped[str] = mapped_column(
dl_status_enum, nullable=False, default="queued"
)
attempt: Mapped[int] = mapped_column(nullable=False, default=0)
max_attempts: Mapped[int] = mapped_column(nullable=False, default=5)
lease_ttl_sec: Mapped[int] = mapped_column(nullable=False, default=60)
lease_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
lease_expires_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True)
)
heartbeat_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
cancel_requested: Mapped[bool] = mapped_column(nullable=False, default=False)
progress: Mapped[dict[str, Any]] = mapped_column(JSONB, default=dict, nullable=False)
progress: Mapped[dict[str, Any]] = mapped_column(
JSONB, default=dict, nullable=False
)
error: Mapped[Optional[str]] = mapped_column(Text)
producer: Mapped[Optional[str]] = mapped_column(Text)
consumer_group: Mapped[Optional[str]] = mapped_column(Text)
@ -62,10 +69,13 @@ class DLJobEvent(Base):
Модель таблицы журнала событий dl_job_events.
Использует логическое имя схемы 'queue' для поддержки schema_translate_map.
"""
__tablename__ = "dl_job_events"
__table_args__ = {"schema": "queue"}
event_id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
event_id: Mapped[int] = mapped_column(
BigInteger, primary_key=True, autoincrement=True
)
job_id: Mapped[str] = mapped_column(UUID(as_uuid=False), nullable=False)
queue: Mapped[str] = mapped_column(Text, nullable=False)
load_dttm: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)

View File

@ -5,7 +5,15 @@ from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import JSON, TIMESTAMP, BigInteger, ForeignKey, String, UniqueConstraint, func
from sqlalchemy import (
JSON,
TIMESTAMP,
BigInteger,
ForeignKey,
String,
UniqueConstraint,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from dataloader.storage.models.base import Base
@ -30,7 +38,9 @@ class Quote(Base):
srce: Mapped[str | None] = mapped_column(String)
ticker: Mapped[str | None] = mapped_column(String)
quote_sect_id: Mapped[int] = mapped_column(
ForeignKey("quotes.quotes_sect.quote_sect_id", ondelete="CASCADE", onupdate="CASCADE"),
ForeignKey(
"quotes.quotes_sect.quote_sect_id", ondelete="CASCADE", onupdate="CASCADE"
),
nullable=False,
)
last_update_dttm: Mapped[datetime | None] = mapped_column(TIMESTAMP(timezone=True))

View File

@ -5,7 +5,17 @@ from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import TIMESTAMP, BigInteger, Boolean, DateTime, Float, ForeignKey, String, UniqueConstraint, func
from sqlalchemy import (
TIMESTAMP,
BigInteger,
Boolean,
DateTime,
Float,
ForeignKey,
String,
UniqueConstraint,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from dataloader.storage.models.base import Base

View File

@ -1,16 +1,23 @@
# src/dataloader/storage/notify_listener.py
from __future__ import annotations
import asyncio
import asyncpg
from typing import Callable, Optional
import asyncpg
class PGNotifyListener:
"""
Прослушиватель PostgreSQL NOTIFY для канала 'dl_jobs'.
"""
def __init__(self, dsn: str, queue: str, callback: Callable[[], None], stop_event: asyncio.Event):
def __init__(
self,
dsn: str,
queue: str,
callback: Callable[[], None],
stop_event: asyncio.Event,
):
self._dsn = dsn
self._queue = queue
self._callback = callback

View File

@ -1,8 +1,8 @@
# src/dataloader/storage/repositories/__init__.py
"""
Репозитории для работы с базой данных.
Организованы по доменам для масштабируемости.
"""
from __future__ import annotations
from .opu import OpuRepository

View File

@ -69,7 +69,6 @@ class OpuRepository:
if not records:
return 0
# Получаем колонки для обновления (все кроме PK и технических)
update_columns = {
c.name
for c in BriefDigitalCertificateOpu.__table__.columns

View File

@ -1,4 +1,3 @@
# src/dataloader/storage/repositories/queue.py
from __future__ import annotations
from datetime import datetime, timedelta, timezone
@ -15,6 +14,7 @@ class QueueRepository:
"""
Репозиторий для работы с очередью задач и журналом событий.
"""
def __init__(self, session: AsyncSession):
self.s = session
@ -62,7 +62,9 @@ class QueueRepository:
finished_at=None,
)
self.s.add(row)
await self._append_event(req.job_id, req.queue, "queued", {"task": req.task})
await self._append_event(
req.job_id, req.queue, "queued", {"task": req.task}
)
return req.job_id, "queued"
async def get_status(self, job_id: str) -> Optional[JobStatus]:
@ -118,7 +120,9 @@ class QueueRepository:
await self._append_event(job_id, job.queue, "cancel_requested", None)
return True
async def claim_one(self, queue: str, claim_backoff_sec: int) -> Optional[dict[str, Any]]:
async def claim_one(
self, queue: str, claim_backoff_sec: int
) -> Optional[dict[str, Any]]:
"""
Захватывает одну задачу из очереди с учётом блокировок и выставляет running.
@ -150,15 +154,21 @@ class QueueRepository:
job.started_at = job.started_at or datetime.now(timezone.utc)
job.attempt = int(job.attempt) + 1
job.heartbeat_at = datetime.now(timezone.utc)
job.lease_expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(job.lease_ttl_sec))
job.lease_expires_at = datetime.now(timezone.utc) + timedelta(
seconds=int(job.lease_ttl_sec)
)
ok = await self._try_advisory_lock(job.lock_key)
if not ok:
job.status = "queued"
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=claim_backoff_sec)
job.available_at = datetime.now(timezone.utc) + timedelta(
seconds=claim_backoff_sec
)
return None
await self._append_event(job.job_id, job.queue, "picked", {"attempt": job.attempt})
await self._append_event(
job.job_id, job.queue, "picked", {"attempt": job.attempt}
)
return {
"job_id": job.job_id,
@ -192,10 +202,15 @@ class QueueRepository:
q = (
update(DLJob)
.where(DLJob.job_id == job_id, DLJob.status == "running")
.values(heartbeat_at=now, lease_expires_at=now + timedelta(seconds=int(ttl_sec)))
.values(
heartbeat_at=now,
lease_expires_at=now + timedelta(seconds=int(ttl_sec)),
)
)
await self.s.execute(q)
await self._append_event(job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec})
await self._append_event(
job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec}
)
return True, cancel_requested
async def finish_ok(self, job_id: str) -> None:
@ -215,7 +230,9 @@ class QueueRepository:
await self._append_event(job_id, job.queue, "succeeded", None)
await self._advisory_unlock(job.lock_key)
async def finish_fail_or_retry(self, job_id: str, err: str, is_canceled: bool = False) -> None:
async def finish_fail_or_retry(
self, job_id: str, err: str, is_canceled: bool = False
) -> None:
"""
Помечает задачу как failed, canceled или возвращает в очередь с задержкой.
@ -239,16 +256,25 @@ class QueueRepository:
can_retry = int(job.attempt) < int(job.max_attempts)
if can_retry:
job.status = "queued"
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=30 * int(job.attempt))
job.available_at = datetime.now(timezone.utc) + timedelta(
seconds=30 * int(job.attempt)
)
job.error = err
job.lease_expires_at = None
await self._append_event(job_id, job.queue, "requeue", {"attempt": job.attempt, "error": err})
await self._append_event(
job_id,
job.queue,
"requeue",
{"attempt": job.attempt, "error": err},
)
else:
job.status = "failed"
job.error = err
job.finished_at = datetime.now(timezone.utc)
job.lease_expires_at = None
await self._append_event(job_id, job.queue, "failed", {"error": err})
await self._append_event(
job_id, job.queue, "failed", {"error": err}
)
await self._advisory_unlock(job.lock_key)
async def requeue_lost(self, now: Optional[datetime] = None) -> list[str]:
@ -293,7 +319,11 @@ class QueueRepository:
Возвращает:
ORM модель DLJob или None
"""
r = await self.s.execute(select(DLJob).where(DLJob.job_id == job_id).with_for_update(skip_locked=True))
r = await self.s.execute(
select(DLJob)
.where(DLJob.job_id == job_id)
.with_for_update(skip_locked=True)
)
return r.scalar_one_or_none()
async def _resolve_queue(self, job_id: str) -> str:
@ -310,7 +340,9 @@ class QueueRepository:
v = r.scalar_one_or_none()
return v or ""
async def _append_event(self, job_id: str, queue: str, kind: str, payload: Optional[dict[str, Any]]) -> None:
async def _append_event(
self, job_id: str, queue: str, kind: str, payload: Optional[dict[str, Any]]
) -> None:
"""
Добавляет запись в журнал событий.
@ -339,7 +371,9 @@ class QueueRepository:
Возвращает:
True, если блокировка получена
"""
r = await self.s.execute(select(func.pg_try_advisory_lock(func.hashtext(lock_key))))
r = await self.s.execute(
select(func.pg_try_advisory_lock(func.hashtext(lock_key)))
)
return bool(r.scalar())
async def _advisory_unlock(self, lock_key: str) -> None:

View File

@ -39,7 +39,9 @@ class QuotesRepository:
result = await self.s.execute(stmt)
return result.scalar_one_or_none()
async def get_or_create_section(self, name: str, params: dict | None = None) -> QuoteSection:
async def get_or_create_section(
self, name: str, params: dict | None = None
) -> QuoteSection:
"""
Получить существующую секцию или создать новую.

View File

@ -1,8 +1,8 @@
# src/dataloader/storage/schemas/__init__.py
"""
DTO (Data Transfer Objects) для слоя хранилища.
Организованы по доменам для масштабируемости.
"""
from __future__ import annotations
from .queue import CreateJobRequest, JobStatus

View File

@ -1,4 +1,3 @@
# src/dataloader/storage/schemas/queue.py
from __future__ import annotations
from dataclasses import dataclass
@ -11,6 +10,7 @@ class CreateJobRequest:
"""
DTO для создания задачи в очереди.
"""
job_id: str
queue: str
task: str
@ -31,6 +31,7 @@ class JobStatus:
"""
DTO для статуса задачи.
"""
job_id: str
status: str
attempt: int

View File

@ -1,2 +1 @@
"""Модуль воркеров для обработки задач."""

View File

@ -1,4 +1,3 @@
# src/dataloader/workers/base.py
from __future__ import annotations
import asyncio
@ -9,8 +8,8 @@ from typing import AsyncIterator, Callable, Optional
from dataloader.config import APP_CONFIG
from dataloader.context import APP_CTX
from dataloader.storage.repositories import QueueRepository
from dataloader.storage.notify_listener import PGNotifyListener
from dataloader.storage.repositories import QueueRepository
from dataloader.workers.pipelines.registry import resolve as resolve_pipeline
@ -19,6 +18,7 @@ class WorkerConfig:
"""
Конфигурация воркера.
"""
queue: str
heartbeat_sec: int
claim_backoff_sec: int
@ -28,6 +28,7 @@ class PGWorker:
"""
Базовый асинхронный воркер очереди Postgres.
"""
def __init__(self, cfg: WorkerConfig, stop_event: asyncio.Event) -> None:
self._cfg = cfg
self._stop = stop_event
@ -46,14 +47,16 @@ class PGWorker:
dsn=APP_CONFIG.pg.url,
queue=self._cfg.queue,
callback=lambda: self._notify_wakeup.set(),
stop_event=self._stop
stop_event=self._stop,
)
try:
await self._listener.start()
except Exception as e:
self._log.warning(f"Failed to start LISTEN/NOTIFY, falling back to polling: {e}")
self._log.warning(
f"Failed to start LISTEN/NOTIFY, falling back to polling: {e}"
)
self._listener = None
try:
while not self._stop.is_set():
claimed = await self._claim_and_execute_once()
@ -69,24 +72,27 @@ class PGWorker:
Ожидание появления задач через LISTEN/NOTIFY или с тайм-аутом.
"""
if self._listener:
# Используем LISTEN/NOTIFY с fallback на таймаут
done, pending = await asyncio.wait(
[asyncio.create_task(self._notify_wakeup.wait()), asyncio.create_task(self._stop.wait())],
[
asyncio.create_task(self._notify_wakeup.wait()),
asyncio.create_task(self._stop.wait()),
],
return_when=asyncio.FIRST_COMPLETED,
timeout=timeout_sec
timeout=timeout_sec,
)
# Отменяем оставшиеся задачи
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Очищаем событие, если оно было установлено
if self._notify_wakeup.is_set():
self._notify_wakeup.clear()
else:
# Fallback на простой таймаут
try:
await asyncio.wait_for(self._stop.wait(), timeout=timeout_sec)
except asyncio.TimeoutError:
@ -110,30 +116,42 @@ class PGWorker:
args = row["args"]
try:
canceled = await self._execute_with_heartbeat(job_id, ttl, self._pipeline(task, args))
canceled = await self._execute_with_heartbeat(
job_id, ttl, self._pipeline(task, args)
)
if canceled:
await repo.finish_fail_or_retry(job_id, "canceled by user", is_canceled=True)
await repo.finish_fail_or_retry(
job_id, "canceled by user", is_canceled=True
)
else:
await repo.finish_ok(job_id)
return True
except asyncio.CancelledError:
await repo.finish_fail_or_retry(job_id, "cancelled by shutdown", is_canceled=True)
await repo.finish_fail_or_retry(
job_id, "cancelled by shutdown", is_canceled=True
)
raise
except Exception as e:
await repo.finish_fail_or_retry(job_id, str(e))
return True
async def _execute_with_heartbeat(self, job_id: str, ttl: int, it: AsyncIterator[None]) -> bool:
async def _execute_with_heartbeat(
self, job_id: str, ttl: int, it: AsyncIterator[None]
) -> bool:
"""
Исполняет конвейер с поддержкой heartbeat.
Возвращает True, если задача была отменена (cancel_requested).
"""
next_hb = datetime.now(timezone.utc) + timedelta(seconds=self._cfg.heartbeat_sec)
next_hb = datetime.now(timezone.utc) + timedelta(
seconds=self._cfg.heartbeat_sec
)
async for _ in it:
now = datetime.now(timezone.utc)
if now >= next_hb:
async with self._sm() as s_hb:
success, cancel_requested = await QueueRepository(s_hb).heartbeat(job_id, ttl)
success, cancel_requested = await QueueRepository(s_hb).heartbeat(
job_id, ttl
)
if cancel_requested:
return True
next_hb = now + timedelta(seconds=self._cfg.heartbeat_sec)

View File

@ -1,4 +1,3 @@
# src/dataloader/workers/manager.py
from __future__ import annotations
import asyncio
@ -16,6 +15,7 @@ class WorkerSpec:
"""
Конфигурация набора воркеров для очереди.
"""
queue: str
concurrency: int
@ -24,6 +24,7 @@ class WorkerManager:
"""
Управляет жизненным циклом асинхронных воркеров.
"""
def __init__(self, specs: list[WorkerSpec]) -> None:
self._log = APP_CTX.get_logger()
self._specs = specs
@ -40,15 +41,22 @@ class WorkerManager:
for spec in self._specs:
for i in range(max(1, spec.concurrency)):
cfg = WorkerConfig(queue=spec.queue, heartbeat_sec=hb, claim_backoff_sec=backoff)
t = asyncio.create_task(PGWorker(cfg, self._stop).run(), name=f"worker:{spec.queue}:{i}")
cfg = WorkerConfig(
queue=spec.queue, heartbeat_sec=hb, claim_backoff_sec=backoff
)
t = asyncio.create_task(
PGWorker(cfg, self._stop).run(), name=f"worker:{spec.queue}:{i}"
)
self._tasks.append(t)
self._reaper_task = asyncio.create_task(self._reaper_loop(), name="reaper")
self._log.info(
"worker_manager.started",
extra={"specs": [spec.__dict__ for spec in self._specs], "total_tasks": len(self._tasks)},
extra={
"specs": [spec.__dict__ for spec in self._specs],
"total_tasks": len(self._tasks),
},
)
async def stop(self) -> None:

View File

@ -1,6 +1,5 @@
"""Модуль пайплайнов обработки задач."""
# src/dataloader/workers/pipelines/__init__.py
from __future__ import annotations
import importlib

View File

@ -59,7 +59,6 @@ def _parse_jsonl_from_zst(file_path: Path, chunk_size: int = 10000):
APP_CTX.logger.warning(f"Failed to parse JSON line: {e}")
continue
# Process remaining buffer
if buffer.strip():
try:
record = orjson.loads(buffer)
@ -85,11 +84,9 @@ def _convert_record(raw: dict[str, Any]) -> dict[str, Any]:
"""
result = raw.copy()
# Преобразуем actdate из ISO строки в date
if "actdate" in result and isinstance(result["actdate"], str):
result["actdate"] = datetime.fromisoformat(result["actdate"]).date()
# Преобразуем wf_load_dttm из ISO строки в datetime
if "wf_load_dttm" in result and isinstance(result["wf_load_dttm"], str):
result["wf_load_dttm"] = datetime.fromisoformat(result["wf_load_dttm"])
@ -118,18 +115,15 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
logger = APP_CTX.logger
logger.info("Starting OPU ETL pipeline")
# Шаг 1: Запуск экспорта
interface = get_gmap2brief_interface()
job_id = await interface.start_export()
logger.info(f"OPU export job started: {job_id}")
yield
# Шаг 2: Ожидание завершения
status = await interface.wait_for_completion(job_id)
logger.info(f"OPU export completed: {status.total_rows} rows")
yield
# Шаг 3: Скачивание архива
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
archive_path = temp_path / f"opu_export_{job_id}.jsonl.zst"
@ -138,7 +132,6 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
logger.info(f"OPU archive downloaded: {archive_path.stat().st_size:,} bytes")
yield
# Шаг 4: Truncate таблицы
async with APP_CTX.sessionmaker() as session:
repo = OpuRepository(session)
await repo.truncate()
@ -146,7 +139,6 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
logger.info("OPU table truncated")
yield
# Шаг 5: Загрузка данных стримингово
total_inserted = 0
batch_num = 0
@ -154,17 +146,16 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
for batch in _parse_jsonl_from_zst(archive_path, chunk_size=5000):
batch_num += 1
# Конвертируем записи
converted = [_convert_record(rec) for rec in batch]
# Вставляем батч
inserted = await repo.bulk_insert(converted)
await session.commit()
total_inserted += inserted
logger.debug(f"Batch {batch_num}: inserted {inserted} rows (total: {total_inserted})")
logger.debug(
f"Batch {batch_num}: inserted {inserted} rows (total: {total_inserted})"
)
# Heartbeat после каждого батча
yield
logger.info(f"OPU ETL completed: {total_inserted} rows inserted")

View File

@ -59,7 +59,9 @@ def _parse_ts_to_datetime(ts: str) -> datetime | None:
return None
def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] | None: # noqa: C901
def _build_value_row(
source: str, dt: datetime, point: Any
) -> dict[str, Any] | None: # noqa: C901
"""Строит строку для `quotes_values` по источнику и типу точки."""
if isinstance(point, int):
return {"dt": dt, "key": point}
@ -82,7 +84,10 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] |
if isinstance(deep_inner, InvestingCandlestick):
return {
"dt": dt,
"price_o": _to_float(getattr(deep_inner, "open_", None) or getattr(deep_inner, "open", None)),
"price_o": _to_float(
getattr(deep_inner, "open_", None)
or getattr(deep_inner, "open", None)
),
"price_h": _to_float(deep_inner.high),
"price_l": _to_float(deep_inner.low),
"price_c": _to_float(deep_inner.close),
@ -92,12 +97,16 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] |
if isinstance(inner, TradingViewTimePoint | SgxTimePoint):
return {
"dt": dt,
"price_o": _to_float(getattr(inner, "open_", None) or getattr(inner, "open", None)),
"price_o": _to_float(
getattr(inner, "open_", None) or getattr(inner, "open", None)
),
"price_h": _to_float(inner.high),
"price_l": _to_float(inner.low),
"price_c": _to_float(inner.close),
"volume": _to_float(
getattr(inner, "volume", None) or getattr(inner, "interest", None) or getattr(inner, "value", None)
getattr(inner, "volume", None)
or getattr(inner, "interest", None)
or getattr(inner, "value", None)
),
}
@ -132,7 +141,9 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] |
"dt": dt,
"value_last": _to_float(deep_inner.last),
"value_previous": _to_float(deep_inner.previous),
"unit": str(deep_inner.unit) if deep_inner.unit is not None else None,
"unit": (
str(deep_inner.unit) if deep_inner.unit is not None else None
),
}
if isinstance(deep_inner, TradingEconomicsStringPercent):
@ -218,7 +229,14 @@ async def load_tenera(args: dict) -> AsyncIterator[None]:
async with APP_CTX.sessionmaker() as session:
repo = QuotesRepository(session)
for source_name in ("cbr", "investing", "sgx", "tradingeconomics", "bloomberg", "trading_view"):
for source_name in (
"cbr",
"investing",
"sgx",
"tradingeconomics",
"bloomberg",
"trading_view",
):
source_data = getattr(data, source_name)
if not source_data:
continue

View File

@ -1,4 +1,3 @@
# src/dataloader/workers/pipelines/noop.py
from __future__ import annotations
import asyncio

View File

@ -1,4 +1,3 @@
# src/dataloader/workers/pipelines/registry.py
from __future__ import annotations
from typing import Any, Callable, Dict, Iterable
@ -6,13 +5,17 @@ from typing import Any, Callable, Dict, Iterable
_Registry: Dict[str, Callable[[dict[str, Any]], Any]] = {}
def register(task: str) -> Callable[[Callable[[dict[str, Any]], Any]], Callable[[dict[str, Any]], Any]]:
def register(
task: str,
) -> Callable[[Callable[[dict[str, Any]], Any]], Callable[[dict[str, Any]], Any]]:
"""
Регистрирует обработчик пайплайна под именем задачи.
"""
def _wrap(fn: Callable[[dict[str, Any]], Any]) -> Callable[[dict[str, Any]], Any]:
_Registry[task] = fn
return fn
return _wrap
@ -22,8 +25,8 @@ def resolve(task: str) -> Callable[[dict[str, Any]], Any]:
"""
try:
return _Registry[task]
except KeyError:
raise KeyError(f"pipeline not found: {task}")
except KeyError as err:
raise KeyError(f"pipeline not found: {task}") from err
def tasks() -> Iterable[str]:

View File

@ -1,7 +1,7 @@
# src/dataloader/workers/reaper.py
from __future__ import annotations
from typing import Sequence
from sqlalchemy.ext.asyncio import AsyncSession
from dataloader.storage.repositories import QueueRepository

View File

@ -1,4 +1,3 @@
# tests/conftest.py
from __future__ import annotations
import asyncio
@ -9,23 +8,23 @@ from uuid import uuid4
import pytest
import pytest_asyncio
from dotenv import load_dotenv
from httpx import AsyncClient, ASGITransport
from httpx import ASGITransport, AsyncClient
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
load_dotenv()
from dataloader.api import app_main
from dataloader.config import APP_CONFIG
from dataloader.context import APP_CTX, get_session
from dataloader.storage.models import Base
from dataloader.storage.engine import create_engine, create_sessionmaker
from dataloader.context import get_session
from dataloader.storage.engine import create_engine
load_dotenv()
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
pytestmark = pytest.mark.asyncio
@pytest_asyncio.fixture(scope="function")
async def db_engine() -> AsyncGenerator[AsyncEngine, None]:
"""
@ -39,15 +38,15 @@ async def db_engine() -> AsyncGenerator[AsyncEngine, None]:
await engine.dispose()
@pytest_asyncio.fixture(scope="function")
async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
"""
Предоставляет сессию БД для каждого теста.
НЕ использует транзакцию, чтобы работали advisory locks.
"""
sessionmaker = async_sessionmaker(bind=db_engine, expire_on_commit=False, class_=AsyncSession)
sessionmaker = async_sessionmaker(
bind=db_engine, expire_on_commit=False, class_=AsyncSession
)
async with sessionmaker() as session:
yield session
await session.rollback()
@ -69,6 +68,7 @@ async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
"""
HTTP клиент для тестирования API.
"""
async def override_get_session() -> AsyncGenerator[AsyncSession, None]:
yield db_session

View File

@ -1 +1 @@
# tests/integration_tests/__init__.py

View File

@ -1,4 +1,3 @@
# tests/integration_tests/test_api_endpoints.py
from __future__ import annotations
import pytest

View File

@ -1,11 +1,11 @@
# tests/unit/test_api_router_not_found.py
from __future__ import annotations
import pytest
from uuid import uuid4, UUID
from uuid import UUID, uuid4
import pytest
from dataloader.api.v1.router import get_status, cancel_job
from dataloader.api.v1.exceptions import JobNotFoundError
from dataloader.api.v1.router import cancel_job, get_status
from dataloader.api.v1.schemas import JobStatusResponse

View File

@ -1,11 +1,11 @@
# tests/unit/test_api_router_success.py
from __future__ import annotations
import pytest
from uuid import uuid4, UUID
from datetime import datetime, timezone
from uuid import UUID, uuid4
from dataloader.api.v1.router import get_status, cancel_job
import pytest
from dataloader.api.v1.router import cancel_job, get_status
from dataloader.api.v1.schemas import JobStatusResponse

View File

@ -1,7 +1,6 @@
# tests/integration_tests/test_queue_repository.py
from __future__ import annotations
from datetime import datetime, timezone, timedelta
from datetime import datetime, timedelta, timezone
from uuid import uuid4
import pytest
@ -10,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from dataloader.storage.models import DLJob
from dataloader.storage.repositories import QueueRepository
from dataloader.storage.schemas import CreateJobRequest, JobStatus
from dataloader.storage.schemas import CreateJobRequest
@pytest.mark.integration
@ -445,6 +444,7 @@ class TestQueueRepository:
await repo.claim_one(queue_name, claim_backoff_sec=15)
import asyncio
await asyncio.sleep(2)
requeued = await repo.requeue_lost()
@ -500,7 +500,9 @@ class TestQueueRepository:
assert st is not None
assert st.status == "queued"
row = (await db_session.execute(select(DLJob).where(DLJob.job_id == job_id))).scalar_one()
row = (
await db_session.execute(select(DLJob).where(DLJob.job_id == job_id))
).scalar_one()
assert row.available_at >= before + timedelta(seconds=15)
assert row.available_at <= after + timedelta(seconds=60)
@ -569,7 +571,9 @@ class TestQueueRepository:
await repo.create_or_get(req)
await repo.claim_one(queue_name, claim_backoff_sec=5)
await repo.finish_fail_or_retry(job_id, err="Canceled by test", is_canceled=True)
await repo.finish_fail_or_retry(
job_id, err="Canceled by test", is_canceled=True
)
st = await repo.get_status(job_id)
assert st is not None

View File

@ -1 +1 @@
# tests/unit/__init__.py

View File

@ -1,13 +1,13 @@
# tests/unit/test_api_service.py
from __future__ import annotations
from datetime import datetime, timezone
from uuid import UUID
from unittest.mock import AsyncMock, Mock, patch
from uuid import UUID
import pytest
from dataloader.api.v1.service import JobsService
from dataloader.api.v1.schemas import TriggerJobRequest
from dataloader.api.v1.service import JobsService
from dataloader.storage.schemas import JobStatus
@ -38,15 +38,19 @@ class TestJobsService:
"""
mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id:
with (
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
):
mock_get_logger.return_value = Mock()
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
mock_repo = Mock()
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued"))
mock_repo.create_or_get = AsyncMock(
return_value=("12345678-1234-5678-1234-567812345678", "queued")
)
mock_repo_cls.return_value = mock_repo
service = JobsService(mock_session)
@ -58,7 +62,7 @@ class TestJobsService:
lock_key="lock_1",
priority=100,
max_attempts=5,
lease_ttl_sec=60
lease_ttl_sec=60,
)
response = await service.trigger(req)
@ -74,15 +78,19 @@ class TestJobsService:
"""
mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id:
with (
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
):
mock_get_logger.return_value = Mock()
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
mock_repo = Mock()
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued"))
mock_repo.create_or_get = AsyncMock(
return_value=("12345678-1234-5678-1234-567812345678", "queued")
)
mock_repo_cls.return_value = mock_repo
service = JobsService(mock_session)
@ -95,7 +103,7 @@ class TestJobsService:
lock_key="lock_1",
priority=100,
max_attempts=5,
lease_ttl_sec=60
lease_ttl_sec=60,
)
response = await service.trigger(req)
@ -112,15 +120,19 @@ class TestJobsService:
"""
mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id:
with (
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
):
mock_get_logger.return_value = Mock()
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
mock_repo = Mock()
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued"))
mock_repo.create_or_get = AsyncMock(
return_value=("12345678-1234-5678-1234-567812345678", "queued")
)
mock_repo_cls.return_value = mock_repo
service = JobsService(mock_session)
@ -135,10 +147,10 @@ class TestJobsService:
available_at=future_time,
priority=100,
max_attempts=5,
lease_ttl_sec=60
lease_ttl_sec=60,
)
response = await service.trigger(req)
await service.trigger(req)
call_args = mock_repo.create_or_get.call_args[0][0]
assert call_args.available_at == future_time
@ -150,15 +162,19 @@ class TestJobsService:
"""
mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id:
with (
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
):
mock_get_logger.return_value = Mock()
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
mock_repo = Mock()
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued"))
mock_repo.create_or_get = AsyncMock(
return_value=("12345678-1234-5678-1234-567812345678", "queued")
)
mock_repo_cls.return_value = mock_repo
service = JobsService(mock_session)
@ -173,10 +189,10 @@ class TestJobsService:
consumer_group="test_group",
priority=100,
max_attempts=5,
lease_ttl_sec=60
lease_ttl_sec=60,
)
response = await service.trigger(req)
await service.trigger(req)
call_args = mock_repo.create_or_get.call_args[0][0]
assert call_args.partition_key == "partition_1"
@ -190,8 +206,10 @@ class TestJobsService:
"""
mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
with (
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
):
mock_get_logger.return_value = Mock()
@ -204,7 +222,7 @@ class TestJobsService:
finished_at=None,
heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc),
error=None,
progress={"step": 1}
progress={"step": 1},
)
mock_repo.get_status = AsyncMock(return_value=mock_status)
mock_repo_cls.return_value = mock_repo
@ -227,8 +245,10 @@ class TestJobsService:
"""
mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
with (
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
):
mock_get_logger.return_value = Mock()
@ -250,8 +270,10 @@ class TestJobsService:
"""
mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
with (
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
):
mock_get_logger.return_value = Mock()
@ -265,7 +287,7 @@ class TestJobsService:
finished_at=None,
heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc),
error=None,
progress={}
progress={},
)
mock_repo.get_status = AsyncMock(return_value=mock_status)
mock_repo_cls.return_value = mock_repo
@ -287,8 +309,10 @@ class TestJobsService:
"""
mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
with (
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
):
mock_get_logger.return_value = Mock()
@ -312,8 +336,10 @@ class TestJobsService:
"""
mock_session = AsyncMock()
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
with (
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
):
mock_get_logger.return_value = Mock()
@ -326,7 +352,7 @@ class TestJobsService:
finished_at=None,
heartbeat_at=None,
error=None,
progress=None
progress=None,
)
mock_repo.get_status = AsyncMock(return_value=mock_status)
mock_repo_cls.return_value = mock_repo

View File

@ -1,18 +1,18 @@
# tests/unit/test_config.py
from __future__ import annotations
import json
from logging import DEBUG, INFO
from unittest.mock import patch
import pytest
from dataloader.config import (
BaseAppSettings,
AppSettings,
BaseAppSettings,
LogSettings,
PGSettings,
WorkerSettings,
Secrets,
WorkerSettings,
)
@ -81,12 +81,15 @@ class TestAppSettings:
"""
Тест загрузки из переменных окружения.
"""
with patch.dict("os.environ", {
"APP_HOST": "127.0.0.1",
"APP_PORT": "9000",
"PROJECT_NAME": "TestProject",
"TIMEZONE": "UTC"
}):
with patch.dict(
"os.environ",
{
"APP_HOST": "127.0.0.1",
"APP_PORT": "9000",
"PROJECT_NAME": "TestProject",
"TIMEZONE": "UTC",
},
):
settings = AppSettings()
assert settings.app_host == "127.0.0.1"
@ -136,7 +139,9 @@ class TestLogSettings:
"""
Тест свойства log_file_abs_path.
"""
with patch.dict("os.environ", {"LOG_PATH": "/var/log", "LOG_FILE_NAME": "test.log"}):
with patch.dict(
"os.environ", {"LOG_PATH": "/var/log", "LOG_FILE_NAME": "test.log"}
):
settings = LogSettings()
assert "test.log" in settings.log_file_abs_path
@ -146,7 +151,10 @@ class TestLogSettings:
"""
Тест свойства metric_file_abs_path.
"""
with patch.dict("os.environ", {"METRIC_PATH": "/var/metrics", "METRIC_FILE_NAME": "metrics.log"}):
with patch.dict(
"os.environ",
{"METRIC_PATH": "/var/metrics", "METRIC_FILE_NAME": "metrics.log"},
):
settings = LogSettings()
assert "metrics.log" in settings.metric_file_abs_path
@ -156,7 +164,10 @@ class TestLogSettings:
"""
Тест свойства audit_file_abs_path.
"""
with patch.dict("os.environ", {"AUDIT_LOG_PATH": "/var/audit", "AUDIT_LOG_FILE_NAME": "audit.log"}):
with patch.dict(
"os.environ",
{"AUDIT_LOG_PATH": "/var/audit", "AUDIT_LOG_FILE_NAME": "audit.log"},
):
settings = LogSettings()
assert "audit.log" in settings.audit_file_abs_path
@ -208,29 +219,37 @@ class TestPGSettings:
"""
Тест формирования строки подключения.
"""
with patch.dict("os.environ", {
"PG_HOST": "db.example.com",
"PG_PORT": "5433",
"PG_USER": "testuser",
"PG_PASSWORD": "testpass",
"PG_DATABASE": "testdb"
}):
with patch.dict(
"os.environ",
{
"PG_HOST": "db.example.com",
"PG_PORT": "5433",
"PG_USER": "testuser",
"PG_PASSWORD": "testpass",
"PG_DATABASE": "testdb",
},
):
settings = PGSettings()
expected = "postgresql+asyncpg://testuser:testpass@db.example.com:5433/testdb"
expected = (
"postgresql+asyncpg://testuser:testpass@db.example.com:5433/testdb"
)
assert settings.url == expected
def test_url_property_with_empty_password(self):
"""
Тест строки подключения с пустым паролем.
"""
with patch.dict("os.environ", {
"PG_HOST": "localhost",
"PG_PORT": "5432",
"PG_USER": "postgres",
"PG_PASSWORD": "",
"PG_DATABASE": "testdb"
}):
with patch.dict(
"os.environ",
{
"PG_HOST": "localhost",
"PG_PORT": "5432",
"PG_USER": "postgres",
"PG_PASSWORD": "",
"PG_DATABASE": "testdb",
},
):
settings = PGSettings()
expected = "postgresql+asyncpg://postgres:@localhost:5432/testdb"
@ -240,15 +259,18 @@ class TestPGSettings:
"""
Тест загрузки из переменных окружения.
"""
with patch.dict("os.environ", {
"PG_HOST": "testhost",
"PG_PORT": "5433",
"PG_USER": "testuser",
"PG_PASSWORD": "testpass",
"PG_DATABASE": "testdb",
"PG_SCHEMA_QUEUE": "queue_schema",
"PG_POOL_SIZE": "20"
}):
with patch.dict(
"os.environ",
{
"PG_HOST": "testhost",
"PG_PORT": "5433",
"PG_USER": "testuser",
"PG_PASSWORD": "testpass",
"PG_DATABASE": "testdb",
"PG_SCHEMA_QUEUE": "queue_schema",
"PG_POOL_SIZE": "20",
},
):
settings = PGSettings()
assert settings.host == "testhost"
@ -292,10 +314,12 @@ class TestWorkerSettings:
"""
Тест парсинга валидного JSON.
"""
workers_json = json.dumps([
{"queue": "queue1", "concurrency": 2},
{"queue": "queue2", "concurrency": 3}
])
workers_json = json.dumps(
[
{"queue": "queue1", "concurrency": 2},
{"queue": "queue2", "concurrency": 3},
]
)
with patch.dict("os.environ", {"WORKERS_JSON": workers_json}):
settings = WorkerSettings()
@ -311,12 +335,14 @@ class TestWorkerSettings:
"""
Тест фильтрации не-словарей из JSON.
"""
workers_json = json.dumps([
{"queue": "queue1", "concurrency": 2},
"invalid_item",
123,
{"queue": "queue2", "concurrency": 3}
])
workers_json = json.dumps(
[
{"queue": "queue1", "concurrency": 2},
"invalid_item",
123,
{"queue": "queue2", "concurrency": 3},
]
)
with patch.dict("os.environ", {"WORKERS_JSON": workers_json}):
settings = WorkerSettings()

View File

@ -1,7 +1,7 @@
# tests/unit/test_context.py
from __future__ import annotations
from unittest.mock import AsyncMock, Mock, patch
import pytest
from dataloader.context import AppContext, get_session
@ -71,10 +71,16 @@ class TestAppContext:
mock_engine = Mock()
mock_sm = Mock()
with patch("dataloader.logger.logger.setup_logging") as mock_setup_logging, \
patch("dataloader.storage.engine.create_engine", return_value=mock_engine) as mock_create_engine, \
patch("dataloader.storage.engine.create_sessionmaker", return_value=mock_sm) as mock_create_sm, \
patch("dataloader.context.APP_CONFIG") as mock_config:
with (
patch("dataloader.logger.logger.setup_logging") as mock_setup_logging,
patch(
"dataloader.storage.engine.create_engine", return_value=mock_engine
) as mock_create_engine,
patch(
"dataloader.storage.engine.create_sessionmaker", return_value=mock_sm
) as mock_create_sm,
patch("dataloader.context.APP_CONFIG") as mock_config,
):
mock_config.pg.url = "postgresql://test"
@ -194,7 +200,7 @@ class TestGetSession:
with patch("dataloader.context.APP_CTX") as mock_ctx:
mock_ctx.sessionmaker = mock_sm
async for session in get_session():
async for _session in get_session():
pass
assert mock_exit.call_count == 1

View File

@ -1,8 +1,8 @@
# tests/unit/test_notify_listener.py
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, Mock, patch
import pytest
from dataloader.storage.notify_listener import PGNotifyListener
@ -25,7 +25,7 @@ class TestPGNotifyListener:
dsn="postgresql://test",
queue="test_queue",
callback=callback,
stop_event=stop_event
stop_event=stop_event,
)
assert listener._dsn == "postgresql://test"
@ -47,14 +47,16 @@ class TestPGNotifyListener:
dsn="postgresql://test",
queue="test_queue",
callback=callback,
stop_event=stop_event
stop_event=stop_event,
)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock()
mock_conn.add_listener = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start()
assert listener._conn == mock_conn
@ -76,14 +78,16 @@ class TestPGNotifyListener:
dsn="postgresql+asyncpg://test",
queue="test_queue",
callback=callback,
stop_event=stop_event
stop_event=stop_event,
)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock()
mock_conn.add_listener = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn) as mock_connect:
with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
) as mock_connect:
await listener.start()
mock_connect.assert_called_once_with("postgresql://test")
@ -102,14 +106,16 @@ class TestPGNotifyListener:
dsn="postgresql://test",
queue="test_queue",
callback=callback,
stop_event=stop_event
stop_event=stop_event,
)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock()
mock_conn.add_listener = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start()
handler = listener._on_notify_handler
@ -131,14 +137,16 @@ class TestPGNotifyListener:
dsn="postgresql://test",
queue="test_queue",
callback=callback,
stop_event=stop_event
stop_event=stop_event,
)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock()
mock_conn.add_listener = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start()
handler = listener._on_notify_handler
@ -160,14 +168,16 @@ class TestPGNotifyListener:
dsn="postgresql://test",
queue="test_queue",
callback=callback,
stop_event=stop_event
stop_event=stop_event,
)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock()
mock_conn.add_listener = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start()
handler = listener._on_notify_handler
@ -189,14 +199,16 @@ class TestPGNotifyListener:
dsn="postgresql://test",
queue="test_queue",
callback=callback,
stop_event=stop_event
stop_event=stop_event,
)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock()
mock_conn.add_listener = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start()
handler = listener._on_notify_handler
@ -218,7 +230,7 @@ class TestPGNotifyListener:
dsn="postgresql://test",
queue="test_queue",
callback=callback,
stop_event=stop_event
stop_event=stop_event,
)
mock_conn = AsyncMock()
@ -227,7 +239,9 @@ class TestPGNotifyListener:
mock_conn.remove_listener = AsyncMock()
mock_conn.close = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start()
assert listener._task is not None
@ -250,7 +264,7 @@ class TestPGNotifyListener:
dsn="postgresql://test",
queue="test_queue",
callback=callback,
stop_event=stop_event
stop_event=stop_event,
)
mock_conn = AsyncMock()
@ -259,7 +273,9 @@ class TestPGNotifyListener:
mock_conn.remove_listener = AsyncMock()
mock_conn.close = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start()
await listener.stop()
@ -280,7 +296,7 @@ class TestPGNotifyListener:
dsn="postgresql://test",
queue="test_queue",
callback=callback,
stop_event=stop_event
stop_event=stop_event,
)
mock_conn = AsyncMock()
@ -289,7 +305,9 @@ class TestPGNotifyListener:
mock_conn.remove_listener = AsyncMock()
mock_conn.close = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start()
stop_event.set()
@ -311,7 +329,7 @@ class TestPGNotifyListener:
dsn="postgresql://test",
queue="test_queue",
callback=callback,
stop_event=stop_event
stop_event=stop_event,
)
mock_conn = AsyncMock()
@ -320,7 +338,9 @@ class TestPGNotifyListener:
mock_conn.remove_listener = AsyncMock(side_effect=Exception("Remove error"))
mock_conn.close = AsyncMock()
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start()
await listener.stop()
@ -340,7 +360,7 @@ class TestPGNotifyListener:
dsn="postgresql://test",
queue="test_queue",
callback=callback,
stop_event=stop_event
stop_event=stop_event,
)
mock_conn = AsyncMock()
@ -349,7 +369,9 @@ class TestPGNotifyListener:
mock_conn.remove_listener = AsyncMock()
mock_conn.close = AsyncMock(side_effect=Exception("Close error"))
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
with patch(
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
):
await listener.start()
await listener.stop()
@ -368,7 +390,7 @@ class TestPGNotifyListener:
dsn="postgresql://test",
queue="test_queue",
callback=callback,
stop_event=stop_event
stop_event=stop_event,
)
await listener.stop()

View File

@ -1,9 +1,8 @@
# tests/unit/test_pipeline_registry.py
from __future__ import annotations
import pytest
from dataloader.workers.pipelines.registry import register, resolve, tasks, _Registry
from dataloader.workers.pipelines.registry import _Registry, register, resolve, tasks
@pytest.mark.unit
@ -22,6 +21,7 @@ class TestPipelineRegistry:
"""
Тест регистрации пайплайна.
"""
@register("test.task")
def test_pipeline(args: dict):
return "result"
@ -33,6 +33,7 @@ class TestPipelineRegistry:
"""
Тест получения зарегистрированного пайплайна.
"""
@register("test.resolve")
def test_pipeline(args: dict):
return "resolved"
@ -54,6 +55,7 @@ class TestPipelineRegistry:
"""
Тест получения списка зарегистрированных задач.
"""
@register("task1")
def pipeline1(args: dict):
pass
@ -70,6 +72,7 @@ class TestPipelineRegistry:
"""
Тест перезаписи существующего пайплайна.
"""
@register("overwrite.task")
def first_pipeline(args: dict):
return "first"

View File

@ -1,9 +1,8 @@
# tests/unit/test_workers_base.py
from __future__ import annotations
import asyncio
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from dataloader.workers.base import PGWorker, WorkerConfig
@ -41,9 +40,11 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls:
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls,
):
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
@ -64,7 +65,9 @@ class TestPGWorker:
stop_event.set()
return False
with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim):
with patch.object(
worker, "_claim_and_execute_once", side_effect=mock_claim
):
await worker.run()
assert mock_listener.start.call_count == 1
@ -78,9 +81,11 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls:
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls,
):
mock_logger = Mock()
mock_ctx.get_logger.return_value = mock_logger
@ -101,7 +106,9 @@ class TestPGWorker:
stop_event.set()
return False
with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim):
with patch.object(
worker, "_claim_and_execute_once", side_effect=mock_claim
):
await worker.run()
assert worker._listener is None
@ -156,8 +163,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_session.commit = AsyncMock()
@ -185,8 +194,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_sm = MagicMock()
@ -196,12 +207,14 @@ class TestPGWorker:
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.claim_one = AsyncMock(return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {"key": "value"}
})
mock_repo.claim_one = AsyncMock(
return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {"key": "value"},
}
)
mock_repo.finish_ok = AsyncMock()
mock_repo_cls.return_value = mock_repo
@ -210,8 +223,10 @@ class TestPGWorker:
async def mock_pipeline(task, args):
yield
with patch.object(worker, "_pipeline", side_effect=mock_pipeline), \
patch.object(worker, "_execute_with_heartbeat", return_value=False):
with (
patch.object(worker, "_pipeline", side_effect=mock_pipeline),
patch.object(worker, "_execute_with_heartbeat", return_value=False),
):
result = await worker._claim_and_execute_once()
assert result is True
@ -225,8 +240,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_sm = MagicMock()
@ -236,12 +253,14 @@ class TestPGWorker:
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.claim_one = AsyncMock(return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {}
})
mock_repo.claim_one = AsyncMock(
return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {},
}
)
mock_repo.finish_fail_or_retry = AsyncMock()
mock_repo_cls.return_value = mock_repo
@ -263,8 +282,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_sm = MagicMock()
@ -274,18 +295,22 @@ class TestPGWorker:
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.claim_one = AsyncMock(return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {}
})
mock_repo.claim_one = AsyncMock(
return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {},
}
)
mock_repo.finish_fail_or_retry = AsyncMock()
mock_repo_cls.return_value = mock_repo
worker = PGWorker(cfg, stop_event)
with patch.object(worker, "_execute_with_heartbeat", side_effect=ValueError("Test error")):
with patch.object(
worker, "_execute_with_heartbeat", side_effect=ValueError("Test error")
):
result = await worker._claim_and_execute_once()
assert result is True
@ -301,8 +326,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_sm = MagicMock()
@ -323,7 +350,9 @@ class TestPGWorker:
await asyncio.sleep(0.6)
yield
canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline())
canceled = await worker._execute_with_heartbeat(
"job-id", 60, slow_pipeline()
)
assert canceled is False
assert mock_repo.heartbeat.call_count >= 1
@ -336,8 +365,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_sm = MagicMock()
@ -358,7 +389,9 @@ class TestPGWorker:
await asyncio.sleep(0.6)
yield
canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline())
canceled = await worker._execute_with_heartbeat(
"job-id", 60, slow_pipeline()
)
assert canceled is True
@ -370,8 +403,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
):
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
@ -397,8 +432,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
):
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
@ -424,8 +461,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
):
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
@ -450,8 +489,10 @@ class TestPGWorker:
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
stop_event = asyncio.Event()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_session = AsyncMock()
mock_sm = MagicMock()
@ -461,12 +502,14 @@ class TestPGWorker:
mock_ctx.sessionmaker = mock_sm
mock_repo = Mock()
mock_repo.claim_one = AsyncMock(return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {}
})
mock_repo.claim_one = AsyncMock(
return_value={
"job_id": "test-job-id",
"lease_ttl_sec": 60,
"task": "test.task",
"args": {},
}
)
mock_repo.finish_fail_or_retry = AsyncMock()
mock_repo_cls.return_value = mock_repo
@ -484,19 +527,16 @@ class TestPGWorker:
assert "cancelled by shutdown" in args[1]
assert kwargs.get("is_canceled") is True
@pytest.mark.asyncio
async def test_execute_with_heartbeat_raises_cancelled_when_stop_set(self):
cfg = WorkerConfig(queue="test", heartbeat_sec=1000, claim_backoff_sec=5)
stop_event = asyncio.Event()
stop_event.set()
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
with (
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
):
mock_ctx.get_logger.return_value = Mock()
mock_ctx.sessionmaker = Mock()
@ -509,4 +549,3 @@ class TestPGWorker:
with pytest.raises(asyncio.CancelledError):
await worker._execute_with_heartbeat("job-id", 60, one_yield())

View File

@ -1,8 +1,8 @@
# tests/unit/test_workers_manager.py
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from dataloader.workers.manager import WorkerManager, WorkerSpec, build_manager_from_env
@ -39,9 +39,11 @@ class TestWorkerManager:
WorkerSpec(queue="queue2", concurrency=1),
]
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls:
with (
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
):
mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.heartbeat_sec = 10
@ -67,9 +69,11 @@ class TestWorkerManager:
"""
specs = [WorkerSpec(queue="test", concurrency=0)]
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls:
with (
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
):
mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.heartbeat_sec = 10
@ -93,9 +97,11 @@ class TestWorkerManager:
"""
specs = [WorkerSpec(queue="test", concurrency=2)]
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls:
with (
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
):
mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.heartbeat_sec = 10
@ -108,8 +114,6 @@ class TestWorkerManager:
manager = WorkerManager(specs)
await manager.start()
initial_task_count = len(manager._tasks)
await manager.stop()
assert manager._stop.is_set()
@ -123,10 +127,12 @@ class TestWorkerManager:
"""
specs = [WorkerSpec(queue="test", concurrency=1)]
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, \
patch("dataloader.workers.manager.requeue_lost") as mock_requeue:
with (
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
patch("dataloader.workers.manager.requeue_lost") as mock_requeue,
):
mock_logger = Mock()
mock_ctx.get_logger.return_value = mock_logger
@ -161,10 +167,12 @@ class TestWorkerManager:
"""
specs = [WorkerSpec(queue="test", concurrency=1)]
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, \
patch("dataloader.workers.manager.requeue_lost") as mock_requeue:
with (
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
patch("dataloader.workers.manager.requeue_lost") as mock_requeue,
):
mock_logger = Mock()
mock_ctx.get_logger.return_value = mock_logger
@ -203,8 +211,10 @@ class TestBuildManagerFromEnv:
"""
Тест создания менеджера из конфигурации.
"""
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg:
with (
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
):
mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.parsed_workers.return_value = [
@ -224,8 +234,10 @@ class TestBuildManagerFromEnv:
"""
Тест, что пустые имена очередей пропускаются.
"""
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg:
with (
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
):
mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.parsed_workers.return_value = [
@ -243,8 +255,10 @@ class TestBuildManagerFromEnv:
"""
Тест обработки отсутствующих полей с дефолтными значениями.
"""
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg:
with (
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
):
mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.parsed_workers.return_value = [
@ -262,8 +276,10 @@ class TestBuildManagerFromEnv:
"""
Тест, что concurrency всегда минимум 1.
"""
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg:
with (
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
):
mock_ctx.get_logger.return_value = Mock()
mock_cfg.worker.parsed_workers.return_value = [

View File

@ -1,8 +1,8 @@
# tests/unit/test_workers_reaper.py
from __future__ import annotations
from unittest.mock import AsyncMock, Mock, patch
import pytest
from unittest.mock import AsyncMock, patch, Mock
from dataloader.workers.reaper import requeue_lost