dataloader/src/dataloader/storage/repositories.py

359 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# src/dataloader/storage/repositories.py
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
from sqlalchemy import BigInteger, String, Text, select, func, update, DateTime
from sqlalchemy.dialects.postgresql import JSONB, ENUM, UUID
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column
from dataloader.storage.db import Base
dl_status_enum = ENUM(
"queued",
"running",
"succeeded",
"failed",
"canceled",
"lost",
name="dl_status",
create_type=False,
native_enum=True,
)
class DLJob(Base):
"""
Модель очереди dl_jobs.
"""
__tablename__ = "dl_jobs"
job_id: Mapped[str] = mapped_column(UUID(as_uuid=False), primary_key=True)
queue: Mapped[str] = mapped_column(Text, nullable=False)
task: Mapped[str] = mapped_column(Text, nullable=False)
args: Mapped[dict[str, Any]] = mapped_column(JSONB, default=dict, nullable=False)
idempotency_key: Mapped[Optional[str]] = mapped_column(Text, unique=True)
lock_key: Mapped[str] = mapped_column(Text, nullable=False)
partition_key: Mapped[str] = mapped_column(Text, default="", nullable=False)
priority: Mapped[int] = mapped_column(nullable=False, default=100)
available_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
status: Mapped[str] = mapped_column(dl_status_enum, nullable=False, default="queued")
attempt: Mapped[int] = mapped_column(nullable=False, default=0)
max_attempts: Mapped[int] = mapped_column(nullable=False, default=5)
lease_ttl_sec: Mapped[int] = mapped_column(nullable=False, default=60)
lease_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
heartbeat_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
cancel_requested: Mapped[bool] = mapped_column(nullable=False, default=False)
progress: Mapped[dict[str, Any]] = mapped_column(JSONB, default=dict, nullable=False)
error: Mapped[Optional[str]] = mapped_column(Text)
producer: Mapped[Optional[str]] = mapped_column(Text)
consumer_group: Mapped[Optional[str]] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
class DLJobEvent(Base):
"""
Модель журнала событий dl_job_events.
"""
__tablename__ = "dl_job_events"
event_id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
job_id: Mapped[str] = mapped_column(UUID(as_uuid=False), nullable=False)
queue: Mapped[str] = mapped_column(Text, nullable=False)
ts: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
kind: Mapped[str] = mapped_column(Text, nullable=False)
payload: Mapped[Optional[dict[str, Any]]] = mapped_column(JSONB)
@dataclass(frozen=True)
class CreateJobRequest:
"""
Параметры постановки задачи.
"""
job_id: str
queue: str
task: str
args: dict[str, Any]
idempotency_key: Optional[str]
lock_key: str
partition_key: str
priority: int
available_at: datetime
max_attempts: int
lease_ttl_sec: int
producer: Optional[str]
consumer_group: Optional[str]
@dataclass(frozen=True)
class JobStatus:
"""
Снимок статуса задачи.
"""
job_id: str
status: str
attempt: int
started_at: Optional[datetime]
finished_at: Optional[datetime]
heartbeat_at: Optional[datetime]
error: Optional[str]
progress: dict[str, Any]
class QueueRepository:
"""
Репозиторий очереди и событий с полнотой ORM.
"""
def __init__(self, session: AsyncSession):
self.s = session
async def create_or_get(self, req: CreateJobRequest) -> tuple[str, str]:
"""
Идемпотентно создаёт запись в очереди и возвращает (job_id, status).
"""
if req.idempotency_key:
q = select(DLJob).where(DLJob.idempotency_key == req.idempotency_key)
r = await self.s.execute(q)
ex = r.scalar_one_or_none()
if ex:
return ex.job_id, ex.status
row = DLJob(
job_id=req.job_id,
queue=req.queue,
task=req.task,
args=req.args or {},
idempotency_key=req.idempotency_key,
lock_key=req.lock_key,
partition_key=req.partition_key or "",
priority=req.priority,
available_at=req.available_at,
status="queued",
attempt=0,
max_attempts=req.max_attempts,
lease_ttl_sec=req.lease_ttl_sec,
lease_expires_at=None,
heartbeat_at=None,
cancel_requested=False,
progress={},
error=None,
producer=req.producer,
consumer_group=req.consumer_group,
created_at=datetime.now(timezone.utc),
started_at=None,
finished_at=None,
)
self.s.add(row)
await self._append_event(req.job_id, req.queue, "queued", {"task": req.task})
await self.s.commit()
return req.job_id, "queued"
async def get_status(self, job_id: str) -> Optional[JobStatus]:
"""
Возвращает статус задачи.
"""
q = select(
DLJob.job_id,
DLJob.status,
DLJob.attempt,
DLJob.started_at,
DLJob.finished_at,
DLJob.heartbeat_at,
DLJob.error,
DLJob.progress,
).where(DLJob.job_id == job_id)
r = await self.s.execute(q)
m = r.first()
if not m:
return None
return JobStatus(
job_id=m.job_id,
status=m.status,
attempt=m.attempt,
started_at=m.started_at,
finished_at=m.finished_at,
heartbeat_at=m.heartbeat_at,
error=m.error,
progress=m.progress or {},
)
async def cancel(self, job_id: str) -> bool:
"""
Устанавливает флаг отмены для задачи.
"""
q = update(DLJob).where(DLJob.job_id == job_id).values(cancel_requested=True)
await self.s.execute(q)
await self._append_event(job_id, await self._resolve_queue(job_id), "cancel_requested", None)
await self.s.commit()
return True
async def claim_one(self, queue: str) -> Optional[dict[str, Any]]:
"""
Захватывает одну задачу из очереди с учётом блокировок и выставляет running.
"""
async with self.s.begin():
q = (
select(DLJob)
.where(
DLJob.status == "queued",
DLJob.queue == queue,
DLJob.available_at <= func.now(),
)
.order_by(DLJob.priority.asc(), DLJob.created_at.asc())
.with_for_update(skip_locked=True)
.limit(1)
)
r = await self.s.execute(q)
job: Optional[DLJob] = r.scalar_one_or_none()
if not job:
return None
job.status = "running"
job.started_at = job.started_at or datetime.now(timezone.utc)
job.attempt = int(job.attempt) + 1
job.heartbeat_at = datetime.now(timezone.utc)
job.lease_expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(job.lease_ttl_sec))
ok = await self._try_advisory_lock(job.lock_key)
if not ok:
job.status = "queued"
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=15)
return None
await self._append_event(job.job_id, job.queue, "picked", {"attempt": job.attempt})
return {
"job_id": job.job_id,
"queue": job.queue,
"task": job.task,
"args": job.args or {},
"lock_key": job.lock_key,
"partition_key": job.partition_key or "",
"lease_ttl_sec": int(job.lease_ttl_sec),
"attempt": int(job.attempt),
}
async def heartbeat(self, job_id: str, ttl_sec: int) -> None:
"""
Обновляет heartbeat и продлевает lease.
"""
now = datetime.now(timezone.utc)
q = (
update(DLJob)
.where(DLJob.job_id == job_id, DLJob.status == "running")
.values(heartbeat_at=now, lease_expires_at=now + timedelta(seconds=int(ttl_sec)))
)
await self.s.execute(q)
await self._append_event(job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec})
await self.s.commit()
async def finish_ok(self, job_id: str) -> None:
"""
Помечает задачу как выполненную успешно и снимает advisory-lock.
"""
job = await self._get(job_id)
if not job:
return
job.status = "succeeded"
job.finished_at = datetime.now(timezone.utc)
job.lease_expires_at = None
await self._append_event(job_id, job.queue, "succeeded", None)
await self._advisory_unlock(job.lock_key)
await self.s.commit()
async def finish_fail_or_retry(self, job_id: str, err: str) -> None:
"""
Помечает задачу как failed или возвращает в очередь с задержкой.
"""
job = await self._get(job_id)
if not job:
return
can_retry = int(job.attempt) < int(job.max_attempts)
if can_retry:
job.status = "queued"
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=30 * int(job.attempt))
job.error = err
job.lease_expires_at = None
await self._append_event(job_id, job.queue, "requeue", {"attempt": job.attempt, "error": err})
else:
job.status = "failed"
job.error = err
job.finished_at = datetime.now(timezone.utc)
job.lease_expires_at = None
await self._append_event(job_id, job.queue, "failed", {"error": err})
await self._advisory_unlock(job.lock_key)
await self.s.commit()
async def requeue_lost(self, now: Optional[datetime] = None) -> list[str]:
"""
Возвращает протухшие running-задачи в очередь.
"""
now = now or datetime.now(timezone.utc)
async with self.s.begin():
q = (
select(DLJob)
.where(
DLJob.status == "running",
DLJob.lease_expires_at.is_not(None),
DLJob.lease_expires_at < now,
)
.with_for_update(skip_locked=True)
)
r = await self.s.execute(q)
rows = list(r.scalars().all())
ids: list[str] = []
for job in rows:
job.status = "queued"
job.available_at = now
job.lease_expires_at = None
ids.append(job.job_id)
await self._append_event(job.job_id, job.queue, "requeue_lost", None)
return ids
async def _get(self, job_id: str) -> Optional[DLJob]:
"""
Возвращает ORM-объект задачи.
"""
r = await self.s.execute(select(DLJob).where(DLJob.job_id == job_id).with_for_update(skip_locked=True))
return r.scalar_one_or_none()
async def _resolve_queue(self, job_id: str) -> str:
"""
Возвращает имя очереди для события.
"""
r = await self.s.execute(select(DLJob.queue).where(DLJob.job_id == job_id))
v = r.scalar_one_or_none()
return v or ""
async def _append_event(self, job_id: str, queue: str, kind: str, payload: Optional[dict[str, Any]]) -> None:
"""
Добавляет запись в журнал событий.
"""
ev = DLJobEvent(
job_id=job_id,
queue=queue or "",
ts=datetime.now(timezone.utc),
kind=kind,
payload=payload or None,
)
self.s.add(ev)
async def _try_advisory_lock(self, lock_key: str) -> bool:
"""
Пытается получить advisory-lock в Postgres.
"""
r = await self.s.execute(select(func.pg_try_advisory_lock(func.hashtext(lock_key))))
return bool(r.scalar())
async def _advisory_unlock(self, lock_key: str) -> None:
"""
Снимает advisory-lock в Postgres.
"""
await self.s.execute(select(func.pg_advisory_unlock(func.hashtext(lock_key))))