ai-email-assistant/src/app/middleware.py

176 lines
6.1 KiB
Python

import time
import random
import hashlib
from typing import Dict, Optional
from fastapi import Request
from fastapi.responses import JSONResponse
from collections import defaultdict, deque
import logging
from src.app.config import settings
logger = logging.getLogger(__name__)
class RateLimitTracker:
def __init__(self):
self.requests: Dict[str, deque] = defaultdict(deque)
self.violations: Dict[str, int] = defaultdict(int)
self.last_violation: Dict[str, float] = defaultdict(float)
def cleanup_old_requests(self, client_key: str, window_seconds: int = 60):
now = time.time()
cutoff = now - window_seconds
while self.requests[client_key] and self.requests[client_key][0] < cutoff:
self.requests[client_key].popleft()
def add_request(self, client_key: str):
now = time.time()
self.requests[client_key].append(now)
def get_request_count(self, client_key: str, window_seconds: int = 60) -> int:
self.cleanup_old_requests(client_key, window_seconds)
return len(self.requests[client_key])
def record_violation(self, client_key: str):
self.violations[client_key] += 1
self.last_violation[client_key] = time.time()
def get_backoff_time(self, client_key: str) -> int:
violations = self.violations[client_key]
if violations == 0:
return 0
base_delay = settings.rate_limit_backoff_base**violations
max_delay = settings.rate_limit_max_backoff
delay = min(base_delay, max_delay)
jitter = delay * settings.rate_limit_jitter * random.random()
return int(delay + jitter)
def should_reset_violations(self, client_key: str, reset_after: int = 3600) -> bool:
if client_key not in self.last_violation:
return False
return time.time() - self.last_violation[client_key] > reset_after
def reset_violations(self, client_key: str):
self.violations[client_key] = 0
if client_key in self.last_violation:
del self.last_violation[client_key]
class RateLimitMiddleware:
def __init__(self):
self.tracker = RateLimitTracker()
self.protected_paths = {
"/api/v1/generate_email": {
"limit": settings.rate_limit_per_minute,
"window": 60,
"burst": settings.rate_limit_burst,
}
}
def get_client_key(self, request: Request) -> str:
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "")
key_data = f"{client_ip}:{user_agent}"
return hashlib.md5(key_data.encode()).hexdigest()[:16]
def is_protected_path(self, path: str) -> Optional[Dict]:
for protected_path, config in self.protected_paths.items():
if path.startswith(protected_path):
return config
return None
def check_burst_limit(self, client_key: str, burst_limit: int) -> bool:
now = time.time()
recent_requests = [
req_time
for req_time in self.tracker.requests[client_key]
if now - req_time < 10
]
return len(recent_requests) < burst_limit
async def __call__(self, request: Request, call_next):
path_config = self.is_protected_path(request.url.path)
if not path_config:
return await call_next(request)
client_key = self.get_client_key(request)
if self.tracker.should_reset_violations(client_key):
self.tracker.reset_violations(client_key)
backoff_time = self.tracker.get_backoff_time(client_key)
if backoff_time > 0:
logger.warning(
f"Rate limit backoff active for {client_key}: {backoff_time}s"
)
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"code": "RATE_LIMIT_BACKOFF",
"retry_after": backoff_time,
"strategy": "exponential_backoff",
},
headers={"Retry-After": str(backoff_time)},
)
request_count = self.tracker.get_request_count(
client_key, path_config["window"]
)
if not self.check_burst_limit(client_key, path_config["burst"]):
self.tracker.record_violation(client_key)
backoff_time = self.tracker.get_backoff_time(client_key)
logger.warning(f"Burst limit exceeded for {client_key}")
return JSONResponse(
status_code=429,
content={
"error": "Burst limit exceeded",
"code": "BURST_LIMIT_EXCEEDED",
"retry_after": backoff_time,
},
headers={"Retry-After": str(backoff_time)},
)
if request_count >= path_config["limit"]:
self.tracker.record_violation(client_key)
backoff_time = self.tracker.get_backoff_time(client_key)
logger.warning(
f"Rate limit exceeded for {client_key}: {request_count}/{path_config['limit']}"
)
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"code": "RATE_LIMIT_EXCEEDED",
"limit": path_config["limit"],
"window": path_config["window"],
"current_usage": request_count,
"retry_after": backoff_time,
},
headers={"Retry-After": str(backoff_time)},
)
self.tracker.add_request(client_key)
response = await call_next(request)
remaining = path_config["limit"] - request_count - 1
response.headers["X-RateLimit-Limit"] = str(path_config["limit"])
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
response.headers["X-RateLimit-Reset"] = str(
int(time.time() + path_config["window"])
)
return response