refactor: refactor code
This commit is contained in:
parent
c907e1d4da
commit
bde0bb0e6f
|
|
@ -34,3 +34,37 @@ dataloader = "dataloader.__main__:main"
|
|||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ["py311"]
|
||||
skip-string-normalization = false
|
||||
preview = true
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 88
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
combine_as_imports = true
|
||||
known_first_party = ["dataloader"]
|
||||
src_paths = ["src"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
target-version = "py311"
|
||||
fix = true
|
||||
lint.select = ["E", "F", "W", "I", "N", "B"]
|
||||
lint.ignore = [
|
||||
"E501",
|
||||
]
|
||||
exclude = [
|
||||
".git",
|
||||
"__pycache__",
|
||||
"build",
|
||||
"dist",
|
||||
".venv",
|
||||
"venv",
|
||||
]
|
||||
|
|
@ -5,7 +5,6 @@
|
|||
|
||||
from . import api
|
||||
|
||||
|
||||
__all__ = [
|
||||
"api",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
import uvicorn
|
||||
|
||||
|
||||
from dataloader.api import app_main
|
||||
from dataloader.config import APP_CONFIG
|
||||
from dataloader.logger.uvicorn_logging_config import LOGGING_CONFIG, setup_uvicorn_logging
|
||||
from dataloader.logger.uvicorn_logging_config import (
|
||||
LOGGING_CONFIG,
|
||||
setup_uvicorn_logging,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Инициализируем логирование uvicorn перед запуском
|
||||
|
||||
setup_uvicorn_logging()
|
||||
|
||||
|
||||
uvicorn.run(
|
||||
app_main,
|
||||
host=APP_CONFIG.app.app_host,
|
||||
|
|
|
|||
|
|
@ -1,20 +1,19 @@
|
|||
# src/dataloader/api/__init__.py
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
import contextlib
|
||||
import typing as tp
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from dataloader.context import APP_CTX
|
||||
from dataloader.workers.manager import WorkerManager, build_manager_from_env
|
||||
from dataloader.workers.pipelines import load_all as load_pipelines
|
||||
|
||||
from .metric_router import router as metric_router
|
||||
from .middleware import log_requests
|
||||
from .os_router import router as service_router
|
||||
from .v1 import router as v1_router
|
||||
from dataloader.context import APP_CTX
|
||||
from dataloader.workers.manager import build_manager_from_env, WorkerManager
|
||||
from dataloader.workers.pipelines import load_all as load_pipelines
|
||||
|
||||
|
||||
_manager: WorkerManager | None = None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
""" 🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!
|
||||
"""
|
||||
"""🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Header, status
|
||||
|
||||
from dataloader.context import APP_CTX
|
||||
|
||||
from . import schemas
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -17,8 +18,7 @@ logger = APP_CTX.get_logger()
|
|||
response_model=schemas.RateResponse,
|
||||
)
|
||||
async def like(
|
||||
# pylint: disable=C0103,W0613
|
||||
header_Request_Id: str = Header(uuid.uuid4(), alias="Request-Id")
|
||||
header_request_id: str = Header(uuid.uuid4(), alias="Request-Id")
|
||||
) -> dict[str, str]:
|
||||
logger.metric(
|
||||
metric_name="dataloader_likes_total",
|
||||
|
|
@ -33,8 +33,7 @@ async def like(
|
|||
response_model=schemas.RateResponse,
|
||||
)
|
||||
async def dislike(
|
||||
# pylint: disable=C0103,W0613
|
||||
header_Request_Id: str = Header(uuid.uuid4(), alias="Request-Id")
|
||||
header_request_id: str = Header(uuid.uuid4(), alias="Request-Id")
|
||||
) -> dict[str, str]:
|
||||
logger.metric(
|
||||
metric_name="dataloader_dislikes_total",
|
||||
|
|
|
|||
|
|
@ -38,8 +38,14 @@ async def log_requests(request: Request, call_next) -> any:
|
|||
logger = APP_CTX.get_logger()
|
||||
request_path = request.url.path
|
||||
|
||||
allowed_headers_to_log = ((k, request.headers.get(k)) for k in HEADERS_WHITE_LIST_TO_LOG)
|
||||
headers_to_log = {header_name: header_value for header_name, header_value in allowed_headers_to_log if header_value}
|
||||
allowed_headers_to_log = (
|
||||
(k, request.headers.get(k)) for k in HEADERS_WHITE_LIST_TO_LOG
|
||||
)
|
||||
headers_to_log = {
|
||||
header_name: header_value
|
||||
for header_name, header_value in allowed_headers_to_log
|
||||
if header_value
|
||||
}
|
||||
|
||||
APP_CTX.get_context_vars_container().set_context_vars(
|
||||
request_id=headers_to_log.get("Request-Id", ""),
|
||||
|
|
@ -50,7 +56,9 @@ async def log_requests(request: Request, call_next) -> any:
|
|||
|
||||
if request_path in NON_LOGGED_ENDPOINTS:
|
||||
response = await call_next(request)
|
||||
logger.debug(f"Processed request for {request_path} with code {response.status_code}")
|
||||
logger.debug(
|
||||
f"Processed request for {request_path} with code {response.status_code}"
|
||||
)
|
||||
elif headers_to_log.get("Request-Id", None):
|
||||
raw_request_body = await request.body()
|
||||
request_body_decoded = _get_decoded_body(raw_request_body, "request", logger)
|
||||
|
|
@ -80,7 +88,7 @@ async def log_requests(request: Request, call_next) -> any:
|
|||
event_params=json.dumps(
|
||||
request_body_decoded,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
|
|
@ -88,12 +96,16 @@ async def log_requests(request: Request, call_next) -> any:
|
|||
response_body = [chunk async for chunk in response.body_iterator]
|
||||
response.body_iterator = iterate_in_threadpool(iter(response_body))
|
||||
|
||||
headers_to_log["Response-Time"] = datetime.now(APP_CTX.get_pytz_timezone()).isoformat()
|
||||
headers_to_log["Response-Time"] = datetime.now(
|
||||
APP_CTX.get_pytz_timezone()
|
||||
).isoformat()
|
||||
for header in headers_to_log:
|
||||
response.headers[header] = headers_to_log[header]
|
||||
|
||||
response_body_extracted = response_body[0] if len(response_body) > 0 else b""
|
||||
decoded_response_body = _get_decoded_body(response_body_extracted, "response", logger)
|
||||
decoded_response_body = _get_decoded_body(
|
||||
response_body_extracted, "response", logger
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Outgoing response to client system",
|
||||
|
|
@ -115,11 +127,13 @@ async def log_requests(request: Request, call_next) -> any:
|
|||
event_params=json.dumps(
|
||||
decoded_response_body,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
processing_time_ms = int(round((time.time() - start_time), 3) * 1000)
|
||||
logger.info(f"Request processing time for {request_path}: {processing_time_ms} ms")
|
||||
logger.info(
|
||||
f"Request processing time for {request_path}: {processing_time_ms} ms"
|
||||
)
|
||||
logger.metric(
|
||||
metric_name="dataloader_process_duration_ms",
|
||||
metric_value=processing_time_ms,
|
||||
|
|
@ -138,7 +152,9 @@ async def log_requests(request: Request, call_next) -> any:
|
|||
else:
|
||||
logger.info(f"Incoming {request.method}-request with no id for {request_path}")
|
||||
response = await call_next(request)
|
||||
logger.info(f"Request with no id for {request_path} processing time: {time.time() - start_time:.3f} s")
|
||||
logger.info(
|
||||
f"Request with no id for {request_path} processing time: {time.time() - start_time:.3f} s"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
# Инфраструктурные endpoint'ы (/health, /status)
|
||||
""" 🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!
|
||||
"""
|
||||
"""🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!"""
|
||||
|
||||
from importlib.metadata import distribution
|
||||
|
||||
|
|
|
|||
|
|
@ -3,22 +3,26 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Ответ для ручки /health"""
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={"example": {"status": "running"}}
|
||||
)
|
||||
|
||||
status: str = Field(default="running", description="Service health check", max_length=7)
|
||||
model_config = ConfigDict(json_schema_extra={"example": {"status": "running"}})
|
||||
|
||||
status: str = Field(
|
||||
default="running", description="Service health check", max_length=7
|
||||
)
|
||||
|
||||
|
||||
class InfoResponse(BaseModel):
|
||||
"""Ответ для ручки /info"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"name": "rest-template",
|
||||
"description": "Python 'AI gateway' template for developing REST microservices",
|
||||
"description": (
|
||||
"Python 'AI gateway' template for developing REST microservices"
|
||||
),
|
||||
"type": "REST API",
|
||||
"version": "0.1.0"
|
||||
"version": "0.1.0",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
@ -26,11 +30,14 @@ class InfoResponse(BaseModel):
|
|||
name: str = Field(description="Service name", max_length=50)
|
||||
description: str = Field(description="Service description", max_length=200)
|
||||
type: str = Field(default="REST API", description="Service type", max_length=20)
|
||||
version: str = Field(description="Service version", max_length=20, pattern=r"^\d+\.\d+\.\d+")
|
||||
version: str = Field(
|
||||
description="Service version", max_length=20, pattern=r"^\d+\.\d+\.\d+"
|
||||
)
|
||||
|
||||
|
||||
class RateResponse(BaseModel):
|
||||
"""Ответ для записи рейтинга"""
|
||||
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
|
||||
rating_result: str = Field(description="Rating that was recorded", max_length=50)
|
||||
|
|
|
|||
|
|
@ -5,12 +5,8 @@ from fastapi import HTTPException, status
|
|||
|
||||
class JobNotFoundError(HTTPException):
|
||||
"""Задача не найдена."""
|
||||
|
||||
|
||||
def __init__(self, job_id: str):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Job {job_id} not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=f"Job {job_id} not found"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,2 +1 @@
|
|||
"""Модели данных для API v1."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
# src/dataloader/api/v1/router.py
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from dataloader.api.v1.exceptions import JobNotFoundError
|
||||
from dataloader.api.v1.schemas import (
|
||||
|
|
@ -16,8 +15,6 @@ from dataloader.api.v1.schemas import (
|
|||
)
|
||||
from dataloader.api.v1.service import JobsService
|
||||
from dataloader.context import get_session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
|
||||
|
|
@ -40,7 +37,9 @@ async def trigger_job(
|
|||
return await svc.trigger(payload)
|
||||
|
||||
|
||||
@router.get("/{job_id}/status", response_model=JobStatusResponse, status_code=HTTPStatus.OK)
|
||||
@router.get(
|
||||
"/{job_id}/status", response_model=JobStatusResponse, status_code=HTTPStatus.OK
|
||||
)
|
||||
async def get_status(
|
||||
job_id: UUID,
|
||||
svc: Annotated[JobsService, Depends(get_service)],
|
||||
|
|
@ -54,7 +53,9 @@ async def get_status(
|
|||
return st
|
||||
|
||||
|
||||
@router.post("/{job_id}/cancel", response_model=JobStatusResponse, status_code=HTTPStatus.OK)
|
||||
@router.post(
|
||||
"/{job_id}/cancel", response_model=JobStatusResponse, status_code=HTTPStatus.OK
|
||||
)
|
||||
async def cancel_job(
|
||||
job_id: UUID,
|
||||
svc: Annotated[JobsService, Depends(get_service)],
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# src/dataloader/api/v1/schemas.py
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
|
@ -12,6 +11,7 @@ class TriggerJobRequest(BaseModel):
|
|||
"""
|
||||
Запрос на постановку задачи в очередь.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
|
||||
queue: str = Field(...)
|
||||
|
|
@ -39,6 +39,7 @@ class TriggerJobResponse(BaseModel):
|
|||
"""
|
||||
Ответ на постановку задачи.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
|
||||
job_id: UUID = Field(...)
|
||||
|
|
@ -49,6 +50,7 @@ class JobStatusResponse(BaseModel):
|
|||
"""
|
||||
Текущий статус задачи.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
|
||||
job_id: UUID = Field(...)
|
||||
|
|
@ -65,6 +67,7 @@ class CancelJobResponse(BaseModel):
|
|||
"""
|
||||
Ответ на запрос отмены задачи.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
|
||||
job_id: UUID = Field(...)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# src/dataloader/api/v1/service.py
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
|
@ -13,15 +12,16 @@ from dataloader.api.v1.schemas import (
|
|||
TriggerJobResponse,
|
||||
)
|
||||
from dataloader.api.v1.utils import new_job_id
|
||||
from dataloader.storage.schemas import CreateJobRequest
|
||||
from dataloader.storage.repositories import QueueRepository
|
||||
from dataloader.logger.logger import get_logger
|
||||
from dataloader.storage.repositories import QueueRepository
|
||||
from dataloader.storage.schemas import CreateJobRequest
|
||||
|
||||
|
||||
class JobsService:
|
||||
"""
|
||||
Бизнес-логика работы с очередью задач.
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self._s = session
|
||||
self._repo = QueueRepository(self._s)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# src/dataloader/api/v1/utils.py
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID, uuid4
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# src/dataloader/config.py
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from logging import DEBUG, INFO
|
||||
from typing import Annotated, Any
|
||||
|
||||
|
|
@ -32,6 +31,7 @@ class BaseAppSettings(BaseSettings):
|
|||
"""
|
||||
Базовый класс для настроек.
|
||||
"""
|
||||
|
||||
local: bool = Field(validation_alias="LOCAL", default=False)
|
||||
debug: bool = Field(validation_alias="DEBUG", default=False)
|
||||
|
||||
|
|
@ -44,6 +44,7 @@ class AppSettings(BaseAppSettings):
|
|||
"""
|
||||
Настройки приложения.
|
||||
"""
|
||||
|
||||
app_host: str = Field(validation_alias="APP_HOST", default="0.0.0.0")
|
||||
app_port: int = Field(validation_alias="APP_PORT", default=8081)
|
||||
kube_net_name: str = Field(validation_alias="PROJECT_NAME", default="AIGATEWAY")
|
||||
|
|
@ -54,15 +55,28 @@ class LogSettings(BaseAppSettings):
|
|||
"""
|
||||
Настройки логирования.
|
||||
"""
|
||||
|
||||
private_log_file_path: str = Field(validation_alias="LOG_PATH", default=os.getcwd())
|
||||
private_log_file_name: str = Field(validation_alias="LOG_FILE_NAME", default="app.log")
|
||||
private_log_file_name: str = Field(
|
||||
validation_alias="LOG_FILE_NAME", default="app.log"
|
||||
)
|
||||
log_rotation: str = Field(validation_alias="LOG_ROTATION", default="10 MB")
|
||||
private_metric_file_path: str = Field(validation_alias="METRIC_PATH", default=os.getcwd())
|
||||
private_metric_file_name: str = Field(validation_alias="METRIC_FILE_NAME", default="app-metric.log")
|
||||
private_audit_file_path: str = Field(validation_alias="AUDIT_LOG_PATH", default=os.getcwd())
|
||||
private_audit_file_name: str = Field(validation_alias="AUDIT_LOG_FILE_NAME", default="events.log")
|
||||
private_metric_file_path: str = Field(
|
||||
validation_alias="METRIC_PATH", default=os.getcwd()
|
||||
)
|
||||
private_metric_file_name: str = Field(
|
||||
validation_alias="METRIC_FILE_NAME", default="app-metric.log"
|
||||
)
|
||||
private_audit_file_path: str = Field(
|
||||
validation_alias="AUDIT_LOG_PATH", default=os.getcwd()
|
||||
)
|
||||
private_audit_file_name: str = Field(
|
||||
validation_alias="AUDIT_LOG_FILE_NAME", default="events.log"
|
||||
)
|
||||
audit_host_ip: str = Field(validation_alias="HOST_IP", default="127.0.0.1")
|
||||
audit_host_uid: str = Field(validation_alias="HOST_UID", default="63b6dcee-170b-49bf-a65c-3ec967398ccd")
|
||||
audit_host_uid: str = Field(
|
||||
validation_alias="HOST_UID", default="63b6dcee-170b-49bf-a65c-3ec967398ccd"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_file_abs_path(path_name: str, file_name: str) -> str:
|
||||
|
|
@ -70,15 +84,21 @@ class LogSettings(BaseAppSettings):
|
|||
|
||||
@property
|
||||
def log_file_abs_path(self) -> str:
|
||||
return self.get_file_abs_path(self.private_log_file_path, self.private_log_file_name)
|
||||
return self.get_file_abs_path(
|
||||
self.private_log_file_path, self.private_log_file_name
|
||||
)
|
||||
|
||||
@property
|
||||
def metric_file_abs_path(self) -> str:
|
||||
return self.get_file_abs_path(self.private_metric_file_path, self.private_metric_file_name)
|
||||
return self.get_file_abs_path(
|
||||
self.private_metric_file_path, self.private_metric_file_name
|
||||
)
|
||||
|
||||
@property
|
||||
def audit_file_abs_path(self) -> str:
|
||||
return self.get_file_abs_path(self.private_audit_file_path, self.private_audit_file_name)
|
||||
return self.get_file_abs_path(
|
||||
self.private_audit_file_path, self.private_audit_file_name
|
||||
)
|
||||
|
||||
@property
|
||||
def log_lvl(self) -> int:
|
||||
|
|
@ -89,6 +109,7 @@ class PGSettings(BaseSettings):
|
|||
"""
|
||||
Настройки подключения к Postgres.
|
||||
"""
|
||||
|
||||
host: str = Field(validation_alias="PG_HOST", default="localhost")
|
||||
port: int = Field(validation_alias="PG_PORT", default=5432)
|
||||
user: str = Field(validation_alias="PG_USER", default="postgres")
|
||||
|
|
@ -116,9 +137,12 @@ class WorkerSettings(BaseSettings):
|
|||
"""
|
||||
Настройки очереди и воркеров.
|
||||
"""
|
||||
|
||||
workers_json: str = Field(validation_alias="WORKERS_JSON", default="[]")
|
||||
heartbeat_sec: int = Field(validation_alias="DL_HEARTBEAT_SEC", default=10)
|
||||
default_lease_ttl_sec: int = Field(validation_alias="DL_DEFAULT_LEASE_TTL_SEC", default=60)
|
||||
default_lease_ttl_sec: int = Field(
|
||||
validation_alias="DL_DEFAULT_LEASE_TTL_SEC", default=60
|
||||
)
|
||||
reaper_period_sec: int = Field(validation_alias="DL_REAPER_PERIOD_SEC", default=10)
|
||||
claim_backoff_sec: int = Field(validation_alias="DL_CLAIM_BACKOFF_SEC", default=15)
|
||||
|
||||
|
|
@ -137,6 +161,7 @@ class CertsSettings(BaseSettings):
|
|||
"""
|
||||
Настройки SSL сертификатов для локальной разработки.
|
||||
"""
|
||||
|
||||
ca_bundle_file: str = Field(validation_alias="CA_BUNDLE_FILE", default="")
|
||||
cert_file: str = Field(validation_alias="CERT_FILE", default="")
|
||||
key_file: str = Field(validation_alias="KEY_FILE", default="")
|
||||
|
|
@ -146,21 +171,24 @@ class SuperTeneraSettings(BaseAppSettings):
|
|||
"""
|
||||
Настройки интеграции с SuperTenera.
|
||||
"""
|
||||
|
||||
host: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
||||
validation_alias="SUPERTENERA_HOST",
|
||||
default="ci03801737-ift-tenera-giga.delta.sbrf.ru/atlant360bc/"
|
||||
default="ci03801737-ift-tenera-giga.delta.sbrf.ru/atlant360bc/",
|
||||
)
|
||||
port: str = Field(validation_alias="SUPERTENERA_PORT", default="443")
|
||||
quotes_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
||||
validation_alias="SUPERTENERA_QUOTES_ENDPOINT",
|
||||
default="/get_gigaparser_quotes/"
|
||||
default="/get_gigaparser_quotes/",
|
||||
)
|
||||
timeout: int = Field(validation_alias="SUPERTENERA_TIMEOUT", default=20)
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
"""Возвращает абсолютный URL для SuperTenera."""
|
||||
domain, raw_path = self.host.split("/", 1) if "/" in self.host else (self.host, "")
|
||||
domain, raw_path = (
|
||||
self.host.split("/", 1) if "/" in self.host else (self.host, "")
|
||||
)
|
||||
return build_url(self.protocol, domain, self.port, raw_path)
|
||||
|
||||
|
||||
|
|
@ -168,22 +196,21 @@ class Gmap2BriefSettings(BaseAppSettings):
|
|||
"""
|
||||
Настройки интеграции с Gmap2Brief (OPU API).
|
||||
"""
|
||||
|
||||
host: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
||||
validation_alias="GMAP2BRIEF_HOST",
|
||||
default="ci02533826-tib-brief.apps.ift-terra000024-edm.ocp.delta.sbrf.ru"
|
||||
default="ci02533826-tib-brief.apps.ift-terra000024-edm.ocp.delta.sbrf.ru",
|
||||
)
|
||||
port: str = Field(validation_alias="GMAP2BRIEF_PORT", default="443")
|
||||
start_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
||||
validation_alias="GMAP2BRIEF_START_ENDPOINT",
|
||||
default="/export/opu/start"
|
||||
validation_alias="GMAP2BRIEF_START_ENDPOINT", default="/export/opu/start"
|
||||
)
|
||||
status_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
||||
validation_alias="GMAP2BRIEF_STATUS_ENDPOINT",
|
||||
default="/export/{job_id}/status"
|
||||
validation_alias="GMAP2BRIEF_STATUS_ENDPOINT", default="/export/{job_id}/status"
|
||||
)
|
||||
download_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
||||
validation_alias="GMAP2BRIEF_DOWNLOAD_ENDPOINT",
|
||||
default="/export/{job_id}/download"
|
||||
default="/export/{job_id}/download",
|
||||
)
|
||||
poll_interval: int = Field(validation_alias="GMAP2BRIEF_POLL_INTERVAL", default=2)
|
||||
timeout: int = Field(validation_alias="GMAP2BRIEF_TIMEOUT", default=3600)
|
||||
|
|
@ -198,6 +225,7 @@ class Secrets:
|
|||
"""
|
||||
Агрегатор настроек приложения.
|
||||
"""
|
||||
|
||||
app: AppSettings = AppSettings()
|
||||
log: LogSettings = LogSettings()
|
||||
pg: PGSettings = PGSettings()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
# src/dataloader/context.py
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from logging import Logger
|
||||
from typing import AsyncGenerator
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
|
@ -15,6 +14,7 @@ class AppContext:
|
|||
"""
|
||||
Контекст приложения, хранящий глобальные зависимости (Singleton pattern).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._engine: AsyncEngine | None = None
|
||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = None
|
||||
|
|
@ -75,6 +75,7 @@ class AppContext:
|
|||
Экземпляр Logger
|
||||
"""
|
||||
from .logger.logger import get_logger as get_app_logger
|
||||
|
||||
return get_app_logger(name)
|
||||
|
||||
def get_context_vars_container(self) -> ContextVarsContainer:
|
||||
|
|
|
|||
|
|
@ -1,2 +1 @@
|
|||
"""Исключения уровня приложения."""
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,10 @@ from typing import TYPE_CHECKING
|
|||
import httpx
|
||||
|
||||
from dataloader.config import APP_CONFIG
|
||||
from dataloader.interfaces.gmap2_brief.schemas import ExportJobStatus, StartExportResponse
|
||||
from dataloader.interfaces.gmap2_brief.schemas import (
|
||||
ExportJobStatus,
|
||||
StartExportResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
|
@ -37,10 +40,11 @@ class Gmap2BriefInterface:
|
|||
|
||||
self._ssl_context = None
|
||||
if APP_CONFIG.app.local and APP_CONFIG.certs.cert_file:
|
||||
self._ssl_context = ssl.create_default_context(cafile=APP_CONFIG.certs.ca_bundle_file)
|
||||
self._ssl_context = ssl.create_default_context(
|
||||
cafile=APP_CONFIG.certs.ca_bundle_file
|
||||
)
|
||||
self._ssl_context.load_cert_chain(
|
||||
certfile=APP_CONFIG.certs.cert_file,
|
||||
keyfile=APP_CONFIG.certs.key_file
|
||||
certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file
|
||||
)
|
||||
|
||||
async def start_export(self) -> str:
|
||||
|
|
@ -56,9 +60,13 @@ class Gmap2BriefInterface:
|
|||
url = self.base_url + APP_CONFIG.gmap2brief.start_endpoint
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None,
|
||||
cert=(
|
||||
(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file)
|
||||
if APP_CONFIG.app.local
|
||||
else None
|
||||
),
|
||||
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
|
||||
timeout=self.timeout
|
||||
timeout=self.timeout,
|
||||
) as client:
|
||||
try:
|
||||
self.logger.info(f"Starting OPU export: POST {url}")
|
||||
|
|
@ -87,12 +95,18 @@ class Gmap2BriefInterface:
|
|||
Raises:
|
||||
Gmap2BriefConnectionError: При ошибке запроса
|
||||
"""
|
||||
url = self.base_url + APP_CONFIG.gmap2brief.status_endpoint.format(job_id=job_id)
|
||||
url = self.base_url + APP_CONFIG.gmap2brief.status_endpoint.format(
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None,
|
||||
cert=(
|
||||
(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file)
|
||||
if APP_CONFIG.app.local
|
||||
else None
|
||||
),
|
||||
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
|
||||
timeout=self.timeout
|
||||
timeout=self.timeout,
|
||||
) as client:
|
||||
try:
|
||||
response = await client.get(url)
|
||||
|
|
@ -106,7 +120,9 @@ class Gmap2BriefInterface:
|
|||
except httpx.RequestError as e:
|
||||
raise Gmap2BriefConnectionError(f"Request error: {e}") from e
|
||||
|
||||
async def wait_for_completion(self, job_id: str, max_wait: int | None = None) -> ExportJobStatus:
|
||||
async def wait_for_completion(
|
||||
self, job_id: str, max_wait: int | None = None
|
||||
) -> ExportJobStatus:
|
||||
"""
|
||||
Ждет завершения задачи экспорта с периодическим polling.
|
||||
|
||||
|
|
@ -127,14 +143,18 @@ class Gmap2BriefInterface:
|
|||
status = await self.get_status(job_id)
|
||||
|
||||
if status.status == "completed":
|
||||
self.logger.info(f"Job {job_id} completed, total_rows={status.total_rows}")
|
||||
self.logger.info(
|
||||
f"Job {job_id} completed, total_rows={status.total_rows}"
|
||||
)
|
||||
return status
|
||||
elif status.status == "failed":
|
||||
raise Gmap2BriefConnectionError(f"Job {job_id} failed: {status.error}")
|
||||
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
if max_wait and elapsed > max_wait:
|
||||
raise Gmap2BriefConnectionError(f"Job {job_id} timeout after {elapsed:.1f}s")
|
||||
raise Gmap2BriefConnectionError(
|
||||
f"Job {job_id} timeout after {elapsed:.1f}s"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Job {job_id} status={status.status}, rows={status.total_rows}, elapsed={elapsed:.1f}s"
|
||||
|
|
@ -155,12 +175,18 @@ class Gmap2BriefInterface:
|
|||
Raises:
|
||||
Gmap2BriefConnectionError: При ошибке скачивания
|
||||
"""
|
||||
url = self.base_url + APP_CONFIG.gmap2brief.download_endpoint.format(job_id=job_id)
|
||||
url = self.base_url + APP_CONFIG.gmap2brief.download_endpoint.format(
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None,
|
||||
cert=(
|
||||
(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file)
|
||||
if APP_CONFIG.app.local
|
||||
else None
|
||||
),
|
||||
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
|
||||
timeout=self.timeout
|
||||
timeout=self.timeout,
|
||||
) as client:
|
||||
try:
|
||||
self.logger.info(f"Downloading export: GET {url}")
|
||||
|
|
|
|||
|
|
@ -17,7 +17,11 @@ class ExportJobStatus(BaseModel):
|
|||
"""Статус задачи экспорта."""
|
||||
|
||||
job_id: str = Field(..., description="Идентификатор задачи")
|
||||
status: Literal["pending", "running", "completed", "failed"] = Field(..., description="Статус задачи")
|
||||
status: Literal["pending", "running", "completed", "failed"] = Field(
|
||||
..., description="Статус задачи"
|
||||
)
|
||||
total_rows: int = Field(default=0, description="Количество обработанных строк")
|
||||
error: str | None = Field(default=None, description="Текст ошибки (если есть)")
|
||||
temp_file_path: str | None = Field(default=None, description="Путь к временному файлу (для completed)")
|
||||
temp_file_path: str | None = Field(
|
||||
default=None, description="Путь к временному файлу (для completed)"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -47,8 +47,12 @@ class SuperTeneraInterface:
|
|||
|
||||
self._ssl_context = None
|
||||
if APP_CONFIG.app.local:
|
||||
self._ssl_context = ssl.create_default_context(cafile=APP_CONFIG.certs.ca_bundle_file)
|
||||
self._ssl_context.load_cert_chain(certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file)
|
||||
self._ssl_context = ssl.create_default_context(
|
||||
cafile=APP_CONFIG.certs.ca_bundle_file
|
||||
)
|
||||
self._ssl_context.load_cert_chain(
|
||||
certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file
|
||||
)
|
||||
|
||||
def form_base_headers(self) -> dict:
|
||||
"""Формирует базовые заголовки для запроса."""
|
||||
|
|
@ -57,7 +61,11 @@ class SuperTeneraInterface:
|
|||
"request-time": str(datetime.now(tz=self.timezone).isoformat()),
|
||||
"system-id": APP_CONFIG.app.kube_net_name,
|
||||
}
|
||||
return {metakey: metavalue for metakey, metavalue in metadata_pairs.items() if metavalue}
|
||||
return {
|
||||
metakey: metavalue
|
||||
for metakey, metavalue in metadata_pairs.items()
|
||||
if metavalue
|
||||
}
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Async context manager enter."""
|
||||
|
|
@ -86,7 +94,9 @@ class SuperTeneraInterface:
|
|||
|
||||
async with self._session.get(url, **kwargs) as response:
|
||||
if APP_CONFIG.app.debug:
|
||||
self.logger.debug(f"Response: {(await response.text(errors='ignore'))[:100]}")
|
||||
self.logger.debug(
|
||||
f"Response: {(await response.text(errors='ignore'))[:100]}"
|
||||
)
|
||||
return await response.json(encoding=encoding, content_type=content_type)
|
||||
|
||||
async def get_quotes_data(self) -> MainData:
|
||||
|
|
@ -103,7 +113,9 @@ class SuperTeneraInterface:
|
|||
kwargs["ssl"] = self._ssl_context
|
||||
try:
|
||||
async with self._session.get(
|
||||
APP_CONFIG.supertenera.quotes_endpoint, timeout=APP_CONFIG.supertenera.timeout, **kwargs
|
||||
APP_CONFIG.supertenera.quotes_endpoint,
|
||||
timeout=APP_CONFIG.supertenera.timeout,
|
||||
**kwargs,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
|
|
@ -112,7 +124,9 @@ class SuperTeneraInterface:
|
|||
f"Ошибка подключения к SuperTenera API при проверке системы - {e.status}."
|
||||
) from e
|
||||
except TimeoutError as e:
|
||||
raise SuperTeneraConnectionError("Ошибка Timeout подключения к SuperTenera API при проверке системы.") from e
|
||||
raise SuperTeneraConnectionError(
|
||||
"Ошибка Timeout подключения к SuperTenera API при проверке системы."
|
||||
) from e
|
||||
|
||||
|
||||
def get_async_tenera_interface() -> SuperTeneraInterface:
|
||||
|
|
|
|||
|
|
@ -77,7 +77,9 @@ class InvestingCandlestick(TeneraBaseModel):
|
|||
value: str = Field(alias="V")
|
||||
|
||||
|
||||
class InvestingTimePoint(RootModel[EmptyTimePoint | InvestingNumeric | InvestingCandlestick]):
|
||||
class InvestingTimePoint(
|
||||
RootModel[EmptyTimePoint | InvestingNumeric | InvestingCandlestick]
|
||||
):
|
||||
"""
|
||||
Union-модель точки времени для источника Investing.com.
|
||||
|
||||
|
|
@ -340,5 +342,9 @@ class MainData(TeneraBaseModel):
|
|||
:return: Отфильтрованный объект
|
||||
"""
|
||||
if isinstance(v, dict):
|
||||
return {key: value for key, value in v.items() if value is not None and not str(key).isdigit()}
|
||||
return {
|
||||
key: value
|
||||
for key, value in v.items()
|
||||
if value is not None and not str(key).isdigit()
|
||||
}
|
||||
return v
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# src/dataloader/logger/__init__.py
|
||||
from .logger import setup_logging, get_logger
|
||||
from .logger import get_logger, setup_logging
|
||||
|
||||
__all__ = ["setup_logging", "get_logger"]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
# Управление контекстом запросов для логирования
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import Final
|
||||
|
||||
|
||||
REQUEST_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("request_id", default="")
|
||||
DEVICE_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("device_id", default="")
|
||||
SESSION_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("session_id", default="")
|
||||
|
|
@ -66,7 +64,13 @@ class ContextVarsContainer:
|
|||
def gw_session_id(self, value: str) -> None:
|
||||
GW_SESSION_ID_CTX_VAR.set(value)
|
||||
|
||||
def set_context_vars(self, request_id: str = "", request_time: str = "", system_id: str = "", gw_session_id: str = "") -> None:
|
||||
def set_context_vars(
|
||||
self,
|
||||
request_id: str = "",
|
||||
request_time: str = "",
|
||||
system_id: str = "",
|
||||
gw_session_id: str = "",
|
||||
) -> None:
|
||||
if request_id:
|
||||
self.set_request_id(request_id)
|
||||
if request_time:
|
||||
|
|
|
|||
|
|
@ -1,17 +1,12 @@
|
|||
# src/dataloader/logger/logger.py
|
||||
import sys
|
||||
import typing
|
||||
from datetime import tzinfo
|
||||
from logging import Logger
|
||||
import logging
|
||||
import sys
|
||||
from logging import Logger
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .context_vars import ContextVarsContainer
|
||||
from dataloader.config import APP_CONFIG
|
||||
|
||||
|
||||
# Определяем фильтры для разных типов логов
|
||||
def metric_only_filter(record: dict) -> bool:
|
||||
return "metric" in record["extra"]
|
||||
|
||||
|
|
@ -26,13 +21,12 @@ def regular_log_filter(record: dict) -> bool:
|
|||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# Get corresponding Loguru level if it exists
|
||||
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
# Find caller from where originated the logged message
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
# Модели логов, метрик, событий аудита
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
# Функции + маскирование args
|
||||
|
||||
|
|
|
|||
|
|
@ -1,34 +1,28 @@
|
|||
# Конфигурация логирования uvicorn
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# Get corresponding Loguru level if it exists
|
||||
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
|
||||
# Find caller from where originated the logged message
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
|
||||
|
||||
def setup_uvicorn_logging() -> None:
|
||||
# Set all uvicorn loggers to use InterceptHandler
|
||||
|
||||
for logger_name in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
|
||||
log = logging.getLogger(logger_name)
|
||||
log.handlers = [InterceptHandler()]
|
||||
|
|
@ -36,7 +30,6 @@ def setup_uvicorn_logging() -> None:
|
|||
log.propagate = False
|
||||
|
||||
|
||||
# uvicorn logging config
|
||||
LOGGING_CONFIG = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
|
|
|
|||
|
|
@ -1,2 +1 @@
|
|||
"""Модуль для работы с хранилищем данных."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
# src/dataloader/storage/engine.py
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from dataloader.config import APP_CONFIG
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
# src/dataloader/storage/models/__init__.py
|
||||
"""
|
||||
ORM модели для работы с базой данных.
|
||||
Организованы по доменам для масштабируемости.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import Base
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# src/dataloader/storage/models/base.py
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
|
@ -9,4 +8,5 @@ class Base(DeclarativeBase):
|
|||
Базовый класс для всех ORM моделей приложения.
|
||||
Используется SQLAlchemy 2.0+ declarative style.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
from datetime import date, datetime
|
||||
|
||||
from sqlalchemy import BigInteger, Date, Integer, Numeric, String, TIMESTAMP, text
|
||||
from sqlalchemy import TIMESTAMP, BigInteger, Date, Integer, Numeric, String, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from dataloader.storage.models.base import Base
|
||||
|
|
@ -16,29 +16,47 @@ class BriefDigitalCertificateOpu(Base):
|
|||
__tablename__ = "brief_digital_certificate_opu"
|
||||
__table_args__ = ({"schema": "opu"},)
|
||||
|
||||
object_id: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
|
||||
object_id: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, server_default=text("'-'")
|
||||
)
|
||||
object_nm: Mapped[str | None] = mapped_column(String)
|
||||
desk_nm: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
|
||||
actdate: Mapped[date] = mapped_column(Date, primary_key=True, server_default=text("CURRENT_DATE"))
|
||||
layer_cd: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
|
||||
desk_nm: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, server_default=text("'-'")
|
||||
)
|
||||
actdate: Mapped[date] = mapped_column(
|
||||
Date, primary_key=True, server_default=text("CURRENT_DATE")
|
||||
)
|
||||
layer_cd: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, server_default=text("'-'")
|
||||
)
|
||||
layer_nm: Mapped[str | None] = mapped_column(String)
|
||||
opu_cd: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
opu_nm_sh: Mapped[str | None] = mapped_column(String)
|
||||
opu_nm: Mapped[str | None] = mapped_column(String)
|
||||
opu_lvl: Mapped[int] = mapped_column(Integer, primary_key=True, server_default=text("'-1'"))
|
||||
opu_prnt_cd: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
|
||||
opu_lvl: Mapped[int] = mapped_column(
|
||||
Integer, primary_key=True, server_default=text("'-1'")
|
||||
)
|
||||
opu_prnt_cd: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, server_default=text("'-'")
|
||||
)
|
||||
opu_prnt_nm_sh: Mapped[str | None] = mapped_column(String)
|
||||
opu_prnt_nm: Mapped[str | None] = mapped_column(String)
|
||||
sum_amountrub_p_usd: Mapped[float | None] = mapped_column(Numeric)
|
||||
wf_load_id: Mapped[int] = mapped_column(BigInteger, nullable=False, server_default=text("'-1'"))
|
||||
wf_load_id: Mapped[int] = mapped_column(
|
||||
BigInteger, nullable=False, server_default=text("'-1'")
|
||||
)
|
||||
wf_load_dttm: Mapped[datetime] = mapped_column(
|
||||
TIMESTAMP(timezone=False),
|
||||
nullable=False,
|
||||
server_default=text("CURRENT_TIMESTAMP"),
|
||||
)
|
||||
wf_row_id: Mapped[int] = mapped_column(BigInteger, nullable=False, server_default=text("'-1'"))
|
||||
wf_row_id: Mapped[int] = mapped_column(
|
||||
BigInteger, nullable=False, server_default=text("'-1'")
|
||||
)
|
||||
object_tp: Mapped[str | None] = mapped_column(String)
|
||||
object_unit: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
|
||||
object_unit: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, server_default=text("'-'")
|
||||
)
|
||||
measure: Mapped[str | None] = mapped_column(String)
|
||||
product_nm: Mapped[str | None] = mapped_column(String)
|
||||
product_prnt_nm: Mapped[str | None] = mapped_column(String)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# src/dataloader/storage/models/queue.py
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
|
@ -10,7 +9,6 @@ from sqlalchemy.orm import Mapped, mapped_column
|
|||
|
||||
from .base import Base
|
||||
|
||||
|
||||
dl_status_enum = ENUM(
|
||||
"queued",
|
||||
"running",
|
||||
|
|
@ -29,6 +27,7 @@ class DLJob(Base):
|
|||
Модель таблицы очереди задач dl_jobs.
|
||||
Использует логическое имя схемы 'queue' для поддержки schema_translate_map.
|
||||
"""
|
||||
|
||||
__tablename__ = "dl_jobs"
|
||||
__table_args__ = {"schema": "queue"}
|
||||
|
||||
|
|
@ -40,15 +39,23 @@ class DLJob(Base):
|
|||
lock_key: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
partition_key: Mapped[str] = mapped_column(Text, default="", nullable=False)
|
||||
priority: Mapped[int] = mapped_column(nullable=False, default=100)
|
||||
available_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
status: Mapped[str] = mapped_column(dl_status_enum, nullable=False, default="queued")
|
||||
available_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
dl_status_enum, nullable=False, default="queued"
|
||||
)
|
||||
attempt: Mapped[int] = mapped_column(nullable=False, default=0)
|
||||
max_attempts: Mapped[int] = mapped_column(nullable=False, default=5)
|
||||
lease_ttl_sec: Mapped[int] = mapped_column(nullable=False, default=60)
|
||||
lease_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
|
||||
lease_expires_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True)
|
||||
)
|
||||
heartbeat_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
|
||||
cancel_requested: Mapped[bool] = mapped_column(nullable=False, default=False)
|
||||
progress: Mapped[dict[str, Any]] = mapped_column(JSONB, default=dict, nullable=False)
|
||||
progress: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSONB, default=dict, nullable=False
|
||||
)
|
||||
error: Mapped[Optional[str]] = mapped_column(Text)
|
||||
producer: Mapped[Optional[str]] = mapped_column(Text)
|
||||
consumer_group: Mapped[Optional[str]] = mapped_column(Text)
|
||||
|
|
@ -62,10 +69,13 @@ class DLJobEvent(Base):
|
|||
Модель таблицы журнала событий dl_job_events.
|
||||
Использует логическое имя схемы 'queue' для поддержки schema_translate_map.
|
||||
"""
|
||||
|
||||
__tablename__ = "dl_job_events"
|
||||
__table_args__ = {"schema": "queue"}
|
||||
|
||||
event_id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
event_id: Mapped[int] = mapped_column(
|
||||
BigInteger, primary_key=True, autoincrement=True
|
||||
)
|
||||
job_id: Mapped[str] = mapped_column(UUID(as_uuid=False), nullable=False)
|
||||
queue: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
load_dttm: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,15 @@ from __future__ import annotations
|
|||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import JSON, TIMESTAMP, BigInteger, ForeignKey, String, UniqueConstraint, func
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
TIMESTAMP,
|
||||
BigInteger,
|
||||
ForeignKey,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from dataloader.storage.models.base import Base
|
||||
|
|
@ -30,7 +38,9 @@ class Quote(Base):
|
|||
srce: Mapped[str | None] = mapped_column(String)
|
||||
ticker: Mapped[str | None] = mapped_column(String)
|
||||
quote_sect_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("quotes.quotes_sect.quote_sect_id", ondelete="CASCADE", onupdate="CASCADE"),
|
||||
ForeignKey(
|
||||
"quotes.quotes_sect.quote_sect_id", ondelete="CASCADE", onupdate="CASCADE"
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
last_update_dttm: Mapped[datetime | None] = mapped_column(TIMESTAMP(timezone=True))
|
||||
|
|
|
|||
|
|
@ -5,7 +5,17 @@ from __future__ import annotations
|
|||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import TIMESTAMP, BigInteger, Boolean, DateTime, Float, ForeignKey, String, UniqueConstraint, func
|
||||
from sqlalchemy import (
|
||||
TIMESTAMP,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from dataloader.storage.models.base import Base
|
||||
|
|
|
|||
|
|
@ -1,16 +1,23 @@
|
|||
# src/dataloader/storage/notify_listener.py
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import asyncpg
|
||||
from typing import Callable, Optional
|
||||
|
||||
import asyncpg
|
||||
|
||||
|
||||
class PGNotifyListener:
|
||||
"""
|
||||
Прослушиватель PostgreSQL NOTIFY для канала 'dl_jobs'.
|
||||
"""
|
||||
def __init__(self, dsn: str, queue: str, callback: Callable[[], None], stop_event: asyncio.Event):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dsn: str,
|
||||
queue: str,
|
||||
callback: Callable[[], None],
|
||||
stop_event: asyncio.Event,
|
||||
):
|
||||
self._dsn = dsn
|
||||
self._queue = queue
|
||||
self._callback = callback
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
# src/dataloader/storage/repositories/__init__.py
|
||||
"""
|
||||
Репозитории для работы с базой данных.
|
||||
Организованы по доменам для масштабируемости.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .opu import OpuRepository
|
||||
|
|
|
|||
|
|
@ -69,7 +69,6 @@ class OpuRepository:
|
|||
if not records:
|
||||
return 0
|
||||
|
||||
# Получаем колонки для обновления (все кроме PK и технических)
|
||||
update_columns = {
|
||||
c.name
|
||||
for c in BriefDigitalCertificateOpu.__table__.columns
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# src/dataloader/storage/repositories/queue.py
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
|
@ -15,6 +14,7 @@ class QueueRepository:
|
|||
"""
|
||||
Репозиторий для работы с очередью задач и журналом событий.
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.s = session
|
||||
|
||||
|
|
@ -62,7 +62,9 @@ class QueueRepository:
|
|||
finished_at=None,
|
||||
)
|
||||
self.s.add(row)
|
||||
await self._append_event(req.job_id, req.queue, "queued", {"task": req.task})
|
||||
await self._append_event(
|
||||
req.job_id, req.queue, "queued", {"task": req.task}
|
||||
)
|
||||
return req.job_id, "queued"
|
||||
|
||||
async def get_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
|
|
@ -118,7 +120,9 @@ class QueueRepository:
|
|||
await self._append_event(job_id, job.queue, "cancel_requested", None)
|
||||
return True
|
||||
|
||||
async def claim_one(self, queue: str, claim_backoff_sec: int) -> Optional[dict[str, Any]]:
|
||||
async def claim_one(
|
||||
self, queue: str, claim_backoff_sec: int
|
||||
) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
Захватывает одну задачу из очереди с учётом блокировок и выставляет running.
|
||||
|
||||
|
|
@ -150,15 +154,21 @@ class QueueRepository:
|
|||
job.started_at = job.started_at or datetime.now(timezone.utc)
|
||||
job.attempt = int(job.attempt) + 1
|
||||
job.heartbeat_at = datetime.now(timezone.utc)
|
||||
job.lease_expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(job.lease_ttl_sec))
|
||||
job.lease_expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=int(job.lease_ttl_sec)
|
||||
)
|
||||
|
||||
ok = await self._try_advisory_lock(job.lock_key)
|
||||
if not ok:
|
||||
job.status = "queued"
|
||||
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=claim_backoff_sec)
|
||||
job.available_at = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=claim_backoff_sec
|
||||
)
|
||||
return None
|
||||
|
||||
await self._append_event(job.job_id, job.queue, "picked", {"attempt": job.attempt})
|
||||
await self._append_event(
|
||||
job.job_id, job.queue, "picked", {"attempt": job.attempt}
|
||||
)
|
||||
|
||||
return {
|
||||
"job_id": job.job_id,
|
||||
|
|
@ -192,10 +202,15 @@ class QueueRepository:
|
|||
q = (
|
||||
update(DLJob)
|
||||
.where(DLJob.job_id == job_id, DLJob.status == "running")
|
||||
.values(heartbeat_at=now, lease_expires_at=now + timedelta(seconds=int(ttl_sec)))
|
||||
.values(
|
||||
heartbeat_at=now,
|
||||
lease_expires_at=now + timedelta(seconds=int(ttl_sec)),
|
||||
)
|
||||
)
|
||||
await self.s.execute(q)
|
||||
await self._append_event(job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec})
|
||||
await self._append_event(
|
||||
job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec}
|
||||
)
|
||||
return True, cancel_requested
|
||||
|
||||
async def finish_ok(self, job_id: str) -> None:
|
||||
|
|
@ -215,7 +230,9 @@ class QueueRepository:
|
|||
await self._append_event(job_id, job.queue, "succeeded", None)
|
||||
await self._advisory_unlock(job.lock_key)
|
||||
|
||||
async def finish_fail_or_retry(self, job_id: str, err: str, is_canceled: bool = False) -> None:
|
||||
async def finish_fail_or_retry(
|
||||
self, job_id: str, err: str, is_canceled: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Помечает задачу как failed, canceled или возвращает в очередь с задержкой.
|
||||
|
||||
|
|
@ -239,16 +256,25 @@ class QueueRepository:
|
|||
can_retry = int(job.attempt) < int(job.max_attempts)
|
||||
if can_retry:
|
||||
job.status = "queued"
|
||||
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=30 * int(job.attempt))
|
||||
job.available_at = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=30 * int(job.attempt)
|
||||
)
|
||||
job.error = err
|
||||
job.lease_expires_at = None
|
||||
await self._append_event(job_id, job.queue, "requeue", {"attempt": job.attempt, "error": err})
|
||||
await self._append_event(
|
||||
job_id,
|
||||
job.queue,
|
||||
"requeue",
|
||||
{"attempt": job.attempt, "error": err},
|
||||
)
|
||||
else:
|
||||
job.status = "failed"
|
||||
job.error = err
|
||||
job.finished_at = datetime.now(timezone.utc)
|
||||
job.lease_expires_at = None
|
||||
await self._append_event(job_id, job.queue, "failed", {"error": err})
|
||||
await self._append_event(
|
||||
job_id, job.queue, "failed", {"error": err}
|
||||
)
|
||||
await self._advisory_unlock(job.lock_key)
|
||||
|
||||
async def requeue_lost(self, now: Optional[datetime] = None) -> list[str]:
|
||||
|
|
@ -293,7 +319,11 @@ class QueueRepository:
|
|||
Возвращает:
|
||||
ORM модель DLJob или None
|
||||
"""
|
||||
r = await self.s.execute(select(DLJob).where(DLJob.job_id == job_id).with_for_update(skip_locked=True))
|
||||
r = await self.s.execute(
|
||||
select(DLJob)
|
||||
.where(DLJob.job_id == job_id)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
return r.scalar_one_or_none()
|
||||
|
||||
async def _resolve_queue(self, job_id: str) -> str:
|
||||
|
|
@ -310,7 +340,9 @@ class QueueRepository:
|
|||
v = r.scalar_one_or_none()
|
||||
return v or ""
|
||||
|
||||
async def _append_event(self, job_id: str, queue: str, kind: str, payload: Optional[dict[str, Any]]) -> None:
|
||||
async def _append_event(
|
||||
self, job_id: str, queue: str, kind: str, payload: Optional[dict[str, Any]]
|
||||
) -> None:
|
||||
"""
|
||||
Добавляет запись в журнал событий.
|
||||
|
||||
|
|
@ -339,7 +371,9 @@ class QueueRepository:
|
|||
Возвращает:
|
||||
True, если блокировка получена
|
||||
"""
|
||||
r = await self.s.execute(select(func.pg_try_advisory_lock(func.hashtext(lock_key))))
|
||||
r = await self.s.execute(
|
||||
select(func.pg_try_advisory_lock(func.hashtext(lock_key)))
|
||||
)
|
||||
return bool(r.scalar())
|
||||
|
||||
async def _advisory_unlock(self, lock_key: str) -> None:
|
||||
|
|
|
|||
|
|
@ -39,7 +39,9 @@ class QuotesRepository:
|
|||
result = await self.s.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_or_create_section(self, name: str, params: dict | None = None) -> QuoteSection:
|
||||
async def get_or_create_section(
|
||||
self, name: str, params: dict | None = None
|
||||
) -> QuoteSection:
|
||||
"""
|
||||
Получить существующую секцию или создать новую.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
# src/dataloader/storage/schemas/__init__.py
|
||||
"""
|
||||
DTO (Data Transfer Objects) для слоя хранилища.
|
||||
Организованы по доменам для масштабируемости.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .queue import CreateJobRequest, JobStatus
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# src/dataloader/storage/schemas/queue.py
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -11,6 +10,7 @@ class CreateJobRequest:
|
|||
"""
|
||||
DTO для создания задачи в очереди.
|
||||
"""
|
||||
|
||||
job_id: str
|
||||
queue: str
|
||||
task: str
|
||||
|
|
@ -31,6 +31,7 @@ class JobStatus:
|
|||
"""
|
||||
DTO для статуса задачи.
|
||||
"""
|
||||
|
||||
job_id: str
|
||||
status: str
|
||||
attempt: int
|
||||
|
|
|
|||
|
|
@ -1,2 +1 @@
|
|||
"""Модуль воркеров для обработки задач."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# src/dataloader/workers/base.py
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
|
@ -9,8 +8,8 @@ from typing import AsyncIterator, Callable, Optional
|
|||
|
||||
from dataloader.config import APP_CONFIG
|
||||
from dataloader.context import APP_CTX
|
||||
from dataloader.storage.repositories import QueueRepository
|
||||
from dataloader.storage.notify_listener import PGNotifyListener
|
||||
from dataloader.storage.repositories import QueueRepository
|
||||
from dataloader.workers.pipelines.registry import resolve as resolve_pipeline
|
||||
|
||||
|
||||
|
|
@ -19,6 +18,7 @@ class WorkerConfig:
|
|||
"""
|
||||
Конфигурация воркера.
|
||||
"""
|
||||
|
||||
queue: str
|
||||
heartbeat_sec: int
|
||||
claim_backoff_sec: int
|
||||
|
|
@ -28,6 +28,7 @@ class PGWorker:
|
|||
"""
|
||||
Базовый асинхронный воркер очереди Postgres.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: WorkerConfig, stop_event: asyncio.Event) -> None:
|
||||
self._cfg = cfg
|
||||
self._stop = stop_event
|
||||
|
|
@ -46,14 +47,16 @@ class PGWorker:
|
|||
dsn=APP_CONFIG.pg.url,
|
||||
queue=self._cfg.queue,
|
||||
callback=lambda: self._notify_wakeup.set(),
|
||||
stop_event=self._stop
|
||||
stop_event=self._stop,
|
||||
)
|
||||
try:
|
||||
await self._listener.start()
|
||||
except Exception as e:
|
||||
self._log.warning(f"Failed to start LISTEN/NOTIFY, falling back to polling: {e}")
|
||||
self._log.warning(
|
||||
f"Failed to start LISTEN/NOTIFY, falling back to polling: {e}"
|
||||
)
|
||||
self._listener = None
|
||||
|
||||
|
||||
try:
|
||||
while not self._stop.is_set():
|
||||
claimed = await self._claim_and_execute_once()
|
||||
|
|
@ -69,24 +72,27 @@ class PGWorker:
|
|||
Ожидание появления задач через LISTEN/NOTIFY или с тайм-аутом.
|
||||
"""
|
||||
if self._listener:
|
||||
# Используем LISTEN/NOTIFY с fallback на таймаут
|
||||
|
||||
done, pending = await asyncio.wait(
|
||||
[asyncio.create_task(self._notify_wakeup.wait()), asyncio.create_task(self._stop.wait())],
|
||||
[
|
||||
asyncio.create_task(self._notify_wakeup.wait()),
|
||||
asyncio.create_task(self._stop.wait()),
|
||||
],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
timeout=timeout_sec
|
||||
timeout=timeout_sec,
|
||||
)
|
||||
# Отменяем оставшиеся задачи
|
||||
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Очищаем событие, если оно было установлено
|
||||
|
||||
if self._notify_wakeup.is_set():
|
||||
self._notify_wakeup.clear()
|
||||
else:
|
||||
# Fallback на простой таймаут
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._stop.wait(), timeout=timeout_sec)
|
||||
except asyncio.TimeoutError:
|
||||
|
|
@ -110,30 +116,42 @@ class PGWorker:
|
|||
args = row["args"]
|
||||
|
||||
try:
|
||||
canceled = await self._execute_with_heartbeat(job_id, ttl, self._pipeline(task, args))
|
||||
canceled = await self._execute_with_heartbeat(
|
||||
job_id, ttl, self._pipeline(task, args)
|
||||
)
|
||||
if canceled:
|
||||
await repo.finish_fail_or_retry(job_id, "canceled by user", is_canceled=True)
|
||||
await repo.finish_fail_or_retry(
|
||||
job_id, "canceled by user", is_canceled=True
|
||||
)
|
||||
else:
|
||||
await repo.finish_ok(job_id)
|
||||
return True
|
||||
except asyncio.CancelledError:
|
||||
await repo.finish_fail_or_retry(job_id, "cancelled by shutdown", is_canceled=True)
|
||||
await repo.finish_fail_or_retry(
|
||||
job_id, "cancelled by shutdown", is_canceled=True
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
await repo.finish_fail_or_retry(job_id, str(e))
|
||||
return True
|
||||
|
||||
async def _execute_with_heartbeat(self, job_id: str, ttl: int, it: AsyncIterator[None]) -> bool:
|
||||
async def _execute_with_heartbeat(
|
||||
self, job_id: str, ttl: int, it: AsyncIterator[None]
|
||||
) -> bool:
|
||||
"""
|
||||
Исполняет конвейер с поддержкой heartbeat.
|
||||
Возвращает True, если задача была отменена (cancel_requested).
|
||||
"""
|
||||
next_hb = datetime.now(timezone.utc) + timedelta(seconds=self._cfg.heartbeat_sec)
|
||||
next_hb = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=self._cfg.heartbeat_sec
|
||||
)
|
||||
async for _ in it:
|
||||
now = datetime.now(timezone.utc)
|
||||
if now >= next_hb:
|
||||
async with self._sm() as s_hb:
|
||||
success, cancel_requested = await QueueRepository(s_hb).heartbeat(job_id, ttl)
|
||||
success, cancel_requested = await QueueRepository(s_hb).heartbeat(
|
||||
job_id, ttl
|
||||
)
|
||||
if cancel_requested:
|
||||
return True
|
||||
next_hb = now + timedelta(seconds=self._cfg.heartbeat_sec)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# src/dataloader/workers/manager.py
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
|
@ -16,6 +15,7 @@ class WorkerSpec:
|
|||
"""
|
||||
Конфигурация набора воркеров для очереди.
|
||||
"""
|
||||
|
||||
queue: str
|
||||
concurrency: int
|
||||
|
||||
|
|
@ -24,6 +24,7 @@ class WorkerManager:
|
|||
"""
|
||||
Управляет жизненным циклом асинхронных воркеров.
|
||||
"""
|
||||
|
||||
def __init__(self, specs: list[WorkerSpec]) -> None:
|
||||
self._log = APP_CTX.get_logger()
|
||||
self._specs = specs
|
||||
|
|
@ -40,15 +41,22 @@ class WorkerManager:
|
|||
|
||||
for spec in self._specs:
|
||||
for i in range(max(1, spec.concurrency)):
|
||||
cfg = WorkerConfig(queue=spec.queue, heartbeat_sec=hb, claim_backoff_sec=backoff)
|
||||
t = asyncio.create_task(PGWorker(cfg, self._stop).run(), name=f"worker:{spec.queue}:{i}")
|
||||
cfg = WorkerConfig(
|
||||
queue=spec.queue, heartbeat_sec=hb, claim_backoff_sec=backoff
|
||||
)
|
||||
t = asyncio.create_task(
|
||||
PGWorker(cfg, self._stop).run(), name=f"worker:{spec.queue}:{i}"
|
||||
)
|
||||
self._tasks.append(t)
|
||||
|
||||
self._reaper_task = asyncio.create_task(self._reaper_loop(), name="reaper")
|
||||
|
||||
self._log.info(
|
||||
"worker_manager.started",
|
||||
extra={"specs": [spec.__dict__ for spec in self._specs], "total_tasks": len(self._tasks)},
|
||||
extra={
|
||||
"specs": [spec.__dict__ for spec in self._specs],
|
||||
"total_tasks": len(self._tasks),
|
||||
},
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
"""Модуль пайплайнов обработки задач."""
|
||||
|
||||
# src/dataloader/workers/pipelines/__init__.py
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
|
|
|
|||
|
|
@ -59,7 +59,6 @@ def _parse_jsonl_from_zst(file_path: Path, chunk_size: int = 10000):
|
|||
APP_CTX.logger.warning(f"Failed to parse JSON line: {e}")
|
||||
continue
|
||||
|
||||
# Process remaining buffer
|
||||
if buffer.strip():
|
||||
try:
|
||||
record = orjson.loads(buffer)
|
||||
|
|
@ -85,11 +84,9 @@ def _convert_record(raw: dict[str, Any]) -> dict[str, Any]:
|
|||
"""
|
||||
result = raw.copy()
|
||||
|
||||
# Преобразуем actdate из ISO строки в date
|
||||
if "actdate" in result and isinstance(result["actdate"], str):
|
||||
result["actdate"] = datetime.fromisoformat(result["actdate"]).date()
|
||||
|
||||
# Преобразуем wf_load_dttm из ISO строки в datetime
|
||||
if "wf_load_dttm" in result and isinstance(result["wf_load_dttm"], str):
|
||||
result["wf_load_dttm"] = datetime.fromisoformat(result["wf_load_dttm"])
|
||||
|
||||
|
|
@ -118,18 +115,15 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
|
|||
logger = APP_CTX.logger
|
||||
logger.info("Starting OPU ETL pipeline")
|
||||
|
||||
# Шаг 1: Запуск экспорта
|
||||
interface = get_gmap2brief_interface()
|
||||
job_id = await interface.start_export()
|
||||
logger.info(f"OPU export job started: {job_id}")
|
||||
yield
|
||||
|
||||
# Шаг 2: Ожидание завершения
|
||||
status = await interface.wait_for_completion(job_id)
|
||||
logger.info(f"OPU export completed: {status.total_rows} rows")
|
||||
yield
|
||||
|
||||
# Шаг 3: Скачивание архива
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
archive_path = temp_path / f"opu_export_{job_id}.jsonl.zst"
|
||||
|
|
@ -138,7 +132,6 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
|
|||
logger.info(f"OPU archive downloaded: {archive_path.stat().st_size:,} bytes")
|
||||
yield
|
||||
|
||||
# Шаг 4: Truncate таблицы
|
||||
async with APP_CTX.sessionmaker() as session:
|
||||
repo = OpuRepository(session)
|
||||
await repo.truncate()
|
||||
|
|
@ -146,7 +139,6 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
|
|||
logger.info("OPU table truncated")
|
||||
yield
|
||||
|
||||
# Шаг 5: Загрузка данных стримингово
|
||||
total_inserted = 0
|
||||
batch_num = 0
|
||||
|
||||
|
|
@ -154,17 +146,16 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
|
|||
for batch in _parse_jsonl_from_zst(archive_path, chunk_size=5000):
|
||||
batch_num += 1
|
||||
|
||||
# Конвертируем записи
|
||||
converted = [_convert_record(rec) for rec in batch]
|
||||
|
||||
# Вставляем батч
|
||||
inserted = await repo.bulk_insert(converted)
|
||||
await session.commit()
|
||||
|
||||
total_inserted += inserted
|
||||
logger.debug(f"Batch {batch_num}: inserted {inserted} rows (total: {total_inserted})")
|
||||
logger.debug(
|
||||
f"Batch {batch_num}: inserted {inserted} rows (total: {total_inserted})"
|
||||
)
|
||||
|
||||
# Heartbeat после каждого батча
|
||||
yield
|
||||
|
||||
logger.info(f"OPU ETL completed: {total_inserted} rows inserted")
|
||||
|
|
|
|||
|
|
@ -59,7 +59,9 @@ def _parse_ts_to_datetime(ts: str) -> datetime | None:
|
|||
return None
|
||||
|
||||
|
||||
def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] | None: # noqa: C901
|
||||
def _build_value_row(
|
||||
source: str, dt: datetime, point: Any
|
||||
) -> dict[str, Any] | None: # noqa: C901
|
||||
"""Строит строку для `quotes_values` по источнику и типу точки."""
|
||||
if isinstance(point, int):
|
||||
return {"dt": dt, "key": point}
|
||||
|
|
@ -82,7 +84,10 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] |
|
|||
if isinstance(deep_inner, InvestingCandlestick):
|
||||
return {
|
||||
"dt": dt,
|
||||
"price_o": _to_float(getattr(deep_inner, "open_", None) or getattr(deep_inner, "open", None)),
|
||||
"price_o": _to_float(
|
||||
getattr(deep_inner, "open_", None)
|
||||
or getattr(deep_inner, "open", None)
|
||||
),
|
||||
"price_h": _to_float(deep_inner.high),
|
||||
"price_l": _to_float(deep_inner.low),
|
||||
"price_c": _to_float(deep_inner.close),
|
||||
|
|
@ -92,12 +97,16 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] |
|
|||
if isinstance(inner, TradingViewTimePoint | SgxTimePoint):
|
||||
return {
|
||||
"dt": dt,
|
||||
"price_o": _to_float(getattr(inner, "open_", None) or getattr(inner, "open", None)),
|
||||
"price_o": _to_float(
|
||||
getattr(inner, "open_", None) or getattr(inner, "open", None)
|
||||
),
|
||||
"price_h": _to_float(inner.high),
|
||||
"price_l": _to_float(inner.low),
|
||||
"price_c": _to_float(inner.close),
|
||||
"volume": _to_float(
|
||||
getattr(inner, "volume", None) or getattr(inner, "interest", None) or getattr(inner, "value", None)
|
||||
getattr(inner, "volume", None)
|
||||
or getattr(inner, "interest", None)
|
||||
or getattr(inner, "value", None)
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -132,7 +141,9 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] |
|
|||
"dt": dt,
|
||||
"value_last": _to_float(deep_inner.last),
|
||||
"value_previous": _to_float(deep_inner.previous),
|
||||
"unit": str(deep_inner.unit) if deep_inner.unit is not None else None,
|
||||
"unit": (
|
||||
str(deep_inner.unit) if deep_inner.unit is not None else None
|
||||
),
|
||||
}
|
||||
|
||||
if isinstance(deep_inner, TradingEconomicsStringPercent):
|
||||
|
|
@ -218,7 +229,14 @@ async def load_tenera(args: dict) -> AsyncIterator[None]:
|
|||
async with APP_CTX.sessionmaker() as session:
|
||||
repo = QuotesRepository(session)
|
||||
|
||||
for source_name in ("cbr", "investing", "sgx", "tradingeconomics", "bloomberg", "trading_view"):
|
||||
for source_name in (
|
||||
"cbr",
|
||||
"investing",
|
||||
"sgx",
|
||||
"tradingeconomics",
|
||||
"bloomberg",
|
||||
"trading_view",
|
||||
):
|
||||
source_data = getattr(data, source_name)
|
||||
if not source_data:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# src/dataloader/workers/pipelines/noop.py
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# src/dataloader/workers/pipelines/registry.py
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable
|
||||
|
|
@ -6,13 +5,17 @@ from typing import Any, Callable, Dict, Iterable
|
|||
_Registry: Dict[str, Callable[[dict[str, Any]], Any]] = {}
|
||||
|
||||
|
||||
def register(task: str) -> Callable[[Callable[[dict[str, Any]], Any]], Callable[[dict[str, Any]], Any]]:
|
||||
def register(
|
||||
task: str,
|
||||
) -> Callable[[Callable[[dict[str, Any]], Any]], Callable[[dict[str, Any]], Any]]:
|
||||
"""
|
||||
Регистрирует обработчик пайплайна под именем задачи.
|
||||
"""
|
||||
|
||||
def _wrap(fn: Callable[[dict[str, Any]], Any]) -> Callable[[dict[str, Any]], Any]:
|
||||
_Registry[task] = fn
|
||||
return fn
|
||||
|
||||
return _wrap
|
||||
|
||||
|
||||
|
|
@ -22,8 +25,8 @@ def resolve(task: str) -> Callable[[dict[str, Any]], Any]:
|
|||
"""
|
||||
try:
|
||||
return _Registry[task]
|
||||
except KeyError:
|
||||
raise KeyError(f"pipeline not found: {task}")
|
||||
except KeyError as err:
|
||||
raise KeyError(f"pipeline not found: {task}") from err
|
||||
|
||||
|
||||
def tasks() -> Iterable[str]:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# src/dataloader/workers/reaper.py
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from dataloader.storage.repositories import QueueRepository
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# tests/conftest.py
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
|
@ -9,23 +8,23 @@ from uuid import uuid4
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
from dotenv import load_dotenv
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
load_dotenv()
|
||||
|
||||
from dataloader.api import app_main
|
||||
from dataloader.config import APP_CONFIG
|
||||
from dataloader.context import APP_CTX, get_session
|
||||
from dataloader.storage.models import Base
|
||||
from dataloader.storage.engine import create_engine, create_sessionmaker
|
||||
from dataloader.context import get_session
|
||||
from dataloader.storage.engine import create_engine
|
||||
|
||||
load_dotenv()
|
||||
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def db_engine() -> AsyncGenerator[AsyncEngine, None]:
|
||||
"""
|
||||
|
|
@ -39,15 +38,15 @@ async def db_engine() -> AsyncGenerator[AsyncEngine, None]:
|
|||
await engine.dispose()
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Предоставляет сессию БД для каждого теста.
|
||||
НЕ использует транзакцию, чтобы работали advisory locks.
|
||||
"""
|
||||
sessionmaker = async_sessionmaker(bind=db_engine, expire_on_commit=False, class_=AsyncSession)
|
||||
sessionmaker = async_sessionmaker(
|
||||
bind=db_engine, expire_on_commit=False, class_=AsyncSession
|
||||
)
|
||||
async with sessionmaker() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
|
@ -69,6 +68,7 @@ async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
|||
"""
|
||||
HTTP клиент для тестирования API.
|
||||
"""
|
||||
|
||||
async def override_get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
# tests/integration_tests/__init__.py
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# tests/integration_tests/test_api_endpoints.py
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
# tests/unit/test_api_router_not_found.py
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from uuid import uuid4, UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from dataloader.api.v1.router import get_status, cancel_job
|
||||
from dataloader.api.v1.exceptions import JobNotFoundError
|
||||
from dataloader.api.v1.router import cancel_job, get_status
|
||||
from dataloader.api.v1.schemas import JobStatusResponse
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
# tests/unit/test_api_router_success.py
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from uuid import uuid4, UUID
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from dataloader.api.v1.router import get_status, cancel_job
|
||||
import pytest
|
||||
|
||||
from dataloader.api.v1.router import cancel_job, get_status
|
||||
from dataloader.api.v1.schemas import JobStatusResponse
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# tests/integration_tests/test_queue_repository.py
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
|
@ -10,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from dataloader.storage.models import DLJob
|
||||
from dataloader.storage.repositories import QueueRepository
|
||||
from dataloader.storage.schemas import CreateJobRequest, JobStatus
|
||||
from dataloader.storage.schemas import CreateJobRequest
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
|
@ -445,6 +444,7 @@ class TestQueueRepository:
|
|||
await repo.claim_one(queue_name, claim_backoff_sec=15)
|
||||
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
requeued = await repo.requeue_lost()
|
||||
|
|
@ -500,7 +500,9 @@ class TestQueueRepository:
|
|||
assert st is not None
|
||||
assert st.status == "queued"
|
||||
|
||||
row = (await db_session.execute(select(DLJob).where(DLJob.job_id == job_id))).scalar_one()
|
||||
row = (
|
||||
await db_session.execute(select(DLJob).where(DLJob.job_id == job_id))
|
||||
).scalar_one()
|
||||
assert row.available_at >= before + timedelta(seconds=15)
|
||||
assert row.available_at <= after + timedelta(seconds=60)
|
||||
|
||||
|
|
@ -569,7 +571,9 @@ class TestQueueRepository:
|
|||
await repo.create_or_get(req)
|
||||
await repo.claim_one(queue_name, claim_backoff_sec=5)
|
||||
|
||||
await repo.finish_fail_or_retry(job_id, err="Canceled by test", is_canceled=True)
|
||||
await repo.finish_fail_or_retry(
|
||||
job_id, err="Canceled by test", is_canceled=True
|
||||
)
|
||||
|
||||
st = await repo.get_status(job_id)
|
||||
assert st is not None
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
# tests/unit/__init__.py
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
# tests/unit/test_api_service.py
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from dataloader.api.v1.service import JobsService
|
||||
from dataloader.api.v1.schemas import TriggerJobRequest
|
||||
from dataloader.api.v1.service import JobsService
|
||||
from dataloader.storage.schemas import JobStatus
|
||||
|
||||
|
||||
|
|
@ -38,15 +38,19 @@ class TestJobsService:
|
|||
"""
|
||||
mock_session = AsyncMock()
|
||||
|
||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \
|
||||
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id:
|
||||
with (
|
||||
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
|
||||
):
|
||||
|
||||
mock_get_logger.return_value = Mock()
|
||||
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued"))
|
||||
mock_repo.create_or_get = AsyncMock(
|
||||
return_value=("12345678-1234-5678-1234-567812345678", "queued")
|
||||
)
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
|
||||
service = JobsService(mock_session)
|
||||
|
|
@ -58,7 +62,7 @@ class TestJobsService:
|
|||
lock_key="lock_1",
|
||||
priority=100,
|
||||
max_attempts=5,
|
||||
lease_ttl_sec=60
|
||||
lease_ttl_sec=60,
|
||||
)
|
||||
|
||||
response = await service.trigger(req)
|
||||
|
|
@ -74,15 +78,19 @@ class TestJobsService:
|
|||
"""
|
||||
mock_session = AsyncMock()
|
||||
|
||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \
|
||||
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id:
|
||||
with (
|
||||
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
|
||||
):
|
||||
|
||||
mock_get_logger.return_value = Mock()
|
||||
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued"))
|
||||
mock_repo.create_or_get = AsyncMock(
|
||||
return_value=("12345678-1234-5678-1234-567812345678", "queued")
|
||||
)
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
|
||||
service = JobsService(mock_session)
|
||||
|
|
@ -95,7 +103,7 @@ class TestJobsService:
|
|||
lock_key="lock_1",
|
||||
priority=100,
|
||||
max_attempts=5,
|
||||
lease_ttl_sec=60
|
||||
lease_ttl_sec=60,
|
||||
)
|
||||
|
||||
response = await service.trigger(req)
|
||||
|
|
@ -112,15 +120,19 @@ class TestJobsService:
|
|||
"""
|
||||
mock_session = AsyncMock()
|
||||
|
||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \
|
||||
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id:
|
||||
with (
|
||||
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
|
||||
):
|
||||
|
||||
mock_get_logger.return_value = Mock()
|
||||
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued"))
|
||||
mock_repo.create_or_get = AsyncMock(
|
||||
return_value=("12345678-1234-5678-1234-567812345678", "queued")
|
||||
)
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
|
||||
service = JobsService(mock_session)
|
||||
|
|
@ -135,10 +147,10 @@ class TestJobsService:
|
|||
available_at=future_time,
|
||||
priority=100,
|
||||
max_attempts=5,
|
||||
lease_ttl_sec=60
|
||||
lease_ttl_sec=60,
|
||||
)
|
||||
|
||||
response = await service.trigger(req)
|
||||
await service.trigger(req)
|
||||
|
||||
call_args = mock_repo.create_or_get.call_args[0][0]
|
||||
assert call_args.available_at == future_time
|
||||
|
|
@ -150,15 +162,19 @@ class TestJobsService:
|
|||
"""
|
||||
mock_session = AsyncMock()
|
||||
|
||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \
|
||||
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id:
|
||||
with (
|
||||
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
|
||||
):
|
||||
|
||||
mock_get_logger.return_value = Mock()
|
||||
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued"))
|
||||
mock_repo.create_or_get = AsyncMock(
|
||||
return_value=("12345678-1234-5678-1234-567812345678", "queued")
|
||||
)
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
|
||||
service = JobsService(mock_session)
|
||||
|
|
@ -173,10 +189,10 @@ class TestJobsService:
|
|||
consumer_group="test_group",
|
||||
priority=100,
|
||||
max_attempts=5,
|
||||
lease_ttl_sec=60
|
||||
lease_ttl_sec=60,
|
||||
)
|
||||
|
||||
response = await service.trigger(req)
|
||||
await service.trigger(req)
|
||||
|
||||
call_args = mock_repo.create_or_get.call_args[0][0]
|
||||
assert call_args.partition_key == "partition_1"
|
||||
|
|
@ -190,8 +206,10 @@ class TestJobsService:
|
|||
"""
|
||||
mock_session = AsyncMock()
|
||||
|
||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
|
||||
with (
|
||||
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||
):
|
||||
|
||||
mock_get_logger.return_value = Mock()
|
||||
|
||||
|
|
@ -204,7 +222,7 @@ class TestJobsService:
|
|||
finished_at=None,
|
||||
heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc),
|
||||
error=None,
|
||||
progress={"step": 1}
|
||||
progress={"step": 1},
|
||||
)
|
||||
mock_repo.get_status = AsyncMock(return_value=mock_status)
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
|
|
@ -227,8 +245,10 @@ class TestJobsService:
|
|||
"""
|
||||
mock_session = AsyncMock()
|
||||
|
||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
|
||||
with (
|
||||
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||
):
|
||||
|
||||
mock_get_logger.return_value = Mock()
|
||||
|
||||
|
|
@ -250,8 +270,10 @@ class TestJobsService:
|
|||
"""
|
||||
mock_session = AsyncMock()
|
||||
|
||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
|
||||
with (
|
||||
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||
):
|
||||
|
||||
mock_get_logger.return_value = Mock()
|
||||
|
||||
|
|
@ -265,7 +287,7 @@ class TestJobsService:
|
|||
finished_at=None,
|
||||
heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc),
|
||||
error=None,
|
||||
progress={}
|
||||
progress={},
|
||||
)
|
||||
mock_repo.get_status = AsyncMock(return_value=mock_status)
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
|
|
@ -287,8 +309,10 @@ class TestJobsService:
|
|||
"""
|
||||
mock_session = AsyncMock()
|
||||
|
||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
|
||||
with (
|
||||
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||
):
|
||||
|
||||
mock_get_logger.return_value = Mock()
|
||||
|
||||
|
|
@ -312,8 +336,10 @@ class TestJobsService:
|
|||
"""
|
||||
mock_session = AsyncMock()
|
||||
|
||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
|
||||
with (
|
||||
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||
):
|
||||
|
||||
mock_get_logger.return_value = Mock()
|
||||
|
||||
|
|
@ -326,7 +352,7 @@ class TestJobsService:
|
|||
finished_at=None,
|
||||
heartbeat_at=None,
|
||||
error=None,
|
||||
progress=None
|
||||
progress=None,
|
||||
)
|
||||
mock_repo.get_status = AsyncMock(return_value=mock_status)
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
|
|
|
|||
|
|
@ -1,18 +1,18 @@
|
|||
# tests/unit/test_config.py
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from logging import DEBUG, INFO
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dataloader.config import (
|
||||
BaseAppSettings,
|
||||
AppSettings,
|
||||
BaseAppSettings,
|
||||
LogSettings,
|
||||
PGSettings,
|
||||
WorkerSettings,
|
||||
Secrets,
|
||||
WorkerSettings,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -81,12 +81,15 @@ class TestAppSettings:
|
|||
"""
|
||||
Тест загрузки из переменных окружения.
|
||||
"""
|
||||
with patch.dict("os.environ", {
|
||||
"APP_HOST": "127.0.0.1",
|
||||
"APP_PORT": "9000",
|
||||
"PROJECT_NAME": "TestProject",
|
||||
"TIMEZONE": "UTC"
|
||||
}):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"APP_HOST": "127.0.0.1",
|
||||
"APP_PORT": "9000",
|
||||
"PROJECT_NAME": "TestProject",
|
||||
"TIMEZONE": "UTC",
|
||||
},
|
||||
):
|
||||
settings = AppSettings()
|
||||
|
||||
assert settings.app_host == "127.0.0.1"
|
||||
|
|
@ -136,7 +139,9 @@ class TestLogSettings:
|
|||
"""
|
||||
Тест свойства log_file_abs_path.
|
||||
"""
|
||||
with patch.dict("os.environ", {"LOG_PATH": "/var/log", "LOG_FILE_NAME": "test.log"}):
|
||||
with patch.dict(
|
||||
"os.environ", {"LOG_PATH": "/var/log", "LOG_FILE_NAME": "test.log"}
|
||||
):
|
||||
settings = LogSettings()
|
||||
|
||||
assert "test.log" in settings.log_file_abs_path
|
||||
|
|
@ -146,7 +151,10 @@ class TestLogSettings:
|
|||
"""
|
||||
Тест свойства metric_file_abs_path.
|
||||
"""
|
||||
with patch.dict("os.environ", {"METRIC_PATH": "/var/metrics", "METRIC_FILE_NAME": "metrics.log"}):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{"METRIC_PATH": "/var/metrics", "METRIC_FILE_NAME": "metrics.log"},
|
||||
):
|
||||
settings = LogSettings()
|
||||
|
||||
assert "metrics.log" in settings.metric_file_abs_path
|
||||
|
|
@ -156,7 +164,10 @@ class TestLogSettings:
|
|||
"""
|
||||
Тест свойства audit_file_abs_path.
|
||||
"""
|
||||
with patch.dict("os.environ", {"AUDIT_LOG_PATH": "/var/audit", "AUDIT_LOG_FILE_NAME": "audit.log"}):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{"AUDIT_LOG_PATH": "/var/audit", "AUDIT_LOG_FILE_NAME": "audit.log"},
|
||||
):
|
||||
settings = LogSettings()
|
||||
|
||||
assert "audit.log" in settings.audit_file_abs_path
|
||||
|
|
@ -208,29 +219,37 @@ class TestPGSettings:
|
|||
"""
|
||||
Тест формирования строки подключения.
|
||||
"""
|
||||
with patch.dict("os.environ", {
|
||||
"PG_HOST": "db.example.com",
|
||||
"PG_PORT": "5433",
|
||||
"PG_USER": "testuser",
|
||||
"PG_PASSWORD": "testpass",
|
||||
"PG_DATABASE": "testdb"
|
||||
}):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"PG_HOST": "db.example.com",
|
||||
"PG_PORT": "5433",
|
||||
"PG_USER": "testuser",
|
||||
"PG_PASSWORD": "testpass",
|
||||
"PG_DATABASE": "testdb",
|
||||
},
|
||||
):
|
||||
settings = PGSettings()
|
||||
|
||||
expected = "postgresql+asyncpg://testuser:testpass@db.example.com:5433/testdb"
|
||||
expected = (
|
||||
"postgresql+asyncpg://testuser:testpass@db.example.com:5433/testdb"
|
||||
)
|
||||
assert settings.url == expected
|
||||
|
||||
def test_url_property_with_empty_password(self):
|
||||
"""
|
||||
Тест строки подключения с пустым паролем.
|
||||
"""
|
||||
with patch.dict("os.environ", {
|
||||
"PG_HOST": "localhost",
|
||||
"PG_PORT": "5432",
|
||||
"PG_USER": "postgres",
|
||||
"PG_PASSWORD": "",
|
||||
"PG_DATABASE": "testdb"
|
||||
}):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"PG_HOST": "localhost",
|
||||
"PG_PORT": "5432",
|
||||
"PG_USER": "postgres",
|
||||
"PG_PASSWORD": "",
|
||||
"PG_DATABASE": "testdb",
|
||||
},
|
||||
):
|
||||
settings = PGSettings()
|
||||
|
||||
expected = "postgresql+asyncpg://postgres:@localhost:5432/testdb"
|
||||
|
|
@ -240,15 +259,18 @@ class TestPGSettings:
|
|||
"""
|
||||
Тест загрузки из переменных окружения.
|
||||
"""
|
||||
with patch.dict("os.environ", {
|
||||
"PG_HOST": "testhost",
|
||||
"PG_PORT": "5433",
|
||||
"PG_USER": "testuser",
|
||||
"PG_PASSWORD": "testpass",
|
||||
"PG_DATABASE": "testdb",
|
||||
"PG_SCHEMA_QUEUE": "queue_schema",
|
||||
"PG_POOL_SIZE": "20"
|
||||
}):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"PG_HOST": "testhost",
|
||||
"PG_PORT": "5433",
|
||||
"PG_USER": "testuser",
|
||||
"PG_PASSWORD": "testpass",
|
||||
"PG_DATABASE": "testdb",
|
||||
"PG_SCHEMA_QUEUE": "queue_schema",
|
||||
"PG_POOL_SIZE": "20",
|
||||
},
|
||||
):
|
||||
settings = PGSettings()
|
||||
|
||||
assert settings.host == "testhost"
|
||||
|
|
@ -292,10 +314,12 @@ class TestWorkerSettings:
|
|||
"""
|
||||
Тест парсинга валидного JSON.
|
||||
"""
|
||||
workers_json = json.dumps([
|
||||
{"queue": "queue1", "concurrency": 2},
|
||||
{"queue": "queue2", "concurrency": 3}
|
||||
])
|
||||
workers_json = json.dumps(
|
||||
[
|
||||
{"queue": "queue1", "concurrency": 2},
|
||||
{"queue": "queue2", "concurrency": 3},
|
||||
]
|
||||
)
|
||||
with patch.dict("os.environ", {"WORKERS_JSON": workers_json}):
|
||||
settings = WorkerSettings()
|
||||
|
||||
|
|
@ -311,12 +335,14 @@ class TestWorkerSettings:
|
|||
"""
|
||||
Тест фильтрации не-словарей из JSON.
|
||||
"""
|
||||
workers_json = json.dumps([
|
||||
{"queue": "queue1", "concurrency": 2},
|
||||
"invalid_item",
|
||||
123,
|
||||
{"queue": "queue2", "concurrency": 3}
|
||||
])
|
||||
workers_json = json.dumps(
|
||||
[
|
||||
{"queue": "queue1", "concurrency": 2},
|
||||
"invalid_item",
|
||||
123,
|
||||
{"queue": "queue2", "concurrency": 3},
|
||||
]
|
||||
)
|
||||
with patch.dict("os.environ", {"WORKERS_JSON": workers_json}):
|
||||
settings = WorkerSettings()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# tests/unit/test_context.py
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dataloader.context import AppContext, get_session
|
||||
|
|
@ -71,10 +71,16 @@ class TestAppContext:
|
|||
mock_engine = Mock()
|
||||
mock_sm = Mock()
|
||||
|
||||
with patch("dataloader.logger.logger.setup_logging") as mock_setup_logging, \
|
||||
patch("dataloader.storage.engine.create_engine", return_value=mock_engine) as mock_create_engine, \
|
||||
patch("dataloader.storage.engine.create_sessionmaker", return_value=mock_sm) as mock_create_sm, \
|
||||
patch("dataloader.context.APP_CONFIG") as mock_config:
|
||||
with (
|
||||
patch("dataloader.logger.logger.setup_logging") as mock_setup_logging,
|
||||
patch(
|
||||
"dataloader.storage.engine.create_engine", return_value=mock_engine
|
||||
) as mock_create_engine,
|
||||
patch(
|
||||
"dataloader.storage.engine.create_sessionmaker", return_value=mock_sm
|
||||
) as mock_create_sm,
|
||||
patch("dataloader.context.APP_CONFIG") as mock_config,
|
||||
):
|
||||
|
||||
mock_config.pg.url = "postgresql://test"
|
||||
|
||||
|
|
@ -194,7 +200,7 @@ class TestGetSession:
|
|||
with patch("dataloader.context.APP_CTX") as mock_ctx:
|
||||
mock_ctx.sessionmaker = mock_sm
|
||||
|
||||
async for session in get_session():
|
||||
async for _session in get_session():
|
||||
pass
|
||||
|
||||
assert mock_exit.call_count == 1
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
# tests/unit/test_notify_listener.py
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dataloader.storage.notify_listener import PGNotifyListener
|
||||
|
|
@ -25,7 +25,7 @@ class TestPGNotifyListener:
|
|||
dsn="postgresql://test",
|
||||
queue="test_queue",
|
||||
callback=callback,
|
||||
stop_event=stop_event
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
assert listener._dsn == "postgresql://test"
|
||||
|
|
@ -47,14 +47,16 @@ class TestPGNotifyListener:
|
|||
dsn="postgresql://test",
|
||||
queue="test_queue",
|
||||
callback=callback,
|
||||
stop_event=stop_event
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute = AsyncMock()
|
||||
mock_conn.add_listener = AsyncMock()
|
||||
|
||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
||||
with patch(
|
||||
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||
):
|
||||
await listener.start()
|
||||
|
||||
assert listener._conn == mock_conn
|
||||
|
|
@ -76,14 +78,16 @@ class TestPGNotifyListener:
|
|||
dsn="postgresql+asyncpg://test",
|
||||
queue="test_queue",
|
||||
callback=callback,
|
||||
stop_event=stop_event
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute = AsyncMock()
|
||||
mock_conn.add_listener = AsyncMock()
|
||||
|
||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn) as mock_connect:
|
||||
with patch(
|
||||
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||
) as mock_connect:
|
||||
await listener.start()
|
||||
|
||||
mock_connect.assert_called_once_with("postgresql://test")
|
||||
|
|
@ -102,14 +106,16 @@ class TestPGNotifyListener:
|
|||
dsn="postgresql://test",
|
||||
queue="test_queue",
|
||||
callback=callback,
|
||||
stop_event=stop_event
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute = AsyncMock()
|
||||
mock_conn.add_listener = AsyncMock()
|
||||
|
||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
||||
with patch(
|
||||
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||
):
|
||||
await listener.start()
|
||||
|
||||
handler = listener._on_notify_handler
|
||||
|
|
@ -131,14 +137,16 @@ class TestPGNotifyListener:
|
|||
dsn="postgresql://test",
|
||||
queue="test_queue",
|
||||
callback=callback,
|
||||
stop_event=stop_event
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute = AsyncMock()
|
||||
mock_conn.add_listener = AsyncMock()
|
||||
|
||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
||||
with patch(
|
||||
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||
):
|
||||
await listener.start()
|
||||
|
||||
handler = listener._on_notify_handler
|
||||
|
|
@ -160,14 +168,16 @@ class TestPGNotifyListener:
|
|||
dsn="postgresql://test",
|
||||
queue="test_queue",
|
||||
callback=callback,
|
||||
stop_event=stop_event
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute = AsyncMock()
|
||||
mock_conn.add_listener = AsyncMock()
|
||||
|
||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
||||
with patch(
|
||||
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||
):
|
||||
await listener.start()
|
||||
|
||||
handler = listener._on_notify_handler
|
||||
|
|
@ -189,14 +199,16 @@ class TestPGNotifyListener:
|
|||
dsn="postgresql://test",
|
||||
queue="test_queue",
|
||||
callback=callback,
|
||||
stop_event=stop_event
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute = AsyncMock()
|
||||
mock_conn.add_listener = AsyncMock()
|
||||
|
||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
||||
with patch(
|
||||
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||
):
|
||||
await listener.start()
|
||||
|
||||
handler = listener._on_notify_handler
|
||||
|
|
@ -218,7 +230,7 @@ class TestPGNotifyListener:
|
|||
dsn="postgresql://test",
|
||||
queue="test_queue",
|
||||
callback=callback,
|
||||
stop_event=stop_event
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
|
|
@ -227,7 +239,9 @@ class TestPGNotifyListener:
|
|||
mock_conn.remove_listener = AsyncMock()
|
||||
mock_conn.close = AsyncMock()
|
||||
|
||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
||||
with patch(
|
||||
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||
):
|
||||
await listener.start()
|
||||
|
||||
assert listener._task is not None
|
||||
|
|
@ -250,7 +264,7 @@ class TestPGNotifyListener:
|
|||
dsn="postgresql://test",
|
||||
queue="test_queue",
|
||||
callback=callback,
|
||||
stop_event=stop_event
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
|
|
@ -259,7 +273,9 @@ class TestPGNotifyListener:
|
|||
mock_conn.remove_listener = AsyncMock()
|
||||
mock_conn.close = AsyncMock()
|
||||
|
||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
||||
with patch(
|
||||
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||
):
|
||||
await listener.start()
|
||||
|
||||
await listener.stop()
|
||||
|
|
@ -280,7 +296,7 @@ class TestPGNotifyListener:
|
|||
dsn="postgresql://test",
|
||||
queue="test_queue",
|
||||
callback=callback,
|
||||
stop_event=stop_event
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
|
|
@ -289,7 +305,9 @@ class TestPGNotifyListener:
|
|||
mock_conn.remove_listener = AsyncMock()
|
||||
mock_conn.close = AsyncMock()
|
||||
|
||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
||||
with patch(
|
||||
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||
):
|
||||
await listener.start()
|
||||
|
||||
stop_event.set()
|
||||
|
|
@ -311,7 +329,7 @@ class TestPGNotifyListener:
|
|||
dsn="postgresql://test",
|
||||
queue="test_queue",
|
||||
callback=callback,
|
||||
stop_event=stop_event
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
|
|
@ -320,7 +338,9 @@ class TestPGNotifyListener:
|
|||
mock_conn.remove_listener = AsyncMock(side_effect=Exception("Remove error"))
|
||||
mock_conn.close = AsyncMock()
|
||||
|
||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
||||
with patch(
|
||||
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||
):
|
||||
await listener.start()
|
||||
|
||||
await listener.stop()
|
||||
|
|
@ -340,7 +360,7 @@ class TestPGNotifyListener:
|
|||
dsn="postgresql://test",
|
||||
queue="test_queue",
|
||||
callback=callback,
|
||||
stop_event=stop_event
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
|
|
@ -349,7 +369,9 @@ class TestPGNotifyListener:
|
|||
mock_conn.remove_listener = AsyncMock()
|
||||
mock_conn.close = AsyncMock(side_effect=Exception("Close error"))
|
||||
|
||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
||||
with patch(
|
||||
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||
):
|
||||
await listener.start()
|
||||
|
||||
await listener.stop()
|
||||
|
|
@ -368,7 +390,7 @@ class TestPGNotifyListener:
|
|||
dsn="postgresql://test",
|
||||
queue="test_queue",
|
||||
callback=callback,
|
||||
stop_event=stop_event
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
await listener.stop()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
# tests/unit/test_pipeline_registry.py
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from dataloader.workers.pipelines.registry import register, resolve, tasks, _Registry
|
||||
from dataloader.workers.pipelines.registry import _Registry, register, resolve, tasks
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
|
@ -22,6 +21,7 @@ class TestPipelineRegistry:
|
|||
"""
|
||||
Тест регистрации пайплайна.
|
||||
"""
|
||||
|
||||
@register("test.task")
|
||||
def test_pipeline(args: dict):
|
||||
return "result"
|
||||
|
|
@ -33,6 +33,7 @@ class TestPipelineRegistry:
|
|||
"""
|
||||
Тест получения зарегистрированного пайплайна.
|
||||
"""
|
||||
|
||||
@register("test.resolve")
|
||||
def test_pipeline(args: dict):
|
||||
return "resolved"
|
||||
|
|
@ -54,6 +55,7 @@ class TestPipelineRegistry:
|
|||
"""
|
||||
Тест получения списка зарегистрированных задач.
|
||||
"""
|
||||
|
||||
@register("task1")
|
||||
def pipeline1(args: dict):
|
||||
pass
|
||||
|
|
@ -70,6 +72,7 @@ class TestPipelineRegistry:
|
|||
"""
|
||||
Тест перезаписи существующего пайплайна.
|
||||
"""
|
||||
|
||||
@register("overwrite.task")
|
||||
def first_pipeline(args: dict):
|
||||
return "first"
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
# tests/unit/test_workers_base.py
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dataloader.workers.base import PGWorker, WorkerConfig
|
||||
|
|
@ -41,9 +40,11 @@ class TestPGWorker:
|
|||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \
|
||||
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls:
|
||||
with (
|
||||
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg,
|
||||
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls,
|
||||
):
|
||||
|
||||
mock_ctx.get_logger.return_value = Mock()
|
||||
mock_ctx.sessionmaker = Mock()
|
||||
|
|
@ -64,7 +65,9 @@ class TestPGWorker:
|
|||
stop_event.set()
|
||||
return False
|
||||
|
||||
with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim):
|
||||
with patch.object(
|
||||
worker, "_claim_and_execute_once", side_effect=mock_claim
|
||||
):
|
||||
await worker.run()
|
||||
|
||||
assert mock_listener.start.call_count == 1
|
||||
|
|
@ -78,9 +81,11 @@ class TestPGWorker:
|
|||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \
|
||||
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls:
|
||||
with (
|
||||
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg,
|
||||
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls,
|
||||
):
|
||||
|
||||
mock_logger = Mock()
|
||||
mock_ctx.get_logger.return_value = mock_logger
|
||||
|
|
@ -101,7 +106,9 @@ class TestPGWorker:
|
|||
stop_event.set()
|
||||
return False
|
||||
|
||||
with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim):
|
||||
with patch.object(
|
||||
worker, "_claim_and_execute_once", side_effect=mock_claim
|
||||
):
|
||||
await worker.run()
|
||||
|
||||
assert worker._listener is None
|
||||
|
|
@ -156,8 +163,10 @@ class TestPGWorker:
|
|||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||||
with (
|
||||
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||
):
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
|
@ -185,8 +194,10 @@ class TestPGWorker:
|
|||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||||
with (
|
||||
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||
):
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_sm = MagicMock()
|
||||
|
|
@ -196,12 +207,14 @@ class TestPGWorker:
|
|||
mock_ctx.sessionmaker = mock_sm
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_repo.claim_one = AsyncMock(return_value={
|
||||
"job_id": "test-job-id",
|
||||
"lease_ttl_sec": 60,
|
||||
"task": "test.task",
|
||||
"args": {"key": "value"}
|
||||
})
|
||||
mock_repo.claim_one = AsyncMock(
|
||||
return_value={
|
||||
"job_id": "test-job-id",
|
||||
"lease_ttl_sec": 60,
|
||||
"task": "test.task",
|
||||
"args": {"key": "value"},
|
||||
}
|
||||
)
|
||||
mock_repo.finish_ok = AsyncMock()
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
|
||||
|
|
@ -210,8 +223,10 @@ class TestPGWorker:
|
|||
async def mock_pipeline(task, args):
|
||||
yield
|
||||
|
||||
with patch.object(worker, "_pipeline", side_effect=mock_pipeline), \
|
||||
patch.object(worker, "_execute_with_heartbeat", return_value=False):
|
||||
with (
|
||||
patch.object(worker, "_pipeline", side_effect=mock_pipeline),
|
||||
patch.object(worker, "_execute_with_heartbeat", return_value=False),
|
||||
):
|
||||
result = await worker._claim_and_execute_once()
|
||||
|
||||
assert result is True
|
||||
|
|
@ -225,8 +240,10 @@ class TestPGWorker:
|
|||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||||
with (
|
||||
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||
):
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_sm = MagicMock()
|
||||
|
|
@ -236,12 +253,14 @@ class TestPGWorker:
|
|||
mock_ctx.sessionmaker = mock_sm
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_repo.claim_one = AsyncMock(return_value={
|
||||
"job_id": "test-job-id",
|
||||
"lease_ttl_sec": 60,
|
||||
"task": "test.task",
|
||||
"args": {}
|
||||
})
|
||||
mock_repo.claim_one = AsyncMock(
|
||||
return_value={
|
||||
"job_id": "test-job-id",
|
||||
"lease_ttl_sec": 60,
|
||||
"task": "test.task",
|
||||
"args": {},
|
||||
}
|
||||
)
|
||||
mock_repo.finish_fail_or_retry = AsyncMock()
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
|
||||
|
|
@ -263,8 +282,10 @@ class TestPGWorker:
|
|||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||||
with (
|
||||
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||
):
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_sm = MagicMock()
|
||||
|
|
@ -274,18 +295,22 @@ class TestPGWorker:
|
|||
mock_ctx.sessionmaker = mock_sm
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_repo.claim_one = AsyncMock(return_value={
|
||||
"job_id": "test-job-id",
|
||||
"lease_ttl_sec": 60,
|
||||
"task": "test.task",
|
||||
"args": {}
|
||||
})
|
||||
mock_repo.claim_one = AsyncMock(
|
||||
return_value={
|
||||
"job_id": "test-job-id",
|
||||
"lease_ttl_sec": 60,
|
||||
"task": "test.task",
|
||||
"args": {},
|
||||
}
|
||||
)
|
||||
mock_repo.finish_fail_or_retry = AsyncMock()
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
|
||||
worker = PGWorker(cfg, stop_event)
|
||||
|
||||
with patch.object(worker, "_execute_with_heartbeat", side_effect=ValueError("Test error")):
|
||||
with patch.object(
|
||||
worker, "_execute_with_heartbeat", side_effect=ValueError("Test error")
|
||||
):
|
||||
result = await worker._claim_and_execute_once()
|
||||
|
||||
assert result is True
|
||||
|
|
@ -301,8 +326,10 @@ class TestPGWorker:
|
|||
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||||
with (
|
||||
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||
):
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_sm = MagicMock()
|
||||
|
|
@ -323,7 +350,9 @@ class TestPGWorker:
|
|||
await asyncio.sleep(0.6)
|
||||
yield
|
||||
|
||||
canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline())
|
||||
canceled = await worker._execute_with_heartbeat(
|
||||
"job-id", 60, slow_pipeline()
|
||||
)
|
||||
|
||||
assert canceled is False
|
||||
assert mock_repo.heartbeat.call_count >= 1
|
||||
|
|
@ -336,8 +365,10 @@ class TestPGWorker:
|
|||
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||||
with (
|
||||
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||
):
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_sm = MagicMock()
|
||||
|
|
@ -358,7 +389,9 @@ class TestPGWorker:
|
|||
await asyncio.sleep(0.6)
|
||||
yield
|
||||
|
||||
canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline())
|
||||
canceled = await worker._execute_with_heartbeat(
|
||||
"job-id", 60, slow_pipeline()
|
||||
)
|
||||
|
||||
assert canceled is True
|
||||
|
||||
|
|
@ -370,8 +403,10 @@ class TestPGWorker:
|
|||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
|
||||
with (
|
||||
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
|
||||
):
|
||||
|
||||
mock_ctx.get_logger.return_value = Mock()
|
||||
mock_ctx.sessionmaker = Mock()
|
||||
|
|
@ -397,8 +432,10 @@ class TestPGWorker:
|
|||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
|
||||
with (
|
||||
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
|
||||
):
|
||||
|
||||
mock_ctx.get_logger.return_value = Mock()
|
||||
mock_ctx.sessionmaker = Mock()
|
||||
|
|
@ -424,8 +461,10 @@ class TestPGWorker:
|
|||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
|
||||
with (
|
||||
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
|
||||
):
|
||||
|
||||
mock_ctx.get_logger.return_value = Mock()
|
||||
mock_ctx.sessionmaker = Mock()
|
||||
|
|
@ -450,8 +489,10 @@ class TestPGWorker:
|
|||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||||
with (
|
||||
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||
):
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_sm = MagicMock()
|
||||
|
|
@ -461,12 +502,14 @@ class TestPGWorker:
|
|||
mock_ctx.sessionmaker = mock_sm
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_repo.claim_one = AsyncMock(return_value={
|
||||
"job_id": "test-job-id",
|
||||
"lease_ttl_sec": 60,
|
||||
"task": "test.task",
|
||||
"args": {}
|
||||
})
|
||||
mock_repo.claim_one = AsyncMock(
|
||||
return_value={
|
||||
"job_id": "test-job-id",
|
||||
"lease_ttl_sec": 60,
|
||||
"task": "test.task",
|
||||
"args": {},
|
||||
}
|
||||
)
|
||||
mock_repo.finish_fail_or_retry = AsyncMock()
|
||||
mock_repo_cls.return_value = mock_repo
|
||||
|
||||
|
|
@ -484,19 +527,16 @@ class TestPGWorker:
|
|||
assert "cancelled by shutdown" in args[1]
|
||||
assert kwargs.get("is_canceled") is True
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_heartbeat_raises_cancelled_when_stop_set(self):
|
||||
cfg = WorkerConfig(queue="test", heartbeat_sec=1000, claim_backoff_sec=5)
|
||||
stop_event = asyncio.Event()
|
||||
stop_event.set()
|
||||
|
||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
||||
with (
|
||||
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||
):
|
||||
|
||||
mock_ctx.get_logger.return_value = Mock()
|
||||
mock_ctx.sessionmaker = Mock()
|
||||
|
|
@ -509,4 +549,3 @@ class TestPGWorker:
|
|||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await worker._execute_with_heartbeat("job-id", 60, one_yield())
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
# tests/unit/test_workers_manager.py
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dataloader.workers.manager import WorkerManager, WorkerSpec, build_manager_from_env
|
||||
|
|
@ -39,9 +39,11 @@ class TestWorkerManager:
|
|||
WorkerSpec(queue="queue2", concurrency=1),
|
||||
]
|
||||
|
||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
|
||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls:
|
||||
with (
|
||||
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
|
||||
):
|
||||
|
||||
mock_ctx.get_logger.return_value = Mock()
|
||||
mock_cfg.worker.heartbeat_sec = 10
|
||||
|
|
@ -67,9 +69,11 @@ class TestWorkerManager:
|
|||
"""
|
||||
specs = [WorkerSpec(queue="test", concurrency=0)]
|
||||
|
||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
|
||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls:
|
||||
with (
|
||||
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
|
||||
):
|
||||
|
||||
mock_ctx.get_logger.return_value = Mock()
|
||||
mock_cfg.worker.heartbeat_sec = 10
|
||||
|
|
@ -93,9 +97,11 @@ class TestWorkerManager:
|
|||
"""
|
||||
specs = [WorkerSpec(queue="test", concurrency=2)]
|
||||
|
||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
|
||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls:
|
||||
with (
|
||||
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
|
||||
):
|
||||
|
||||
mock_ctx.get_logger.return_value = Mock()
|
||||
mock_cfg.worker.heartbeat_sec = 10
|
||||
|
|
@ -108,8 +114,6 @@ class TestWorkerManager:
|
|||
manager = WorkerManager(specs)
|
||||
await manager.start()
|
||||
|
||||
initial_task_count = len(manager._tasks)
|
||||
|
||||
await manager.stop()
|
||||
|
||||
assert manager._stop.is_set()
|
||||
|
|
@ -123,10 +127,12 @@ class TestWorkerManager:
|
|||
"""
|
||||
specs = [WorkerSpec(queue="test", concurrency=1)]
|
||||
|
||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
|
||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, \
|
||||
patch("dataloader.workers.manager.requeue_lost") as mock_requeue:
|
||||
with (
|
||||
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
|
||||
patch("dataloader.workers.manager.requeue_lost") as mock_requeue,
|
||||
):
|
||||
|
||||
mock_logger = Mock()
|
||||
mock_ctx.get_logger.return_value = mock_logger
|
||||
|
|
@ -161,10 +167,12 @@ class TestWorkerManager:
|
|||
"""
|
||||
specs = [WorkerSpec(queue="test", concurrency=1)]
|
||||
|
||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
|
||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, \
|
||||
patch("dataloader.workers.manager.requeue_lost") as mock_requeue:
|
||||
with (
|
||||
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
|
||||
patch("dataloader.workers.manager.requeue_lost") as mock_requeue,
|
||||
):
|
||||
|
||||
mock_logger = Mock()
|
||||
mock_ctx.get_logger.return_value = mock_logger
|
||||
|
|
@ -203,8 +211,10 @@ class TestBuildManagerFromEnv:
|
|||
"""
|
||||
Тест создания менеджера из конфигурации.
|
||||
"""
|
||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg:
|
||||
with (
|
||||
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||
):
|
||||
|
||||
mock_ctx.get_logger.return_value = Mock()
|
||||
mock_cfg.worker.parsed_workers.return_value = [
|
||||
|
|
@ -224,8 +234,10 @@ class TestBuildManagerFromEnv:
|
|||
"""
|
||||
Тест, что пустые имена очередей пропускаются.
|
||||
"""
|
||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg:
|
||||
with (
|
||||
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||
):
|
||||
|
||||
mock_ctx.get_logger.return_value = Mock()
|
||||
mock_cfg.worker.parsed_workers.return_value = [
|
||||
|
|
@ -243,8 +255,10 @@ class TestBuildManagerFromEnv:
|
|||
"""
|
||||
Тест обработки отсутствующих полей с дефолтными значениями.
|
||||
"""
|
||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg:
|
||||
with (
|
||||
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||
):
|
||||
|
||||
mock_ctx.get_logger.return_value = Mock()
|
||||
mock_cfg.worker.parsed_workers.return_value = [
|
||||
|
|
@ -262,8 +276,10 @@ class TestBuildManagerFromEnv:
|
|||
"""
|
||||
Тест, что concurrency всегда минимум 1.
|
||||
"""
|
||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg:
|
||||
with (
|
||||
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||
):
|
||||
|
||||
mock_ctx.get_logger.return_value = Mock()
|
||||
mock_cfg.worker.parsed_workers.return_value = [
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
# tests/unit/test_workers_reaper.py
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, Mock
|
||||
|
||||
from dataloader.workers.reaper import requeue_lost
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue