124 lines
3.2 KiB
Python
124 lines
3.2 KiB
Python
"""Share repository for database operations."""
|
|
|
|
import secrets
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.domain.models import Share
|
|
|
|
|
|
class ShareRepository:
|
|
"""Repository for share database operations."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
"""
|
|
Initialize share repository.
|
|
|
|
Args:
|
|
session: Database session
|
|
"""
|
|
self.session = session
|
|
|
|
def _generate_token(self) -> str:
|
|
"""Generate a secure random token."""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
async def create(
|
|
self,
|
|
owner_user_id: str,
|
|
asset_id: Optional[str] = None,
|
|
album_id: Optional[str] = None,
|
|
expires_in_seconds: Optional[int] = None,
|
|
password_hash: Optional[str] = None,
|
|
) -> Share:
|
|
"""
|
|
Create a new share link.
|
|
|
|
Args:
|
|
owner_user_id: Owner user ID
|
|
asset_id: Optional asset ID
|
|
album_id: Optional album ID
|
|
expires_in_seconds: Optional expiration time in seconds
|
|
password_hash: Optional password hash
|
|
|
|
Returns:
|
|
Created share instance
|
|
"""
|
|
token = self._generate_token()
|
|
expires_at = None
|
|
if expires_in_seconds:
|
|
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds)
|
|
|
|
share = Share(
|
|
owner_user_id=owner_user_id,
|
|
asset_id=asset_id,
|
|
album_id=album_id,
|
|
token=token,
|
|
expires_at=expires_at,
|
|
password_hash=password_hash,
|
|
)
|
|
self.session.add(share)
|
|
await self.session.flush()
|
|
await self.session.refresh(share)
|
|
return share
|
|
|
|
async def get_by_id(self, share_id: str) -> Optional[Share]:
|
|
"""
|
|
Get share by ID.
|
|
|
|
Args:
|
|
share_id: Share ID
|
|
|
|
Returns:
|
|
Share instance or None if not found
|
|
"""
|
|
result = await self.session.execute(select(Share).where(Share.id == share_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_token(self, token: str) -> Optional[Share]:
|
|
"""
|
|
Get share by token.
|
|
|
|
Args:
|
|
token: Share token
|
|
|
|
Returns:
|
|
Share instance or None if not found
|
|
"""
|
|
result = await self.session.execute(select(Share).where(Share.token == token))
|
|
return result.scalar_one_or_none()
|
|
|
|
async def revoke(self, share: Share) -> Share:
|
|
"""
|
|
Revoke a share link.
|
|
|
|
Args:
|
|
share: Share to revoke
|
|
|
|
Returns:
|
|
Updated share
|
|
"""
|
|
share.revoked_at = datetime.now(timezone.utc)
|
|
await self.session.flush()
|
|
await self.session.refresh(share)
|
|
return share
|
|
|
|
def is_valid(self, share: Share) -> bool:
|
|
"""
|
|
Check if a share is valid (not revoked or expired).
|
|
|
|
Args:
|
|
share: Share to check
|
|
|
|
Returns:
|
|
True if valid, False otherwise
|
|
"""
|
|
if share.revoked_at:
|
|
return False
|
|
if share.expires_at and share.expires_at < datetime.now(timezone.utc):
|
|
return False
|
|
return True
|