refactor: refactor code
This commit is contained in:
parent
c907e1d4da
commit
bde0bb0e6f
|
|
@ -34,3 +34,37 @@ dataloader = "dataloader.__main__:main"
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 88
|
||||||
|
target-version = ["py311"]
|
||||||
|
skip-string-normalization = false
|
||||||
|
preview = true
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
profile = "black"
|
||||||
|
line_length = 88
|
||||||
|
multi_line_output = 3
|
||||||
|
include_trailing_comma = true
|
||||||
|
force_grid_wrap = 0
|
||||||
|
combine_as_imports = true
|
||||||
|
known_first_party = ["dataloader"]
|
||||||
|
src_paths = ["src"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 88
|
||||||
|
target-version = "py311"
|
||||||
|
fix = true
|
||||||
|
lint.select = ["E", "F", "W", "I", "N", "B"]
|
||||||
|
lint.ignore = [
|
||||||
|
"E501",
|
||||||
|
]
|
||||||
|
exclude = [
|
||||||
|
".git",
|
||||||
|
"__pycache__",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
".venv",
|
||||||
|
"venv",
|
||||||
|
]
|
||||||
|
|
@ -5,7 +5,6 @@
|
||||||
|
|
||||||
from . import api
|
from . import api
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"api",
|
"api",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,15 @@
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
|
||||||
from dataloader.api import app_main
|
from dataloader.api import app_main
|
||||||
from dataloader.config import APP_CONFIG
|
from dataloader.config import APP_CONFIG
|
||||||
from dataloader.logger.uvicorn_logging_config import LOGGING_CONFIG, setup_uvicorn_logging
|
from dataloader.logger.uvicorn_logging_config import (
|
||||||
|
LOGGING_CONFIG,
|
||||||
|
setup_uvicorn_logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
# Инициализируем логирование uvicorn перед запуском
|
|
||||||
setup_uvicorn_logging()
|
setup_uvicorn_logging()
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,19 @@
|
||||||
# src/dataloader/api/__init__.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import typing as tp
|
import typing as tp
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from dataloader.context import APP_CTX
|
||||||
|
from dataloader.workers.manager import WorkerManager, build_manager_from_env
|
||||||
|
from dataloader.workers.pipelines import load_all as load_pipelines
|
||||||
|
|
||||||
from .metric_router import router as metric_router
|
from .metric_router import router as metric_router
|
||||||
from .middleware import log_requests
|
from .middleware import log_requests
|
||||||
from .os_router import router as service_router
|
from .os_router import router as service_router
|
||||||
from .v1 import router as v1_router
|
from .v1 import router as v1_router
|
||||||
from dataloader.context import APP_CTX
|
|
||||||
from dataloader.workers.manager import build_manager_from_env, WorkerManager
|
|
||||||
from dataloader.workers.pipelines import load_all as load_pipelines
|
|
||||||
|
|
||||||
|
|
||||||
_manager: WorkerManager | None = None
|
_manager: WorkerManager | None = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
""" 🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!
|
"""🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!"""
|
||||||
"""
|
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Header, status
|
from fastapi import APIRouter, Header, status
|
||||||
|
|
||||||
from dataloader.context import APP_CTX
|
from dataloader.context import APP_CTX
|
||||||
|
|
||||||
from . import schemas
|
from . import schemas
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
@ -17,8 +18,7 @@ logger = APP_CTX.get_logger()
|
||||||
response_model=schemas.RateResponse,
|
response_model=schemas.RateResponse,
|
||||||
)
|
)
|
||||||
async def like(
|
async def like(
|
||||||
# pylint: disable=C0103,W0613
|
header_request_id: str = Header(uuid.uuid4(), alias="Request-Id")
|
||||||
header_Request_Id: str = Header(uuid.uuid4(), alias="Request-Id")
|
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
logger.metric(
|
logger.metric(
|
||||||
metric_name="dataloader_likes_total",
|
metric_name="dataloader_likes_total",
|
||||||
|
|
@ -33,8 +33,7 @@ async def like(
|
||||||
response_model=schemas.RateResponse,
|
response_model=schemas.RateResponse,
|
||||||
)
|
)
|
||||||
async def dislike(
|
async def dislike(
|
||||||
# pylint: disable=C0103,W0613
|
header_request_id: str = Header(uuid.uuid4(), alias="Request-Id")
|
||||||
header_Request_Id: str = Header(uuid.uuid4(), alias="Request-Id")
|
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
logger.metric(
|
logger.metric(
|
||||||
metric_name="dataloader_dislikes_total",
|
metric_name="dataloader_dislikes_total",
|
||||||
|
|
|
||||||
|
|
@ -38,8 +38,14 @@ async def log_requests(request: Request, call_next) -> any:
|
||||||
logger = APP_CTX.get_logger()
|
logger = APP_CTX.get_logger()
|
||||||
request_path = request.url.path
|
request_path = request.url.path
|
||||||
|
|
||||||
allowed_headers_to_log = ((k, request.headers.get(k)) for k in HEADERS_WHITE_LIST_TO_LOG)
|
allowed_headers_to_log = (
|
||||||
headers_to_log = {header_name: header_value for header_name, header_value in allowed_headers_to_log if header_value}
|
(k, request.headers.get(k)) for k in HEADERS_WHITE_LIST_TO_LOG
|
||||||
|
)
|
||||||
|
headers_to_log = {
|
||||||
|
header_name: header_value
|
||||||
|
for header_name, header_value in allowed_headers_to_log
|
||||||
|
if header_value
|
||||||
|
}
|
||||||
|
|
||||||
APP_CTX.get_context_vars_container().set_context_vars(
|
APP_CTX.get_context_vars_container().set_context_vars(
|
||||||
request_id=headers_to_log.get("Request-Id", ""),
|
request_id=headers_to_log.get("Request-Id", ""),
|
||||||
|
|
@ -50,7 +56,9 @@ async def log_requests(request: Request, call_next) -> any:
|
||||||
|
|
||||||
if request_path in NON_LOGGED_ENDPOINTS:
|
if request_path in NON_LOGGED_ENDPOINTS:
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
logger.debug(f"Processed request for {request_path} with code {response.status_code}")
|
logger.debug(
|
||||||
|
f"Processed request for {request_path} with code {response.status_code}"
|
||||||
|
)
|
||||||
elif headers_to_log.get("Request-Id", None):
|
elif headers_to_log.get("Request-Id", None):
|
||||||
raw_request_body = await request.body()
|
raw_request_body = await request.body()
|
||||||
request_body_decoded = _get_decoded_body(raw_request_body, "request", logger)
|
request_body_decoded = _get_decoded_body(raw_request_body, "request", logger)
|
||||||
|
|
@ -80,7 +88,7 @@ async def log_requests(request: Request, call_next) -> any:
|
||||||
event_params=json.dumps(
|
event_params=json.dumps(
|
||||||
request_body_decoded,
|
request_body_decoded,
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
@ -88,12 +96,16 @@ async def log_requests(request: Request, call_next) -> any:
|
||||||
response_body = [chunk async for chunk in response.body_iterator]
|
response_body = [chunk async for chunk in response.body_iterator]
|
||||||
response.body_iterator = iterate_in_threadpool(iter(response_body))
|
response.body_iterator = iterate_in_threadpool(iter(response_body))
|
||||||
|
|
||||||
headers_to_log["Response-Time"] = datetime.now(APP_CTX.get_pytz_timezone()).isoformat()
|
headers_to_log["Response-Time"] = datetime.now(
|
||||||
|
APP_CTX.get_pytz_timezone()
|
||||||
|
).isoformat()
|
||||||
for header in headers_to_log:
|
for header in headers_to_log:
|
||||||
response.headers[header] = headers_to_log[header]
|
response.headers[header] = headers_to_log[header]
|
||||||
|
|
||||||
response_body_extracted = response_body[0] if len(response_body) > 0 else b""
|
response_body_extracted = response_body[0] if len(response_body) > 0 else b""
|
||||||
decoded_response_body = _get_decoded_body(response_body_extracted, "response", logger)
|
decoded_response_body = _get_decoded_body(
|
||||||
|
response_body_extracted, "response", logger
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Outgoing response to client system",
|
"Outgoing response to client system",
|
||||||
|
|
@ -115,11 +127,13 @@ async def log_requests(request: Request, call_next) -> any:
|
||||||
event_params=json.dumps(
|
event_params=json.dumps(
|
||||||
decoded_response_body,
|
decoded_response_body,
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
processing_time_ms = int(round((time.time() - start_time), 3) * 1000)
|
processing_time_ms = int(round((time.time() - start_time), 3) * 1000)
|
||||||
logger.info(f"Request processing time for {request_path}: {processing_time_ms} ms")
|
logger.info(
|
||||||
|
f"Request processing time for {request_path}: {processing_time_ms} ms"
|
||||||
|
)
|
||||||
logger.metric(
|
logger.metric(
|
||||||
metric_name="dataloader_process_duration_ms",
|
metric_name="dataloader_process_duration_ms",
|
||||||
metric_value=processing_time_ms,
|
metric_value=processing_time_ms,
|
||||||
|
|
@ -138,7 +152,9 @@ async def log_requests(request: Request, call_next) -> any:
|
||||||
else:
|
else:
|
||||||
logger.info(f"Incoming {request.method}-request with no id for {request_path}")
|
logger.info(f"Incoming {request.method}-request with no id for {request_path}")
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
logger.info(f"Request with no id for {request_path} processing time: {time.time() - start_time:.3f} s")
|
logger.info(
|
||||||
|
f"Request with no id for {request_path} processing time: {time.time() - start_time:.3f} s"
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
# Инфраструктурные endpoint'ы (/health, /status)
|
"""🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!"""
|
||||||
""" 🚨 НЕ РЕДАКТИРОВАТЬ !!!!!!
|
|
||||||
"""
|
|
||||||
|
|
||||||
from importlib.metadata import distribution
|
from importlib.metadata import distribution
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,22 +3,26 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
class HealthResponse(BaseModel):
|
class HealthResponse(BaseModel):
|
||||||
"""Ответ для ручки /health"""
|
"""Ответ для ручки /health"""
|
||||||
model_config = ConfigDict(
|
|
||||||
json_schema_extra={"example": {"status": "running"}}
|
|
||||||
)
|
|
||||||
|
|
||||||
status: str = Field(default="running", description="Service health check", max_length=7)
|
model_config = ConfigDict(json_schema_extra={"example": {"status": "running"}})
|
||||||
|
|
||||||
|
status: str = Field(
|
||||||
|
default="running", description="Service health check", max_length=7
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InfoResponse(BaseModel):
|
class InfoResponse(BaseModel):
|
||||||
"""Ответ для ручки /info"""
|
"""Ответ для ручки /info"""
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"example": {
|
"example": {
|
||||||
"name": "rest-template",
|
"name": "rest-template",
|
||||||
"description": "Python 'AI gateway' template for developing REST microservices",
|
"description": (
|
||||||
|
"Python 'AI gateway' template for developing REST microservices"
|
||||||
|
),
|
||||||
"type": "REST API",
|
"type": "REST API",
|
||||||
"version": "0.1.0"
|
"version": "0.1.0",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -26,11 +30,14 @@ class InfoResponse(BaseModel):
|
||||||
name: str = Field(description="Service name", max_length=50)
|
name: str = Field(description="Service name", max_length=50)
|
||||||
description: str = Field(description="Service description", max_length=200)
|
description: str = Field(description="Service description", max_length=200)
|
||||||
type: str = Field(default="REST API", description="Service type", max_length=20)
|
type: str = Field(default="REST API", description="Service type", max_length=20)
|
||||||
version: str = Field(description="Service version", max_length=20, pattern=r"^\d+\.\d+\.\d+")
|
version: str = Field(
|
||||||
|
description="Service version", max_length=20, pattern=r"^\d+\.\d+\.\d+"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RateResponse(BaseModel):
|
class RateResponse(BaseModel):
|
||||||
"""Ответ для записи рейтинга"""
|
"""Ответ для записи рейтинга"""
|
||||||
|
|
||||||
model_config = ConfigDict(str_strip_whitespace=True)
|
model_config = ConfigDict(str_strip_whitespace=True)
|
||||||
|
|
||||||
rating_result: str = Field(description="Rating that was recorded", max_length=50)
|
rating_result: str = Field(description="Rating that was recorded", max_length=50)
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,5 @@ class JobNotFoundError(HTTPException):
|
||||||
|
|
||||||
def __init__(self, job_id: str):
|
def __init__(self, job_id: str):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND, detail=f"Job {job_id} not found"
|
||||||
detail=f"Job {job_id} not found"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1 @@
|
||||||
"""Модели данных для API v1."""
|
"""Модели данных для API v1."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
# src/dataloader/api/v1/router.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from dataloader.api.v1.exceptions import JobNotFoundError
|
from dataloader.api.v1.exceptions import JobNotFoundError
|
||||||
from dataloader.api.v1.schemas import (
|
from dataloader.api.v1.schemas import (
|
||||||
|
|
@ -16,8 +15,6 @@ from dataloader.api.v1.schemas import (
|
||||||
)
|
)
|
||||||
from dataloader.api.v1.service import JobsService
|
from dataloader.api.v1.service import JobsService
|
||||||
from dataloader.context import get_session
|
from dataloader.context import get_session
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/jobs", tags=["jobs"])
|
router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||||
|
|
||||||
|
|
@ -40,7 +37,9 @@ async def trigger_job(
|
||||||
return await svc.trigger(payload)
|
return await svc.trigger(payload)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{job_id}/status", response_model=JobStatusResponse, status_code=HTTPStatus.OK)
|
@router.get(
|
||||||
|
"/{job_id}/status", response_model=JobStatusResponse, status_code=HTTPStatus.OK
|
||||||
|
)
|
||||||
async def get_status(
|
async def get_status(
|
||||||
job_id: UUID,
|
job_id: UUID,
|
||||||
svc: Annotated[JobsService, Depends(get_service)],
|
svc: Annotated[JobsService, Depends(get_service)],
|
||||||
|
|
@ -54,7 +53,9 @@ async def get_status(
|
||||||
return st
|
return st
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{job_id}/cancel", response_model=JobStatusResponse, status_code=HTTPStatus.OK)
|
@router.post(
|
||||||
|
"/{job_id}/cancel", response_model=JobStatusResponse, status_code=HTTPStatus.OK
|
||||||
|
)
|
||||||
async def cancel_job(
|
async def cancel_job(
|
||||||
job_id: UUID,
|
job_id: UUID,
|
||||||
svc: Annotated[JobsService, Depends(get_service)],
|
svc: Annotated[JobsService, Depends(get_service)],
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# src/dataloader/api/v1/schemas.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
@ -12,6 +11,7 @@ class TriggerJobRequest(BaseModel):
|
||||||
"""
|
"""
|
||||||
Запрос на постановку задачи в очередь.
|
Запрос на постановку задачи в очередь.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(str_strip_whitespace=True)
|
model_config = ConfigDict(str_strip_whitespace=True)
|
||||||
|
|
||||||
queue: str = Field(...)
|
queue: str = Field(...)
|
||||||
|
|
@ -39,6 +39,7 @@ class TriggerJobResponse(BaseModel):
|
||||||
"""
|
"""
|
||||||
Ответ на постановку задачи.
|
Ответ на постановку задачи.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(str_strip_whitespace=True)
|
model_config = ConfigDict(str_strip_whitespace=True)
|
||||||
|
|
||||||
job_id: UUID = Field(...)
|
job_id: UUID = Field(...)
|
||||||
|
|
@ -49,6 +50,7 @@ class JobStatusResponse(BaseModel):
|
||||||
"""
|
"""
|
||||||
Текущий статус задачи.
|
Текущий статус задачи.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(str_strip_whitespace=True)
|
model_config = ConfigDict(str_strip_whitespace=True)
|
||||||
|
|
||||||
job_id: UUID = Field(...)
|
job_id: UUID = Field(...)
|
||||||
|
|
@ -65,6 +67,7 @@ class CancelJobResponse(BaseModel):
|
||||||
"""
|
"""
|
||||||
Ответ на запрос отмены задачи.
|
Ответ на запрос отмены задачи.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(str_strip_whitespace=True)
|
model_config = ConfigDict(str_strip_whitespace=True)
|
||||||
|
|
||||||
job_id: UUID = Field(...)
|
job_id: UUID = Field(...)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# src/dataloader/api/v1/service.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
@ -13,15 +12,16 @@ from dataloader.api.v1.schemas import (
|
||||||
TriggerJobResponse,
|
TriggerJobResponse,
|
||||||
)
|
)
|
||||||
from dataloader.api.v1.utils import new_job_id
|
from dataloader.api.v1.utils import new_job_id
|
||||||
from dataloader.storage.schemas import CreateJobRequest
|
|
||||||
from dataloader.storage.repositories import QueueRepository
|
|
||||||
from dataloader.logger.logger import get_logger
|
from dataloader.logger.logger import get_logger
|
||||||
|
from dataloader.storage.repositories import QueueRepository
|
||||||
|
from dataloader.storage.schemas import CreateJobRequest
|
||||||
|
|
||||||
|
|
||||||
class JobsService:
|
class JobsService:
|
||||||
"""
|
"""
|
||||||
Бизнес-логика работы с очередью задач.
|
Бизнес-логика работы с очередью задач.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, session: AsyncSession):
|
def __init__(self, session: AsyncSession):
|
||||||
self._s = session
|
self._s = session
|
||||||
self._repo = QueueRepository(self._s)
|
self._repo = QueueRepository(self._s)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# src/dataloader/api/v1/utils.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
# src/dataloader/config.py
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from logging import DEBUG, INFO
|
from logging import DEBUG, INFO
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
|
@ -32,6 +31,7 @@ class BaseAppSettings(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Базовый класс для настроек.
|
Базовый класс для настроек.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
local: bool = Field(validation_alias="LOCAL", default=False)
|
local: bool = Field(validation_alias="LOCAL", default=False)
|
||||||
debug: bool = Field(validation_alias="DEBUG", default=False)
|
debug: bool = Field(validation_alias="DEBUG", default=False)
|
||||||
|
|
||||||
|
|
@ -44,6 +44,7 @@ class AppSettings(BaseAppSettings):
|
||||||
"""
|
"""
|
||||||
Настройки приложения.
|
Настройки приложения.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
app_host: str = Field(validation_alias="APP_HOST", default="0.0.0.0")
|
app_host: str = Field(validation_alias="APP_HOST", default="0.0.0.0")
|
||||||
app_port: int = Field(validation_alias="APP_PORT", default=8081)
|
app_port: int = Field(validation_alias="APP_PORT", default=8081)
|
||||||
kube_net_name: str = Field(validation_alias="PROJECT_NAME", default="AIGATEWAY")
|
kube_net_name: str = Field(validation_alias="PROJECT_NAME", default="AIGATEWAY")
|
||||||
|
|
@ -54,15 +55,28 @@ class LogSettings(BaseAppSettings):
|
||||||
"""
|
"""
|
||||||
Настройки логирования.
|
Настройки логирования.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
private_log_file_path: str = Field(validation_alias="LOG_PATH", default=os.getcwd())
|
private_log_file_path: str = Field(validation_alias="LOG_PATH", default=os.getcwd())
|
||||||
private_log_file_name: str = Field(validation_alias="LOG_FILE_NAME", default="app.log")
|
private_log_file_name: str = Field(
|
||||||
|
validation_alias="LOG_FILE_NAME", default="app.log"
|
||||||
|
)
|
||||||
log_rotation: str = Field(validation_alias="LOG_ROTATION", default="10 MB")
|
log_rotation: str = Field(validation_alias="LOG_ROTATION", default="10 MB")
|
||||||
private_metric_file_path: str = Field(validation_alias="METRIC_PATH", default=os.getcwd())
|
private_metric_file_path: str = Field(
|
||||||
private_metric_file_name: str = Field(validation_alias="METRIC_FILE_NAME", default="app-metric.log")
|
validation_alias="METRIC_PATH", default=os.getcwd()
|
||||||
private_audit_file_path: str = Field(validation_alias="AUDIT_LOG_PATH", default=os.getcwd())
|
)
|
||||||
private_audit_file_name: str = Field(validation_alias="AUDIT_LOG_FILE_NAME", default="events.log")
|
private_metric_file_name: str = Field(
|
||||||
|
validation_alias="METRIC_FILE_NAME", default="app-metric.log"
|
||||||
|
)
|
||||||
|
private_audit_file_path: str = Field(
|
||||||
|
validation_alias="AUDIT_LOG_PATH", default=os.getcwd()
|
||||||
|
)
|
||||||
|
private_audit_file_name: str = Field(
|
||||||
|
validation_alias="AUDIT_LOG_FILE_NAME", default="events.log"
|
||||||
|
)
|
||||||
audit_host_ip: str = Field(validation_alias="HOST_IP", default="127.0.0.1")
|
audit_host_ip: str = Field(validation_alias="HOST_IP", default="127.0.0.1")
|
||||||
audit_host_uid: str = Field(validation_alias="HOST_UID", default="63b6dcee-170b-49bf-a65c-3ec967398ccd")
|
audit_host_uid: str = Field(
|
||||||
|
validation_alias="HOST_UID", default="63b6dcee-170b-49bf-a65c-3ec967398ccd"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_file_abs_path(path_name: str, file_name: str) -> str:
|
def get_file_abs_path(path_name: str, file_name: str) -> str:
|
||||||
|
|
@ -70,15 +84,21 @@ class LogSettings(BaseAppSettings):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def log_file_abs_path(self) -> str:
|
def log_file_abs_path(self) -> str:
|
||||||
return self.get_file_abs_path(self.private_log_file_path, self.private_log_file_name)
|
return self.get_file_abs_path(
|
||||||
|
self.private_log_file_path, self.private_log_file_name
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metric_file_abs_path(self) -> str:
|
def metric_file_abs_path(self) -> str:
|
||||||
return self.get_file_abs_path(self.private_metric_file_path, self.private_metric_file_name)
|
return self.get_file_abs_path(
|
||||||
|
self.private_metric_file_path, self.private_metric_file_name
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audit_file_abs_path(self) -> str:
|
def audit_file_abs_path(self) -> str:
|
||||||
return self.get_file_abs_path(self.private_audit_file_path, self.private_audit_file_name)
|
return self.get_file_abs_path(
|
||||||
|
self.private_audit_file_path, self.private_audit_file_name
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def log_lvl(self) -> int:
|
def log_lvl(self) -> int:
|
||||||
|
|
@ -89,6 +109,7 @@ class PGSettings(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Настройки подключения к Postgres.
|
Настройки подключения к Postgres.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
host: str = Field(validation_alias="PG_HOST", default="localhost")
|
host: str = Field(validation_alias="PG_HOST", default="localhost")
|
||||||
port: int = Field(validation_alias="PG_PORT", default=5432)
|
port: int = Field(validation_alias="PG_PORT", default=5432)
|
||||||
user: str = Field(validation_alias="PG_USER", default="postgres")
|
user: str = Field(validation_alias="PG_USER", default="postgres")
|
||||||
|
|
@ -116,9 +137,12 @@ class WorkerSettings(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Настройки очереди и воркеров.
|
Настройки очереди и воркеров.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
workers_json: str = Field(validation_alias="WORKERS_JSON", default="[]")
|
workers_json: str = Field(validation_alias="WORKERS_JSON", default="[]")
|
||||||
heartbeat_sec: int = Field(validation_alias="DL_HEARTBEAT_SEC", default=10)
|
heartbeat_sec: int = Field(validation_alias="DL_HEARTBEAT_SEC", default=10)
|
||||||
default_lease_ttl_sec: int = Field(validation_alias="DL_DEFAULT_LEASE_TTL_SEC", default=60)
|
default_lease_ttl_sec: int = Field(
|
||||||
|
validation_alias="DL_DEFAULT_LEASE_TTL_SEC", default=60
|
||||||
|
)
|
||||||
reaper_period_sec: int = Field(validation_alias="DL_REAPER_PERIOD_SEC", default=10)
|
reaper_period_sec: int = Field(validation_alias="DL_REAPER_PERIOD_SEC", default=10)
|
||||||
claim_backoff_sec: int = Field(validation_alias="DL_CLAIM_BACKOFF_SEC", default=15)
|
claim_backoff_sec: int = Field(validation_alias="DL_CLAIM_BACKOFF_SEC", default=15)
|
||||||
|
|
||||||
|
|
@ -137,6 +161,7 @@ class CertsSettings(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Настройки SSL сертификатов для локальной разработки.
|
Настройки SSL сертификатов для локальной разработки.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ca_bundle_file: str = Field(validation_alias="CA_BUNDLE_FILE", default="")
|
ca_bundle_file: str = Field(validation_alias="CA_BUNDLE_FILE", default="")
|
||||||
cert_file: str = Field(validation_alias="CERT_FILE", default="")
|
cert_file: str = Field(validation_alias="CERT_FILE", default="")
|
||||||
key_file: str = Field(validation_alias="KEY_FILE", default="")
|
key_file: str = Field(validation_alias="KEY_FILE", default="")
|
||||||
|
|
@ -146,21 +171,24 @@ class SuperTeneraSettings(BaseAppSettings):
|
||||||
"""
|
"""
|
||||||
Настройки интеграции с SuperTenera.
|
Настройки интеграции с SuperTenera.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
host: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
host: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
||||||
validation_alias="SUPERTENERA_HOST",
|
validation_alias="SUPERTENERA_HOST",
|
||||||
default="ci03801737-ift-tenera-giga.delta.sbrf.ru/atlant360bc/"
|
default="ci03801737-ift-tenera-giga.delta.sbrf.ru/atlant360bc/",
|
||||||
)
|
)
|
||||||
port: str = Field(validation_alias="SUPERTENERA_PORT", default="443")
|
port: str = Field(validation_alias="SUPERTENERA_PORT", default="443")
|
||||||
quotes_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
quotes_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
||||||
validation_alias="SUPERTENERA_QUOTES_ENDPOINT",
|
validation_alias="SUPERTENERA_QUOTES_ENDPOINT",
|
||||||
default="/get_gigaparser_quotes/"
|
default="/get_gigaparser_quotes/",
|
||||||
)
|
)
|
||||||
timeout: int = Field(validation_alias="SUPERTENERA_TIMEOUT", default=20)
|
timeout: int = Field(validation_alias="SUPERTENERA_TIMEOUT", default=20)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base_url(self) -> str:
|
def base_url(self) -> str:
|
||||||
"""Возвращает абсолютный URL для SuperTenera."""
|
"""Возвращает абсолютный URL для SuperTenera."""
|
||||||
domain, raw_path = self.host.split("/", 1) if "/" in self.host else (self.host, "")
|
domain, raw_path = (
|
||||||
|
self.host.split("/", 1) if "/" in self.host else (self.host, "")
|
||||||
|
)
|
||||||
return build_url(self.protocol, domain, self.port, raw_path)
|
return build_url(self.protocol, domain, self.port, raw_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -168,22 +196,21 @@ class Gmap2BriefSettings(BaseAppSettings):
|
||||||
"""
|
"""
|
||||||
Настройки интеграции с Gmap2Brief (OPU API).
|
Настройки интеграции с Gmap2Brief (OPU API).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
host: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
host: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
||||||
validation_alias="GMAP2BRIEF_HOST",
|
validation_alias="GMAP2BRIEF_HOST",
|
||||||
default="ci02533826-tib-brief.apps.ift-terra000024-edm.ocp.delta.sbrf.ru"
|
default="ci02533826-tib-brief.apps.ift-terra000024-edm.ocp.delta.sbrf.ru",
|
||||||
)
|
)
|
||||||
port: str = Field(validation_alias="GMAP2BRIEF_PORT", default="443")
|
port: str = Field(validation_alias="GMAP2BRIEF_PORT", default="443")
|
||||||
start_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
start_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
||||||
validation_alias="GMAP2BRIEF_START_ENDPOINT",
|
validation_alias="GMAP2BRIEF_START_ENDPOINT", default="/export/opu/start"
|
||||||
default="/export/opu/start"
|
|
||||||
)
|
)
|
||||||
status_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
status_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
||||||
validation_alias="GMAP2BRIEF_STATUS_ENDPOINT",
|
validation_alias="GMAP2BRIEF_STATUS_ENDPOINT", default="/export/{job_id}/status"
|
||||||
default="/export/{job_id}/status"
|
|
||||||
)
|
)
|
||||||
download_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
download_endpoint: Annotated[str, BeforeValidator(strip_slashes)] = Field(
|
||||||
validation_alias="GMAP2BRIEF_DOWNLOAD_ENDPOINT",
|
validation_alias="GMAP2BRIEF_DOWNLOAD_ENDPOINT",
|
||||||
default="/export/{job_id}/download"
|
default="/export/{job_id}/download",
|
||||||
)
|
)
|
||||||
poll_interval: int = Field(validation_alias="GMAP2BRIEF_POLL_INTERVAL", default=2)
|
poll_interval: int = Field(validation_alias="GMAP2BRIEF_POLL_INTERVAL", default=2)
|
||||||
timeout: int = Field(validation_alias="GMAP2BRIEF_TIMEOUT", default=3600)
|
timeout: int = Field(validation_alias="GMAP2BRIEF_TIMEOUT", default=3600)
|
||||||
|
|
@ -198,6 +225,7 @@ class Secrets:
|
||||||
"""
|
"""
|
||||||
Агрегатор настроек приложения.
|
Агрегатор настроек приложения.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
app: AppSettings = AppSettings()
|
app: AppSettings = AppSettings()
|
||||||
log: LogSettings = LogSettings()
|
log: LogSettings = LogSettings()
|
||||||
pg: PGSettings = PGSettings()
|
pg: PGSettings = PGSettings()
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
# src/dataloader/context.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
from typing import AsyncGenerator
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||||
|
|
@ -15,6 +14,7 @@ class AppContext:
|
||||||
"""
|
"""
|
||||||
Контекст приложения, хранящий глобальные зависимости (Singleton pattern).
|
Контекст приложения, хранящий глобальные зависимости (Singleton pattern).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._engine: AsyncEngine | None = None
|
self._engine: AsyncEngine | None = None
|
||||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = None
|
self._sessionmaker: async_sessionmaker[AsyncSession] | None = None
|
||||||
|
|
@ -75,6 +75,7 @@ class AppContext:
|
||||||
Экземпляр Logger
|
Экземпляр Logger
|
||||||
"""
|
"""
|
||||||
from .logger.logger import get_logger as get_app_logger
|
from .logger.logger import get_logger as get_app_logger
|
||||||
|
|
||||||
return get_app_logger(name)
|
return get_app_logger(name)
|
||||||
|
|
||||||
def get_context_vars_container(self) -> ContextVarsContainer:
|
def get_context_vars_container(self) -> ContextVarsContainer:
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1 @@
|
||||||
"""Исключения уровня приложения."""
|
"""Исключения уровня приложения."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,10 @@ from typing import TYPE_CHECKING
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from dataloader.config import APP_CONFIG
|
from dataloader.config import APP_CONFIG
|
||||||
from dataloader.interfaces.gmap2_brief.schemas import ExportJobStatus, StartExportResponse
|
from dataloader.interfaces.gmap2_brief.schemas import (
|
||||||
|
ExportJobStatus,
|
||||||
|
StartExportResponse,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
@ -37,10 +40,11 @@ class Gmap2BriefInterface:
|
||||||
|
|
||||||
self._ssl_context = None
|
self._ssl_context = None
|
||||||
if APP_CONFIG.app.local and APP_CONFIG.certs.cert_file:
|
if APP_CONFIG.app.local and APP_CONFIG.certs.cert_file:
|
||||||
self._ssl_context = ssl.create_default_context(cafile=APP_CONFIG.certs.ca_bundle_file)
|
self._ssl_context = ssl.create_default_context(
|
||||||
|
cafile=APP_CONFIG.certs.ca_bundle_file
|
||||||
|
)
|
||||||
self._ssl_context.load_cert_chain(
|
self._ssl_context.load_cert_chain(
|
||||||
certfile=APP_CONFIG.certs.cert_file,
|
certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file
|
||||||
keyfile=APP_CONFIG.certs.key_file
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def start_export(self) -> str:
|
async def start_export(self) -> str:
|
||||||
|
|
@ -56,9 +60,13 @@ class Gmap2BriefInterface:
|
||||||
url = self.base_url + APP_CONFIG.gmap2brief.start_endpoint
|
url = self.base_url + APP_CONFIG.gmap2brief.start_endpoint
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None,
|
cert=(
|
||||||
|
(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file)
|
||||||
|
if APP_CONFIG.app.local
|
||||||
|
else None
|
||||||
|
),
|
||||||
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
|
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
|
||||||
timeout=self.timeout
|
timeout=self.timeout,
|
||||||
) as client:
|
) as client:
|
||||||
try:
|
try:
|
||||||
self.logger.info(f"Starting OPU export: POST {url}")
|
self.logger.info(f"Starting OPU export: POST {url}")
|
||||||
|
|
@ -87,12 +95,18 @@ class Gmap2BriefInterface:
|
||||||
Raises:
|
Raises:
|
||||||
Gmap2BriefConnectionError: При ошибке запроса
|
Gmap2BriefConnectionError: При ошибке запроса
|
||||||
"""
|
"""
|
||||||
url = self.base_url + APP_CONFIG.gmap2brief.status_endpoint.format(job_id=job_id)
|
url = self.base_url + APP_CONFIG.gmap2brief.status_endpoint.format(
|
||||||
|
job_id=job_id
|
||||||
|
)
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None,
|
cert=(
|
||||||
|
(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file)
|
||||||
|
if APP_CONFIG.app.local
|
||||||
|
else None
|
||||||
|
),
|
||||||
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
|
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
|
||||||
timeout=self.timeout
|
timeout=self.timeout,
|
||||||
) as client:
|
) as client:
|
||||||
try:
|
try:
|
||||||
response = await client.get(url)
|
response = await client.get(url)
|
||||||
|
|
@ -106,7 +120,9 @@ class Gmap2BriefInterface:
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
raise Gmap2BriefConnectionError(f"Request error: {e}") from e
|
raise Gmap2BriefConnectionError(f"Request error: {e}") from e
|
||||||
|
|
||||||
async def wait_for_completion(self, job_id: str, max_wait: int | None = None) -> ExportJobStatus:
|
async def wait_for_completion(
|
||||||
|
self, job_id: str, max_wait: int | None = None
|
||||||
|
) -> ExportJobStatus:
|
||||||
"""
|
"""
|
||||||
Ждет завершения задачи экспорта с периодическим polling.
|
Ждет завершения задачи экспорта с периодическим polling.
|
||||||
|
|
||||||
|
|
@ -127,14 +143,18 @@ class Gmap2BriefInterface:
|
||||||
status = await self.get_status(job_id)
|
status = await self.get_status(job_id)
|
||||||
|
|
||||||
if status.status == "completed":
|
if status.status == "completed":
|
||||||
self.logger.info(f"Job {job_id} completed, total_rows={status.total_rows}")
|
self.logger.info(
|
||||||
|
f"Job {job_id} completed, total_rows={status.total_rows}"
|
||||||
|
)
|
||||||
return status
|
return status
|
||||||
elif status.status == "failed":
|
elif status.status == "failed":
|
||||||
raise Gmap2BriefConnectionError(f"Job {job_id} failed: {status.error}")
|
raise Gmap2BriefConnectionError(f"Job {job_id} failed: {status.error}")
|
||||||
|
|
||||||
elapsed = asyncio.get_event_loop().time() - start_time
|
elapsed = asyncio.get_event_loop().time() - start_time
|
||||||
if max_wait and elapsed > max_wait:
|
if max_wait and elapsed > max_wait:
|
||||||
raise Gmap2BriefConnectionError(f"Job {job_id} timeout after {elapsed:.1f}s")
|
raise Gmap2BriefConnectionError(
|
||||||
|
f"Job {job_id} timeout after {elapsed:.1f}s"
|
||||||
|
)
|
||||||
|
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Job {job_id} status={status.status}, rows={status.total_rows}, elapsed={elapsed:.1f}s"
|
f"Job {job_id} status={status.status}, rows={status.total_rows}, elapsed={elapsed:.1f}s"
|
||||||
|
|
@ -155,12 +175,18 @@ class Gmap2BriefInterface:
|
||||||
Raises:
|
Raises:
|
||||||
Gmap2BriefConnectionError: При ошибке скачивания
|
Gmap2BriefConnectionError: При ошибке скачивания
|
||||||
"""
|
"""
|
||||||
url = self.base_url + APP_CONFIG.gmap2brief.download_endpoint.format(job_id=job_id)
|
url = self.base_url + APP_CONFIG.gmap2brief.download_endpoint.format(
|
||||||
|
job_id=job_id
|
||||||
|
)
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
cert=(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file) if APP_CONFIG.app.local else None,
|
cert=(
|
||||||
|
(APP_CONFIG.certs.cert_file, APP_CONFIG.certs.key_file)
|
||||||
|
if APP_CONFIG.app.local
|
||||||
|
else None
|
||||||
|
),
|
||||||
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
|
verify=APP_CONFIG.certs.ca_bundle_file if APP_CONFIG.app.local else True,
|
||||||
timeout=self.timeout
|
timeout=self.timeout,
|
||||||
) as client:
|
) as client:
|
||||||
try:
|
try:
|
||||||
self.logger.info(f"Downloading export: GET {url}")
|
self.logger.info(f"Downloading export: GET {url}")
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,11 @@ class ExportJobStatus(BaseModel):
|
||||||
"""Статус задачи экспорта."""
|
"""Статус задачи экспорта."""
|
||||||
|
|
||||||
job_id: str = Field(..., description="Идентификатор задачи")
|
job_id: str = Field(..., description="Идентификатор задачи")
|
||||||
status: Literal["pending", "running", "completed", "failed"] = Field(..., description="Статус задачи")
|
status: Literal["pending", "running", "completed", "failed"] = Field(
|
||||||
|
..., description="Статус задачи"
|
||||||
|
)
|
||||||
total_rows: int = Field(default=0, description="Количество обработанных строк")
|
total_rows: int = Field(default=0, description="Количество обработанных строк")
|
||||||
error: str | None = Field(default=None, description="Текст ошибки (если есть)")
|
error: str | None = Field(default=None, description="Текст ошибки (если есть)")
|
||||||
temp_file_path: str | None = Field(default=None, description="Путь к временному файлу (для completed)")
|
temp_file_path: str | None = Field(
|
||||||
|
default=None, description="Путь к временному файлу (для completed)"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -47,8 +47,12 @@ class SuperTeneraInterface:
|
||||||
|
|
||||||
self._ssl_context = None
|
self._ssl_context = None
|
||||||
if APP_CONFIG.app.local:
|
if APP_CONFIG.app.local:
|
||||||
self._ssl_context = ssl.create_default_context(cafile=APP_CONFIG.certs.ca_bundle_file)
|
self._ssl_context = ssl.create_default_context(
|
||||||
self._ssl_context.load_cert_chain(certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file)
|
cafile=APP_CONFIG.certs.ca_bundle_file
|
||||||
|
)
|
||||||
|
self._ssl_context.load_cert_chain(
|
||||||
|
certfile=APP_CONFIG.certs.cert_file, keyfile=APP_CONFIG.certs.key_file
|
||||||
|
)
|
||||||
|
|
||||||
def form_base_headers(self) -> dict:
|
def form_base_headers(self) -> dict:
|
||||||
"""Формирует базовые заголовки для запроса."""
|
"""Формирует базовые заголовки для запроса."""
|
||||||
|
|
@ -57,7 +61,11 @@ class SuperTeneraInterface:
|
||||||
"request-time": str(datetime.now(tz=self.timezone).isoformat()),
|
"request-time": str(datetime.now(tz=self.timezone).isoformat()),
|
||||||
"system-id": APP_CONFIG.app.kube_net_name,
|
"system-id": APP_CONFIG.app.kube_net_name,
|
||||||
}
|
}
|
||||||
return {metakey: metavalue for metakey, metavalue in metadata_pairs.items() if metavalue}
|
return {
|
||||||
|
metakey: metavalue
|
||||||
|
for metakey, metavalue in metadata_pairs.items()
|
||||||
|
if metavalue
|
||||||
|
}
|
||||||
|
|
||||||
async def __aenter__(self) -> Self:
|
async def __aenter__(self) -> Self:
|
||||||
"""Async context manager enter."""
|
"""Async context manager enter."""
|
||||||
|
|
@ -86,7 +94,9 @@ class SuperTeneraInterface:
|
||||||
|
|
||||||
async with self._session.get(url, **kwargs) as response:
|
async with self._session.get(url, **kwargs) as response:
|
||||||
if APP_CONFIG.app.debug:
|
if APP_CONFIG.app.debug:
|
||||||
self.logger.debug(f"Response: {(await response.text(errors='ignore'))[:100]}")
|
self.logger.debug(
|
||||||
|
f"Response: {(await response.text(errors='ignore'))[:100]}"
|
||||||
|
)
|
||||||
return await response.json(encoding=encoding, content_type=content_type)
|
return await response.json(encoding=encoding, content_type=content_type)
|
||||||
|
|
||||||
async def get_quotes_data(self) -> MainData:
|
async def get_quotes_data(self) -> MainData:
|
||||||
|
|
@ -103,7 +113,9 @@ class SuperTeneraInterface:
|
||||||
kwargs["ssl"] = self._ssl_context
|
kwargs["ssl"] = self._ssl_context
|
||||||
try:
|
try:
|
||||||
async with self._session.get(
|
async with self._session.get(
|
||||||
APP_CONFIG.supertenera.quotes_endpoint, timeout=APP_CONFIG.supertenera.timeout, **kwargs
|
APP_CONFIG.supertenera.quotes_endpoint,
|
||||||
|
timeout=APP_CONFIG.supertenera.timeout,
|
||||||
|
**kwargs,
|
||||||
) as resp:
|
) as resp:
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return True
|
return True
|
||||||
|
|
@ -112,7 +124,9 @@ class SuperTeneraInterface:
|
||||||
f"Ошибка подключения к SuperTenera API при проверке системы - {e.status}."
|
f"Ошибка подключения к SuperTenera API при проверке системы - {e.status}."
|
||||||
) from e
|
) from e
|
||||||
except TimeoutError as e:
|
except TimeoutError as e:
|
||||||
raise SuperTeneraConnectionError("Ошибка Timeout подключения к SuperTenera API при проверке системы.") from e
|
raise SuperTeneraConnectionError(
|
||||||
|
"Ошибка Timeout подключения к SuperTenera API при проверке системы."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
def get_async_tenera_interface() -> SuperTeneraInterface:
|
def get_async_tenera_interface() -> SuperTeneraInterface:
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,9 @@ class InvestingCandlestick(TeneraBaseModel):
|
||||||
value: str = Field(alias="V")
|
value: str = Field(alias="V")
|
||||||
|
|
||||||
|
|
||||||
class InvestingTimePoint(RootModel[EmptyTimePoint | InvestingNumeric | InvestingCandlestick]):
|
class InvestingTimePoint(
|
||||||
|
RootModel[EmptyTimePoint | InvestingNumeric | InvestingCandlestick]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Union-модель точки времени для источника Investing.com.
|
Union-модель точки времени для источника Investing.com.
|
||||||
|
|
||||||
|
|
@ -340,5 +342,9 @@ class MainData(TeneraBaseModel):
|
||||||
:return: Отфильтрованный объект
|
:return: Отфильтрованный объект
|
||||||
"""
|
"""
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
return {key: value for key, value in v.items() if value is not None and not str(key).isdigit()}
|
return {
|
||||||
|
key: value
|
||||||
|
for key, value in v.items()
|
||||||
|
if value is not None and not str(key).isdigit()
|
||||||
|
}
|
||||||
return v
|
return v
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# src/dataloader/logger/__init__.py
|
from .logger import get_logger, setup_logging
|
||||||
from .logger import setup_logging, get_logger
|
|
||||||
|
|
||||||
__all__ = ["setup_logging", "get_logger"]
|
__all__ = ["setup_logging", "get_logger"]
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
# Управление контекстом запросов для логирования
|
|
||||||
import uuid
|
import uuid
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Final
|
from typing import Final
|
||||||
|
|
||||||
|
|
||||||
REQUEST_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("request_id", default="")
|
REQUEST_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("request_id", default="")
|
||||||
DEVICE_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("device_id", default="")
|
DEVICE_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("device_id", default="")
|
||||||
SESSION_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("session_id", default="")
|
SESSION_ID_CTX_VAR: Final[ContextVar[str]] = ContextVar("session_id", default="")
|
||||||
|
|
@ -66,7 +64,13 @@ class ContextVarsContainer:
|
||||||
def gw_session_id(self, value: str) -> None:
|
def gw_session_id(self, value: str) -> None:
|
||||||
GW_SESSION_ID_CTX_VAR.set(value)
|
GW_SESSION_ID_CTX_VAR.set(value)
|
||||||
|
|
||||||
def set_context_vars(self, request_id: str = "", request_time: str = "", system_id: str = "", gw_session_id: str = "") -> None:
|
def set_context_vars(
|
||||||
|
self,
|
||||||
|
request_id: str = "",
|
||||||
|
request_time: str = "",
|
||||||
|
system_id: str = "",
|
||||||
|
gw_session_id: str = "",
|
||||||
|
) -> None:
|
||||||
if request_id:
|
if request_id:
|
||||||
self.set_request_id(request_id)
|
self.set_request_id(request_id)
|
||||||
if request_time:
|
if request_time:
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,12 @@
|
||||||
# src/dataloader/logger/logger.py
|
|
||||||
import sys
|
|
||||||
import typing
|
|
||||||
from datetime import tzinfo
|
|
||||||
from logging import Logger
|
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .context_vars import ContextVarsContainer
|
|
||||||
from dataloader.config import APP_CONFIG
|
from dataloader.config import APP_CONFIG
|
||||||
|
|
||||||
|
|
||||||
# Определяем фильтры для разных типов логов
|
|
||||||
def metric_only_filter(record: dict) -> bool:
|
def metric_only_filter(record: dict) -> bool:
|
||||||
return "metric" in record["extra"]
|
return "metric" in record["extra"]
|
||||||
|
|
||||||
|
|
@ -26,13 +21,12 @@ def regular_log_filter(record: dict) -> bool:
|
||||||
|
|
||||||
class InterceptHandler(logging.Handler):
|
class InterceptHandler(logging.Handler):
|
||||||
def emit(self, record: logging.LogRecord) -> None:
|
def emit(self, record: logging.LogRecord) -> None:
|
||||||
# Get corresponding Loguru level if it exists
|
|
||||||
try:
|
try:
|
||||||
level = logger.level(record.levelname).name
|
level = logger.level(record.levelname).name
|
||||||
except ValueError:
|
except ValueError:
|
||||||
level = record.levelno
|
level = record.levelno
|
||||||
|
|
||||||
# Find caller from where originated the logged message
|
|
||||||
frame, depth = logging.currentframe(), 2
|
frame, depth = logging.currentframe(), 2
|
||||||
while frame.f_code.co_filename == logging.__file__:
|
while frame.f_code.co_filename == logging.__file__:
|
||||||
frame = frame.f_back
|
frame = frame.f_back
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
# Модели логов, метрик, событий аудита
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
# Функции + маскирование args
|
|
||||||
|
|
|
||||||
|
|
@ -1,34 +1,28 @@
|
||||||
# Конфигурация логирования uvicorn
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
class InterceptHandler(logging.Handler):
|
class InterceptHandler(logging.Handler):
|
||||||
def emit(self, record: logging.LogRecord) -> None:
|
def emit(self, record: logging.LogRecord) -> None:
|
||||||
# Get corresponding Loguru level if it exists
|
|
||||||
try:
|
try:
|
||||||
level = logger.level(record.levelname).name
|
level = logger.level(record.levelname).name
|
||||||
except ValueError:
|
except ValueError:
|
||||||
level = record.levelno
|
level = record.levelno
|
||||||
|
|
||||||
|
|
||||||
# Find caller from where originated the logged message
|
|
||||||
frame, depth = logging.currentframe(), 2
|
frame, depth = logging.currentframe(), 2
|
||||||
while frame.f_code.co_filename == logging.__file__:
|
while frame.f_code.co_filename == logging.__file__:
|
||||||
frame = frame.f_back
|
frame = frame.f_back
|
||||||
depth += 1
|
depth += 1
|
||||||
|
|
||||||
|
|
||||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||||
level, record.getMessage()
|
level, record.getMessage()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def setup_uvicorn_logging() -> None:
|
def setup_uvicorn_logging() -> None:
|
||||||
# Set all uvicorn loggers to use InterceptHandler
|
|
||||||
for logger_name in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
|
for logger_name in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
|
||||||
log = logging.getLogger(logger_name)
|
log = logging.getLogger(logger_name)
|
||||||
log.handlers = [InterceptHandler()]
|
log.handlers = [InterceptHandler()]
|
||||||
|
|
@ -36,7 +30,6 @@ def setup_uvicorn_logging() -> None:
|
||||||
log.propagate = False
|
log.propagate = False
|
||||||
|
|
||||||
|
|
||||||
# uvicorn logging config
|
|
||||||
LOGGING_CONFIG = {
|
LOGGING_CONFIG = {
|
||||||
"version": 1,
|
"version": 1,
|
||||||
"disable_existing_loggers": False,
|
"disable_existing_loggers": False,
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1 @@
|
||||||
"""Модуль для работы с хранилищем данных."""
|
"""Модуль для работы с хранилищем данных."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,11 @@
|
||||||
# src/dataloader/storage/engine.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncEngine,
|
||||||
|
AsyncSession,
|
||||||
|
async_sessionmaker,
|
||||||
|
create_async_engine,
|
||||||
|
)
|
||||||
|
|
||||||
from dataloader.config import APP_CONFIG
|
from dataloader.config import APP_CONFIG
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
# src/dataloader/storage/models/__init__.py
|
|
||||||
"""
|
"""
|
||||||
ORM модели для работы с базой данных.
|
ORM модели для работы с базой данных.
|
||||||
Организованы по доменам для масштабируемости.
|
Организованы по доменам для масштабируемости.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .base import Base
|
from .base import Base
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# src/dataloader/storage/models/base.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
@ -9,4 +8,5 @@ class Base(DeclarativeBase):
|
||||||
Базовый класс для всех ORM моделей приложения.
|
Базовый класс для всех ORM моделей приложения.
|
||||||
Используется SQLAlchemy 2.0+ declarative style.
|
Используется SQLAlchemy 2.0+ declarative style.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
|
|
||||||
from sqlalchemy import BigInteger, Date, Integer, Numeric, String, TIMESTAMP, text
|
from sqlalchemy import TIMESTAMP, BigInteger, Date, Integer, Numeric, String, text
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from dataloader.storage.models.base import Base
|
from dataloader.storage.models.base import Base
|
||||||
|
|
@ -16,29 +16,47 @@ class BriefDigitalCertificateOpu(Base):
|
||||||
__tablename__ = "brief_digital_certificate_opu"
|
__tablename__ = "brief_digital_certificate_opu"
|
||||||
__table_args__ = ({"schema": "opu"},)
|
__table_args__ = ({"schema": "opu"},)
|
||||||
|
|
||||||
object_id: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
|
object_id: Mapped[str] = mapped_column(
|
||||||
|
String, primary_key=True, server_default=text("'-'")
|
||||||
|
)
|
||||||
object_nm: Mapped[str | None] = mapped_column(String)
|
object_nm: Mapped[str | None] = mapped_column(String)
|
||||||
desk_nm: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
|
desk_nm: Mapped[str] = mapped_column(
|
||||||
actdate: Mapped[date] = mapped_column(Date, primary_key=True, server_default=text("CURRENT_DATE"))
|
String, primary_key=True, server_default=text("'-'")
|
||||||
layer_cd: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
|
)
|
||||||
|
actdate: Mapped[date] = mapped_column(
|
||||||
|
Date, primary_key=True, server_default=text("CURRENT_DATE")
|
||||||
|
)
|
||||||
|
layer_cd: Mapped[str] = mapped_column(
|
||||||
|
String, primary_key=True, server_default=text("'-'")
|
||||||
|
)
|
||||||
layer_nm: Mapped[str | None] = mapped_column(String)
|
layer_nm: Mapped[str | None] = mapped_column(String)
|
||||||
opu_cd: Mapped[str] = mapped_column(String, primary_key=True)
|
opu_cd: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
opu_nm_sh: Mapped[str | None] = mapped_column(String)
|
opu_nm_sh: Mapped[str | None] = mapped_column(String)
|
||||||
opu_nm: Mapped[str | None] = mapped_column(String)
|
opu_nm: Mapped[str | None] = mapped_column(String)
|
||||||
opu_lvl: Mapped[int] = mapped_column(Integer, primary_key=True, server_default=text("'-1'"))
|
opu_lvl: Mapped[int] = mapped_column(
|
||||||
opu_prnt_cd: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
|
Integer, primary_key=True, server_default=text("'-1'")
|
||||||
|
)
|
||||||
|
opu_prnt_cd: Mapped[str] = mapped_column(
|
||||||
|
String, primary_key=True, server_default=text("'-'")
|
||||||
|
)
|
||||||
opu_prnt_nm_sh: Mapped[str | None] = mapped_column(String)
|
opu_prnt_nm_sh: Mapped[str | None] = mapped_column(String)
|
||||||
opu_prnt_nm: Mapped[str | None] = mapped_column(String)
|
opu_prnt_nm: Mapped[str | None] = mapped_column(String)
|
||||||
sum_amountrub_p_usd: Mapped[float | None] = mapped_column(Numeric)
|
sum_amountrub_p_usd: Mapped[float | None] = mapped_column(Numeric)
|
||||||
wf_load_id: Mapped[int] = mapped_column(BigInteger, nullable=False, server_default=text("'-1'"))
|
wf_load_id: Mapped[int] = mapped_column(
|
||||||
|
BigInteger, nullable=False, server_default=text("'-1'")
|
||||||
|
)
|
||||||
wf_load_dttm: Mapped[datetime] = mapped_column(
|
wf_load_dttm: Mapped[datetime] = mapped_column(
|
||||||
TIMESTAMP(timezone=False),
|
TIMESTAMP(timezone=False),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default=text("CURRENT_TIMESTAMP"),
|
server_default=text("CURRENT_TIMESTAMP"),
|
||||||
)
|
)
|
||||||
wf_row_id: Mapped[int] = mapped_column(BigInteger, nullable=False, server_default=text("'-1'"))
|
wf_row_id: Mapped[int] = mapped_column(
|
||||||
|
BigInteger, nullable=False, server_default=text("'-1'")
|
||||||
|
)
|
||||||
object_tp: Mapped[str | None] = mapped_column(String)
|
object_tp: Mapped[str | None] = mapped_column(String)
|
||||||
object_unit: Mapped[str] = mapped_column(String, primary_key=True, server_default=text("'-'"))
|
object_unit: Mapped[str] = mapped_column(
|
||||||
|
String, primary_key=True, server_default=text("'-'")
|
||||||
|
)
|
||||||
measure: Mapped[str | None] = mapped_column(String)
|
measure: Mapped[str | None] = mapped_column(String)
|
||||||
product_nm: Mapped[str | None] = mapped_column(String)
|
product_nm: Mapped[str | None] = mapped_column(String)
|
||||||
product_prnt_nm: Mapped[str | None] = mapped_column(String)
|
product_prnt_nm: Mapped[str | None] = mapped_column(String)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# src/dataloader/storage/models/queue.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
@ -10,7 +9,6 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from .base import Base
|
from .base import Base
|
||||||
|
|
||||||
|
|
||||||
dl_status_enum = ENUM(
|
dl_status_enum = ENUM(
|
||||||
"queued",
|
"queued",
|
||||||
"running",
|
"running",
|
||||||
|
|
@ -29,6 +27,7 @@ class DLJob(Base):
|
||||||
Модель таблицы очереди задач dl_jobs.
|
Модель таблицы очереди задач dl_jobs.
|
||||||
Использует логическое имя схемы 'queue' для поддержки schema_translate_map.
|
Использует логическое имя схемы 'queue' для поддержки schema_translate_map.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "dl_jobs"
|
__tablename__ = "dl_jobs"
|
||||||
__table_args__ = {"schema": "queue"}
|
__table_args__ = {"schema": "queue"}
|
||||||
|
|
||||||
|
|
@ -40,15 +39,23 @@ class DLJob(Base):
|
||||||
lock_key: Mapped[str] = mapped_column(Text, nullable=False)
|
lock_key: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
partition_key: Mapped[str] = mapped_column(Text, default="", nullable=False)
|
partition_key: Mapped[str] = mapped_column(Text, default="", nullable=False)
|
||||||
priority: Mapped[int] = mapped_column(nullable=False, default=100)
|
priority: Mapped[int] = mapped_column(nullable=False, default=100)
|
||||||
available_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
available_at: Mapped[datetime] = mapped_column(
|
||||||
status: Mapped[str] = mapped_column(dl_status_enum, nullable=False, default="queued")
|
DateTime(timezone=True), nullable=False
|
||||||
|
)
|
||||||
|
status: Mapped[str] = mapped_column(
|
||||||
|
dl_status_enum, nullable=False, default="queued"
|
||||||
|
)
|
||||||
attempt: Mapped[int] = mapped_column(nullable=False, default=0)
|
attempt: Mapped[int] = mapped_column(nullable=False, default=0)
|
||||||
max_attempts: Mapped[int] = mapped_column(nullable=False, default=5)
|
max_attempts: Mapped[int] = mapped_column(nullable=False, default=5)
|
||||||
lease_ttl_sec: Mapped[int] = mapped_column(nullable=False, default=60)
|
lease_ttl_sec: Mapped[int] = mapped_column(nullable=False, default=60)
|
||||||
lease_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
|
lease_expires_at: Mapped[Optional[datetime]] = mapped_column(
|
||||||
|
DateTime(timezone=True)
|
||||||
|
)
|
||||||
heartbeat_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
|
heartbeat_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
|
||||||
cancel_requested: Mapped[bool] = mapped_column(nullable=False, default=False)
|
cancel_requested: Mapped[bool] = mapped_column(nullable=False, default=False)
|
||||||
progress: Mapped[dict[str, Any]] = mapped_column(JSONB, default=dict, nullable=False)
|
progress: Mapped[dict[str, Any]] = mapped_column(
|
||||||
|
JSONB, default=dict, nullable=False
|
||||||
|
)
|
||||||
error: Mapped[Optional[str]] = mapped_column(Text)
|
error: Mapped[Optional[str]] = mapped_column(Text)
|
||||||
producer: Mapped[Optional[str]] = mapped_column(Text)
|
producer: Mapped[Optional[str]] = mapped_column(Text)
|
||||||
consumer_group: Mapped[Optional[str]] = mapped_column(Text)
|
consumer_group: Mapped[Optional[str]] = mapped_column(Text)
|
||||||
|
|
@ -62,10 +69,13 @@ class DLJobEvent(Base):
|
||||||
Модель таблицы журнала событий dl_job_events.
|
Модель таблицы журнала событий dl_job_events.
|
||||||
Использует логическое имя схемы 'queue' для поддержки schema_translate_map.
|
Использует логическое имя схемы 'queue' для поддержки schema_translate_map.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "dl_job_events"
|
__tablename__ = "dl_job_events"
|
||||||
__table_args__ = {"schema": "queue"}
|
__table_args__ = {"schema": "queue"}
|
||||||
|
|
||||||
event_id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
event_id: Mapped[int] = mapped_column(
|
||||||
|
BigInteger, primary_key=True, autoincrement=True
|
||||||
|
)
|
||||||
job_id: Mapped[str] = mapped_column(UUID(as_uuid=False), nullable=False)
|
job_id: Mapped[str] = mapped_column(UUID(as_uuid=False), nullable=False)
|
||||||
queue: Mapped[str] = mapped_column(Text, nullable=False)
|
queue: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
load_dttm: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
load_dttm: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,15 @@ from __future__ import annotations
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import JSON, TIMESTAMP, BigInteger, ForeignKey, String, UniqueConstraint, func
|
from sqlalchemy import (
|
||||||
|
JSON,
|
||||||
|
TIMESTAMP,
|
||||||
|
BigInteger,
|
||||||
|
ForeignKey,
|
||||||
|
String,
|
||||||
|
UniqueConstraint,
|
||||||
|
func,
|
||||||
|
)
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from dataloader.storage.models.base import Base
|
from dataloader.storage.models.base import Base
|
||||||
|
|
@ -30,7 +38,9 @@ class Quote(Base):
|
||||||
srce: Mapped[str | None] = mapped_column(String)
|
srce: Mapped[str | None] = mapped_column(String)
|
||||||
ticker: Mapped[str | None] = mapped_column(String)
|
ticker: Mapped[str | None] = mapped_column(String)
|
||||||
quote_sect_id: Mapped[int] = mapped_column(
|
quote_sect_id: Mapped[int] = mapped_column(
|
||||||
ForeignKey("quotes.quotes_sect.quote_sect_id", ondelete="CASCADE", onupdate="CASCADE"),
|
ForeignKey(
|
||||||
|
"quotes.quotes_sect.quote_sect_id", ondelete="CASCADE", onupdate="CASCADE"
|
||||||
|
),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
last_update_dttm: Mapped[datetime | None] = mapped_column(TIMESTAMP(timezone=True))
|
last_update_dttm: Mapped[datetime | None] = mapped_column(TIMESTAMP(timezone=True))
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,17 @@ from __future__ import annotations
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import TIMESTAMP, BigInteger, Boolean, DateTime, Float, ForeignKey, String, UniqueConstraint, func
|
from sqlalchemy import (
|
||||||
|
TIMESTAMP,
|
||||||
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
|
DateTime,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
String,
|
||||||
|
UniqueConstraint,
|
||||||
|
func,
|
||||||
|
)
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from dataloader.storage.models.base import Base
|
from dataloader.storage.models.base import Base
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,23 @@
|
||||||
# src/dataloader/storage/notify_listener.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import asyncpg
|
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
|
||||||
class PGNotifyListener:
|
class PGNotifyListener:
|
||||||
"""
|
"""
|
||||||
Прослушиватель PostgreSQL NOTIFY для канала 'dl_jobs'.
|
Прослушиватель PostgreSQL NOTIFY для канала 'dl_jobs'.
|
||||||
"""
|
"""
|
||||||
def __init__(self, dsn: str, queue: str, callback: Callable[[], None], stop_event: asyncio.Event):
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dsn: str,
|
||||||
|
queue: str,
|
||||||
|
callback: Callable[[], None],
|
||||||
|
stop_event: asyncio.Event,
|
||||||
|
):
|
||||||
self._dsn = dsn
|
self._dsn = dsn
|
||||||
self._queue = queue
|
self._queue = queue
|
||||||
self._callback = callback
|
self._callback = callback
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
# src/dataloader/storage/repositories/__init__.py
|
|
||||||
"""
|
"""
|
||||||
Репозитории для работы с базой данных.
|
Репозитории для работы с базой данных.
|
||||||
Организованы по доменам для масштабируемости.
|
Организованы по доменам для масштабируемости.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .opu import OpuRepository
|
from .opu import OpuRepository
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,6 @@ class OpuRepository:
|
||||||
if not records:
|
if not records:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Получаем колонки для обновления (все кроме PK и технических)
|
|
||||||
update_columns = {
|
update_columns = {
|
||||||
c.name
|
c.name
|
||||||
for c in BriefDigitalCertificateOpu.__table__.columns
|
for c in BriefDigitalCertificateOpu.__table__.columns
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# src/dataloader/storage/repositories/queue.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
@ -15,6 +14,7 @@ class QueueRepository:
|
||||||
"""
|
"""
|
||||||
Репозиторий для работы с очередью задач и журналом событий.
|
Репозиторий для работы с очередью задач и журналом событий.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, session: AsyncSession):
|
def __init__(self, session: AsyncSession):
|
||||||
self.s = session
|
self.s = session
|
||||||
|
|
||||||
|
|
@ -62,7 +62,9 @@ class QueueRepository:
|
||||||
finished_at=None,
|
finished_at=None,
|
||||||
)
|
)
|
||||||
self.s.add(row)
|
self.s.add(row)
|
||||||
await self._append_event(req.job_id, req.queue, "queued", {"task": req.task})
|
await self._append_event(
|
||||||
|
req.job_id, req.queue, "queued", {"task": req.task}
|
||||||
|
)
|
||||||
return req.job_id, "queued"
|
return req.job_id, "queued"
|
||||||
|
|
||||||
async def get_status(self, job_id: str) -> Optional[JobStatus]:
|
async def get_status(self, job_id: str) -> Optional[JobStatus]:
|
||||||
|
|
@ -118,7 +120,9 @@ class QueueRepository:
|
||||||
await self._append_event(job_id, job.queue, "cancel_requested", None)
|
await self._append_event(job_id, job.queue, "cancel_requested", None)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def claim_one(self, queue: str, claim_backoff_sec: int) -> Optional[dict[str, Any]]:
|
async def claim_one(
|
||||||
|
self, queue: str, claim_backoff_sec: int
|
||||||
|
) -> Optional[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Захватывает одну задачу из очереди с учётом блокировок и выставляет running.
|
Захватывает одну задачу из очереди с учётом блокировок и выставляет running.
|
||||||
|
|
||||||
|
|
@ -150,15 +154,21 @@ class QueueRepository:
|
||||||
job.started_at = job.started_at or datetime.now(timezone.utc)
|
job.started_at = job.started_at or datetime.now(timezone.utc)
|
||||||
job.attempt = int(job.attempt) + 1
|
job.attempt = int(job.attempt) + 1
|
||||||
job.heartbeat_at = datetime.now(timezone.utc)
|
job.heartbeat_at = datetime.now(timezone.utc)
|
||||||
job.lease_expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(job.lease_ttl_sec))
|
job.lease_expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
seconds=int(job.lease_ttl_sec)
|
||||||
|
)
|
||||||
|
|
||||||
ok = await self._try_advisory_lock(job.lock_key)
|
ok = await self._try_advisory_lock(job.lock_key)
|
||||||
if not ok:
|
if not ok:
|
||||||
job.status = "queued"
|
job.status = "queued"
|
||||||
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=claim_backoff_sec)
|
job.available_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
seconds=claim_backoff_sec
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
await self._append_event(job.job_id, job.queue, "picked", {"attempt": job.attempt})
|
await self._append_event(
|
||||||
|
job.job_id, job.queue, "picked", {"attempt": job.attempt}
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"job_id": job.job_id,
|
"job_id": job.job_id,
|
||||||
|
|
@ -192,10 +202,15 @@ class QueueRepository:
|
||||||
q = (
|
q = (
|
||||||
update(DLJob)
|
update(DLJob)
|
||||||
.where(DLJob.job_id == job_id, DLJob.status == "running")
|
.where(DLJob.job_id == job_id, DLJob.status == "running")
|
||||||
.values(heartbeat_at=now, lease_expires_at=now + timedelta(seconds=int(ttl_sec)))
|
.values(
|
||||||
|
heartbeat_at=now,
|
||||||
|
lease_expires_at=now + timedelta(seconds=int(ttl_sec)),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
await self.s.execute(q)
|
await self.s.execute(q)
|
||||||
await self._append_event(job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec})
|
await self._append_event(
|
||||||
|
job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec}
|
||||||
|
)
|
||||||
return True, cancel_requested
|
return True, cancel_requested
|
||||||
|
|
||||||
async def finish_ok(self, job_id: str) -> None:
|
async def finish_ok(self, job_id: str) -> None:
|
||||||
|
|
@ -215,7 +230,9 @@ class QueueRepository:
|
||||||
await self._append_event(job_id, job.queue, "succeeded", None)
|
await self._append_event(job_id, job.queue, "succeeded", None)
|
||||||
await self._advisory_unlock(job.lock_key)
|
await self._advisory_unlock(job.lock_key)
|
||||||
|
|
||||||
async def finish_fail_or_retry(self, job_id: str, err: str, is_canceled: bool = False) -> None:
|
async def finish_fail_or_retry(
|
||||||
|
self, job_id: str, err: str, is_canceled: bool = False
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Помечает задачу как failed, canceled или возвращает в очередь с задержкой.
|
Помечает задачу как failed, canceled или возвращает в очередь с задержкой.
|
||||||
|
|
||||||
|
|
@ -239,16 +256,25 @@ class QueueRepository:
|
||||||
can_retry = int(job.attempt) < int(job.max_attempts)
|
can_retry = int(job.attempt) < int(job.max_attempts)
|
||||||
if can_retry:
|
if can_retry:
|
||||||
job.status = "queued"
|
job.status = "queued"
|
||||||
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=30 * int(job.attempt))
|
job.available_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
seconds=30 * int(job.attempt)
|
||||||
|
)
|
||||||
job.error = err
|
job.error = err
|
||||||
job.lease_expires_at = None
|
job.lease_expires_at = None
|
||||||
await self._append_event(job_id, job.queue, "requeue", {"attempt": job.attempt, "error": err})
|
await self._append_event(
|
||||||
|
job_id,
|
||||||
|
job.queue,
|
||||||
|
"requeue",
|
||||||
|
{"attempt": job.attempt, "error": err},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
job.status = "failed"
|
job.status = "failed"
|
||||||
job.error = err
|
job.error = err
|
||||||
job.finished_at = datetime.now(timezone.utc)
|
job.finished_at = datetime.now(timezone.utc)
|
||||||
job.lease_expires_at = None
|
job.lease_expires_at = None
|
||||||
await self._append_event(job_id, job.queue, "failed", {"error": err})
|
await self._append_event(
|
||||||
|
job_id, job.queue, "failed", {"error": err}
|
||||||
|
)
|
||||||
await self._advisory_unlock(job.lock_key)
|
await self._advisory_unlock(job.lock_key)
|
||||||
|
|
||||||
async def requeue_lost(self, now: Optional[datetime] = None) -> list[str]:
|
async def requeue_lost(self, now: Optional[datetime] = None) -> list[str]:
|
||||||
|
|
@ -293,7 +319,11 @@ class QueueRepository:
|
||||||
Возвращает:
|
Возвращает:
|
||||||
ORM модель DLJob или None
|
ORM модель DLJob или None
|
||||||
"""
|
"""
|
||||||
r = await self.s.execute(select(DLJob).where(DLJob.job_id == job_id).with_for_update(skip_locked=True))
|
r = await self.s.execute(
|
||||||
|
select(DLJob)
|
||||||
|
.where(DLJob.job_id == job_id)
|
||||||
|
.with_for_update(skip_locked=True)
|
||||||
|
)
|
||||||
return r.scalar_one_or_none()
|
return r.scalar_one_or_none()
|
||||||
|
|
||||||
async def _resolve_queue(self, job_id: str) -> str:
|
async def _resolve_queue(self, job_id: str) -> str:
|
||||||
|
|
@ -310,7 +340,9 @@ class QueueRepository:
|
||||||
v = r.scalar_one_or_none()
|
v = r.scalar_one_or_none()
|
||||||
return v or ""
|
return v or ""
|
||||||
|
|
||||||
async def _append_event(self, job_id: str, queue: str, kind: str, payload: Optional[dict[str, Any]]) -> None:
|
async def _append_event(
|
||||||
|
self, job_id: str, queue: str, kind: str, payload: Optional[dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Добавляет запись в журнал событий.
|
Добавляет запись в журнал событий.
|
||||||
|
|
||||||
|
|
@ -339,7 +371,9 @@ class QueueRepository:
|
||||||
Возвращает:
|
Возвращает:
|
||||||
True, если блокировка получена
|
True, если блокировка получена
|
||||||
"""
|
"""
|
||||||
r = await self.s.execute(select(func.pg_try_advisory_lock(func.hashtext(lock_key))))
|
r = await self.s.execute(
|
||||||
|
select(func.pg_try_advisory_lock(func.hashtext(lock_key)))
|
||||||
|
)
|
||||||
return bool(r.scalar())
|
return bool(r.scalar())
|
||||||
|
|
||||||
async def _advisory_unlock(self, lock_key: str) -> None:
|
async def _advisory_unlock(self, lock_key: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,9 @@ class QuotesRepository:
|
||||||
result = await self.s.execute(stmt)
|
result = await self.s.execute(stmt)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
async def get_or_create_section(self, name: str, params: dict | None = None) -> QuoteSection:
|
async def get_or_create_section(
|
||||||
|
self, name: str, params: dict | None = None
|
||||||
|
) -> QuoteSection:
|
||||||
"""
|
"""
|
||||||
Получить существующую секцию или создать новую.
|
Получить существующую секцию или создать новую.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
# src/dataloader/storage/schemas/__init__.py
|
|
||||||
"""
|
"""
|
||||||
DTO (Data Transfer Objects) для слоя хранилища.
|
DTO (Data Transfer Objects) для слоя хранилища.
|
||||||
Организованы по доменам для масштабируемости.
|
Организованы по доменам для масштабируемости.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .queue import CreateJobRequest, JobStatus
|
from .queue import CreateJobRequest, JobStatus
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# src/dataloader/storage/schemas/queue.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
@ -11,6 +10,7 @@ class CreateJobRequest:
|
||||||
"""
|
"""
|
||||||
DTO для создания задачи в очереди.
|
DTO для создания задачи в очереди.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
job_id: str
|
job_id: str
|
||||||
queue: str
|
queue: str
|
||||||
task: str
|
task: str
|
||||||
|
|
@ -31,6 +31,7 @@ class JobStatus:
|
||||||
"""
|
"""
|
||||||
DTO для статуса задачи.
|
DTO для статуса задачи.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
job_id: str
|
job_id: str
|
||||||
status: str
|
status: str
|
||||||
attempt: int
|
attempt: int
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1 @@
|
||||||
"""Модуль воркеров для обработки задач."""
|
"""Модуль воркеров для обработки задач."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# src/dataloader/workers/base.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -9,8 +8,8 @@ from typing import AsyncIterator, Callable, Optional
|
||||||
|
|
||||||
from dataloader.config import APP_CONFIG
|
from dataloader.config import APP_CONFIG
|
||||||
from dataloader.context import APP_CTX
|
from dataloader.context import APP_CTX
|
||||||
from dataloader.storage.repositories import QueueRepository
|
|
||||||
from dataloader.storage.notify_listener import PGNotifyListener
|
from dataloader.storage.notify_listener import PGNotifyListener
|
||||||
|
from dataloader.storage.repositories import QueueRepository
|
||||||
from dataloader.workers.pipelines.registry import resolve as resolve_pipeline
|
from dataloader.workers.pipelines.registry import resolve as resolve_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -19,6 +18,7 @@ class WorkerConfig:
|
||||||
"""
|
"""
|
||||||
Конфигурация воркера.
|
Конфигурация воркера.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
queue: str
|
queue: str
|
||||||
heartbeat_sec: int
|
heartbeat_sec: int
|
||||||
claim_backoff_sec: int
|
claim_backoff_sec: int
|
||||||
|
|
@ -28,6 +28,7 @@ class PGWorker:
|
||||||
"""
|
"""
|
||||||
Базовый асинхронный воркер очереди Postgres.
|
Базовый асинхронный воркер очереди Postgres.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg: WorkerConfig, stop_event: asyncio.Event) -> None:
|
def __init__(self, cfg: WorkerConfig, stop_event: asyncio.Event) -> None:
|
||||||
self._cfg = cfg
|
self._cfg = cfg
|
||||||
self._stop = stop_event
|
self._stop = stop_event
|
||||||
|
|
@ -46,12 +47,14 @@ class PGWorker:
|
||||||
dsn=APP_CONFIG.pg.url,
|
dsn=APP_CONFIG.pg.url,
|
||||||
queue=self._cfg.queue,
|
queue=self._cfg.queue,
|
||||||
callback=lambda: self._notify_wakeup.set(),
|
callback=lambda: self._notify_wakeup.set(),
|
||||||
stop_event=self._stop
|
stop_event=self._stop,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await self._listener.start()
|
await self._listener.start()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._log.warning(f"Failed to start LISTEN/NOTIFY, falling back to polling: {e}")
|
self._log.warning(
|
||||||
|
f"Failed to start LISTEN/NOTIFY, falling back to polling: {e}"
|
||||||
|
)
|
||||||
self._listener = None
|
self._listener = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -69,24 +72,27 @@ class PGWorker:
|
||||||
Ожидание появления задач через LISTEN/NOTIFY или с тайм-аутом.
|
Ожидание появления задач через LISTEN/NOTIFY или с тайм-аутом.
|
||||||
"""
|
"""
|
||||||
if self._listener:
|
if self._listener:
|
||||||
# Используем LISTEN/NOTIFY с fallback на таймаут
|
|
||||||
done, pending = await asyncio.wait(
|
done, pending = await asyncio.wait(
|
||||||
[asyncio.create_task(self._notify_wakeup.wait()), asyncio.create_task(self._stop.wait())],
|
[
|
||||||
|
asyncio.create_task(self._notify_wakeup.wait()),
|
||||||
|
asyncio.create_task(self._stop.wait()),
|
||||||
|
],
|
||||||
return_when=asyncio.FIRST_COMPLETED,
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
timeout=timeout_sec
|
timeout=timeout_sec,
|
||||||
)
|
)
|
||||||
# Отменяем оставшиеся задачи
|
|
||||||
for task in pending:
|
for task in pending:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
try:
|
try:
|
||||||
await task
|
await task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
# Очищаем событие, если оно было установлено
|
|
||||||
if self._notify_wakeup.is_set():
|
if self._notify_wakeup.is_set():
|
||||||
self._notify_wakeup.clear()
|
self._notify_wakeup.clear()
|
||||||
else:
|
else:
|
||||||
# Fallback на простой таймаут
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self._stop.wait(), timeout=timeout_sec)
|
await asyncio.wait_for(self._stop.wait(), timeout=timeout_sec)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|
@ -110,30 +116,42 @@ class PGWorker:
|
||||||
args = row["args"]
|
args = row["args"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
canceled = await self._execute_with_heartbeat(job_id, ttl, self._pipeline(task, args))
|
canceled = await self._execute_with_heartbeat(
|
||||||
|
job_id, ttl, self._pipeline(task, args)
|
||||||
|
)
|
||||||
if canceled:
|
if canceled:
|
||||||
await repo.finish_fail_or_retry(job_id, "canceled by user", is_canceled=True)
|
await repo.finish_fail_or_retry(
|
||||||
|
job_id, "canceled by user", is_canceled=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await repo.finish_ok(job_id)
|
await repo.finish_ok(job_id)
|
||||||
return True
|
return True
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
await repo.finish_fail_or_retry(job_id, "cancelled by shutdown", is_canceled=True)
|
await repo.finish_fail_or_retry(
|
||||||
|
job_id, "cancelled by shutdown", is_canceled=True
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await repo.finish_fail_or_retry(job_id, str(e))
|
await repo.finish_fail_or_retry(job_id, str(e))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _execute_with_heartbeat(self, job_id: str, ttl: int, it: AsyncIterator[None]) -> bool:
|
async def _execute_with_heartbeat(
|
||||||
|
self, job_id: str, ttl: int, it: AsyncIterator[None]
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Исполняет конвейер с поддержкой heartbeat.
|
Исполняет конвейер с поддержкой heartbeat.
|
||||||
Возвращает True, если задача была отменена (cancel_requested).
|
Возвращает True, если задача была отменена (cancel_requested).
|
||||||
"""
|
"""
|
||||||
next_hb = datetime.now(timezone.utc) + timedelta(seconds=self._cfg.heartbeat_sec)
|
next_hb = datetime.now(timezone.utc) + timedelta(
|
||||||
|
seconds=self._cfg.heartbeat_sec
|
||||||
|
)
|
||||||
async for _ in it:
|
async for _ in it:
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
if now >= next_hb:
|
if now >= next_hb:
|
||||||
async with self._sm() as s_hb:
|
async with self._sm() as s_hb:
|
||||||
success, cancel_requested = await QueueRepository(s_hb).heartbeat(job_id, ttl)
|
success, cancel_requested = await QueueRepository(s_hb).heartbeat(
|
||||||
|
job_id, ttl
|
||||||
|
)
|
||||||
if cancel_requested:
|
if cancel_requested:
|
||||||
return True
|
return True
|
||||||
next_hb = now + timedelta(seconds=self._cfg.heartbeat_sec)
|
next_hb = now + timedelta(seconds=self._cfg.heartbeat_sec)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# src/dataloader/workers/manager.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -16,6 +15,7 @@ class WorkerSpec:
|
||||||
"""
|
"""
|
||||||
Конфигурация набора воркеров для очереди.
|
Конфигурация набора воркеров для очереди.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
queue: str
|
queue: str
|
||||||
concurrency: int
|
concurrency: int
|
||||||
|
|
||||||
|
|
@ -24,6 +24,7 @@ class WorkerManager:
|
||||||
"""
|
"""
|
||||||
Управляет жизненным циклом асинхронных воркеров.
|
Управляет жизненным циклом асинхронных воркеров.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, specs: list[WorkerSpec]) -> None:
|
def __init__(self, specs: list[WorkerSpec]) -> None:
|
||||||
self._log = APP_CTX.get_logger()
|
self._log = APP_CTX.get_logger()
|
||||||
self._specs = specs
|
self._specs = specs
|
||||||
|
|
@ -40,15 +41,22 @@ class WorkerManager:
|
||||||
|
|
||||||
for spec in self._specs:
|
for spec in self._specs:
|
||||||
for i in range(max(1, spec.concurrency)):
|
for i in range(max(1, spec.concurrency)):
|
||||||
cfg = WorkerConfig(queue=spec.queue, heartbeat_sec=hb, claim_backoff_sec=backoff)
|
cfg = WorkerConfig(
|
||||||
t = asyncio.create_task(PGWorker(cfg, self._stop).run(), name=f"worker:{spec.queue}:{i}")
|
queue=spec.queue, heartbeat_sec=hb, claim_backoff_sec=backoff
|
||||||
|
)
|
||||||
|
t = asyncio.create_task(
|
||||||
|
PGWorker(cfg, self._stop).run(), name=f"worker:{spec.queue}:{i}"
|
||||||
|
)
|
||||||
self._tasks.append(t)
|
self._tasks.append(t)
|
||||||
|
|
||||||
self._reaper_task = asyncio.create_task(self._reaper_loop(), name="reaper")
|
self._reaper_task = asyncio.create_task(self._reaper_loop(), name="reaper")
|
||||||
|
|
||||||
self._log.info(
|
self._log.info(
|
||||||
"worker_manager.started",
|
"worker_manager.started",
|
||||||
extra={"specs": [spec.__dict__ for spec in self._specs], "total_tasks": len(self._tasks)},
|
extra={
|
||||||
|
"specs": [spec.__dict__ for spec in self._specs],
|
||||||
|
"total_tasks": len(self._tasks),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
"""Модуль пайплайнов обработки задач."""
|
"""Модуль пайплайнов обработки задач."""
|
||||||
|
|
||||||
# src/dataloader/workers/pipelines/__init__.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,6 @@ def _parse_jsonl_from_zst(file_path: Path, chunk_size: int = 10000):
|
||||||
APP_CTX.logger.warning(f"Failed to parse JSON line: {e}")
|
APP_CTX.logger.warning(f"Failed to parse JSON line: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Process remaining buffer
|
|
||||||
if buffer.strip():
|
if buffer.strip():
|
||||||
try:
|
try:
|
||||||
record = orjson.loads(buffer)
|
record = orjson.loads(buffer)
|
||||||
|
|
@ -85,11 +84,9 @@ def _convert_record(raw: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
result = raw.copy()
|
result = raw.copy()
|
||||||
|
|
||||||
# Преобразуем actdate из ISO строки в date
|
|
||||||
if "actdate" in result and isinstance(result["actdate"], str):
|
if "actdate" in result and isinstance(result["actdate"], str):
|
||||||
result["actdate"] = datetime.fromisoformat(result["actdate"]).date()
|
result["actdate"] = datetime.fromisoformat(result["actdate"]).date()
|
||||||
|
|
||||||
# Преобразуем wf_load_dttm из ISO строки в datetime
|
|
||||||
if "wf_load_dttm" in result and isinstance(result["wf_load_dttm"], str):
|
if "wf_load_dttm" in result and isinstance(result["wf_load_dttm"], str):
|
||||||
result["wf_load_dttm"] = datetime.fromisoformat(result["wf_load_dttm"])
|
result["wf_load_dttm"] = datetime.fromisoformat(result["wf_load_dttm"])
|
||||||
|
|
||||||
|
|
@ -118,18 +115,15 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
|
||||||
logger = APP_CTX.logger
|
logger = APP_CTX.logger
|
||||||
logger.info("Starting OPU ETL pipeline")
|
logger.info("Starting OPU ETL pipeline")
|
||||||
|
|
||||||
# Шаг 1: Запуск экспорта
|
|
||||||
interface = get_gmap2brief_interface()
|
interface = get_gmap2brief_interface()
|
||||||
job_id = await interface.start_export()
|
job_id = await interface.start_export()
|
||||||
logger.info(f"OPU export job started: {job_id}")
|
logger.info(f"OPU export job started: {job_id}")
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Шаг 2: Ожидание завершения
|
|
||||||
status = await interface.wait_for_completion(job_id)
|
status = await interface.wait_for_completion(job_id)
|
||||||
logger.info(f"OPU export completed: {status.total_rows} rows")
|
logger.info(f"OPU export completed: {status.total_rows} rows")
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Шаг 3: Скачивание архива
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
temp_path = Path(temp_dir)
|
temp_path = Path(temp_dir)
|
||||||
archive_path = temp_path / f"opu_export_{job_id}.jsonl.zst"
|
archive_path = temp_path / f"opu_export_{job_id}.jsonl.zst"
|
||||||
|
|
@ -138,7 +132,6 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
|
||||||
logger.info(f"OPU archive downloaded: {archive_path.stat().st_size:,} bytes")
|
logger.info(f"OPU archive downloaded: {archive_path.stat().st_size:,} bytes")
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Шаг 4: Truncate таблицы
|
|
||||||
async with APP_CTX.sessionmaker() as session:
|
async with APP_CTX.sessionmaker() as session:
|
||||||
repo = OpuRepository(session)
|
repo = OpuRepository(session)
|
||||||
await repo.truncate()
|
await repo.truncate()
|
||||||
|
|
@ -146,7 +139,6 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
|
||||||
logger.info("OPU table truncated")
|
logger.info("OPU table truncated")
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Шаг 5: Загрузка данных стримингово
|
|
||||||
total_inserted = 0
|
total_inserted = 0
|
||||||
batch_num = 0
|
batch_num = 0
|
||||||
|
|
||||||
|
|
@ -154,17 +146,16 @@ async def load_opu(args: dict) -> AsyncIterator[None]:
|
||||||
for batch in _parse_jsonl_from_zst(archive_path, chunk_size=5000):
|
for batch in _parse_jsonl_from_zst(archive_path, chunk_size=5000):
|
||||||
batch_num += 1
|
batch_num += 1
|
||||||
|
|
||||||
# Конвертируем записи
|
|
||||||
converted = [_convert_record(rec) for rec in batch]
|
converted = [_convert_record(rec) for rec in batch]
|
||||||
|
|
||||||
# Вставляем батч
|
|
||||||
inserted = await repo.bulk_insert(converted)
|
inserted = await repo.bulk_insert(converted)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
total_inserted += inserted
|
total_inserted += inserted
|
||||||
logger.debug(f"Batch {batch_num}: inserted {inserted} rows (total: {total_inserted})")
|
logger.debug(
|
||||||
|
f"Batch {batch_num}: inserted {inserted} rows (total: {total_inserted})"
|
||||||
|
)
|
||||||
|
|
||||||
# Heartbeat после каждого батча
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
logger.info(f"OPU ETL completed: {total_inserted} rows inserted")
|
logger.info(f"OPU ETL completed: {total_inserted} rows inserted")
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,9 @@ def _parse_ts_to_datetime(ts: str) -> datetime | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] | None: # noqa: C901
|
def _build_value_row(
|
||||||
|
source: str, dt: datetime, point: Any
|
||||||
|
) -> dict[str, Any] | None: # noqa: C901
|
||||||
"""Строит строку для `quotes_values` по источнику и типу точки."""
|
"""Строит строку для `quotes_values` по источнику и типу точки."""
|
||||||
if isinstance(point, int):
|
if isinstance(point, int):
|
||||||
return {"dt": dt, "key": point}
|
return {"dt": dt, "key": point}
|
||||||
|
|
@ -82,7 +84,10 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] |
|
||||||
if isinstance(deep_inner, InvestingCandlestick):
|
if isinstance(deep_inner, InvestingCandlestick):
|
||||||
return {
|
return {
|
||||||
"dt": dt,
|
"dt": dt,
|
||||||
"price_o": _to_float(getattr(deep_inner, "open_", None) or getattr(deep_inner, "open", None)),
|
"price_o": _to_float(
|
||||||
|
getattr(deep_inner, "open_", None)
|
||||||
|
or getattr(deep_inner, "open", None)
|
||||||
|
),
|
||||||
"price_h": _to_float(deep_inner.high),
|
"price_h": _to_float(deep_inner.high),
|
||||||
"price_l": _to_float(deep_inner.low),
|
"price_l": _to_float(deep_inner.low),
|
||||||
"price_c": _to_float(deep_inner.close),
|
"price_c": _to_float(deep_inner.close),
|
||||||
|
|
@ -92,12 +97,16 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] |
|
||||||
if isinstance(inner, TradingViewTimePoint | SgxTimePoint):
|
if isinstance(inner, TradingViewTimePoint | SgxTimePoint):
|
||||||
return {
|
return {
|
||||||
"dt": dt,
|
"dt": dt,
|
||||||
"price_o": _to_float(getattr(inner, "open_", None) or getattr(inner, "open", None)),
|
"price_o": _to_float(
|
||||||
|
getattr(inner, "open_", None) or getattr(inner, "open", None)
|
||||||
|
),
|
||||||
"price_h": _to_float(inner.high),
|
"price_h": _to_float(inner.high),
|
||||||
"price_l": _to_float(inner.low),
|
"price_l": _to_float(inner.low),
|
||||||
"price_c": _to_float(inner.close),
|
"price_c": _to_float(inner.close),
|
||||||
"volume": _to_float(
|
"volume": _to_float(
|
||||||
getattr(inner, "volume", None) or getattr(inner, "interest", None) or getattr(inner, "value", None)
|
getattr(inner, "volume", None)
|
||||||
|
or getattr(inner, "interest", None)
|
||||||
|
or getattr(inner, "value", None)
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -132,7 +141,9 @@ def _build_value_row(source: str, dt: datetime, point: Any) -> dict[str, Any] |
|
||||||
"dt": dt,
|
"dt": dt,
|
||||||
"value_last": _to_float(deep_inner.last),
|
"value_last": _to_float(deep_inner.last),
|
||||||
"value_previous": _to_float(deep_inner.previous),
|
"value_previous": _to_float(deep_inner.previous),
|
||||||
"unit": str(deep_inner.unit) if deep_inner.unit is not None else None,
|
"unit": (
|
||||||
|
str(deep_inner.unit) if deep_inner.unit is not None else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(deep_inner, TradingEconomicsStringPercent):
|
if isinstance(deep_inner, TradingEconomicsStringPercent):
|
||||||
|
|
@ -218,7 +229,14 @@ async def load_tenera(args: dict) -> AsyncIterator[None]:
|
||||||
async with APP_CTX.sessionmaker() as session:
|
async with APP_CTX.sessionmaker() as session:
|
||||||
repo = QuotesRepository(session)
|
repo = QuotesRepository(session)
|
||||||
|
|
||||||
for source_name in ("cbr", "investing", "sgx", "tradingeconomics", "bloomberg", "trading_view"):
|
for source_name in (
|
||||||
|
"cbr",
|
||||||
|
"investing",
|
||||||
|
"sgx",
|
||||||
|
"tradingeconomics",
|
||||||
|
"bloomberg",
|
||||||
|
"trading_view",
|
||||||
|
):
|
||||||
source_data = getattr(data, source_name)
|
source_data = getattr(data, source_name)
|
||||||
if not source_data:
|
if not source_data:
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# src/dataloader/workers/pipelines/noop.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# src/dataloader/workers/pipelines/registry.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, Iterable
|
from typing import Any, Callable, Dict, Iterable
|
||||||
|
|
@ -6,13 +5,17 @@ from typing import Any, Callable, Dict, Iterable
|
||||||
_Registry: Dict[str, Callable[[dict[str, Any]], Any]] = {}
|
_Registry: Dict[str, Callable[[dict[str, Any]], Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
def register(task: str) -> Callable[[Callable[[dict[str, Any]], Any]], Callable[[dict[str, Any]], Any]]:
|
def register(
|
||||||
|
task: str,
|
||||||
|
) -> Callable[[Callable[[dict[str, Any]], Any]], Callable[[dict[str, Any]], Any]]:
|
||||||
"""
|
"""
|
||||||
Регистрирует обработчик пайплайна под именем задачи.
|
Регистрирует обработчик пайплайна под именем задачи.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _wrap(fn: Callable[[dict[str, Any]], Any]) -> Callable[[dict[str, Any]], Any]:
|
def _wrap(fn: Callable[[dict[str, Any]], Any]) -> Callable[[dict[str, Any]], Any]:
|
||||||
_Registry[task] = fn
|
_Registry[task] = fn
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
return _wrap
|
return _wrap
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -22,8 +25,8 @@ def resolve(task: str) -> Callable[[dict[str, Any]], Any]:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return _Registry[task]
|
return _Registry[task]
|
||||||
except KeyError:
|
except KeyError as err:
|
||||||
raise KeyError(f"pipeline not found: {task}")
|
raise KeyError(f"pipeline not found: {task}") from err
|
||||||
|
|
||||||
|
|
||||||
def tasks() -> Iterable[str]:
|
def tasks() -> Iterable[str]:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# src/dataloader/workers/reaper.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from dataloader.storage.repositories import QueueRepository
|
from dataloader.storage.repositories import QueueRepository
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# tests/conftest.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -9,23 +8,23 @@ from uuid import uuid4
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from httpx import AsyncClient, ASGITransport
|
from httpx import ASGITransport, AsyncClient
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
from dataloader.api import app_main
|
from dataloader.api import app_main
|
||||||
from dataloader.config import APP_CONFIG
|
from dataloader.config import APP_CONFIG
|
||||||
from dataloader.context import APP_CTX, get_session
|
from dataloader.context import get_session
|
||||||
from dataloader.storage.models import Base
|
from dataloader.storage.engine import create_engine
|
||||||
from dataloader.storage.engine import create_engine, create_sessionmaker
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
if sys.platform == "win32":
|
if sys.platform == "win32":
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function")
|
@pytest_asyncio.fixture(scope="function")
|
||||||
async def db_engine() -> AsyncGenerator[AsyncEngine, None]:
|
async def db_engine() -> AsyncGenerator[AsyncEngine, None]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -39,15 +38,15 @@ async def db_engine() -> AsyncGenerator[AsyncEngine, None]:
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function")
|
@pytest_asyncio.fixture(scope="function")
|
||||||
async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
|
async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""
|
"""
|
||||||
Предоставляет сессию БД для каждого теста.
|
Предоставляет сессию БД для каждого теста.
|
||||||
НЕ использует транзакцию, чтобы работали advisory locks.
|
НЕ использует транзакцию, чтобы работали advisory locks.
|
||||||
"""
|
"""
|
||||||
sessionmaker = async_sessionmaker(bind=db_engine, expire_on_commit=False, class_=AsyncSession)
|
sessionmaker = async_sessionmaker(
|
||||||
|
bind=db_engine, expire_on_commit=False, class_=AsyncSession
|
||||||
|
)
|
||||||
async with sessionmaker() as session:
|
async with sessionmaker() as session:
|
||||||
yield session
|
yield session
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
|
|
@ -69,6 +68,7 @@ async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
||||||
"""
|
"""
|
||||||
HTTP клиент для тестирования API.
|
HTTP клиент для тестирования API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def override_get_session() -> AsyncGenerator[AsyncSession, None]:
|
async def override_get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
yield db_session
|
yield db_session
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
# tests/integration_tests/__init__.py
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
# tests/integration_tests/test_api_endpoints.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
# tests/unit/test_api_router_not_found.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
from uuid import UUID, uuid4
|
||||||
from uuid import uuid4, UUID
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from dataloader.api.v1.router import get_status, cancel_job
|
|
||||||
from dataloader.api.v1.exceptions import JobNotFoundError
|
from dataloader.api.v1.exceptions import JobNotFoundError
|
||||||
|
from dataloader.api.v1.router import cancel_job, get_status
|
||||||
from dataloader.api.v1.schemas import JobStatusResponse
|
from dataloader.api.v1.schemas import JobStatusResponse
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
# tests/unit/test_api_router_success.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
|
||||||
from uuid import uuid4, UUID
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from dataloader.api.v1.router import get_status, cancel_job
|
import pytest
|
||||||
|
|
||||||
|
from dataloader.api.v1.router import cancel_job, get_status
|
||||||
from dataloader.api.v1.schemas import JobStatusResponse
|
from dataloader.api.v1.schemas import JobStatusResponse
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
# tests/integration_tests/test_queue_repository.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -10,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from dataloader.storage.models import DLJob
|
from dataloader.storage.models import DLJob
|
||||||
from dataloader.storage.repositories import QueueRepository
|
from dataloader.storage.repositories import QueueRepository
|
||||||
from dataloader.storage.schemas import CreateJobRequest, JobStatus
|
from dataloader.storage.schemas import CreateJobRequest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
|
|
@ -445,6 +444,7 @@ class TestQueueRepository:
|
||||||
await repo.claim_one(queue_name, claim_backoff_sec=15)
|
await repo.claim_one(queue_name, claim_backoff_sec=15)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
requeued = await repo.requeue_lost()
|
requeued = await repo.requeue_lost()
|
||||||
|
|
@ -500,7 +500,9 @@ class TestQueueRepository:
|
||||||
assert st is not None
|
assert st is not None
|
||||||
assert st.status == "queued"
|
assert st.status == "queued"
|
||||||
|
|
||||||
row = (await db_session.execute(select(DLJob).where(DLJob.job_id == job_id))).scalar_one()
|
row = (
|
||||||
|
await db_session.execute(select(DLJob).where(DLJob.job_id == job_id))
|
||||||
|
).scalar_one()
|
||||||
assert row.available_at >= before + timedelta(seconds=15)
|
assert row.available_at >= before + timedelta(seconds=15)
|
||||||
assert row.available_at <= after + timedelta(seconds=60)
|
assert row.available_at <= after + timedelta(seconds=60)
|
||||||
|
|
||||||
|
|
@ -569,7 +571,9 @@ class TestQueueRepository:
|
||||||
await repo.create_or_get(req)
|
await repo.create_or_get(req)
|
||||||
await repo.claim_one(queue_name, claim_backoff_sec=5)
|
await repo.claim_one(queue_name, claim_backoff_sec=5)
|
||||||
|
|
||||||
await repo.finish_fail_or_retry(job_id, err="Canceled by test", is_canceled=True)
|
await repo.finish_fail_or_retry(
|
||||||
|
job_id, err="Canceled by test", is_canceled=True
|
||||||
|
)
|
||||||
|
|
||||||
st = await repo.get_status(job_id)
|
st = await repo.get_status(job_id)
|
||||||
assert st is not None
|
assert st is not None
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
# tests/unit/__init__.py
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
# tests/unit/test_api_service.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from uuid import UUID
|
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from dataloader.api.v1.service import JobsService
|
|
||||||
from dataloader.api.v1.schemas import TriggerJobRequest
|
from dataloader.api.v1.schemas import TriggerJobRequest
|
||||||
|
from dataloader.api.v1.service import JobsService
|
||||||
from dataloader.storage.schemas import JobStatus
|
from dataloader.storage.schemas import JobStatus
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -38,15 +38,19 @@ class TestJobsService:
|
||||||
"""
|
"""
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
with (
|
||||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \
|
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||||
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id:
|
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||||
|
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
|
||||||
|
):
|
||||||
|
|
||||||
mock_get_logger.return_value = Mock()
|
mock_get_logger.return_value = Mock()
|
||||||
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
|
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
|
||||||
|
|
||||||
mock_repo = Mock()
|
mock_repo = Mock()
|
||||||
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued"))
|
mock_repo.create_or_get = AsyncMock(
|
||||||
|
return_value=("12345678-1234-5678-1234-567812345678", "queued")
|
||||||
|
)
|
||||||
mock_repo_cls.return_value = mock_repo
|
mock_repo_cls.return_value = mock_repo
|
||||||
|
|
||||||
service = JobsService(mock_session)
|
service = JobsService(mock_session)
|
||||||
|
|
@ -58,7 +62,7 @@ class TestJobsService:
|
||||||
lock_key="lock_1",
|
lock_key="lock_1",
|
||||||
priority=100,
|
priority=100,
|
||||||
max_attempts=5,
|
max_attempts=5,
|
||||||
lease_ttl_sec=60
|
lease_ttl_sec=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await service.trigger(req)
|
response = await service.trigger(req)
|
||||||
|
|
@ -74,15 +78,19 @@ class TestJobsService:
|
||||||
"""
|
"""
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
with (
|
||||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \
|
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||||
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id:
|
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||||
|
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
|
||||||
|
):
|
||||||
|
|
||||||
mock_get_logger.return_value = Mock()
|
mock_get_logger.return_value = Mock()
|
||||||
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
|
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
|
||||||
|
|
||||||
mock_repo = Mock()
|
mock_repo = Mock()
|
||||||
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued"))
|
mock_repo.create_or_get = AsyncMock(
|
||||||
|
return_value=("12345678-1234-5678-1234-567812345678", "queued")
|
||||||
|
)
|
||||||
mock_repo_cls.return_value = mock_repo
|
mock_repo_cls.return_value = mock_repo
|
||||||
|
|
||||||
service = JobsService(mock_session)
|
service = JobsService(mock_session)
|
||||||
|
|
@ -95,7 +103,7 @@ class TestJobsService:
|
||||||
lock_key="lock_1",
|
lock_key="lock_1",
|
||||||
priority=100,
|
priority=100,
|
||||||
max_attempts=5,
|
max_attempts=5,
|
||||||
lease_ttl_sec=60
|
lease_ttl_sec=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await service.trigger(req)
|
response = await service.trigger(req)
|
||||||
|
|
@ -112,15 +120,19 @@ class TestJobsService:
|
||||||
"""
|
"""
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
with (
|
||||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \
|
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||||
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id:
|
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||||
|
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
|
||||||
|
):
|
||||||
|
|
||||||
mock_get_logger.return_value = Mock()
|
mock_get_logger.return_value = Mock()
|
||||||
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
|
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
|
||||||
|
|
||||||
mock_repo = Mock()
|
mock_repo = Mock()
|
||||||
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued"))
|
mock_repo.create_or_get = AsyncMock(
|
||||||
|
return_value=("12345678-1234-5678-1234-567812345678", "queued")
|
||||||
|
)
|
||||||
mock_repo_cls.return_value = mock_repo
|
mock_repo_cls.return_value = mock_repo
|
||||||
|
|
||||||
service = JobsService(mock_session)
|
service = JobsService(mock_session)
|
||||||
|
|
@ -135,10 +147,10 @@ class TestJobsService:
|
||||||
available_at=future_time,
|
available_at=future_time,
|
||||||
priority=100,
|
priority=100,
|
||||||
max_attempts=5,
|
max_attempts=5,
|
||||||
lease_ttl_sec=60
|
lease_ttl_sec=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await service.trigger(req)
|
await service.trigger(req)
|
||||||
|
|
||||||
call_args = mock_repo.create_or_get.call_args[0][0]
|
call_args = mock_repo.create_or_get.call_args[0][0]
|
||||||
assert call_args.available_at == future_time
|
assert call_args.available_at == future_time
|
||||||
|
|
@ -150,15 +162,19 @@ class TestJobsService:
|
||||||
"""
|
"""
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
with (
|
||||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls, \
|
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||||
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id:
|
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||||
|
patch("dataloader.api.v1.service.new_job_id") as mock_new_job_id,
|
||||||
|
):
|
||||||
|
|
||||||
mock_get_logger.return_value = Mock()
|
mock_get_logger.return_value = Mock()
|
||||||
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
|
mock_new_job_id.return_value = UUID("12345678-1234-5678-1234-567812345678")
|
||||||
|
|
||||||
mock_repo = Mock()
|
mock_repo = Mock()
|
||||||
mock_repo.create_or_get = AsyncMock(return_value=("12345678-1234-5678-1234-567812345678", "queued"))
|
mock_repo.create_or_get = AsyncMock(
|
||||||
|
return_value=("12345678-1234-5678-1234-567812345678", "queued")
|
||||||
|
)
|
||||||
mock_repo_cls.return_value = mock_repo
|
mock_repo_cls.return_value = mock_repo
|
||||||
|
|
||||||
service = JobsService(mock_session)
|
service = JobsService(mock_session)
|
||||||
|
|
@ -173,10 +189,10 @@ class TestJobsService:
|
||||||
consumer_group="test_group",
|
consumer_group="test_group",
|
||||||
priority=100,
|
priority=100,
|
||||||
max_attempts=5,
|
max_attempts=5,
|
||||||
lease_ttl_sec=60
|
lease_ttl_sec=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await service.trigger(req)
|
await service.trigger(req)
|
||||||
|
|
||||||
call_args = mock_repo.create_or_get.call_args[0][0]
|
call_args = mock_repo.create_or_get.call_args[0][0]
|
||||||
assert call_args.partition_key == "partition_1"
|
assert call_args.partition_key == "partition_1"
|
||||||
|
|
@ -190,8 +206,10 @@ class TestJobsService:
|
||||||
"""
|
"""
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
with (
|
||||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
|
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||||
|
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_get_logger.return_value = Mock()
|
mock_get_logger.return_value = Mock()
|
||||||
|
|
||||||
|
|
@ -204,7 +222,7 @@ class TestJobsService:
|
||||||
finished_at=None,
|
finished_at=None,
|
||||||
heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc),
|
heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc),
|
||||||
error=None,
|
error=None,
|
||||||
progress={"step": 1}
|
progress={"step": 1},
|
||||||
)
|
)
|
||||||
mock_repo.get_status = AsyncMock(return_value=mock_status)
|
mock_repo.get_status = AsyncMock(return_value=mock_status)
|
||||||
mock_repo_cls.return_value = mock_repo
|
mock_repo_cls.return_value = mock_repo
|
||||||
|
|
@ -227,8 +245,10 @@ class TestJobsService:
|
||||||
"""
|
"""
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
with (
|
||||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
|
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||||
|
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_get_logger.return_value = Mock()
|
mock_get_logger.return_value = Mock()
|
||||||
|
|
||||||
|
|
@ -250,8 +270,10 @@ class TestJobsService:
|
||||||
"""
|
"""
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
with (
|
||||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
|
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||||
|
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_get_logger.return_value = Mock()
|
mock_get_logger.return_value = Mock()
|
||||||
|
|
||||||
|
|
@ -265,7 +287,7 @@ class TestJobsService:
|
||||||
finished_at=None,
|
finished_at=None,
|
||||||
heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc),
|
heartbeat_at=datetime(2025, 1, 1, 12, 5, 0, tzinfo=timezone.utc),
|
||||||
error=None,
|
error=None,
|
||||||
progress={}
|
progress={},
|
||||||
)
|
)
|
||||||
mock_repo.get_status = AsyncMock(return_value=mock_status)
|
mock_repo.get_status = AsyncMock(return_value=mock_status)
|
||||||
mock_repo_cls.return_value = mock_repo
|
mock_repo_cls.return_value = mock_repo
|
||||||
|
|
@ -287,8 +309,10 @@ class TestJobsService:
|
||||||
"""
|
"""
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
with (
|
||||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
|
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||||
|
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_get_logger.return_value = Mock()
|
mock_get_logger.return_value = Mock()
|
||||||
|
|
||||||
|
|
@ -312,8 +336,10 @@ class TestJobsService:
|
||||||
"""
|
"""
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.api.v1.service.get_logger") as mock_get_logger, \
|
with (
|
||||||
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls:
|
patch("dataloader.api.v1.service.get_logger") as mock_get_logger,
|
||||||
|
patch("dataloader.api.v1.service.QueueRepository") as mock_repo_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_get_logger.return_value = Mock()
|
mock_get_logger.return_value = Mock()
|
||||||
|
|
||||||
|
|
@ -326,7 +352,7 @@ class TestJobsService:
|
||||||
finished_at=None,
|
finished_at=None,
|
||||||
heartbeat_at=None,
|
heartbeat_at=None,
|
||||||
error=None,
|
error=None,
|
||||||
progress=None
|
progress=None,
|
||||||
)
|
)
|
||||||
mock_repo.get_status = AsyncMock(return_value=mock_status)
|
mock_repo.get_status = AsyncMock(return_value=mock_status)
|
||||||
mock_repo_cls.return_value = mock_repo
|
mock_repo_cls.return_value = mock_repo
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,18 @@
|
||||||
# tests/unit/test_config.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from logging import DEBUG, INFO
|
from logging import DEBUG, INFO
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from dataloader.config import (
|
from dataloader.config import (
|
||||||
BaseAppSettings,
|
|
||||||
AppSettings,
|
AppSettings,
|
||||||
|
BaseAppSettings,
|
||||||
LogSettings,
|
LogSettings,
|
||||||
PGSettings,
|
PGSettings,
|
||||||
WorkerSettings,
|
|
||||||
Secrets,
|
Secrets,
|
||||||
|
WorkerSettings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -81,12 +81,15 @@ class TestAppSettings:
|
||||||
"""
|
"""
|
||||||
Тест загрузки из переменных окружения.
|
Тест загрузки из переменных окружения.
|
||||||
"""
|
"""
|
||||||
with patch.dict("os.environ", {
|
with patch.dict(
|
||||||
|
"os.environ",
|
||||||
|
{
|
||||||
"APP_HOST": "127.0.0.1",
|
"APP_HOST": "127.0.0.1",
|
||||||
"APP_PORT": "9000",
|
"APP_PORT": "9000",
|
||||||
"PROJECT_NAME": "TestProject",
|
"PROJECT_NAME": "TestProject",
|
||||||
"TIMEZONE": "UTC"
|
"TIMEZONE": "UTC",
|
||||||
}):
|
},
|
||||||
|
):
|
||||||
settings = AppSettings()
|
settings = AppSettings()
|
||||||
|
|
||||||
assert settings.app_host == "127.0.0.1"
|
assert settings.app_host == "127.0.0.1"
|
||||||
|
|
@ -136,7 +139,9 @@ class TestLogSettings:
|
||||||
"""
|
"""
|
||||||
Тест свойства log_file_abs_path.
|
Тест свойства log_file_abs_path.
|
||||||
"""
|
"""
|
||||||
with patch.dict("os.environ", {"LOG_PATH": "/var/log", "LOG_FILE_NAME": "test.log"}):
|
with patch.dict(
|
||||||
|
"os.environ", {"LOG_PATH": "/var/log", "LOG_FILE_NAME": "test.log"}
|
||||||
|
):
|
||||||
settings = LogSettings()
|
settings = LogSettings()
|
||||||
|
|
||||||
assert "test.log" in settings.log_file_abs_path
|
assert "test.log" in settings.log_file_abs_path
|
||||||
|
|
@ -146,7 +151,10 @@ class TestLogSettings:
|
||||||
"""
|
"""
|
||||||
Тест свойства metric_file_abs_path.
|
Тест свойства metric_file_abs_path.
|
||||||
"""
|
"""
|
||||||
with patch.dict("os.environ", {"METRIC_PATH": "/var/metrics", "METRIC_FILE_NAME": "metrics.log"}):
|
with patch.dict(
|
||||||
|
"os.environ",
|
||||||
|
{"METRIC_PATH": "/var/metrics", "METRIC_FILE_NAME": "metrics.log"},
|
||||||
|
):
|
||||||
settings = LogSettings()
|
settings = LogSettings()
|
||||||
|
|
||||||
assert "metrics.log" in settings.metric_file_abs_path
|
assert "metrics.log" in settings.metric_file_abs_path
|
||||||
|
|
@ -156,7 +164,10 @@ class TestLogSettings:
|
||||||
"""
|
"""
|
||||||
Тест свойства audit_file_abs_path.
|
Тест свойства audit_file_abs_path.
|
||||||
"""
|
"""
|
||||||
with patch.dict("os.environ", {"AUDIT_LOG_PATH": "/var/audit", "AUDIT_LOG_FILE_NAME": "audit.log"}):
|
with patch.dict(
|
||||||
|
"os.environ",
|
||||||
|
{"AUDIT_LOG_PATH": "/var/audit", "AUDIT_LOG_FILE_NAME": "audit.log"},
|
||||||
|
):
|
||||||
settings = LogSettings()
|
settings = LogSettings()
|
||||||
|
|
||||||
assert "audit.log" in settings.audit_file_abs_path
|
assert "audit.log" in settings.audit_file_abs_path
|
||||||
|
|
@ -208,29 +219,37 @@ class TestPGSettings:
|
||||||
"""
|
"""
|
||||||
Тест формирования строки подключения.
|
Тест формирования строки подключения.
|
||||||
"""
|
"""
|
||||||
with patch.dict("os.environ", {
|
with patch.dict(
|
||||||
|
"os.environ",
|
||||||
|
{
|
||||||
"PG_HOST": "db.example.com",
|
"PG_HOST": "db.example.com",
|
||||||
"PG_PORT": "5433",
|
"PG_PORT": "5433",
|
||||||
"PG_USER": "testuser",
|
"PG_USER": "testuser",
|
||||||
"PG_PASSWORD": "testpass",
|
"PG_PASSWORD": "testpass",
|
||||||
"PG_DATABASE": "testdb"
|
"PG_DATABASE": "testdb",
|
||||||
}):
|
},
|
||||||
|
):
|
||||||
settings = PGSettings()
|
settings = PGSettings()
|
||||||
|
|
||||||
expected = "postgresql+asyncpg://testuser:testpass@db.example.com:5433/testdb"
|
expected = (
|
||||||
|
"postgresql+asyncpg://testuser:testpass@db.example.com:5433/testdb"
|
||||||
|
)
|
||||||
assert settings.url == expected
|
assert settings.url == expected
|
||||||
|
|
||||||
def test_url_property_with_empty_password(self):
|
def test_url_property_with_empty_password(self):
|
||||||
"""
|
"""
|
||||||
Тест строки подключения с пустым паролем.
|
Тест строки подключения с пустым паролем.
|
||||||
"""
|
"""
|
||||||
with patch.dict("os.environ", {
|
with patch.dict(
|
||||||
|
"os.environ",
|
||||||
|
{
|
||||||
"PG_HOST": "localhost",
|
"PG_HOST": "localhost",
|
||||||
"PG_PORT": "5432",
|
"PG_PORT": "5432",
|
||||||
"PG_USER": "postgres",
|
"PG_USER": "postgres",
|
||||||
"PG_PASSWORD": "",
|
"PG_PASSWORD": "",
|
||||||
"PG_DATABASE": "testdb"
|
"PG_DATABASE": "testdb",
|
||||||
}):
|
},
|
||||||
|
):
|
||||||
settings = PGSettings()
|
settings = PGSettings()
|
||||||
|
|
||||||
expected = "postgresql+asyncpg://postgres:@localhost:5432/testdb"
|
expected = "postgresql+asyncpg://postgres:@localhost:5432/testdb"
|
||||||
|
|
@ -240,15 +259,18 @@ class TestPGSettings:
|
||||||
"""
|
"""
|
||||||
Тест загрузки из переменных окружения.
|
Тест загрузки из переменных окружения.
|
||||||
"""
|
"""
|
||||||
with patch.dict("os.environ", {
|
with patch.dict(
|
||||||
|
"os.environ",
|
||||||
|
{
|
||||||
"PG_HOST": "testhost",
|
"PG_HOST": "testhost",
|
||||||
"PG_PORT": "5433",
|
"PG_PORT": "5433",
|
||||||
"PG_USER": "testuser",
|
"PG_USER": "testuser",
|
||||||
"PG_PASSWORD": "testpass",
|
"PG_PASSWORD": "testpass",
|
||||||
"PG_DATABASE": "testdb",
|
"PG_DATABASE": "testdb",
|
||||||
"PG_SCHEMA_QUEUE": "queue_schema",
|
"PG_SCHEMA_QUEUE": "queue_schema",
|
||||||
"PG_POOL_SIZE": "20"
|
"PG_POOL_SIZE": "20",
|
||||||
}):
|
},
|
||||||
|
):
|
||||||
settings = PGSettings()
|
settings = PGSettings()
|
||||||
|
|
||||||
assert settings.host == "testhost"
|
assert settings.host == "testhost"
|
||||||
|
|
@ -292,10 +314,12 @@ class TestWorkerSettings:
|
||||||
"""
|
"""
|
||||||
Тест парсинга валидного JSON.
|
Тест парсинга валидного JSON.
|
||||||
"""
|
"""
|
||||||
workers_json = json.dumps([
|
workers_json = json.dumps(
|
||||||
|
[
|
||||||
{"queue": "queue1", "concurrency": 2},
|
{"queue": "queue1", "concurrency": 2},
|
||||||
{"queue": "queue2", "concurrency": 3}
|
{"queue": "queue2", "concurrency": 3},
|
||||||
])
|
]
|
||||||
|
)
|
||||||
with patch.dict("os.environ", {"WORKERS_JSON": workers_json}):
|
with patch.dict("os.environ", {"WORKERS_JSON": workers_json}):
|
||||||
settings = WorkerSettings()
|
settings = WorkerSettings()
|
||||||
|
|
||||||
|
|
@ -311,12 +335,14 @@ class TestWorkerSettings:
|
||||||
"""
|
"""
|
||||||
Тест фильтрации не-словарей из JSON.
|
Тест фильтрации не-словарей из JSON.
|
||||||
"""
|
"""
|
||||||
workers_json = json.dumps([
|
workers_json = json.dumps(
|
||||||
|
[
|
||||||
{"queue": "queue1", "concurrency": 2},
|
{"queue": "queue1", "concurrency": 2},
|
||||||
"invalid_item",
|
"invalid_item",
|
||||||
123,
|
123,
|
||||||
{"queue": "queue2", "concurrency": 3}
|
{"queue": "queue2", "concurrency": 3},
|
||||||
])
|
]
|
||||||
|
)
|
||||||
with patch.dict("os.environ", {"WORKERS_JSON": workers_json}):
|
with patch.dict("os.environ", {"WORKERS_JSON": workers_json}):
|
||||||
settings = WorkerSettings()
|
settings = WorkerSettings()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# tests/unit/test_context.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from dataloader.context import AppContext, get_session
|
from dataloader.context import AppContext, get_session
|
||||||
|
|
@ -71,10 +71,16 @@ class TestAppContext:
|
||||||
mock_engine = Mock()
|
mock_engine = Mock()
|
||||||
mock_sm = Mock()
|
mock_sm = Mock()
|
||||||
|
|
||||||
with patch("dataloader.logger.logger.setup_logging") as mock_setup_logging, \
|
with (
|
||||||
patch("dataloader.storage.engine.create_engine", return_value=mock_engine) as mock_create_engine, \
|
patch("dataloader.logger.logger.setup_logging") as mock_setup_logging,
|
||||||
patch("dataloader.storage.engine.create_sessionmaker", return_value=mock_sm) as mock_create_sm, \
|
patch(
|
||||||
patch("dataloader.context.APP_CONFIG") as mock_config:
|
"dataloader.storage.engine.create_engine", return_value=mock_engine
|
||||||
|
) as mock_create_engine,
|
||||||
|
patch(
|
||||||
|
"dataloader.storage.engine.create_sessionmaker", return_value=mock_sm
|
||||||
|
) as mock_create_sm,
|
||||||
|
patch("dataloader.context.APP_CONFIG") as mock_config,
|
||||||
|
):
|
||||||
|
|
||||||
mock_config.pg.url = "postgresql://test"
|
mock_config.pg.url = "postgresql://test"
|
||||||
|
|
||||||
|
|
@ -194,7 +200,7 @@ class TestGetSession:
|
||||||
with patch("dataloader.context.APP_CTX") as mock_ctx:
|
with patch("dataloader.context.APP_CTX") as mock_ctx:
|
||||||
mock_ctx.sessionmaker = mock_sm
|
mock_ctx.sessionmaker = mock_sm
|
||||||
|
|
||||||
async for session in get_session():
|
async for _session in get_session():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert mock_exit.call_count == 1
|
assert mock_exit.call_count == 1
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
# tests/unit/test_notify_listener.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from dataloader.storage.notify_listener import PGNotifyListener
|
from dataloader.storage.notify_listener import PGNotifyListener
|
||||||
|
|
@ -25,7 +25,7 @@ class TestPGNotifyListener:
|
||||||
dsn="postgresql://test",
|
dsn="postgresql://test",
|
||||||
queue="test_queue",
|
queue="test_queue",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
stop_event=stop_event
|
stop_event=stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert listener._dsn == "postgresql://test"
|
assert listener._dsn == "postgresql://test"
|
||||||
|
|
@ -47,14 +47,16 @@ class TestPGNotifyListener:
|
||||||
dsn="postgresql://test",
|
dsn="postgresql://test",
|
||||||
queue="test_queue",
|
queue="test_queue",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
stop_event=stop_event
|
stop_event=stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
mock_conn.execute = AsyncMock()
|
mock_conn.execute = AsyncMock()
|
||||||
mock_conn.add_listener = AsyncMock()
|
mock_conn.add_listener = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
with patch(
|
||||||
|
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||||
|
):
|
||||||
await listener.start()
|
await listener.start()
|
||||||
|
|
||||||
assert listener._conn == mock_conn
|
assert listener._conn == mock_conn
|
||||||
|
|
@ -76,14 +78,16 @@ class TestPGNotifyListener:
|
||||||
dsn="postgresql+asyncpg://test",
|
dsn="postgresql+asyncpg://test",
|
||||||
queue="test_queue",
|
queue="test_queue",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
stop_event=stop_event
|
stop_event=stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
mock_conn.execute = AsyncMock()
|
mock_conn.execute = AsyncMock()
|
||||||
mock_conn.add_listener = AsyncMock()
|
mock_conn.add_listener = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn) as mock_connect:
|
with patch(
|
||||||
|
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||||
|
) as mock_connect:
|
||||||
await listener.start()
|
await listener.start()
|
||||||
|
|
||||||
mock_connect.assert_called_once_with("postgresql://test")
|
mock_connect.assert_called_once_with("postgresql://test")
|
||||||
|
|
@ -102,14 +106,16 @@ class TestPGNotifyListener:
|
||||||
dsn="postgresql://test",
|
dsn="postgresql://test",
|
||||||
queue="test_queue",
|
queue="test_queue",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
stop_event=stop_event
|
stop_event=stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
mock_conn.execute = AsyncMock()
|
mock_conn.execute = AsyncMock()
|
||||||
mock_conn.add_listener = AsyncMock()
|
mock_conn.add_listener = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
with patch(
|
||||||
|
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||||
|
):
|
||||||
await listener.start()
|
await listener.start()
|
||||||
|
|
||||||
handler = listener._on_notify_handler
|
handler = listener._on_notify_handler
|
||||||
|
|
@ -131,14 +137,16 @@ class TestPGNotifyListener:
|
||||||
dsn="postgresql://test",
|
dsn="postgresql://test",
|
||||||
queue="test_queue",
|
queue="test_queue",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
stop_event=stop_event
|
stop_event=stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
mock_conn.execute = AsyncMock()
|
mock_conn.execute = AsyncMock()
|
||||||
mock_conn.add_listener = AsyncMock()
|
mock_conn.add_listener = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
with patch(
|
||||||
|
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||||
|
):
|
||||||
await listener.start()
|
await listener.start()
|
||||||
|
|
||||||
handler = listener._on_notify_handler
|
handler = listener._on_notify_handler
|
||||||
|
|
@ -160,14 +168,16 @@ class TestPGNotifyListener:
|
||||||
dsn="postgresql://test",
|
dsn="postgresql://test",
|
||||||
queue="test_queue",
|
queue="test_queue",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
stop_event=stop_event
|
stop_event=stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
mock_conn.execute = AsyncMock()
|
mock_conn.execute = AsyncMock()
|
||||||
mock_conn.add_listener = AsyncMock()
|
mock_conn.add_listener = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
with patch(
|
||||||
|
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||||
|
):
|
||||||
await listener.start()
|
await listener.start()
|
||||||
|
|
||||||
handler = listener._on_notify_handler
|
handler = listener._on_notify_handler
|
||||||
|
|
@ -189,14 +199,16 @@ class TestPGNotifyListener:
|
||||||
dsn="postgresql://test",
|
dsn="postgresql://test",
|
||||||
queue="test_queue",
|
queue="test_queue",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
stop_event=stop_event
|
stop_event=stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
mock_conn.execute = AsyncMock()
|
mock_conn.execute = AsyncMock()
|
||||||
mock_conn.add_listener = AsyncMock()
|
mock_conn.add_listener = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
with patch(
|
||||||
|
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||||
|
):
|
||||||
await listener.start()
|
await listener.start()
|
||||||
|
|
||||||
handler = listener._on_notify_handler
|
handler = listener._on_notify_handler
|
||||||
|
|
@ -218,7 +230,7 @@ class TestPGNotifyListener:
|
||||||
dsn="postgresql://test",
|
dsn="postgresql://test",
|
||||||
queue="test_queue",
|
queue="test_queue",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
stop_event=stop_event
|
stop_event=stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
|
|
@ -227,7 +239,9 @@ class TestPGNotifyListener:
|
||||||
mock_conn.remove_listener = AsyncMock()
|
mock_conn.remove_listener = AsyncMock()
|
||||||
mock_conn.close = AsyncMock()
|
mock_conn.close = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
with patch(
|
||||||
|
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||||
|
):
|
||||||
await listener.start()
|
await listener.start()
|
||||||
|
|
||||||
assert listener._task is not None
|
assert listener._task is not None
|
||||||
|
|
@ -250,7 +264,7 @@ class TestPGNotifyListener:
|
||||||
dsn="postgresql://test",
|
dsn="postgresql://test",
|
||||||
queue="test_queue",
|
queue="test_queue",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
stop_event=stop_event
|
stop_event=stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
|
|
@ -259,7 +273,9 @@ class TestPGNotifyListener:
|
||||||
mock_conn.remove_listener = AsyncMock()
|
mock_conn.remove_listener = AsyncMock()
|
||||||
mock_conn.close = AsyncMock()
|
mock_conn.close = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
with patch(
|
||||||
|
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||||
|
):
|
||||||
await listener.start()
|
await listener.start()
|
||||||
|
|
||||||
await listener.stop()
|
await listener.stop()
|
||||||
|
|
@ -280,7 +296,7 @@ class TestPGNotifyListener:
|
||||||
dsn="postgresql://test",
|
dsn="postgresql://test",
|
||||||
queue="test_queue",
|
queue="test_queue",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
stop_event=stop_event
|
stop_event=stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
|
|
@ -289,7 +305,9 @@ class TestPGNotifyListener:
|
||||||
mock_conn.remove_listener = AsyncMock()
|
mock_conn.remove_listener = AsyncMock()
|
||||||
mock_conn.close = AsyncMock()
|
mock_conn.close = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
with patch(
|
||||||
|
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||||
|
):
|
||||||
await listener.start()
|
await listener.start()
|
||||||
|
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
|
|
@ -311,7 +329,7 @@ class TestPGNotifyListener:
|
||||||
dsn="postgresql://test",
|
dsn="postgresql://test",
|
||||||
queue="test_queue",
|
queue="test_queue",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
stop_event=stop_event
|
stop_event=stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
|
|
@ -320,7 +338,9 @@ class TestPGNotifyListener:
|
||||||
mock_conn.remove_listener = AsyncMock(side_effect=Exception("Remove error"))
|
mock_conn.remove_listener = AsyncMock(side_effect=Exception("Remove error"))
|
||||||
mock_conn.close = AsyncMock()
|
mock_conn.close = AsyncMock()
|
||||||
|
|
||||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
with patch(
|
||||||
|
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||||
|
):
|
||||||
await listener.start()
|
await listener.start()
|
||||||
|
|
||||||
await listener.stop()
|
await listener.stop()
|
||||||
|
|
@ -340,7 +360,7 @@ class TestPGNotifyListener:
|
||||||
dsn="postgresql://test",
|
dsn="postgresql://test",
|
||||||
queue="test_queue",
|
queue="test_queue",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
stop_event=stop_event
|
stop_event=stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
|
|
@ -349,7 +369,9 @@ class TestPGNotifyListener:
|
||||||
mock_conn.remove_listener = AsyncMock()
|
mock_conn.remove_listener = AsyncMock()
|
||||||
mock_conn.close = AsyncMock(side_effect=Exception("Close error"))
|
mock_conn.close = AsyncMock(side_effect=Exception("Close error"))
|
||||||
|
|
||||||
with patch("dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn):
|
with patch(
|
||||||
|
"dataloader.storage.notify_listener.asyncpg.connect", return_value=mock_conn
|
||||||
|
):
|
||||||
await listener.start()
|
await listener.start()
|
||||||
|
|
||||||
await listener.stop()
|
await listener.stop()
|
||||||
|
|
@ -368,7 +390,7 @@ class TestPGNotifyListener:
|
||||||
dsn="postgresql://test",
|
dsn="postgresql://test",
|
||||||
queue="test_queue",
|
queue="test_queue",
|
||||||
callback=callback,
|
callback=callback,
|
||||||
stop_event=stop_event
|
stop_event=stop_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
await listener.stop()
|
await listener.stop()
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
# tests/unit/test_pipeline_registry.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from dataloader.workers.pipelines.registry import register, resolve, tasks, _Registry
|
from dataloader.workers.pipelines.registry import _Registry, register, resolve, tasks
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
|
|
@ -22,6 +21,7 @@ class TestPipelineRegistry:
|
||||||
"""
|
"""
|
||||||
Тест регистрации пайплайна.
|
Тест регистрации пайплайна.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@register("test.task")
|
@register("test.task")
|
||||||
def test_pipeline(args: dict):
|
def test_pipeline(args: dict):
|
||||||
return "result"
|
return "result"
|
||||||
|
|
@ -33,6 +33,7 @@ class TestPipelineRegistry:
|
||||||
"""
|
"""
|
||||||
Тест получения зарегистрированного пайплайна.
|
Тест получения зарегистрированного пайплайна.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@register("test.resolve")
|
@register("test.resolve")
|
||||||
def test_pipeline(args: dict):
|
def test_pipeline(args: dict):
|
||||||
return "resolved"
|
return "resolved"
|
||||||
|
|
@ -54,6 +55,7 @@ class TestPipelineRegistry:
|
||||||
"""
|
"""
|
||||||
Тест получения списка зарегистрированных задач.
|
Тест получения списка зарегистрированных задач.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@register("task1")
|
@register("task1")
|
||||||
def pipeline1(args: dict):
|
def pipeline1(args: dict):
|
||||||
pass
|
pass
|
||||||
|
|
@ -70,6 +72,7 @@ class TestPipelineRegistry:
|
||||||
"""
|
"""
|
||||||
Тест перезаписи существующего пайплайна.
|
Тест перезаписи существующего пайплайна.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@register("overwrite.task")
|
@register("overwrite.task")
|
||||||
def first_pipeline(args: dict):
|
def first_pipeline(args: dict):
|
||||||
return "first"
|
return "first"
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
# tests/unit/test_workers_base.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime, timezone
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from dataloader.workers.base import PGWorker, WorkerConfig
|
from dataloader.workers.base import PGWorker, WorkerConfig
|
||||||
|
|
@ -41,9 +40,11 @@ class TestPGWorker:
|
||||||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
|
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \
|
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||||
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls:
|
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg,
|
||||||
|
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_ctx.get_logger.return_value = Mock()
|
mock_ctx.get_logger.return_value = Mock()
|
||||||
mock_ctx.sessionmaker = Mock()
|
mock_ctx.sessionmaker = Mock()
|
||||||
|
|
@ -64,7 +65,9 @@ class TestPGWorker:
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim):
|
with patch.object(
|
||||||
|
worker, "_claim_and_execute_once", side_effect=mock_claim
|
||||||
|
):
|
||||||
await worker.run()
|
await worker.run()
|
||||||
|
|
||||||
assert mock_listener.start.call_count == 1
|
assert mock_listener.start.call_count == 1
|
||||||
|
|
@ -78,9 +81,11 @@ class TestPGWorker:
|
||||||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
|
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=1)
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg, \
|
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||||
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls:
|
patch("dataloader.workers.base.APP_CONFIG") as mock_cfg,
|
||||||
|
patch("dataloader.workers.base.PGNotifyListener") as mock_listener_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_logger = Mock()
|
mock_logger = Mock()
|
||||||
mock_ctx.get_logger.return_value = mock_logger
|
mock_ctx.get_logger.return_value = mock_logger
|
||||||
|
|
@ -101,7 +106,9 @@ class TestPGWorker:
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
with patch.object(worker, "_claim_and_execute_once", side_effect=mock_claim):
|
with patch.object(
|
||||||
|
worker, "_claim_and_execute_once", side_effect=mock_claim
|
||||||
|
):
|
||||||
await worker.run()
|
await worker.run()
|
||||||
|
|
||||||
assert worker._listener is None
|
assert worker._listener is None
|
||||||
|
|
@ -156,8 +163,10 @@ class TestPGWorker:
|
||||||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
mock_session.commit = AsyncMock()
|
mock_session.commit = AsyncMock()
|
||||||
|
|
@ -185,8 +194,10 @@ class TestPGWorker:
|
||||||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
mock_sm = MagicMock()
|
mock_sm = MagicMock()
|
||||||
|
|
@ -196,12 +207,14 @@ class TestPGWorker:
|
||||||
mock_ctx.sessionmaker = mock_sm
|
mock_ctx.sessionmaker = mock_sm
|
||||||
|
|
||||||
mock_repo = Mock()
|
mock_repo = Mock()
|
||||||
mock_repo.claim_one = AsyncMock(return_value={
|
mock_repo.claim_one = AsyncMock(
|
||||||
|
return_value={
|
||||||
"job_id": "test-job-id",
|
"job_id": "test-job-id",
|
||||||
"lease_ttl_sec": 60,
|
"lease_ttl_sec": 60,
|
||||||
"task": "test.task",
|
"task": "test.task",
|
||||||
"args": {"key": "value"}
|
"args": {"key": "value"},
|
||||||
})
|
}
|
||||||
|
)
|
||||||
mock_repo.finish_ok = AsyncMock()
|
mock_repo.finish_ok = AsyncMock()
|
||||||
mock_repo_cls.return_value = mock_repo
|
mock_repo_cls.return_value = mock_repo
|
||||||
|
|
||||||
|
|
@ -210,8 +223,10 @@ class TestPGWorker:
|
||||||
async def mock_pipeline(task, args):
|
async def mock_pipeline(task, args):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
with patch.object(worker, "_pipeline", side_effect=mock_pipeline), \
|
with (
|
||||||
patch.object(worker, "_execute_with_heartbeat", return_value=False):
|
patch.object(worker, "_pipeline", side_effect=mock_pipeline),
|
||||||
|
patch.object(worker, "_execute_with_heartbeat", return_value=False),
|
||||||
|
):
|
||||||
result = await worker._claim_and_execute_once()
|
result = await worker._claim_and_execute_once()
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
@ -225,8 +240,10 @@ class TestPGWorker:
|
||||||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
mock_sm = MagicMock()
|
mock_sm = MagicMock()
|
||||||
|
|
@ -236,12 +253,14 @@ class TestPGWorker:
|
||||||
mock_ctx.sessionmaker = mock_sm
|
mock_ctx.sessionmaker = mock_sm
|
||||||
|
|
||||||
mock_repo = Mock()
|
mock_repo = Mock()
|
||||||
mock_repo.claim_one = AsyncMock(return_value={
|
mock_repo.claim_one = AsyncMock(
|
||||||
|
return_value={
|
||||||
"job_id": "test-job-id",
|
"job_id": "test-job-id",
|
||||||
"lease_ttl_sec": 60,
|
"lease_ttl_sec": 60,
|
||||||
"task": "test.task",
|
"task": "test.task",
|
||||||
"args": {}
|
"args": {},
|
||||||
})
|
}
|
||||||
|
)
|
||||||
mock_repo.finish_fail_or_retry = AsyncMock()
|
mock_repo.finish_fail_or_retry = AsyncMock()
|
||||||
mock_repo_cls.return_value = mock_repo
|
mock_repo_cls.return_value = mock_repo
|
||||||
|
|
||||||
|
|
@ -263,8 +282,10 @@ class TestPGWorker:
|
||||||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
mock_sm = MagicMock()
|
mock_sm = MagicMock()
|
||||||
|
|
@ -274,18 +295,22 @@ class TestPGWorker:
|
||||||
mock_ctx.sessionmaker = mock_sm
|
mock_ctx.sessionmaker = mock_sm
|
||||||
|
|
||||||
mock_repo = Mock()
|
mock_repo = Mock()
|
||||||
mock_repo.claim_one = AsyncMock(return_value={
|
mock_repo.claim_one = AsyncMock(
|
||||||
|
return_value={
|
||||||
"job_id": "test-job-id",
|
"job_id": "test-job-id",
|
||||||
"lease_ttl_sec": 60,
|
"lease_ttl_sec": 60,
|
||||||
"task": "test.task",
|
"task": "test.task",
|
||||||
"args": {}
|
"args": {},
|
||||||
})
|
}
|
||||||
|
)
|
||||||
mock_repo.finish_fail_or_retry = AsyncMock()
|
mock_repo.finish_fail_or_retry = AsyncMock()
|
||||||
mock_repo_cls.return_value = mock_repo
|
mock_repo_cls.return_value = mock_repo
|
||||||
|
|
||||||
worker = PGWorker(cfg, stop_event)
|
worker = PGWorker(cfg, stop_event)
|
||||||
|
|
||||||
with patch.object(worker, "_execute_with_heartbeat", side_effect=ValueError("Test error")):
|
with patch.object(
|
||||||
|
worker, "_execute_with_heartbeat", side_effect=ValueError("Test error")
|
||||||
|
):
|
||||||
result = await worker._claim_and_execute_once()
|
result = await worker._claim_and_execute_once()
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
@ -301,8 +326,10 @@ class TestPGWorker:
|
||||||
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
|
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
mock_sm = MagicMock()
|
mock_sm = MagicMock()
|
||||||
|
|
@ -323,7 +350,9 @@ class TestPGWorker:
|
||||||
await asyncio.sleep(0.6)
|
await asyncio.sleep(0.6)
|
||||||
yield
|
yield
|
||||||
|
|
||||||
canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline())
|
canceled = await worker._execute_with_heartbeat(
|
||||||
|
"job-id", 60, slow_pipeline()
|
||||||
|
)
|
||||||
|
|
||||||
assert canceled is False
|
assert canceled is False
|
||||||
assert mock_repo.heartbeat.call_count >= 1
|
assert mock_repo.heartbeat.call_count >= 1
|
||||||
|
|
@ -336,8 +365,10 @@ class TestPGWorker:
|
||||||
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
|
cfg = WorkerConfig(queue="test", heartbeat_sec=1, claim_backoff_sec=5)
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
mock_sm = MagicMock()
|
mock_sm = MagicMock()
|
||||||
|
|
@ -358,7 +389,9 @@ class TestPGWorker:
|
||||||
await asyncio.sleep(0.6)
|
await asyncio.sleep(0.6)
|
||||||
yield
|
yield
|
||||||
|
|
||||||
canceled = await worker._execute_with_heartbeat("job-id", 60, slow_pipeline())
|
canceled = await worker._execute_with_heartbeat(
|
||||||
|
"job-id", 60, slow_pipeline()
|
||||||
|
)
|
||||||
|
|
||||||
assert canceled is True
|
assert canceled is True
|
||||||
|
|
||||||
|
|
@ -370,8 +403,10 @@ class TestPGWorker:
|
||||||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
|
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
|
||||||
|
):
|
||||||
|
|
||||||
mock_ctx.get_logger.return_value = Mock()
|
mock_ctx.get_logger.return_value = Mock()
|
||||||
mock_ctx.sessionmaker = Mock()
|
mock_ctx.sessionmaker = Mock()
|
||||||
|
|
@ -397,8 +432,10 @@ class TestPGWorker:
|
||||||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
|
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
|
||||||
|
):
|
||||||
|
|
||||||
mock_ctx.get_logger.return_value = Mock()
|
mock_ctx.get_logger.return_value = Mock()
|
||||||
mock_ctx.sessionmaker = Mock()
|
mock_ctx.sessionmaker = Mock()
|
||||||
|
|
@ -424,8 +461,10 @@ class TestPGWorker:
|
||||||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve:
|
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.base.resolve_pipeline") as mock_resolve,
|
||||||
|
):
|
||||||
|
|
||||||
mock_ctx.get_logger.return_value = Mock()
|
mock_ctx.get_logger.return_value = Mock()
|
||||||
mock_ctx.sessionmaker = Mock()
|
mock_ctx.sessionmaker = Mock()
|
||||||
|
|
@ -450,8 +489,10 @@ class TestPGWorker:
|
||||||
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
cfg = WorkerConfig(queue="test", heartbeat_sec=10, claim_backoff_sec=5)
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
mock_sm = MagicMock()
|
mock_sm = MagicMock()
|
||||||
|
|
@ -461,12 +502,14 @@ class TestPGWorker:
|
||||||
mock_ctx.sessionmaker = mock_sm
|
mock_ctx.sessionmaker = mock_sm
|
||||||
|
|
||||||
mock_repo = Mock()
|
mock_repo = Mock()
|
||||||
mock_repo.claim_one = AsyncMock(return_value={
|
mock_repo.claim_one = AsyncMock(
|
||||||
|
return_value={
|
||||||
"job_id": "test-job-id",
|
"job_id": "test-job-id",
|
||||||
"lease_ttl_sec": 60,
|
"lease_ttl_sec": 60,
|
||||||
"task": "test.task",
|
"task": "test.task",
|
||||||
"args": {}
|
"args": {},
|
||||||
})
|
}
|
||||||
|
)
|
||||||
mock_repo.finish_fail_or_retry = AsyncMock()
|
mock_repo.finish_fail_or_retry = AsyncMock()
|
||||||
mock_repo_cls.return_value = mock_repo
|
mock_repo_cls.return_value = mock_repo
|
||||||
|
|
||||||
|
|
@ -484,19 +527,16 @@ class TestPGWorker:
|
||||||
assert "cancelled by shutdown" in args[1]
|
assert "cancelled by shutdown" in args[1]
|
||||||
assert kwargs.get("is_canceled") is True
|
assert kwargs.get("is_canceled") is True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_execute_with_heartbeat_raises_cancelled_when_stop_set(self):
|
async def test_execute_with_heartbeat_raises_cancelled_when_stop_set(self):
|
||||||
cfg = WorkerConfig(queue="test", heartbeat_sec=1000, claim_backoff_sec=5)
|
cfg = WorkerConfig(queue="test", heartbeat_sec=1000, claim_backoff_sec=5)
|
||||||
stop_event = asyncio.Event()
|
stop_event = asyncio.Event()
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
|
|
||||||
with patch("dataloader.workers.base.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls:
|
patch("dataloader.workers.base.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.base.QueueRepository") as mock_repo_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_ctx.get_logger.return_value = Mock()
|
mock_ctx.get_logger.return_value = Mock()
|
||||||
mock_ctx.sessionmaker = Mock()
|
mock_ctx.sessionmaker = Mock()
|
||||||
|
|
@ -509,4 +549,3 @@ class TestPGWorker:
|
||||||
|
|
||||||
with pytest.raises(asyncio.CancelledError):
|
with pytest.raises(asyncio.CancelledError):
|
||||||
await worker._execute_with_heartbeat("job-id", 60, one_yield())
|
await worker._execute_with_heartbeat("job-id", 60, one_yield())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
# tests/unit/test_workers_manager.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from dataloader.workers.manager import WorkerManager, WorkerSpec, build_manager_from_env
|
from dataloader.workers.manager import WorkerManager, WorkerSpec, build_manager_from_env
|
||||||
|
|
@ -39,9 +39,11 @@ class TestWorkerManager:
|
||||||
WorkerSpec(queue="queue2", concurrency=1),
|
WorkerSpec(queue="queue2", concurrency=1),
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
|
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls:
|
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||||
|
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_ctx.get_logger.return_value = Mock()
|
mock_ctx.get_logger.return_value = Mock()
|
||||||
mock_cfg.worker.heartbeat_sec = 10
|
mock_cfg.worker.heartbeat_sec = 10
|
||||||
|
|
@ -67,9 +69,11 @@ class TestWorkerManager:
|
||||||
"""
|
"""
|
||||||
specs = [WorkerSpec(queue="test", concurrency=0)]
|
specs = [WorkerSpec(queue="test", concurrency=0)]
|
||||||
|
|
||||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
|
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls:
|
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||||
|
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_ctx.get_logger.return_value = Mock()
|
mock_ctx.get_logger.return_value = Mock()
|
||||||
mock_cfg.worker.heartbeat_sec = 10
|
mock_cfg.worker.heartbeat_sec = 10
|
||||||
|
|
@ -93,9 +97,11 @@ class TestWorkerManager:
|
||||||
"""
|
"""
|
||||||
specs = [WorkerSpec(queue="test", concurrency=2)]
|
specs = [WorkerSpec(queue="test", concurrency=2)]
|
||||||
|
|
||||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
|
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls:
|
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||||
|
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
|
||||||
|
):
|
||||||
|
|
||||||
mock_ctx.get_logger.return_value = Mock()
|
mock_ctx.get_logger.return_value = Mock()
|
||||||
mock_cfg.worker.heartbeat_sec = 10
|
mock_cfg.worker.heartbeat_sec = 10
|
||||||
|
|
@ -108,8 +114,6 @@ class TestWorkerManager:
|
||||||
manager = WorkerManager(specs)
|
manager = WorkerManager(specs)
|
||||||
await manager.start()
|
await manager.start()
|
||||||
|
|
||||||
initial_task_count = len(manager._tasks)
|
|
||||||
|
|
||||||
await manager.stop()
|
await manager.stop()
|
||||||
|
|
||||||
assert manager._stop.is_set()
|
assert manager._stop.is_set()
|
||||||
|
|
@ -123,10 +127,12 @@ class TestWorkerManager:
|
||||||
"""
|
"""
|
||||||
specs = [WorkerSpec(queue="test", concurrency=1)]
|
specs = [WorkerSpec(queue="test", concurrency=1)]
|
||||||
|
|
||||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
|
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, \
|
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||||
patch("dataloader.workers.manager.requeue_lost") as mock_requeue:
|
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
|
||||||
|
patch("dataloader.workers.manager.requeue_lost") as mock_requeue,
|
||||||
|
):
|
||||||
|
|
||||||
mock_logger = Mock()
|
mock_logger = Mock()
|
||||||
mock_ctx.get_logger.return_value = mock_logger
|
mock_ctx.get_logger.return_value = mock_logger
|
||||||
|
|
@ -161,10 +167,12 @@ class TestWorkerManager:
|
||||||
"""
|
"""
|
||||||
specs = [WorkerSpec(queue="test", concurrency=1)]
|
specs = [WorkerSpec(queue="test", concurrency=1)]
|
||||||
|
|
||||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg, \
|
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||||
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls, \
|
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||||
patch("dataloader.workers.manager.requeue_lost") as mock_requeue:
|
patch("dataloader.workers.manager.PGWorker") as mock_worker_cls,
|
||||||
|
patch("dataloader.workers.manager.requeue_lost") as mock_requeue,
|
||||||
|
):
|
||||||
|
|
||||||
mock_logger = Mock()
|
mock_logger = Mock()
|
||||||
mock_ctx.get_logger.return_value = mock_logger
|
mock_ctx.get_logger.return_value = mock_logger
|
||||||
|
|
@ -203,8 +211,10 @@ class TestBuildManagerFromEnv:
|
||||||
"""
|
"""
|
||||||
Тест создания менеджера из конфигурации.
|
Тест создания менеджера из конфигурации.
|
||||||
"""
|
"""
|
||||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg:
|
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||||
|
):
|
||||||
|
|
||||||
mock_ctx.get_logger.return_value = Mock()
|
mock_ctx.get_logger.return_value = Mock()
|
||||||
mock_cfg.worker.parsed_workers.return_value = [
|
mock_cfg.worker.parsed_workers.return_value = [
|
||||||
|
|
@ -224,8 +234,10 @@ class TestBuildManagerFromEnv:
|
||||||
"""
|
"""
|
||||||
Тест, что пустые имена очередей пропускаются.
|
Тест, что пустые имена очередей пропускаются.
|
||||||
"""
|
"""
|
||||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg:
|
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||||
|
):
|
||||||
|
|
||||||
mock_ctx.get_logger.return_value = Mock()
|
mock_ctx.get_logger.return_value = Mock()
|
||||||
mock_cfg.worker.parsed_workers.return_value = [
|
mock_cfg.worker.parsed_workers.return_value = [
|
||||||
|
|
@ -243,8 +255,10 @@ class TestBuildManagerFromEnv:
|
||||||
"""
|
"""
|
||||||
Тест обработки отсутствующих полей с дефолтными значениями.
|
Тест обработки отсутствующих полей с дефолтными значениями.
|
||||||
"""
|
"""
|
||||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg:
|
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||||
|
):
|
||||||
|
|
||||||
mock_ctx.get_logger.return_value = Mock()
|
mock_ctx.get_logger.return_value = Mock()
|
||||||
mock_cfg.worker.parsed_workers.return_value = [
|
mock_cfg.worker.parsed_workers.return_value = [
|
||||||
|
|
@ -262,8 +276,10 @@ class TestBuildManagerFromEnv:
|
||||||
"""
|
"""
|
||||||
Тест, что concurrency всегда минимум 1.
|
Тест, что concurrency всегда минимум 1.
|
||||||
"""
|
"""
|
||||||
with patch("dataloader.workers.manager.APP_CTX") as mock_ctx, \
|
with (
|
||||||
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg:
|
patch("dataloader.workers.manager.APP_CTX") as mock_ctx,
|
||||||
|
patch("dataloader.workers.manager.APP_CONFIG") as mock_cfg,
|
||||||
|
):
|
||||||
|
|
||||||
mock_ctx.get_logger.return_value = Mock()
|
mock_ctx.get_logger.return_value = Mock()
|
||||||
mock_cfg.worker.parsed_workers.return_value = [
|
mock_cfg.worker.parsed_workers.return_value = [
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
# tests/unit/test_workers_reaper.py
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, patch, Mock
|
|
||||||
|
|
||||||
from dataloader.workers.reaper import requeue_lost
|
from dataloader.workers.reaper import requeue_lost
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue