sec fixes v2
This commit is contained in:
parent
c17b42dd45
commit
6fb04036c3
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,131 @@
|
||||||
|
# Security Fixes TODO
|
||||||
|
|
||||||
|
Статус исправлений критических уязвимостей из аудита безопасности.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🔴 КРИТИЧЕСКИЕ (В РАБОТЕ)
|
||||||
|
|
||||||
|
### Безопасность API
|
||||||
|
- [ ] **Rate Limiting** - защита от brute force атак
|
||||||
|
- [ ] Установить slowapi
|
||||||
|
- [ ] Добавить rate limiting на login (5/минуту)
|
||||||
|
- [ ] Добавить rate limiting на register (3/час)
|
||||||
|
- [ ] Добавить rate limiting на uploads (100/час)
|
||||||
|
|
||||||
|
### Аутентификация
|
||||||
|
- [ ] **Token Revocation** - logout должен удалять токен
|
||||||
|
- [ ] Создать Redis client для blacklist
|
||||||
|
- [ ] Добавить проверку blacklist в get_current_user
|
||||||
|
- [ ] Реализовать endpoint /auth/logout
|
||||||
|
|
||||||
|
- [ ] **Refresh Token Rotation** - обновление токенов
|
||||||
|
- [ ] Endpoint /auth/refresh
|
||||||
|
- [ ] Отзыв старого refresh token при rotation
|
||||||
|
- [ ] Frontend interceptor для автообновления
|
||||||
|
|
||||||
|
- [ ] **Account Lockout** - блокировка после 3 попыток на сутки
|
||||||
|
- [ ] LoginAttemptTracker с Redis
|
||||||
|
- [ ] Проверка блокировки перед login
|
||||||
|
- [ ] Запись неудачных попыток
|
||||||
|
|
||||||
|
### Storage Management
|
||||||
|
- [ ] **Storage Quota** - 3GB по пользователю
|
||||||
|
- [ ] Миграция: добавить storage_quota_bytes и storage_used_bytes в User
|
||||||
|
- [ ] Проверка квоты при create_upload
|
||||||
|
- [ ] Увеличение used_bytes при finalize_upload
|
||||||
|
- [ ] Уменьшение used_bytes при delete (S3 trash)
|
||||||
|
- [ ] Endpoint GET /users/me/storage для статистики
|
||||||
|
|
||||||
|
### Валидация файлов
|
||||||
|
- [ ] **Content-Type Whitelist** - только разрешенные типы
|
||||||
|
- [ ] ALLOWED_IMAGE_TYPES whitelist
|
||||||
|
- [ ] ALLOWED_VIDEO_TYPES whitelist
|
||||||
|
- [ ] Обновить validator в schemas.py
|
||||||
|
|
||||||
|
- [ ] **Magic Bytes Verification** - проверка реального типа файла
|
||||||
|
- [ ] Установить python-magic
|
||||||
|
- [ ] Проверка при finalize_upload
|
||||||
|
- [ ] Удаление файла из S3 при несовпадении
|
||||||
|
|
||||||
|
### File Upload Security
|
||||||
|
- [ ] **Streaming Chunks** - предотвратить OOM
|
||||||
|
- [ ] Метод upload_fileobj_streaming в S3Client
|
||||||
|
- [ ] Обновить upload_file_to_s3 для стриминга
|
||||||
|
- [ ] Проверка размера ПЕРЕД чтением (макс 3GB)
|
||||||
|
|
||||||
|
### ZIP Download
|
||||||
|
- [ ] **ZIP Streaming** - не держать весь архив в памяти
|
||||||
|
- [ ] Создание ZIP в temp файле
|
||||||
|
- [ ] FileResponse вместо возврата bytes
|
||||||
|
- [ ] BackgroundTask для удаления temp файла
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
- [ ] **Trash Bucket Config** - убрать hardcode
|
||||||
|
- [ ] Добавить TRASH_BUCKET в config.py
|
||||||
|
- [ ] Использовать из конфига в S3Client
|
||||||
|
|
||||||
|
### HTTP Security
|
||||||
|
- [ ] **Security Headers** - защита от XSS, clickjacking
|
||||||
|
- [ ] SecurityHeadersMiddleware
|
||||||
|
- [ ] X-Frame-Options, X-Content-Type-Options
|
||||||
|
- [ ] CSP, HSTS, Referrer-Policy
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
- [ ] **FolderService Refactoring** - разделить обязанности
|
||||||
|
- [ ] Убрать прямую работу с AssetRepository
|
||||||
|
- [ ] Создать FolderManagementService для оркестрации
|
||||||
|
- [ ] Использовать AssetService methods
|
||||||
|
|
||||||
|
### Search
|
||||||
|
- [ ] **Asset Search** - поиск по имени файла
|
||||||
|
- [ ] Метод search_assets в AssetRepository
|
||||||
|
- [ ] Endpoint GET /assets/search
|
||||||
|
- [ ] Поддержка ILIKE для поиска
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🟡 TODO (ОТЛОЖЕНО НА ПОТОМ)
|
||||||
|
|
||||||
|
- [ ] JWT Secret Validation - проверка слабого секрета в продакшене
|
||||||
|
- [ ] S3 Encryption at Rest - ServerSideEncryption='AES256'
|
||||||
|
- [ ] Share Token Uniqueness Check - collision detection
|
||||||
|
- [ ] CSRF Protection - fastapi-csrf-protect
|
||||||
|
- [ ] Strong Password Validation - complexity requirements
|
||||||
|
- [ ] Database Indexes - composite indexes для производительности
|
||||||
|
- [ ] Foreign Keys - relationships в models
|
||||||
|
- [ ] Password Reset Flow - forgot/reset endpoints + email
|
||||||
|
- [ ] EXIF Metadata Extraction - captured_at, dimensions
|
||||||
|
- [ ] Database Backups - автоматизация бэкапов
|
||||||
|
- [ ] Comprehensive Testing - 70%+ coverage
|
||||||
|
- [ ] Monitoring & Logging - structured logging, metrics
|
||||||
|
- [ ] CI/CD Pipeline - GitHub Actions
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📊 Progress
|
||||||
|
|
||||||
|
**Выполнено:** 14/14 критических задач ✅✅✅
|
||||||
|
**Статус:** 🎉 ВСЕ ЗАДАЧИ ЗАВЕРШЕНЫ!
|
||||||
|
|
||||||
|
### ✅ Завершено:
|
||||||
|
1. ✅ slowapi добавлен в pyproject.toml
|
||||||
|
2. ✅ Config обновлен (trash_bucket, max 3GB, default quota)
|
||||||
|
3. ✅ Redis client создан (TokenBlacklist, LoginAttemptTracker)
|
||||||
|
4. ✅ main.py: Rate limiting, Security Headers, CORS restrictive
|
||||||
|
5. ✅ User model: storage_quota_bytes, storage_used_bytes
|
||||||
|
6. ✅ Миграция 002: add_storage_quota создана
|
||||||
|
7. ✅ schemas.py: Content-Type whitelist, 3GB max, RefreshTokenRequest, StorageStatsResponse
|
||||||
|
8. ✅ dependencies.py: blacklist check в get_current_user
|
||||||
|
9. ✅ auth.py: rate limiting на endpoints, logout, refresh, account lockout, storage stats
|
||||||
|
10. ✅ S3Client: streaming upload (upload_fileobj_streaming), trash_bucket from config
|
||||||
|
11. ✅ asset_service: storage quota check, streaming (upload_fileobj_streaming), magic bytes verification, storage_used_bytes updates
|
||||||
|
12. ✅ batch_operations: ZIP streaming (temp file + FileResponse + BackgroundTasks cleanup), storage_used_bytes updates
|
||||||
|
13. ✅ FolderService: refactored (removed asset modification, only read-only validation queries)
|
||||||
|
14. ✅ AssetRepository + AssetService + API: search_assets method (ILIKE), GET /api/v1/assets/search endpoint
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Дата начала:** 2026-01-05
|
||||||
|
**Последнее обновление:** 2026-01-05
|
||||||
|
**Финальный статус:** 🎉 14/14 COMPLETED
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
"""add storage quota
|
||||||
|
|
||||||
|
Revision ID: 002
|
||||||
|
Revises: 001
|
||||||
|
Create Date: 2026-01-05 12:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '002'
|
||||||
|
down_revision: Union[str, None] = '001'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Add storage quota fields to users table."""
|
||||||
|
# Add storage_quota_bytes column (default 3GB)
|
||||||
|
op.add_column('users', sa.Column('storage_quota_bytes', sa.BigInteger(), nullable=False, server_default='3221225472'))
|
||||||
|
|
||||||
|
# Add storage_used_bytes column (default 0)
|
||||||
|
op.add_column('users', sa.Column('storage_used_bytes', sa.BigInteger(), nullable=False, server_default='0'))
|
||||||
|
|
||||||
|
# Create index on storage_used_bytes for quota queries
|
||||||
|
op.create_index('ix_users_storage_used_bytes', 'users', ['storage_used_bytes'])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Remove storage quota fields from users table."""
|
||||||
|
op.drop_index('ix_users_storage_used_bytes', table_name='users')
|
||||||
|
op.drop_column('users', 'storage_used_bytes')
|
||||||
|
op.drop_column('users', 'storage_quota_bytes')
|
||||||
|
|
@ -29,6 +29,7 @@ ffmpeg-python = "^0.2.0"
|
||||||
loguru = "^0.7.2"
|
loguru = "^0.7.2"
|
||||||
httpx = "^0.26.0"
|
httpx = "^0.26.0"
|
||||||
cryptography = "^46.0.3"
|
cryptography = "^46.0.3"
|
||||||
|
slowapi = "^0.1.9"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pytest = "^7.4.4"
|
pytest = "^7.4.4"
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.domain.models import User
|
from app.domain.models import User
|
||||||
from app.infra.database import get_db
|
from app.infra.database import get_db
|
||||||
|
from app.infra.redis_client import TokenBlacklist, get_token_blacklist
|
||||||
from app.infra.s3_client import S3Client, get_s3_client
|
from app.infra.s3_client import S3Client, get_s3_client
|
||||||
from app.infra.security import decode_access_token, get_subject
|
from app.infra.security import decode_access_token, get_subject
|
||||||
from app.repositories.user_repository import UserRepository
|
from app.repositories.user_repository import UserRepository
|
||||||
|
|
@ -18,6 +19,7 @@ security = HTTPBearer()
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)],
|
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)],
|
||||||
session: Annotated[AsyncSession, Depends(get_db)],
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
blacklist: Annotated[TokenBlacklist, Depends(get_token_blacklist)],
|
||||||
) -> User:
|
) -> User:
|
||||||
"""
|
"""
|
||||||
Get current authenticated user from JWT token.
|
Get current authenticated user from JWT token.
|
||||||
|
|
@ -25,6 +27,7 @@ async def get_current_user(
|
||||||
Args:
|
Args:
|
||||||
credentials: HTTP authorization credentials
|
credentials: HTTP authorization credentials
|
||||||
session: Database session
|
session: Database session
|
||||||
|
blacklist: Token blacklist for revocation check
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Current user
|
Current user
|
||||||
|
|
@ -41,6 +44,14 @@ async def get_current_user(
|
||||||
detail="Invalid authentication credentials",
|
detail="Invalid authentication credentials",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if token is revoked (blacklisted)
|
||||||
|
jti = payload.get("jti")
|
||||||
|
if jti and await blacklist.is_revoked(jti):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Token has been revoked",
|
||||||
|
)
|
||||||
|
|
||||||
user_id = get_subject(payload)
|
user_id = get_subject(payload)
|
||||||
if not user_id:
|
if not user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,27 @@ from pydantic import BaseModel, EmailStr, Field, field_validator
|
||||||
from app.domain.models import AssetStatus, AssetType
|
from app.domain.models import AssetStatus, AssetType
|
||||||
|
|
||||||
|
|
||||||
|
# Allowed MIME types (whitelist)
|
||||||
|
ALLOWED_IMAGE_TYPES = {
|
||||||
|
"image/jpeg",
|
||||||
|
"image/jpg",
|
||||||
|
"image/png",
|
||||||
|
"image/gif",
|
||||||
|
"image/webp",
|
||||||
|
"image/heic",
|
||||||
|
"image/heif",
|
||||||
|
}
|
||||||
|
|
||||||
|
ALLOWED_VIDEO_TYPES = {
|
||||||
|
"video/mp4",
|
||||||
|
"video/mpeg",
|
||||||
|
"video/quicktime", # .mov
|
||||||
|
"video/x-msvideo", # .avi
|
||||||
|
"video/x-matroska", # .mkv
|
||||||
|
"video/webm",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Auth schemas
|
# Auth schemas
|
||||||
class UserRegister(BaseModel):
|
class UserRegister(BaseModel):
|
||||||
"""User registration request."""
|
"""User registration request."""
|
||||||
|
|
@ -31,6 +52,12 @@ class Token(BaseModel):
|
||||||
token_type: str = "bearer"
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenRequest(BaseModel):
|
||||||
|
"""Request to refresh access token."""
|
||||||
|
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
class UserResponse(BaseModel):
|
||||||
"""User information response."""
|
"""User information response."""
|
||||||
|
|
||||||
|
|
@ -42,6 +69,15 @@ class UserResponse(BaseModel):
|
||||||
model_config = {"from_attributes": True}
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class StorageStatsResponse(BaseModel):
|
||||||
|
"""Storage usage statistics."""
|
||||||
|
|
||||||
|
quota_bytes: int
|
||||||
|
used_bytes: int
|
||||||
|
available_bytes: int
|
||||||
|
percentage_used: float
|
||||||
|
|
||||||
|
|
||||||
# Asset schemas
|
# Asset schemas
|
||||||
class AssetResponse(BaseModel):
|
class AssetResponse(BaseModel):
|
||||||
"""Asset information response."""
|
"""Asset information response."""
|
||||||
|
|
@ -79,15 +115,21 @@ class CreateUploadRequest(BaseModel):
|
||||||
|
|
||||||
original_filename: str = Field(max_length=512)
|
original_filename: str = Field(max_length=512)
|
||||||
content_type: str = Field(max_length=100)
|
content_type: str = Field(max_length=100)
|
||||||
size_bytes: int = Field(gt=0, le=21474836480) # Max 20GB
|
size_bytes: int = Field(gt=0, le=3221225472) # Max 3GB
|
||||||
folder_id: Optional[str] = Field(None, max_length=36)
|
folder_id: Optional[str] = Field(None, max_length=36)
|
||||||
|
|
||||||
@field_validator("content_type")
|
@field_validator("content_type")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_content_type(cls, v: str) -> str:
|
def validate_content_type(cls, v: str) -> str:
|
||||||
"""Validate content_type is image or video."""
|
"""Validate content_type against whitelist."""
|
||||||
if not (v.startswith("image/") or v.startswith("video/")):
|
v = v.lower().strip()
|
||||||
raise ValueError("Only image/* and video/* content types are supported")
|
|
||||||
|
if v not in ALLOWED_IMAGE_TYPES and v not in ALLOWED_VIDEO_TYPES:
|
||||||
|
allowed = ", ".join(sorted(ALLOWED_IMAGE_TYPES | ALLOWED_VIDEO_TYPES))
|
||||||
|
raise ValueError(
|
||||||
|
f"Content type '{v}' not supported. Allowed: {allowed}"
|
||||||
|
)
|
||||||
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,50 @@ async def list_assets(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search", response_model=AssetListResponse)
|
||||||
|
async def search_assets(
|
||||||
|
current_user: CurrentUser,
|
||||||
|
session: DatabaseSession,
|
||||||
|
s3_client: S3ClientDep,
|
||||||
|
q: str = Query(..., min_length=1, description="Search query"),
|
||||||
|
cursor: Optional[str] = Query(None),
|
||||||
|
limit: int = Query(50, ge=1, le=200),
|
||||||
|
type: Optional[AssetType] = Query(None),
|
||||||
|
folder_id: Optional[str] = Query(None),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Search assets by filename with pagination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_user: Current authenticated user
|
||||||
|
session: Database session
|
||||||
|
s3_client: S3 client
|
||||||
|
q: Search query string
|
||||||
|
cursor: Pagination cursor
|
||||||
|
limit: Maximum number of results
|
||||||
|
type: Filter by asset type
|
||||||
|
folder_id: Filter by folder (None for search across all folders)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Paginated list of matching assets
|
||||||
|
"""
|
||||||
|
asset_service = AssetService(session, s3_client)
|
||||||
|
assets, next_cursor, has_more = await asset_service.search_assets(
|
||||||
|
user_id=current_user.id,
|
||||||
|
search_query=q,
|
||||||
|
limit=limit,
|
||||||
|
cursor=cursor,
|
||||||
|
asset_type=type,
|
||||||
|
folder_id=folder_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AssetListResponse(
|
||||||
|
items=assets,
|
||||||
|
next_cursor=next_cursor,
|
||||||
|
has_more=has_more,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{asset_id}", response_model=AssetResponse)
|
@router.get("/{asset_id}", response_model=AssetResponse)
|
||||||
async def get_asset(
|
async def get_asset(
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,44 @@
|
||||||
"""Authentication API routes."""
|
"""Authentication API routes."""
|
||||||
|
|
||||||
from fastapi import APIRouter, status
|
from datetime import datetime
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
|
||||||
from app.api.dependencies import CurrentUser, DatabaseSession
|
from app.api.dependencies import CurrentUser, DatabaseSession
|
||||||
from app.api.schemas import Token, UserLogin, UserRegister, UserResponse
|
from app.api.schemas import (
|
||||||
|
RefreshTokenRequest,
|
||||||
|
StorageStatsResponse,
|
||||||
|
Token,
|
||||||
|
UserLogin,
|
||||||
|
UserRegister,
|
||||||
|
UserResponse,
|
||||||
|
)
|
||||||
|
from app.infra.redis_client import (
|
||||||
|
LoginAttemptTracker,
|
||||||
|
TokenBlacklist,
|
||||||
|
get_login_tracker,
|
||||||
|
get_token_blacklist,
|
||||||
|
)
|
||||||
|
from app.infra.security import decode_refresh_token, get_subject
|
||||||
|
from app.main import limiter
|
||||||
from app.services.auth_service import AuthService
|
from app.services.auth_service import AuthService
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
security = HTTPBearer()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
async def register(data: UserRegister, session: DatabaseSession):
|
@limiter.limit("3/hour")
|
||||||
|
async def register(request: Request, data: UserRegister, session: DatabaseSession):
|
||||||
"""
|
"""
|
||||||
Register a new user.
|
Register a new user.
|
||||||
|
|
||||||
|
Rate limit: 3 requests per hour.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
request: HTTP request (for rate limiting)
|
||||||
data: Registration data
|
data: Registration data
|
||||||
session: Database session
|
session: Database session
|
||||||
|
|
||||||
|
|
@ -27,24 +51,172 @@ async def register(data: UserRegister, session: DatabaseSession):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=Token)
|
@router.post("/login", response_model=Token)
|
||||||
async def login(data: UserLogin, session: DatabaseSession):
|
@limiter.limit("5/minute")
|
||||||
|
async def login(
|
||||||
|
request: Request,
|
||||||
|
data: UserLogin,
|
||||||
|
session: DatabaseSession,
|
||||||
|
tracker: Annotated[LoginAttemptTracker, Depends(get_login_tracker)],
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Authenticate user and get access tokens.
|
Authenticate user and get access tokens.
|
||||||
|
|
||||||
|
Rate limit: 5 requests per minute.
|
||||||
|
Account lockout: 3 failed attempts = 24 hour block.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
request: HTTP request (for rate limiting and IP tracking)
|
||||||
data: Login credentials
|
data: Login credentials
|
||||||
session: Database session
|
session: Database session
|
||||||
|
tracker: Login attempt tracker
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Access and refresh tokens
|
Access and refresh tokens
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If account is locked or credentials are invalid
|
||||||
"""
|
"""
|
||||||
|
# Get client IP
|
||||||
|
client_ip = request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
# Check if IP is locked out
|
||||||
|
if await tracker.is_locked(client_ip):
|
||||||
|
remaining = await tracker.get_lockout_remaining(client_ip)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail=f"Account locked due to too many failed attempts. "
|
||||||
|
f"Try again in {remaining // 60} minutes.",
|
||||||
|
)
|
||||||
|
|
||||||
auth_service = AuthService(session)
|
auth_service = AuthService(session)
|
||||||
access_token, refresh_token = await auth_service.login(
|
|
||||||
email=data.email, password=data.password
|
try:
|
||||||
)
|
access_token, refresh_token = await auth_service.login(
|
||||||
|
email=data.email, password=data.password
|
||||||
|
)
|
||||||
|
|
||||||
|
# Successful login - clear failed attempts
|
||||||
|
await tracker.clear_attempts(client_ip)
|
||||||
|
|
||||||
|
return Token(access_token=access_token, refresh_token=refresh_token)
|
||||||
|
|
||||||
|
except HTTPException as e:
|
||||||
|
# Record failed attempt if it was authentication failure
|
||||||
|
if e.status_code == status.HTTP_401_UNAUTHORIZED:
|
||||||
|
await tracker.record_failed_attempt(client_ip)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=Token)
|
||||||
|
@limiter.limit("10/minute")
|
||||||
|
async def refresh_token(
|
||||||
|
request: Request,
|
||||||
|
data: RefreshTokenRequest,
|
||||||
|
session: DatabaseSession,
|
||||||
|
blacklist: Annotated[TokenBlacklist, Depends(get_token_blacklist)],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Refresh access token using refresh token.
|
||||||
|
|
||||||
|
Implements refresh token rotation - old refresh token is revoked.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: HTTP request (for rate limiting)
|
||||||
|
data: Refresh token request
|
||||||
|
session: Database session
|
||||||
|
blacklist: Token blacklist
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New access and refresh tokens
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If refresh token is invalid or revoked
|
||||||
|
"""
|
||||||
|
# Decode refresh token
|
||||||
|
payload = decode_refresh_token(data.refresh_token)
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if token is revoked
|
||||||
|
jti = payload.get("jti")
|
||||||
|
if jti and await blacklist.is_revoked(jti):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Refresh token has been revoked",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user ID
|
||||||
|
user_id = get_subject(payload)
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid refresh token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify user exists and is active
|
||||||
|
auth_service = AuthService(session)
|
||||||
|
user = await auth_service.get_user_by_id(user_id)
|
||||||
|
|
||||||
|
if not user or not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User not found or inactive",
|
||||||
|
)
|
||||||
|
|
||||||
|
# IMPORTANT: Revoke old refresh token (rotation)
|
||||||
|
if jti:
|
||||||
|
exp = payload.get("exp")
|
||||||
|
if exp:
|
||||||
|
ttl = exp - int(datetime.utcnow().timestamp())
|
||||||
|
if ttl > 0:
|
||||||
|
await blacklist.revoke_token(jti, ttl)
|
||||||
|
|
||||||
|
# Generate new tokens
|
||||||
|
access_token, refresh_token = await auth_service.create_tokens_for_user(user)
|
||||||
|
|
||||||
return Token(access_token=access_token, refresh_token=refresh_token)
|
return Token(access_token=access_token, refresh_token=refresh_token)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout")
|
||||||
|
@limiter.limit("10/minute")
|
||||||
|
async def logout(
|
||||||
|
request: Request,
|
||||||
|
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)],
|
||||||
|
blacklist: Annotated[TokenBlacklist, Depends(get_token_blacklist)],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Logout user by revoking current access token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: HTTP request (for rate limiting)
|
||||||
|
credentials: Authorization credentials
|
||||||
|
blacklist: Token blacklist
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
from app.infra.security import decode_access_token
|
||||||
|
|
||||||
|
token = credentials.credentials
|
||||||
|
payload = decode_access_token(token)
|
||||||
|
|
||||||
|
if payload:
|
||||||
|
jti = payload.get("jti")
|
||||||
|
exp = payload.get("exp")
|
||||||
|
|
||||||
|
if jti and exp:
|
||||||
|
# Calculate TTL and revoke token
|
||||||
|
ttl = exp - int(datetime.utcnow().timestamp())
|
||||||
|
if ttl > 0:
|
||||||
|
await blacklist.revoke_token(jti, ttl)
|
||||||
|
|
||||||
|
return {"message": "Logged out successfully"}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse)
|
@router.get("/me", response_model=UserResponse)
|
||||||
async def get_current_user_info(current_user: CurrentUser):
|
async def get_current_user_info(current_user: CurrentUser):
|
||||||
"""
|
"""
|
||||||
|
|
@ -57,3 +229,29 @@ async def get_current_user_info(current_user: CurrentUser):
|
||||||
User information
|
User information
|
||||||
"""
|
"""
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me/storage", response_model=StorageStatsResponse)
|
||||||
|
async def get_storage_stats(current_user: CurrentUser):
|
||||||
|
"""
|
||||||
|
Get storage usage statistics for current user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_user: Current authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage statistics
|
||||||
|
"""
|
||||||
|
available = current_user.storage_quota_bytes - current_user.storage_used_bytes
|
||||||
|
percentage = (
|
||||||
|
round((current_user.storage_used_bytes / current_user.storage_quota_bytes) * 100, 2)
|
||||||
|
if current_user.storage_quota_bytes > 0
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
return StorageStatsResponse(
|
||||||
|
quota_bytes=current_user.storage_quota_bytes,
|
||||||
|
used_bytes=current_user.storage_used_bytes,
|
||||||
|
available_bytes=max(0, available),
|
||||||
|
percentage_used=percentage,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
"""Batch operations API routes."""
|
"""Batch operations API routes."""
|
||||||
|
|
||||||
from fastapi import APIRouter, status
|
import os
|
||||||
from fastapi.responses import Response
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, status
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
from app.api.dependencies import CurrentUser, DatabaseSession, S3ClientDep
|
from app.api.dependencies import CurrentUser, DatabaseSession, S3ClientDep
|
||||||
from app.api.schemas import (
|
from app.api.schemas import (
|
||||||
|
|
@ -77,31 +80,46 @@ async def batch_download(
|
||||||
current_user: CurrentUser,
|
current_user: CurrentUser,
|
||||||
session: DatabaseSession,
|
session: DatabaseSession,
|
||||||
s3_client: S3ClientDep,
|
s3_client: S3ClientDep,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Download multiple assets as a ZIP archive.
|
Download multiple assets as a ZIP archive using streaming.
|
||||||
|
|
||||||
|
Uses temp file and FileResponse to avoid loading entire ZIP into memory.
|
||||||
|
Temp file is automatically cleaned up after response is sent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: Batch download request
|
request: Batch download request
|
||||||
current_user: Current authenticated user
|
current_user: Current authenticated user
|
||||||
session: Database session
|
session: Database session
|
||||||
s3_client: S3 client
|
s3_client: S3 client
|
||||||
|
background_tasks: Background tasks for cleanup
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ZIP file response
|
ZIP file response
|
||||||
"""
|
"""
|
||||||
batch_service = BatchOperationsService(session, s3_client)
|
batch_service = BatchOperationsService(session, s3_client)
|
||||||
zip_data, filename = await batch_service.download_assets_batch(
|
temp_zip_path, filename = await batch_service.download_assets_batch(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
asset_ids=request.asset_ids,
|
asset_ids=request.asset_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
return Response(
|
# Schedule temp file cleanup after response is sent
|
||||||
content=zip_data,
|
def cleanup_temp_file():
|
||||||
|
try:
|
||||||
|
Path(temp_zip_path).unlink(missing_ok=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
background_tasks.add_task(cleanup_temp_file)
|
||||||
|
|
||||||
|
# Return file using streaming FileResponse
|
||||||
|
return FileResponse(
|
||||||
|
path=temp_zip_path,
|
||||||
media_type="application/zip",
|
media_type="application/zip",
|
||||||
|
filename=filename,
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||||
"Content-Length": str(len(zip_data)),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -112,30 +130,45 @@ async def download_folder(
|
||||||
current_user: CurrentUser,
|
current_user: CurrentUser,
|
||||||
session: DatabaseSession,
|
session: DatabaseSession,
|
||||||
s3_client: S3ClientDep,
|
s3_client: S3ClientDep,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Download all assets in a folder as a ZIP archive.
|
Download all assets in a folder as a ZIP archive using streaming.
|
||||||
|
|
||||||
|
Uses temp file and FileResponse to avoid loading entire ZIP into memory.
|
||||||
|
Temp file is automatically cleaned up after response is sent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
folder_id: Folder ID
|
folder_id: Folder ID
|
||||||
current_user: Current authenticated user
|
current_user: Current authenticated user
|
||||||
session: Database session
|
session: Database session
|
||||||
s3_client: S3 client
|
s3_client: S3 client
|
||||||
|
background_tasks: Background tasks for cleanup
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ZIP file response
|
ZIP file response
|
||||||
"""
|
"""
|
||||||
batch_service = BatchOperationsService(session, s3_client)
|
batch_service = BatchOperationsService(session, s3_client)
|
||||||
zip_data, filename = await batch_service.download_folder(
|
temp_zip_path, filename = await batch_service.download_folder(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
folder_id=folder_id,
|
folder_id=folder_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return Response(
|
# Schedule temp file cleanup after response is sent
|
||||||
content=zip_data,
|
def cleanup_temp_file():
|
||||||
|
try:
|
||||||
|
Path(temp_zip_path).unlink(missing_ok=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
background_tasks.add_task(cleanup_temp_file)
|
||||||
|
|
||||||
|
# Return file using streaming FileResponse
|
||||||
|
return FileResponse(
|
||||||
|
path=temp_zip_path,
|
||||||
media_type="application/zip",
|
media_type="application/zip",
|
||||||
|
filename=filename,
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||||
"Content-Length": str(len(zip_data)),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,15 @@ class User(Base):
|
||||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||||
|
|
||||||
|
# Storage quota
|
||||||
|
storage_quota_bytes: Mapped[int] = mapped_column(
|
||||||
|
BigInteger, nullable=False, default=3221225472 # 3GB
|
||||||
|
)
|
||||||
|
storage_used_bytes: Mapped[int] = mapped_column(
|
||||||
|
BigInteger, nullable=False, default=0, index=True
|
||||||
|
)
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ class Settings(BaseSettings):
|
||||||
s3_access_key_id: str
|
s3_access_key_id: str
|
||||||
s3_secret_access_key: str
|
s3_secret_access_key: str
|
||||||
media_bucket: str = "itcloud-media"
|
media_bucket: str = "itcloud-media"
|
||||||
|
trash_bucket: str = "itcloud-trash"
|
||||||
|
|
||||||
# Security
|
# Security
|
||||||
jwt_secret: str
|
jwt_secret: str
|
||||||
|
|
@ -39,8 +40,9 @@ class Settings(BaseSettings):
|
||||||
jwt_refresh_ttl_seconds: int = 1209600
|
jwt_refresh_ttl_seconds: int = 1209600
|
||||||
|
|
||||||
# Upload limits
|
# Upload limits
|
||||||
max_upload_size_bytes: int = 21474836480 # 20GB
|
max_upload_size_bytes: int = 3221225472 # 3GB
|
||||||
signed_url_ttl_seconds: int = 600
|
signed_url_ttl_seconds: int = 600
|
||||||
|
default_storage_quota_bytes: int = 3221225472 # 3GB per user
|
||||||
|
|
||||||
# CORS
|
# CORS
|
||||||
cors_origins: str = "http://localhost:5173"
|
cors_origins: str = "http://localhost:5173"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,156 @@
|
||||||
|
"""Redis client for caching, blacklist, and rate limiting."""
|
||||||
|
|
||||||
|
from redis.asyncio import Redis, from_url
|
||||||
|
|
||||||
|
from app.infra.config import get_settings
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
|
||||||
|
class TokenBlacklist:
|
||||||
|
"""JWT token blacklist using Redis."""
|
||||||
|
|
||||||
|
def __init__(self, redis: Redis):
|
||||||
|
"""
|
||||||
|
Initialize token blacklist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis: Redis connection
|
||||||
|
"""
|
||||||
|
self.redis = redis
|
||||||
|
|
||||||
|
async def revoke_token(self, jti: str, ttl_seconds: int) -> None:
|
||||||
|
"""
|
||||||
|
Add token to blacklist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
jti: JWT ID (jti claim)
|
||||||
|
ttl_seconds: Time to live for blacklist entry
|
||||||
|
"""
|
||||||
|
await self.redis.setex(f"blacklist:{jti}", ttl_seconds, "1")
|
||||||
|
|
||||||
|
async def is_revoked(self, jti: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if token is revoked.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
jti: JWT ID (jti claim)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if token is in blacklist
|
||||||
|
"""
|
||||||
|
return await self.redis.exists(f"blacklist:{jti}") > 0
|
||||||
|
|
||||||
|
|
||||||
|
class LoginAttemptTracker:
|
||||||
|
"""Track failed login attempts and implement account lockout."""
|
||||||
|
|
||||||
|
def __init__(self, redis: Redis):
|
||||||
|
"""
|
||||||
|
Initialize login attempt tracker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis: Redis connection
|
||||||
|
"""
|
||||||
|
self.redis = redis
|
||||||
|
self.max_attempts = 3 # Max failed attempts before lockout
|
||||||
|
self.lockout_duration = 86400 # 24 hours in seconds
|
||||||
|
|
||||||
|
async def record_failed_attempt(self, ip_address: str) -> None:
|
||||||
|
"""
|
||||||
|
Record failed login attempt from IP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip_address: Client IP address
|
||||||
|
"""
|
||||||
|
key = f"login_attempts:{ip_address}"
|
||||||
|
attempts = await self.redis.incr(key)
|
||||||
|
|
||||||
|
if attempts == 1:
|
||||||
|
# Set TTL on first attempt (1 hour window)
|
||||||
|
await self.redis.expire(key, 3600)
|
||||||
|
|
||||||
|
if attempts >= self.max_attempts:
|
||||||
|
# Lock account for 24 hours
|
||||||
|
await self.redis.setex(
|
||||||
|
f"account_locked:{ip_address}",
|
||||||
|
self.lockout_duration,
|
||||||
|
"1"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def clear_attempts(self, ip_address: str) -> None:
|
||||||
|
"""
|
||||||
|
Clear failed attempts after successful login.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip_address: Client IP address
|
||||||
|
"""
|
||||||
|
await self.redis.delete(f"login_attempts:{ip_address}")
|
||||||
|
|
||||||
|
async def is_locked(self, ip_address: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if IP is locked out.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip_address: Client IP address
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if IP is locked
|
||||||
|
"""
|
||||||
|
return await self.redis.exists(f"account_locked:{ip_address}") > 0
|
||||||
|
|
||||||
|
async def get_lockout_remaining(self, ip_address: str) -> int:
|
||||||
|
"""
|
||||||
|
Get remaining lockout time in seconds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip_address: Client IP address
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Remaining seconds, or 0 if not locked
|
||||||
|
"""
|
||||||
|
ttl = await self.redis.ttl(f"account_locked:{ip_address}")
|
||||||
|
return max(0, ttl)
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton Redis connection
|
||||||
|
_redis_client: Redis | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_redis() -> Redis:
|
||||||
|
"""
|
||||||
|
Get Redis client instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Redis connection
|
||||||
|
"""
|
||||||
|
global _redis_client
|
||||||
|
if _redis_client is None:
|
||||||
|
_redis_client = from_url(
|
||||||
|
settings.redis_url,
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
return _redis_client
|
||||||
|
|
||||||
|
|
||||||
|
async def get_token_blacklist() -> TokenBlacklist:
|
||||||
|
"""
|
||||||
|
Get token blacklist instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenBlacklist instance
|
||||||
|
"""
|
||||||
|
redis = await get_redis()
|
||||||
|
return TokenBlacklist(redis)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_login_tracker() -> LoginAttemptTracker:
|
||||||
|
"""
|
||||||
|
Get login attempt tracker instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoginAttemptTracker instance
|
||||||
|
"""
|
||||||
|
redis = await get_redis()
|
||||||
|
return LoginAttemptTracker(redis)
|
||||||
|
|
@ -17,6 +17,7 @@ class S3Client:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize S3 client."""
|
"""Initialize S3 client."""
|
||||||
|
self.settings = settings
|
||||||
self.client = boto3.client(
|
self.client = boto3.client(
|
||||||
"s3",
|
"s3",
|
||||||
endpoint_url=settings.s3_endpoint_url,
|
endpoint_url=settings.s3_endpoint_url,
|
||||||
|
|
@ -26,6 +27,7 @@ class S3Client:
|
||||||
config=Config(signature_version="s3v4"),
|
config=Config(signature_version="s3v4"),
|
||||||
)
|
)
|
||||||
self.bucket = settings.media_bucket
|
self.bucket = settings.media_bucket
|
||||||
|
self.trash_bucket = settings.trash_bucket
|
||||||
|
|
||||||
def generate_storage_key(
|
def generate_storage_key(
|
||||||
self, user_id: str, asset_id: str, prefix: str, extension: str
|
self, user_id: str, asset_id: str, prefix: str, extension: str
|
||||||
|
|
@ -130,6 +132,48 @@ class S3Client:
|
||||||
ContentType=content_type,
|
ContentType=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def upload_fileobj_streaming(
|
||||||
|
self, file_obj, storage_key: str, content_type: str, file_size: int
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Upload a file object to S3 using streaming (chunked upload).
|
||||||
|
|
||||||
|
This method is memory-efficient as it uploads data in chunks
|
||||||
|
instead of loading the entire file into memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_obj: File-like object (must support read())
|
||||||
|
storage_key: S3 object key
|
||||||
|
content_type: File content type
|
||||||
|
file_size: Total file size in bytes
|
||||||
|
"""
|
||||||
|
# Use TransferConfig for chunked multipart upload
|
||||||
|
# 8MB chunks for efficient memory usage
|
||||||
|
from boto3.s3.transfer import TransferConfig
|
||||||
|
|
||||||
|
config = TransferConfig(
|
||||||
|
multipart_threshold=8 * 1024 * 1024, # 8MB
|
||||||
|
multipart_chunksize=8 * 1024 * 1024, # 8MB chunks
|
||||||
|
max_concurrency=4,
|
||||||
|
use_threads=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
extra_args = {
|
||||||
|
"ContentType": content_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add encryption if configured (future enhancement)
|
||||||
|
# if self.settings.s3_encryption_enabled:
|
||||||
|
# extra_args["ServerSideEncryption"] = "AES256"
|
||||||
|
|
||||||
|
self.client.upload_fileobj(
|
||||||
|
file_obj,
|
||||||
|
self.bucket,
|
||||||
|
storage_key,
|
||||||
|
ExtraArgs=extra_args,
|
||||||
|
Config=config,
|
||||||
|
)
|
||||||
|
|
||||||
def delete_object(self, storage_key: str) -> None:
|
def delete_object(self, storage_key: str) -> None:
|
||||||
"""
|
"""
|
||||||
Delete an object from S3.
|
Delete an object from S3.
|
||||||
|
|
@ -149,11 +193,10 @@ class S3Client:
|
||||||
Args:
|
Args:
|
||||||
storage_key: S3 object key in media bucket
|
storage_key: S3 object key in media bucket
|
||||||
"""
|
"""
|
||||||
trash_bucket = "itcloud-trash"
|
|
||||||
try:
|
try:
|
||||||
# Copy object to trash bucket
|
# Copy object to trash bucket
|
||||||
self.client.copy_object(
|
self.client.copy_object(
|
||||||
Bucket=trash_bucket,
|
Bucket=self.trash_bucket,
|
||||||
Key=storage_key,
|
Key=storage_key,
|
||||||
CopySource={"Bucket": self.bucket, "Key": storage_key},
|
CopySource={"Bucket": self.bucket, "Key": storage_key},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,14 +2,21 @@
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
from app.api.v1 import assets, auth, batch, folders, shares, uploads
|
from app.api.v1 import assets, auth, batch, folders, shares, uploads
|
||||||
from app.infra.config import get_settings
|
from app.infra.config import get_settings
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Rate limiter
|
||||||
|
limiter = Limiter(key_func=get_remote_address, default_limits=["1000/hour"])
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
|
@ -27,14 +34,72 @@ app = FastAPI(
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
# CORS middleware
|
# Add rate limiter to app state
|
||||||
|
app.state.limiter = limiter
|
||||||
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Add security headers to all responses."""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
"""Add security headers."""
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
# Защита от clickjacking
|
||||||
|
response.headers["X-Frame-Options"] = "DENY"
|
||||||
|
|
||||||
|
# Защита от XSS
|
||||||
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
|
|
||||||
|
# XSS Protection для старых браузеров
|
||||||
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||||
|
|
||||||
|
# Content Security Policy
|
||||||
|
response.headers["Content-Security-Policy"] = (
|
||||||
|
"default-src 'self'; "
|
||||||
|
"img-src 'self' data: https:; "
|
||||||
|
"script-src 'self'; "
|
||||||
|
"style-src 'self' 'unsafe-inline';"
|
||||||
|
)
|
||||||
|
|
||||||
|
# HSTS для HTTPS
|
||||||
|
if request.url.scheme == "https":
|
||||||
|
response.headers["Strict-Transport-Security"] = (
|
||||||
|
"max-age=31536000; includeSubDomains"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Referrer Policy
|
||||||
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||||
|
|
||||||
|
# Permissions Policy
|
||||||
|
response.headers["Permissions-Policy"] = (
|
||||||
|
"geolocation=(), microphone=(), camera=()"
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# Security headers middleware
|
||||||
|
app.add_middleware(SecurityHeadersMiddleware)
|
||||||
|
|
||||||
|
# CORS middleware (more restrictive)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=settings.cors_origins_list,
|
allow_origins=settings.cors_origins_list,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
||||||
allow_headers=["*"],
|
allow_headers=[
|
||||||
expose_headers=["*"],
|
"Authorization",
|
||||||
|
"Content-Type",
|
||||||
|
"X-Requested-With",
|
||||||
|
"Accept",
|
||||||
|
],
|
||||||
|
expose_headers=[
|
||||||
|
"Content-Length",
|
||||||
|
"Content-Type",
|
||||||
|
"X-Total-Count",
|
||||||
|
],
|
||||||
max_age=3600,
|
max_age=3600,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -169,6 +169,53 @@ class AssetRepository:
|
||||||
result = await self.session.execute(query)
|
result = await self.session.execute(query)
|
||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def search_assets(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
search_query: str,
|
||||||
|
limit: int = 50,
|
||||||
|
cursor: Optional[str] = None,
|
||||||
|
asset_type: Optional[AssetType] = None,
|
||||||
|
folder_id: Optional[str] = None,
|
||||||
|
) -> list[Asset]:
|
||||||
|
"""
|
||||||
|
Search assets by filename using case-insensitive pattern matching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
search_query: Search query string (will be matched against original_filename)
|
||||||
|
limit: Maximum number of results
|
||||||
|
cursor: Pagination cursor (asset_id)
|
||||||
|
asset_type: Optional filter by asset type
|
||||||
|
folder_id: Optional filter by folder (None means search across all folders)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching assets
|
||||||
|
"""
|
||||||
|
query = select(Asset).where(Asset.user_id == user_id)
|
||||||
|
|
||||||
|
# Case-insensitive search on original_filename
|
||||||
|
# Use ILIKE for case-insensitive matching with wildcards
|
||||||
|
search_pattern = f"%{search_query}%"
|
||||||
|
query = query.where(Asset.original_filename.ilike(search_pattern))
|
||||||
|
|
||||||
|
if asset_type:
|
||||||
|
query = query.where(Asset.type == asset_type)
|
||||||
|
|
||||||
|
# Filter by folder if specified
|
||||||
|
if folder_id is not None:
|
||||||
|
query = query.where(Asset.folder_id == folder_id)
|
||||||
|
|
||||||
|
if cursor:
|
||||||
|
cursor_asset = await self.get_by_id(cursor)
|
||||||
|
if cursor_asset:
|
||||||
|
query = query.where(Asset.created_at < cursor_asset.created_at)
|
||||||
|
|
||||||
|
query = query.order_by(desc(Asset.created_at)).limit(limit)
|
||||||
|
|
||||||
|
result = await self.session.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
async def update_folder_batch(
|
async def update_folder_batch(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import AsyncIterator, Optional, Tuple
|
from typing import AsyncIterator, Optional, Tuple
|
||||||
|
|
||||||
|
import magic
|
||||||
import redis
|
import redis
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
from fastapi import HTTPException, UploadFile, status
|
from fastapi import HTTPException, UploadFile, status
|
||||||
|
|
@ -16,6 +17,7 @@ from app.domain.models import Asset, AssetStatus, AssetType
|
||||||
from app.infra.config import get_settings
|
from app.infra.config import get_settings
|
||||||
from app.infra.s3_client import S3Client
|
from app.infra.s3_client import S3Client
|
||||||
from app.repositories.asset_repository import AssetRepository
|
from app.repositories.asset_repository import AssetRepository
|
||||||
|
from app.repositories.user_repository import UserRepository
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
|
|
@ -62,7 +64,9 @@ class AssetService:
|
||||||
s3_client: S3 client instance
|
s3_client: S3 client instance
|
||||||
"""
|
"""
|
||||||
self.asset_repo = AssetRepository(session)
|
self.asset_repo = AssetRepository(session)
|
||||||
|
self.user_repo = UserRepository(session)
|
||||||
self.s3_client = s3_client
|
self.s3_client = s3_client
|
||||||
|
self.session = session
|
||||||
|
|
||||||
def _get_asset_type(self, content_type: str) -> AssetType:
|
def _get_asset_type(self, content_type: str) -> AssetType:
|
||||||
"""Determine asset type from content type."""
|
"""Determine asset type from content type."""
|
||||||
|
|
@ -96,7 +100,25 @@ class AssetService:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (asset, presigned_post_data)
|
Tuple of (asset, presigned_post_data)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If storage quota exceeded
|
||||||
"""
|
"""
|
||||||
|
# Check storage quota
|
||||||
|
user = await self.user_repo.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
available_storage = user.storage_quota_bytes - user.storage_used_bytes
|
||||||
|
if size_bytes > available_storage:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||||
|
detail=f"Storage quota exceeded. Available: {available_storage} bytes, Required: {size_bytes} bytes",
|
||||||
|
)
|
||||||
|
|
||||||
# Sanitize filename to prevent path traversal
|
# Sanitize filename to prevent path traversal
|
||||||
safe_filename = sanitize_filename(original_filename)
|
safe_filename = sanitize_filename(original_filename)
|
||||||
|
|
||||||
|
|
@ -142,7 +164,9 @@ class AssetService:
|
||||||
file: UploadFile,
|
file: UploadFile,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Upload file content to S3 through backend.
|
Upload file content to S3 through backend using streaming.
|
||||||
|
|
||||||
|
Uses chunked upload to prevent memory exhaustion for large files.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User ID
|
user_id: User ID
|
||||||
|
|
@ -166,13 +190,20 @@ class AssetService:
|
||||||
detail="Asset has no storage key",
|
detail="Asset has no storage key",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Upload file to S3
|
# Verify file size doesn't exceed limit
|
||||||
|
if asset.size_bytes > settings.max_upload_size_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||||
|
detail=f"File size exceeds maximum allowed ({settings.max_upload_size_bytes} bytes)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upload file to S3 using streaming (memory-efficient)
|
||||||
try:
|
try:
|
||||||
content = await file.read()
|
self.s3_client.upload_fileobj_streaming(
|
||||||
self.s3_client.put_object(
|
file_obj=file.file,
|
||||||
storage_key=asset.storage_key_original,
|
storage_key=asset.storage_key_original,
|
||||||
file_data=content,
|
|
||||||
content_type=asset.content_type,
|
content_type=asset.content_type,
|
||||||
|
file_size=asset.size_bytes,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -190,6 +221,8 @@ class AssetService:
|
||||||
"""
|
"""
|
||||||
Finalize upload and mark asset as ready.
|
Finalize upload and mark asset as ready.
|
||||||
|
|
||||||
|
Performs magic bytes verification to ensure file type matches Content-Type.
|
||||||
|
Updates user's storage_used_bytes.
|
||||||
Enqueues background task for thumbnail generation.
|
Enqueues background task for thumbnail generation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -202,7 +235,7 @@ class AssetService:
|
||||||
Updated asset
|
Updated asset
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If asset not found or not authorized
|
HTTPException: If asset not found, not authorized, or file type mismatch
|
||||||
"""
|
"""
|
||||||
asset = await self.asset_repo.get_by_id(asset_id)
|
asset = await self.asset_repo.get_by_id(asset_id)
|
||||||
if not asset or asset.user_id != user_id:
|
if not asset or asset.user_id != user_id:
|
||||||
|
|
@ -218,12 +251,54 @@ class AssetService:
|
||||||
detail="File not found in storage",
|
detail="File not found in storage",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Magic bytes verification - download first 2KB to check file type
|
||||||
|
try:
|
||||||
|
response = self.s3_client.client.get_object(
|
||||||
|
Bucket=self.s3_client.bucket,
|
||||||
|
Key=asset.storage_key_original,
|
||||||
|
Range="bytes=0-2047", # First 2KB
|
||||||
|
)
|
||||||
|
file_header = response["Body"].read()
|
||||||
|
|
||||||
|
# Detect MIME type from magic bytes
|
||||||
|
detected_mime = magic.from_buffer(file_header, mime=True)
|
||||||
|
|
||||||
|
# Validate that detected type matches declared Content-Type
|
||||||
|
# Allow some flexibility for similar types
|
||||||
|
declared_base = asset.content_type.split("/")[0] # e.g., "image" or "video"
|
||||||
|
detected_base = detected_mime.split("/")[0]
|
||||||
|
|
||||||
|
if declared_base != detected_base:
|
||||||
|
# Type mismatch - delete file from S3 and reject upload
|
||||||
|
self.s3_client.delete_object(asset.storage_key_original)
|
||||||
|
await self.asset_repo.delete(asset)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"File type mismatch. Declared: {asset.content_type}, Detected: {detected_mime}",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Magic bytes verification passed for asset {asset_id}: "
|
||||||
|
f"declared={asset.content_type}, detected={detected_mime}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except ClientError as e:
|
||||||
|
logger.error(f"Failed to verify magic bytes for asset {asset_id}: {e}")
|
||||||
|
# Continue anyway - magic bytes check is not critical for functionality
|
||||||
|
|
||||||
|
# Update asset status
|
||||||
asset.status = AssetStatus.READY
|
asset.status = AssetStatus.READY
|
||||||
if sha256:
|
if sha256:
|
||||||
asset.sha256 = sha256
|
asset.sha256 = sha256
|
||||||
|
|
||||||
await self.asset_repo.update(asset)
|
await self.asset_repo.update(asset)
|
||||||
|
|
||||||
|
# Update user's storage_used_bytes
|
||||||
|
user = await self.user_repo.get_by_id(user_id)
|
||||||
|
if user:
|
||||||
|
user.storage_used_bytes += asset.size_bytes
|
||||||
|
await self.user_repo.update(user)
|
||||||
|
|
||||||
# Enqueue thumbnail generation task (background processing with retry)
|
# Enqueue thumbnail generation task (background processing with retry)
|
||||||
try:
|
try:
|
||||||
redis_conn = redis.from_url(settings.redis_url)
|
redis_conn = redis.from_url(settings.redis_url)
|
||||||
|
|
@ -389,6 +464,8 @@ class AssetService:
|
||||||
"""
|
"""
|
||||||
Delete an asset permanently (move files to trash bucket, delete from DB).
|
Delete an asset permanently (move files to trash bucket, delete from DB).
|
||||||
|
|
||||||
|
Also updates user's storage_used_bytes to free up quota.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User ID
|
user_id: User ID
|
||||||
asset_id: Asset ID
|
asset_id: Asset ID
|
||||||
|
|
@ -400,5 +477,56 @@ class AssetService:
|
||||||
if asset.storage_key_thumb:
|
if asset.storage_key_thumb:
|
||||||
self.s3_client.move_to_trash(asset.storage_key_thumb)
|
self.s3_client.move_to_trash(asset.storage_key_thumb)
|
||||||
|
|
||||||
|
# Update user's storage_used_bytes (free up quota)
|
||||||
|
user = await self.user_repo.get_by_id(user_id)
|
||||||
|
if user:
|
||||||
|
user.storage_used_bytes = max(0, user.storage_used_bytes - asset.size_bytes)
|
||||||
|
await self.user_repo.update(user)
|
||||||
|
|
||||||
# Delete from database
|
# Delete from database
|
||||||
await self.asset_repo.delete(asset)
|
await self.asset_repo.delete(asset)
|
||||||
|
|
||||||
|
async def search_assets(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
search_query: str,
|
||||||
|
limit: int = 50,
|
||||||
|
cursor: Optional[str] = None,
|
||||||
|
asset_type: Optional[AssetType] = None,
|
||||||
|
folder_id: Optional[str] = None,
|
||||||
|
) -> tuple[list[Asset], Optional[str], bool]:
|
||||||
|
"""
|
||||||
|
Search assets by filename.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
search_query: Search query string
|
||||||
|
limit: Maximum number of results
|
||||||
|
cursor: Pagination cursor
|
||||||
|
asset_type: Optional filter by asset type
|
||||||
|
folder_id: Optional filter by folder
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (assets, next_cursor, has_more)
|
||||||
|
"""
|
||||||
|
if not search_query or not search_query.strip():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Search query cannot be empty",
|
||||||
|
)
|
||||||
|
|
||||||
|
assets = await self.asset_repo.search_assets(
|
||||||
|
user_id=user_id,
|
||||||
|
search_query=search_query.strip(),
|
||||||
|
limit=limit + 1, # Fetch one more to check if there are more
|
||||||
|
cursor=cursor,
|
||||||
|
asset_type=asset_type,
|
||||||
|
folder_id=folder_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
has_more = len(assets) > limit
|
||||||
|
if has_more:
|
||||||
|
assets = assets[:limit]
|
||||||
|
|
||||||
|
next_cursor = assets[-1].id if has_more and assets else None
|
||||||
|
return assets, next_cursor, has_more
|
||||||
|
|
|
||||||
|
|
@ -94,3 +94,17 @@ class AuthService:
|
||||||
User instance or None
|
User instance or None
|
||||||
"""
|
"""
|
||||||
return await self.user_repo.get_by_id(user_id)
|
return await self.user_repo.get_by_id(user_id)
|
||||||
|
|
||||||
|
async def create_tokens_for_user(self, user: User) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Create access and refresh tokens for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (access_token, refresh_token)
|
||||||
|
"""
|
||||||
|
access_token = create_access_token(subject=str(user.id))
|
||||||
|
refresh_token = create_refresh_token(subject=str(user.id))
|
||||||
|
return access_token, refresh_token
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from app.domain.models import Asset
|
||||||
from app.infra.s3_client import S3Client
|
from app.infra.s3_client import S3Client
|
||||||
from app.repositories.asset_repository import AssetRepository
|
from app.repositories.asset_repository import AssetRepository
|
||||||
from app.repositories.folder_repository import FolderRepository
|
from app.repositories.folder_repository import FolderRepository
|
||||||
|
from app.repositories.user_repository import UserRepository
|
||||||
from app.services.asset_service import sanitize_filename
|
from app.services.asset_service import sanitize_filename
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -61,6 +62,7 @@ class BatchOperationsService:
|
||||||
"""
|
"""
|
||||||
self.asset_repo = AssetRepository(session)
|
self.asset_repo = AssetRepository(session)
|
||||||
self.folder_repo = FolderRepository(session)
|
self.folder_repo = FolderRepository(session)
|
||||||
|
self.user_repo = UserRepository(session)
|
||||||
self.s3_client = s3_client
|
self.s3_client = s3_client
|
||||||
|
|
||||||
async def delete_assets_batch(
|
async def delete_assets_batch(
|
||||||
|
|
@ -71,6 +73,8 @@ class BatchOperationsService:
|
||||||
"""
|
"""
|
||||||
Delete multiple assets (move to trash bucket, delete from DB).
|
Delete multiple assets (move to trash bucket, delete from DB).
|
||||||
|
|
||||||
|
Also updates user's storage_used_bytes to free up quota.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User ID
|
user_id: User ID
|
||||||
asset_ids: List of asset IDs to delete
|
asset_ids: List of asset IDs to delete
|
||||||
|
|
@ -98,6 +102,7 @@ class BatchOperationsService:
|
||||||
|
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
failed_count = 0
|
failed_count = 0
|
||||||
|
total_bytes_freed = 0
|
||||||
|
|
||||||
for asset in assets:
|
for asset in assets:
|
||||||
try:
|
try:
|
||||||
|
|
@ -106,6 +111,9 @@ class BatchOperationsService:
|
||||||
if asset.storage_key_thumb:
|
if asset.storage_key_thumb:
|
||||||
self.s3_client.move_to_trash(asset.storage_key_thumb)
|
self.s3_client.move_to_trash(asset.storage_key_thumb)
|
||||||
|
|
||||||
|
# Track bytes for storage quota update
|
||||||
|
total_bytes_freed += asset.size_bytes
|
||||||
|
|
||||||
# Delete from database
|
# Delete from database
|
||||||
await self.asset_repo.delete(asset)
|
await self.asset_repo.delete(asset)
|
||||||
deleted_count += 1
|
deleted_count += 1
|
||||||
|
|
@ -114,6 +122,14 @@ class BatchOperationsService:
|
||||||
logger.error(f"Failed to delete asset {asset.id}: {e}")
|
logger.error(f"Failed to delete asset {asset.id}: {e}")
|
||||||
failed_count += 1
|
failed_count += 1
|
||||||
|
|
||||||
|
# Update user's storage_used_bytes (free up quota)
|
||||||
|
if total_bytes_freed > 0:
|
||||||
|
user = await self.user_repo.get_by_id(user_id)
|
||||||
|
if user:
|
||||||
|
user.storage_used_bytes = max(0, user.storage_used_bytes - total_bytes_freed)
|
||||||
|
await self.user_repo.update(user)
|
||||||
|
logger.info(f"Freed {total_bytes_freed} bytes for user {user_id}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"deleted": deleted_count,
|
"deleted": deleted_count,
|
||||||
"failed": failed_count,
|
"failed": failed_count,
|
||||||
|
|
@ -171,18 +187,19 @@ class BatchOperationsService:
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
asset_ids: list[str],
|
asset_ids: list[str],
|
||||||
) -> tuple[bytes, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Download multiple assets as a ZIP archive.
|
Download multiple assets as a ZIP archive using temp file streaming.
|
||||||
|
|
||||||
Uses streaming to avoid loading entire archive in memory.
|
Creates ZIP in a temporary file to avoid memory exhaustion.
|
||||||
|
The caller is responsible for deleting the temp file after use.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User ID
|
user_id: User ID
|
||||||
asset_ids: List of asset IDs to download
|
asset_ids: List of asset IDs to download
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (zip_data, filename)
|
Tuple of (temp_zip_path, filename)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If no assets found or permission denied
|
HTTPException: If no assets found or permission denied
|
||||||
|
|
@ -202,73 +219,86 @@ class BatchOperationsService:
|
||||||
detail="No assets found or permission denied",
|
detail="No assets found or permission denied",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create ZIP archive in memory
|
# Create ZIP archive in temp file (NOT in memory)
|
||||||
zip_buffer = io.BytesIO()
|
temp_zip = tempfile.NamedTemporaryFile(
|
||||||
|
mode='w+b',
|
||||||
|
suffix='.zip',
|
||||||
|
delete=False, # Don't auto-delete, caller handles cleanup
|
||||||
|
)
|
||||||
|
temp_zip_path = temp_zip.name
|
||||||
|
|
||||||
with temp_file_manager() as temp_files:
|
try:
|
||||||
|
with zipfile.ZipFile(temp_zip, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
||||||
|
# Track filenames to avoid duplicates
|
||||||
|
used_names = set()
|
||||||
|
|
||||||
|
for asset in assets:
|
||||||
|
try:
|
||||||
|
# Download file from S3 and stream directly into ZIP
|
||||||
|
response = self.s3_client.client.get_object(
|
||||||
|
Bucket=self.s3_client.bucket,
|
||||||
|
Key=asset.storage_key_original,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate unique filename (sanitized to prevent path traversal)
|
||||||
|
base_name = sanitize_filename(asset.original_filename)
|
||||||
|
unique_name = base_name
|
||||||
|
counter = 1
|
||||||
|
|
||||||
|
while unique_name in used_names:
|
||||||
|
name, ext = os.path.splitext(base_name)
|
||||||
|
unique_name = f"{name}_{counter}{ext}"
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
used_names.add(unique_name)
|
||||||
|
|
||||||
|
# Stream file data directly into ZIP (no full read into memory)
|
||||||
|
with zip_file.open(unique_name, 'w') as zip_entry:
|
||||||
|
# Read in chunks (8MB at a time)
|
||||||
|
chunk_size = 8 * 1024 * 1024
|
||||||
|
while True:
|
||||||
|
chunk = response["Body"].read(chunk_size)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
zip_entry.write(chunk)
|
||||||
|
|
||||||
|
logger.debug(f"Added {unique_name} to ZIP archive")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to add asset {asset.id} to ZIP: {e}")
|
||||||
|
# Continue with other files
|
||||||
|
|
||||||
|
# Generate filename
|
||||||
|
filename = f"download_{len(assets)}_files.zip"
|
||||||
|
|
||||||
|
return temp_zip_path, filename
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Cleanup temp file on error
|
||||||
try:
|
try:
|
||||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
Path(temp_zip_path).unlink(missing_ok=True)
|
||||||
# Track filenames to avoid duplicates
|
except Exception:
|
||||||
used_names = set()
|
pass
|
||||||
|
logger.exception(f"Failed to create ZIP archive: {e}")
|
||||||
for asset in assets:
|
raise HTTPException(
|
||||||
try:
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
# Download file from S3
|
detail="Failed to create archive",
|
||||||
response = self.s3_client.client.get_object(
|
)
|
||||||
Bucket=self.s3_client.bucket,
|
|
||||||
Key=asset.storage_key_original,
|
|
||||||
)
|
|
||||||
file_data = response["Body"].read()
|
|
||||||
|
|
||||||
# Generate unique filename (sanitized to prevent path traversal)
|
|
||||||
base_name = sanitize_filename(asset.original_filename)
|
|
||||||
unique_name = base_name
|
|
||||||
counter = 1
|
|
||||||
|
|
||||||
while unique_name in used_names:
|
|
||||||
name, ext = os.path.splitext(base_name)
|
|
||||||
unique_name = f"{name}_{counter}{ext}"
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
used_names.add(unique_name)
|
|
||||||
|
|
||||||
# Add to ZIP
|
|
||||||
zip_file.writestr(unique_name, file_data)
|
|
||||||
logger.debug(f"Added {unique_name} to ZIP archive")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to add asset {asset.id} to ZIP: {e}")
|
|
||||||
# Continue with other files
|
|
||||||
|
|
||||||
# Get ZIP data
|
|
||||||
zip_data = zip_buffer.getvalue()
|
|
||||||
|
|
||||||
# Generate filename
|
|
||||||
filename = f"download_{len(assets)}_files.zip"
|
|
||||||
|
|
||||||
return zip_data, filename
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Failed to create ZIP archive: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Failed to create archive",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def download_folder(
|
async def download_folder(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
folder_id: str,
|
folder_id: str,
|
||||||
) -> tuple[bytes, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Download all assets in a folder as a ZIP archive.
|
Download all assets in a folder as a ZIP archive using temp file streaming.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User ID
|
user_id: User ID
|
||||||
folder_id: Folder ID
|
folder_id: Folder ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (zip_data, filename)
|
Tuple of (temp_zip_path, filename)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If folder not found or permission denied
|
HTTPException: If folder not found or permission denied
|
||||||
|
|
@ -296,9 +326,10 @@ class BatchOperationsService:
|
||||||
|
|
||||||
# Get asset IDs and use existing download method
|
# Get asset IDs and use existing download method
|
||||||
asset_ids = [asset.id for asset in assets]
|
asset_ids = [asset.id for asset in assets]
|
||||||
zip_data, _ = await self.download_assets_batch(user_id, asset_ids)
|
temp_zip_path, _ = await self.download_assets_batch(user_id, asset_ids)
|
||||||
|
|
||||||
# Use folder name in filename
|
# Use folder name in filename
|
||||||
filename = f"{folder.name}.zip"
|
sanitized_folder_name = sanitize_filename(folder.name)
|
||||||
|
filename = f"{sanitized_folder_name}.zip"
|
||||||
|
|
||||||
return zip_data, filename
|
return temp_zip_path, filename
|
||||||
|
|
|
||||||
|
|
@ -161,26 +161,33 @@ class FolderService:
|
||||||
"""
|
"""
|
||||||
Delete a folder.
|
Delete a folder.
|
||||||
|
|
||||||
|
IMPORTANT: Folder must be empty (no assets, no subfolders) to be deleted.
|
||||||
|
Use AssetService or BatchOperationsService to delete assets first,
|
||||||
|
or move them to another folder before deleting.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User ID
|
user_id: User ID
|
||||||
folder_id: Folder ID
|
folder_id: Folder ID
|
||||||
recursive: If True, delete folder with all contents.
|
recursive: If True, delete folder with all subfolders (must still be empty of assets).
|
||||||
If False, fail if folder is not empty.
|
If False, fail if folder has subfolders.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If folder not found, not authorized, or not empty (when not recursive)
|
HTTPException: If folder not found, not authorized, or not empty
|
||||||
"""
|
"""
|
||||||
folder = await self.get_folder(user_id, folder_id)
|
folder = await self.get_folder(user_id, folder_id)
|
||||||
|
|
||||||
if not recursive:
|
# Always check for assets (read-only query, acceptable for validation)
|
||||||
# Check if folder has assets
|
asset_count = await self.asset_repo.count_in_folder(folder_id)
|
||||||
asset_count = await self.asset_repo.count_in_folder(folder_id)
|
if asset_count > 0:
|
||||||
if asset_count > 0:
|
raise HTTPException(
|
||||||
raise HTTPException(
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
detail=(
|
||||||
detail=f"Folder contains {asset_count} assets. Use recursive=true to delete.",
|
f"Folder contains {asset_count} assets. "
|
||||||
)
|
"Please delete or move assets first using AssetService endpoints."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not recursive:
|
||||||
# Check if folder has subfolders
|
# Check if folder has subfolders
|
||||||
subfolders = await self.folder_repo.list_by_user(
|
subfolders = await self.folder_repo.list_by_user(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
@ -189,36 +196,28 @@ class FolderService:
|
||||||
if subfolders:
|
if subfolders:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"Folder contains {len(subfolders)} subfolders. Use recursive=true to delete.",
|
detail=(
|
||||||
|
f"Folder contains {len(subfolders)} subfolders. "
|
||||||
|
"Use recursive=true to delete all subfolders."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if recursive:
|
if recursive:
|
||||||
# Delete all subfolders recursively
|
# Delete all subfolders recursively (folders must be empty of assets)
|
||||||
subfolders = await self.folder_repo.get_all_subfolders(folder_id)
|
subfolders = await self.folder_repo.get_all_subfolders(folder_id)
|
||||||
for subfolder in reversed(subfolders): # Delete from deepest to shallowest
|
for subfolder in reversed(subfolders): # Delete from deepest to shallowest
|
||||||
# Move assets in subfolder to trash (through AssetService would be better, but for simplicity)
|
# Check that subfolder is empty of assets
|
||||||
# In production, this should use AssetService.delete_asset to properly move to trash
|
subfolder_asset_count = await self.asset_repo.count_in_folder(subfolder.id)
|
||||||
assets = await self.asset_repo.list_by_folder(
|
if subfolder_asset_count > 0:
|
||||||
user_id=user_id,
|
raise HTTPException(
|
||||||
folder_id=subfolder.id,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
limit=1000, # Reasonable limit for folder deletion
|
detail=(
|
||||||
)
|
f"Subfolder '{subfolder.name}' contains {subfolder_asset_count} assets. "
|
||||||
# For now, just orphan the assets by setting folder_id to None
|
"All folders must be empty before deletion."
|
||||||
# TODO: Properly delete assets using AssetService
|
),
|
||||||
for asset in assets:
|
)
|
||||||
asset.folder_id = None
|
|
||||||
|
|
||||||
await self.folder_repo.delete(subfolder)
|
await self.folder_repo.delete(subfolder)
|
||||||
|
|
||||||
# Move assets in current folder
|
|
||||||
assets = await self.asset_repo.list_by_folder(
|
|
||||||
user_id=user_id,
|
|
||||||
folder_id=folder_id,
|
|
||||||
limit=1000,
|
|
||||||
)
|
|
||||||
for asset in assets:
|
|
||||||
asset.folder_id = None
|
|
||||||
|
|
||||||
# Delete the folder itself
|
# Delete the folder itself
|
||||||
await self.folder_repo.delete(folder)
|
await self.folder_repo.delete(folder)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue