359 lines
13 KiB
Python
359 lines
13 KiB
Python
# src/dataloader/storage/repositories.py
|
||
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Any, Optional
|
||
|
||
from sqlalchemy import BigInteger, String, Text, select, func, update, DateTime
|
||
from sqlalchemy.dialects.postgresql import JSONB, ENUM, UUID
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import Mapped, mapped_column
|
||
|
||
from dataloader.storage.db import Base
|
||
|
||
|
||
dl_status_enum = ENUM(
|
||
"queued",
|
||
"running",
|
||
"succeeded",
|
||
"failed",
|
||
"canceled",
|
||
"lost",
|
||
name="dl_status",
|
||
create_type=False,
|
||
native_enum=True,
|
||
)
|
||
|
||
|
||
class DLJob(Base):
|
||
"""
|
||
Модель очереди dl_jobs.
|
||
"""
|
||
__tablename__ = "dl_jobs"
|
||
|
||
job_id: Mapped[str] = mapped_column(UUID(as_uuid=False), primary_key=True)
|
||
queue: Mapped[str] = mapped_column(Text, nullable=False)
|
||
task: Mapped[str] = mapped_column(Text, nullable=False)
|
||
args: Mapped[dict[str, Any]] = mapped_column(JSONB, default=dict, nullable=False)
|
||
idempotency_key: Mapped[Optional[str]] = mapped_column(Text, unique=True)
|
||
lock_key: Mapped[str] = mapped_column(Text, nullable=False)
|
||
partition_key: Mapped[str] = mapped_column(Text, default="", nullable=False)
|
||
priority: Mapped[int] = mapped_column(nullable=False, default=100)
|
||
available_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||
status: Mapped[str] = mapped_column(dl_status_enum, nullable=False, default="queued")
|
||
attempt: Mapped[int] = mapped_column(nullable=False, default=0)
|
||
max_attempts: Mapped[int] = mapped_column(nullable=False, default=5)
|
||
lease_ttl_sec: Mapped[int] = mapped_column(nullable=False, default=60)
|
||
lease_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
|
||
heartbeat_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
|
||
cancel_requested: Mapped[bool] = mapped_column(nullable=False, default=False)
|
||
progress: Mapped[dict[str, Any]] = mapped_column(JSONB, default=dict, nullable=False)
|
||
error: Mapped[Optional[str]] = mapped_column(Text)
|
||
producer: Mapped[Optional[str]] = mapped_column(Text)
|
||
consumer_group: Mapped[Optional[str]] = mapped_column(Text)
|
||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
|
||
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
|
||
|
||
|
||
class DLJobEvent(Base):
|
||
"""
|
||
Модель журнала событий dl_job_events.
|
||
"""
|
||
__tablename__ = "dl_job_events"
|
||
|
||
event_id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||
job_id: Mapped[str] = mapped_column(UUID(as_uuid=False), nullable=False)
|
||
queue: Mapped[str] = mapped_column(Text, nullable=False)
|
||
ts: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||
kind: Mapped[str] = mapped_column(Text, nullable=False)
|
||
payload: Mapped[Optional[dict[str, Any]]] = mapped_column(JSONB)
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class CreateJobRequest:
|
||
"""
|
||
Параметры постановки задачи.
|
||
"""
|
||
job_id: str
|
||
queue: str
|
||
task: str
|
||
args: dict[str, Any]
|
||
idempotency_key: Optional[str]
|
||
lock_key: str
|
||
partition_key: str
|
||
priority: int
|
||
available_at: datetime
|
||
max_attempts: int
|
||
lease_ttl_sec: int
|
||
producer: Optional[str]
|
||
consumer_group: Optional[str]
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class JobStatus:
|
||
"""
|
||
Снимок статуса задачи.
|
||
"""
|
||
job_id: str
|
||
status: str
|
||
attempt: int
|
||
started_at: Optional[datetime]
|
||
finished_at: Optional[datetime]
|
||
heartbeat_at: Optional[datetime]
|
||
error: Optional[str]
|
||
progress: dict[str, Any]
|
||
|
||
|
||
class QueueRepository:
|
||
"""
|
||
Репозиторий очереди и событий с полнотой ORM.
|
||
"""
|
||
def __init__(self, session: AsyncSession):
|
||
self.s = session
|
||
|
||
async def create_or_get(self, req: CreateJobRequest) -> tuple[str, str]:
|
||
"""
|
||
Идемпотентно создаёт запись в очереди и возвращает (job_id, status).
|
||
"""
|
||
if req.idempotency_key:
|
||
q = select(DLJob).where(DLJob.idempotency_key == req.idempotency_key)
|
||
r = await self.s.execute(q)
|
||
ex = r.scalar_one_or_none()
|
||
if ex:
|
||
return ex.job_id, ex.status
|
||
|
||
row = DLJob(
|
||
job_id=req.job_id,
|
||
queue=req.queue,
|
||
task=req.task,
|
||
args=req.args or {},
|
||
idempotency_key=req.idempotency_key,
|
||
lock_key=req.lock_key,
|
||
partition_key=req.partition_key or "",
|
||
priority=req.priority,
|
||
available_at=req.available_at,
|
||
status="queued",
|
||
attempt=0,
|
||
max_attempts=req.max_attempts,
|
||
lease_ttl_sec=req.lease_ttl_sec,
|
||
lease_expires_at=None,
|
||
heartbeat_at=None,
|
||
cancel_requested=False,
|
||
progress={},
|
||
error=None,
|
||
producer=req.producer,
|
||
consumer_group=req.consumer_group,
|
||
created_at=datetime.now(timezone.utc),
|
||
started_at=None,
|
||
finished_at=None,
|
||
)
|
||
self.s.add(row)
|
||
await self._append_event(req.job_id, req.queue, "queued", {"task": req.task})
|
||
await self.s.commit()
|
||
return req.job_id, "queued"
|
||
|
||
async def get_status(self, job_id: str) -> Optional[JobStatus]:
|
||
"""
|
||
Возвращает статус задачи.
|
||
"""
|
||
q = select(
|
||
DLJob.job_id,
|
||
DLJob.status,
|
||
DLJob.attempt,
|
||
DLJob.started_at,
|
||
DLJob.finished_at,
|
||
DLJob.heartbeat_at,
|
||
DLJob.error,
|
||
DLJob.progress,
|
||
).where(DLJob.job_id == job_id)
|
||
r = await self.s.execute(q)
|
||
m = r.first()
|
||
if not m:
|
||
return None
|
||
return JobStatus(
|
||
job_id=m.job_id,
|
||
status=m.status,
|
||
attempt=m.attempt,
|
||
started_at=m.started_at,
|
||
finished_at=m.finished_at,
|
||
heartbeat_at=m.heartbeat_at,
|
||
error=m.error,
|
||
progress=m.progress or {},
|
||
)
|
||
|
||
async def cancel(self, job_id: str) -> bool:
|
||
"""
|
||
Устанавливает флаг отмены для задачи.
|
||
"""
|
||
q = update(DLJob).where(DLJob.job_id == job_id).values(cancel_requested=True)
|
||
await self.s.execute(q)
|
||
await self._append_event(job_id, await self._resolve_queue(job_id), "cancel_requested", None)
|
||
await self.s.commit()
|
||
return True
|
||
|
||
async def claim_one(self, queue: str) -> Optional[dict[str, Any]]:
|
||
"""
|
||
Захватывает одну задачу из очереди с учётом блокировок и выставляет running.
|
||
"""
|
||
async with self.s.begin():
|
||
q = (
|
||
select(DLJob)
|
||
.where(
|
||
DLJob.status == "queued",
|
||
DLJob.queue == queue,
|
||
DLJob.available_at <= func.now(),
|
||
)
|
||
.order_by(DLJob.priority.asc(), DLJob.created_at.asc())
|
||
.with_for_update(skip_locked=True)
|
||
.limit(1)
|
||
)
|
||
r = await self.s.execute(q)
|
||
job: Optional[DLJob] = r.scalar_one_or_none()
|
||
if not job:
|
||
return None
|
||
|
||
job.status = "running"
|
||
job.started_at = job.started_at or datetime.now(timezone.utc)
|
||
job.attempt = int(job.attempt) + 1
|
||
job.heartbeat_at = datetime.now(timezone.utc)
|
||
job.lease_expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(job.lease_ttl_sec))
|
||
|
||
ok = await self._try_advisory_lock(job.lock_key)
|
||
if not ok:
|
||
job.status = "queued"
|
||
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=15)
|
||
return None
|
||
|
||
await self._append_event(job.job_id, job.queue, "picked", {"attempt": job.attempt})
|
||
|
||
return {
|
||
"job_id": job.job_id,
|
||
"queue": job.queue,
|
||
"task": job.task,
|
||
"args": job.args or {},
|
||
"lock_key": job.lock_key,
|
||
"partition_key": job.partition_key or "",
|
||
"lease_ttl_sec": int(job.lease_ttl_sec),
|
||
"attempt": int(job.attempt),
|
||
}
|
||
|
||
async def heartbeat(self, job_id: str, ttl_sec: int) -> None:
|
||
"""
|
||
Обновляет heartbeat и продлевает lease.
|
||
"""
|
||
now = datetime.now(timezone.utc)
|
||
q = (
|
||
update(DLJob)
|
||
.where(DLJob.job_id == job_id, DLJob.status == "running")
|
||
.values(heartbeat_at=now, lease_expires_at=now + timedelta(seconds=int(ttl_sec)))
|
||
)
|
||
await self.s.execute(q)
|
||
await self._append_event(job_id, await self._resolve_queue(job_id), "heartbeat", {"ttl": ttl_sec})
|
||
await self.s.commit()
|
||
|
||
async def finish_ok(self, job_id: str) -> None:
|
||
"""
|
||
Помечает задачу как выполненную успешно и снимает advisory-lock.
|
||
"""
|
||
job = await self._get(job_id)
|
||
if not job:
|
||
return
|
||
job.status = "succeeded"
|
||
job.finished_at = datetime.now(timezone.utc)
|
||
job.lease_expires_at = None
|
||
await self._append_event(job_id, job.queue, "succeeded", None)
|
||
await self._advisory_unlock(job.lock_key)
|
||
await self.s.commit()
|
||
|
||
async def finish_fail_or_retry(self, job_id: str, err: str) -> None:
|
||
"""
|
||
Помечает задачу как failed или возвращает в очередь с задержкой.
|
||
"""
|
||
job = await self._get(job_id)
|
||
if not job:
|
||
return
|
||
can_retry = int(job.attempt) < int(job.max_attempts)
|
||
if can_retry:
|
||
job.status = "queued"
|
||
job.available_at = datetime.now(timezone.utc) + timedelta(seconds=30 * int(job.attempt))
|
||
job.error = err
|
||
job.lease_expires_at = None
|
||
await self._append_event(job_id, job.queue, "requeue", {"attempt": job.attempt, "error": err})
|
||
else:
|
||
job.status = "failed"
|
||
job.error = err
|
||
job.finished_at = datetime.now(timezone.utc)
|
||
job.lease_expires_at = None
|
||
await self._append_event(job_id, job.queue, "failed", {"error": err})
|
||
await self._advisory_unlock(job.lock_key)
|
||
await self.s.commit()
|
||
|
||
async def requeue_lost(self, now: Optional[datetime] = None) -> list[str]:
|
||
"""
|
||
Возвращает протухшие running-задачи в очередь.
|
||
"""
|
||
now = now or datetime.now(timezone.utc)
|
||
async with self.s.begin():
|
||
q = (
|
||
select(DLJob)
|
||
.where(
|
||
DLJob.status == "running",
|
||
DLJob.lease_expires_at.is_not(None),
|
||
DLJob.lease_expires_at < now,
|
||
)
|
||
.with_for_update(skip_locked=True)
|
||
)
|
||
r = await self.s.execute(q)
|
||
rows = list(r.scalars().all())
|
||
ids: list[str] = []
|
||
for job in rows:
|
||
job.status = "queued"
|
||
job.available_at = now
|
||
job.lease_expires_at = None
|
||
ids.append(job.job_id)
|
||
await self._append_event(job.job_id, job.queue, "requeue_lost", None)
|
||
return ids
|
||
|
||
async def _get(self, job_id: str) -> Optional[DLJob]:
|
||
"""
|
||
Возвращает ORM-объект задачи.
|
||
"""
|
||
r = await self.s.execute(select(DLJob).where(DLJob.job_id == job_id).with_for_update(skip_locked=True))
|
||
return r.scalar_one_or_none()
|
||
|
||
async def _resolve_queue(self, job_id: str) -> str:
|
||
"""
|
||
Возвращает имя очереди для события.
|
||
"""
|
||
r = await self.s.execute(select(DLJob.queue).where(DLJob.job_id == job_id))
|
||
v = r.scalar_one_or_none()
|
||
return v or ""
|
||
|
||
async def _append_event(self, job_id: str, queue: str, kind: str, payload: Optional[dict[str, Any]]) -> None:
|
||
"""
|
||
Добавляет запись в журнал событий.
|
||
"""
|
||
ev = DLJobEvent(
|
||
job_id=job_id,
|
||
queue=queue or "",
|
||
ts=datetime.now(timezone.utc),
|
||
kind=kind,
|
||
payload=payload or None,
|
||
)
|
||
self.s.add(ev)
|
||
|
||
async def _try_advisory_lock(self, lock_key: str) -> bool:
|
||
"""
|
||
Пытается получить advisory-lock в Postgres.
|
||
"""
|
||
r = await self.s.execute(select(func.pg_try_advisory_lock(func.hashtext(lock_key))))
|
||
return bool(r.scalar())
|
||
|
||
async def _advisory_unlock(self, lock_key: str) -> None:
|
||
"""
|
||
Снимает advisory-lock в Postgres.
|
||
"""
|
||
await self.s.execute(select(func.pg_advisory_unlock(func.hashtext(lock_key))))
|