import time import random import hashlib from typing import Dict, Optional from fastapi import Request from fastapi.responses import JSONResponse from collections import defaultdict, deque import logging from src.app.config import settings logger = logging.getLogger(__name__) class RateLimitTracker: def __init__(self): self.requests: Dict[str, deque] = defaultdict(deque) self.violations: Dict[str, int] = defaultdict(int) self.last_violation: Dict[str, float] = defaultdict(float) def cleanup_old_requests(self, client_key: str, window_seconds: int = 60): now = time.time() cutoff = now - window_seconds while self.requests[client_key] and self.requests[client_key][0] < cutoff: self.requests[client_key].popleft() def add_request(self, client_key: str): now = time.time() self.requests[client_key].append(now) def get_request_count(self, client_key: str, window_seconds: int = 60) -> int: self.cleanup_old_requests(client_key, window_seconds) return len(self.requests[client_key]) def record_violation(self, client_key: str): self.violations[client_key] += 1 self.last_violation[client_key] = time.time() def get_backoff_time(self, client_key: str) -> int: violations = self.violations[client_key] if violations == 0: return 0 base_delay = settings.rate_limit_backoff_base**violations max_delay = settings.rate_limit_max_backoff delay = min(base_delay, max_delay) jitter = delay * settings.rate_limit_jitter * random.random() return int(delay + jitter) def should_reset_violations(self, client_key: str, reset_after: int = 3600) -> bool: if client_key not in self.last_violation: return False return time.time() - self.last_violation[client_key] > reset_after def reset_violations(self, client_key: str): self.violations[client_key] = 0 if client_key in self.last_violation: del self.last_violation[client_key] class RateLimitMiddleware: def __init__(self): self.tracker = RateLimitTracker() self.protected_paths = { "/api/v1/generate_email": { "limit": settings.rate_limit_per_minute, "window": 60, "burst": settings.rate_limit_burst, } } def get_client_key(self, request: Request) -> str: client_ip = request.client.host if request.client else "unknown" user_agent = request.headers.get("user-agent", "") key_data = f"{client_ip}:{user_agent}" return hashlib.md5(key_data.encode()).hexdigest()[:16] def is_protected_path(self, path: str) -> Optional[Dict]: for protected_path, config in self.protected_paths.items(): if path.startswith(protected_path): return config return None def check_burst_limit(self, client_key: str, burst_limit: int) -> bool: now = time.time() recent_requests = [ req_time for req_time in self.tracker.requests[client_key] if now - req_time < 10 ] return len(recent_requests) < burst_limit async def __call__(self, request: Request, call_next): path_config = self.is_protected_path(request.url.path) if not path_config: return await call_next(request) client_key = self.get_client_key(request) if self.tracker.should_reset_violations(client_key): self.tracker.reset_violations(client_key) backoff_time = self.tracker.get_backoff_time(client_key) if backoff_time > 0: logger.warning( f"Rate limit backoff active for {client_key}: {backoff_time}s" ) return JSONResponse( status_code=429, content={ "error": "Rate limit exceeded", "code": "RATE_LIMIT_BACKOFF", "retry_after": backoff_time, "strategy": "exponential_backoff", }, headers={"Retry-After": str(backoff_time)}, ) request_count = self.tracker.get_request_count( client_key, path_config["window"] ) if not self.check_burst_limit(client_key, path_config["burst"]): self.tracker.record_violation(client_key) backoff_time = self.tracker.get_backoff_time(client_key) logger.warning(f"Burst limit exceeded for {client_key}") return JSONResponse( status_code=429, content={ "error": "Burst limit exceeded", "code": "BURST_LIMIT_EXCEEDED", "retry_after": backoff_time, }, headers={"Retry-After": str(backoff_time)}, ) if request_count >= path_config["limit"]: self.tracker.record_violation(client_key) backoff_time = self.tracker.get_backoff_time(client_key) logger.warning( f"Rate limit exceeded for {client_key}: {request_count}/{path_config['limit']}" ) return JSONResponse( status_code=429, content={ "error": "Rate limit exceeded", "code": "RATE_LIMIT_EXCEEDED", "limit": path_config["limit"], "window": path_config["window"], "current_usage": request_count, "retry_after": backoff_time, }, headers={"Retry-After": str(backoff_time)}, ) self.tracker.add_request(client_key) response = await call_next(request) remaining = path_config["limit"] - request_count - 1 response.headers["X-RateLimit-Limit"] = str(path_config["limit"]) response.headers["X-RateLimit-Remaining"] = str(max(0, remaining)) response.headers["X-RateLimit-Reset"] = str( int(time.time() + path_config["window"]) ) return response