176 lines
6.1 KiB
Python
176 lines
6.1 KiB
Python
import time
|
|
import random
|
|
import hashlib
|
|
from typing import Dict, Optional
|
|
from fastapi import Request
|
|
from fastapi.responses import JSONResponse
|
|
from collections import defaultdict, deque
|
|
import logging
|
|
|
|
from src.app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RateLimitTracker:
|
|
def __init__(self):
|
|
self.requests: Dict[str, deque] = defaultdict(deque)
|
|
self.violations: Dict[str, int] = defaultdict(int)
|
|
self.last_violation: Dict[str, float] = defaultdict(float)
|
|
|
|
def cleanup_old_requests(self, client_key: str, window_seconds: int = 60):
|
|
now = time.time()
|
|
cutoff = now - window_seconds
|
|
|
|
while self.requests[client_key] and self.requests[client_key][0] < cutoff:
|
|
self.requests[client_key].popleft()
|
|
|
|
def add_request(self, client_key: str):
|
|
now = time.time()
|
|
self.requests[client_key].append(now)
|
|
|
|
def get_request_count(self, client_key: str, window_seconds: int = 60) -> int:
|
|
self.cleanup_old_requests(client_key, window_seconds)
|
|
return len(self.requests[client_key])
|
|
|
|
def record_violation(self, client_key: str):
|
|
self.violations[client_key] += 1
|
|
self.last_violation[client_key] = time.time()
|
|
|
|
def get_backoff_time(self, client_key: str) -> int:
|
|
violations = self.violations[client_key]
|
|
if violations == 0:
|
|
return 0
|
|
|
|
base_delay = settings.rate_limit_backoff_base**violations
|
|
max_delay = settings.rate_limit_max_backoff
|
|
|
|
delay = min(base_delay, max_delay)
|
|
jitter = delay * settings.rate_limit_jitter * random.random()
|
|
|
|
return int(delay + jitter)
|
|
|
|
def should_reset_violations(self, client_key: str, reset_after: int = 3600) -> bool:
|
|
if client_key not in self.last_violation:
|
|
return False
|
|
|
|
return time.time() - self.last_violation[client_key] > reset_after
|
|
|
|
def reset_violations(self, client_key: str):
|
|
self.violations[client_key] = 0
|
|
if client_key in self.last_violation:
|
|
del self.last_violation[client_key]
|
|
|
|
|
|
class RateLimitMiddleware:
|
|
def __init__(self):
|
|
self.tracker = RateLimitTracker()
|
|
self.protected_paths = {
|
|
"/api/v1/generate_email": {
|
|
"limit": settings.rate_limit_per_minute,
|
|
"window": 60,
|
|
"burst": settings.rate_limit_burst,
|
|
}
|
|
}
|
|
|
|
def get_client_key(self, request: Request) -> str:
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_agent = request.headers.get("user-agent", "")
|
|
|
|
key_data = f"{client_ip}:{user_agent}"
|
|
return hashlib.md5(key_data.encode()).hexdigest()[:16]
|
|
|
|
def is_protected_path(self, path: str) -> Optional[Dict]:
|
|
for protected_path, config in self.protected_paths.items():
|
|
if path.startswith(protected_path):
|
|
return config
|
|
return None
|
|
|
|
def check_burst_limit(self, client_key: str, burst_limit: int) -> bool:
|
|
now = time.time()
|
|
recent_requests = [
|
|
req_time
|
|
for req_time in self.tracker.requests[client_key]
|
|
if now - req_time < 10
|
|
]
|
|
return len(recent_requests) < burst_limit
|
|
|
|
async def __call__(self, request: Request, call_next):
|
|
path_config = self.is_protected_path(request.url.path)
|
|
|
|
if not path_config:
|
|
return await call_next(request)
|
|
|
|
client_key = self.get_client_key(request)
|
|
|
|
if self.tracker.should_reset_violations(client_key):
|
|
self.tracker.reset_violations(client_key)
|
|
|
|
backoff_time = self.tracker.get_backoff_time(client_key)
|
|
if backoff_time > 0:
|
|
logger.warning(
|
|
f"Rate limit backoff active for {client_key}: {backoff_time}s"
|
|
)
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "Rate limit exceeded",
|
|
"code": "RATE_LIMIT_BACKOFF",
|
|
"retry_after": backoff_time,
|
|
"strategy": "exponential_backoff",
|
|
},
|
|
headers={"Retry-After": str(backoff_time)},
|
|
)
|
|
|
|
request_count = self.tracker.get_request_count(
|
|
client_key, path_config["window"]
|
|
)
|
|
|
|
if not self.check_burst_limit(client_key, path_config["burst"]):
|
|
self.tracker.record_violation(client_key)
|
|
backoff_time = self.tracker.get_backoff_time(client_key)
|
|
|
|
logger.warning(f"Burst limit exceeded for {client_key}")
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "Burst limit exceeded",
|
|
"code": "BURST_LIMIT_EXCEEDED",
|
|
"retry_after": backoff_time,
|
|
},
|
|
headers={"Retry-After": str(backoff_time)},
|
|
)
|
|
|
|
if request_count >= path_config["limit"]:
|
|
self.tracker.record_violation(client_key)
|
|
backoff_time = self.tracker.get_backoff_time(client_key)
|
|
|
|
logger.warning(
|
|
f"Rate limit exceeded for {client_key}: {request_count}/{path_config['limit']}"
|
|
)
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "Rate limit exceeded",
|
|
"code": "RATE_LIMIT_EXCEEDED",
|
|
"limit": path_config["limit"],
|
|
"window": path_config["window"],
|
|
"current_usage": request_count,
|
|
"retry_after": backoff_time,
|
|
},
|
|
headers={"Retry-After": str(backoff_time)},
|
|
)
|
|
|
|
self.tracker.add_request(client_key)
|
|
|
|
response = await call_next(request)
|
|
|
|
remaining = path_config["limit"] - request_count - 1
|
|
response.headers["X-RateLimit-Limit"] = str(path_config["limit"])
|
|
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
|
|
response.headers["X-RateLimit-Reset"] = str(
|
|
int(time.time() + path_config["window"])
|
|
)
|
|
|
|
return response
|