Add rate limiting with backoff
CI / lint_and_test (push) Successful in 2m34s Details

This commit is contained in:
itqop 2025-07-19 20:40:25 +03:00
parent d32beef6af
commit 0aa8964c6c
7 changed files with 315 additions and 5 deletions

View File

@ -352,7 +352,6 @@ pytest src/tests/
pytest src/tests/test_api.py pytest src/tests/test_api.py
``` ```
### TODO: ### TODO:
- Реализовать rate limiting с backoff - Улучшить валидацию входных данных (почта, защита хака llm)
- Покрытие кода тестами до 80%, доавить unit тесты для бизнес-логики - Увеличить покрытие кода тестами до 80%, доавить unit тесты для бизнес-логики
- Улучшить валидацию входных данных
- Сделать качественную докуентацию - Сделать качественную докуентацию

View File

@ -24,6 +24,11 @@ services:
- CHUNK_SIZE=500 - CHUNK_SIZE=500
- CHUNK_OVERLAP=100 - CHUNK_OVERLAP=100
- API_SECRET_KEY=${API_SECRET_KEY:-secret} - API_SECRET_KEY=${API_SECRET_KEY:-secret}
- RATE_LIMIT_PER_MINUTE=${RATE_LIMIT_PER_MINUTE:-10}
- RATE_LIMIT_BURST=${RATE_LIMIT_BURST:-3}
- RATE_LIMIT_BACKOFF_BASE=${RATE_LIMIT_BACKOFF_BASE:-2.0}
- RATE_LIMIT_MAX_BACKOFF=${RATE_LIMIT_MAX_BACKOFF:-300}
- RATE_LIMIT_JITTER=${RATE_LIMIT_JITTER:-0.1}
- PYTHONPATH=/app - PYTHONPATH=/app
- PYTHONUNBUFFERED=1 - PYTHONUNBUFFERED=1
- PYTHONDONTWRITEBYTECODE=1 - PYTHONDONTWRITEBYTECODE=1

View File

@ -24,6 +24,22 @@ class Settings(BaseSettings):
chunk_size: int = 500 chunk_size: int = 500
chunk_overlap: int = 100 chunk_overlap: int = 100
rate_limit_per_minute: int = Field(
default=10, json_schema_extra={"env": "RATE_LIMIT_PER_MINUTE"}
)
rate_limit_burst: int = Field(
default=3, json_schema_extra={"env": "RATE_LIMIT_BURST"}
)
rate_limit_backoff_base: float = Field(
default=2.0, json_schema_extra={"env": "RATE_LIMIT_BACKOFF_BASE"}
)
rate_limit_max_backoff: int = Field(
default=300, json_schema_extra={"env": "RATE_LIMIT_MAX_BACKOFF"}
)
rate_limit_jitter: float = Field(
default=0.1, json_schema_extra={"env": "RATE_LIMIT_JITTER"}
)
model_config = ConfigDict(env_file=".env") model_config = ConfigDict(env_file=".env")

View File

@ -11,6 +11,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.app.routers import generate, ingest, health from src.app.routers import generate, ingest, health
from src.app.config import settings from src.app.config import settings
from src.app.middleware import RateLimitMiddleware
logging.basicConfig( logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@ -33,6 +34,9 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
rate_limit_middleware = RateLimitMiddleware()
app.middleware("http")(rate_limit_middleware)
@app.middleware("http") @app.middleware("http")
async def logging_middleware(request: Request, call_next): async def logging_middleware(request: Request, call_next):

175
src/app/middleware.py Normal file
View File

@ -0,0 +1,175 @@
import time
import random
import hashlib
from typing import Dict, Optional
from fastapi import Request
from fastapi.responses import JSONResponse
from collections import defaultdict, deque
import logging
from src.app.config import settings
logger = logging.getLogger(__name__)
class RateLimitTracker:
def __init__(self):
self.requests: Dict[str, deque] = defaultdict(deque)
self.violations: Dict[str, int] = defaultdict(int)
self.last_violation: Dict[str, float] = defaultdict(float)
def cleanup_old_requests(self, client_key: str, window_seconds: int = 60):
now = time.time()
cutoff = now - window_seconds
while self.requests[client_key] and self.requests[client_key][0] < cutoff:
self.requests[client_key].popleft()
def add_request(self, client_key: str):
now = time.time()
self.requests[client_key].append(now)
def get_request_count(self, client_key: str, window_seconds: int = 60) -> int:
self.cleanup_old_requests(client_key, window_seconds)
return len(self.requests[client_key])
def record_violation(self, client_key: str):
self.violations[client_key] += 1
self.last_violation[client_key] = time.time()
def get_backoff_time(self, client_key: str) -> int:
violations = self.violations[client_key]
if violations == 0:
return 0
base_delay = settings.rate_limit_backoff_base**violations
max_delay = settings.rate_limit_max_backoff
delay = min(base_delay, max_delay)
jitter = delay * settings.rate_limit_jitter * random.random()
return int(delay + jitter)
def should_reset_violations(self, client_key: str, reset_after: int = 3600) -> bool:
if client_key not in self.last_violation:
return False
return time.time() - self.last_violation[client_key] > reset_after
def reset_violations(self, client_key: str):
self.violations[client_key] = 0
if client_key in self.last_violation:
del self.last_violation[client_key]
class RateLimitMiddleware:
def __init__(self):
self.tracker = RateLimitTracker()
self.protected_paths = {
"/api/v1/generate_email": {
"limit": settings.rate_limit_per_minute,
"window": 60,
"burst": settings.rate_limit_burst,
}
}
def get_client_key(self, request: Request) -> str:
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "")
key_data = f"{client_ip}:{user_agent}"
return hashlib.md5(key_data.encode()).hexdigest()[:16]
def is_protected_path(self, path: str) -> Optional[Dict]:
for protected_path, config in self.protected_paths.items():
if path.startswith(protected_path):
return config
return None
def check_burst_limit(self, client_key: str, burst_limit: int) -> bool:
now = time.time()
recent_requests = [
req_time
for req_time in self.tracker.requests[client_key]
if now - req_time < 10
]
return len(recent_requests) < burst_limit
async def __call__(self, request: Request, call_next):
path_config = self.is_protected_path(request.url.path)
if not path_config:
return await call_next(request)
client_key = self.get_client_key(request)
if self.tracker.should_reset_violations(client_key):
self.tracker.reset_violations(client_key)
backoff_time = self.tracker.get_backoff_time(client_key)
if backoff_time > 0:
logger.warning(
f"Rate limit backoff active for {client_key}: {backoff_time}s"
)
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"code": "RATE_LIMIT_BACKOFF",
"retry_after": backoff_time,
"strategy": "exponential_backoff",
},
headers={"Retry-After": str(backoff_time)},
)
request_count = self.tracker.get_request_count(
client_key, path_config["window"]
)
if not self.check_burst_limit(client_key, path_config["burst"]):
self.tracker.record_violation(client_key)
backoff_time = self.tracker.get_backoff_time(client_key)
logger.warning(f"Burst limit exceeded for {client_key}")
return JSONResponse(
status_code=429,
content={
"error": "Burst limit exceeded",
"code": "BURST_LIMIT_EXCEEDED",
"retry_after": backoff_time,
},
headers={"Retry-After": str(backoff_time)},
)
if request_count >= path_config["limit"]:
self.tracker.record_violation(client_key)
backoff_time = self.tracker.get_backoff_time(client_key)
logger.warning(
f"Rate limit exceeded for {client_key}: {request_count}/{path_config['limit']}"
)
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"code": "RATE_LIMIT_EXCEEDED",
"limit": path_config["limit"],
"window": path_config["window"],
"current_usage": request_count,
"retry_after": backoff_time,
},
headers={"Retry-After": str(backoff_time)},
)
self.tracker.add_request(client_key)
response = await call_next(request)
remaining = path_config["limit"] - request_count - 1
response.headers["X-RateLimit-Limit"] = str(path_config["limit"])
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
response.headers["X-RateLimit-Reset"] = str(
int(time.time() + path_config["window"])
)
return response

View File

@ -83,7 +83,7 @@ def test_generate_email_with_email():
} }
response = client.post("/api/v1/generate_email", json=payload) response = client.post("/api/v1/generate_email", json=payload)
assert response.status_code in [200, 500] assert response.status_code in [200, 429, 500]
def test_api_status(): def test_api_status():
@ -99,7 +99,7 @@ def test_invalid_json():
content="invalid json".encode("utf-8"), content="invalid json".encode("utf-8"),
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
assert response.status_code == 422 assert response.status_code in [422, 429]
def test_admin_endpoints_without_auth(): def test_admin_endpoints_without_auth():

View File

@ -0,0 +1,111 @@
import time
from fastapi.testclient import TestClient
import sys
import os
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
from src.app.main import app
client = TestClient(app)
def test_rate_limit_basic():
payload = {
"contact": "Test User",
"position": "Developer",
"company_name": "Test Company",
"segment": "IT",
}
responses = []
for i in range(5):
response = client.post("/api/v1/generate_email", json=payload)
responses.append(response.status_code)
assert any(code == 429 for code in responses[-2:])
def test_rate_limit_headers():
payload = {
"contact": "Test User",
"position": "Developer",
"company_name": "Test Company",
"segment": "IT",
}
response = client.post("/api/v1/generate_email", json=payload)
if response.status_code != 429:
assert "X-RateLimit-Limit" in response.headers
assert "X-RateLimit-Remaining" in response.headers
assert "X-RateLimit-Reset" in response.headers
def test_rate_limit_response_format():
payload = {
"contact": "Test User",
"position": "Developer",
"company_name": "Test Company",
"segment": "IT",
}
for i in range(15):
response = client.post("/api/v1/generate_email", json=payload)
if response.status_code == 429:
data = response.json()
assert "error" in data
assert "code" in data
assert "retry_after" in data
assert "Retry-After" in response.headers
break
def test_burst_limit():
payload = {
"contact": "Test User",
"position": "Developer",
"company_name": "Test Company",
"segment": "IT",
}
responses = []
for i in range(5):
response = client.post("/api/v1/generate_email", json=payload)
responses.append(response.status_code)
if i < 4:
time.sleep(0.1)
burst_exceeded = any(
client.post("/api/v1/generate_email", json=payload).status_code == 429
for _ in range(3)
)
assert burst_exceeded
def test_non_protected_endpoint():
response = client.get("/")
assert response.status_code == 200
assert "X-RateLimit-Limit" not in response.headers
def test_backoff_strategy():
payload = {
"contact": "Backoff Test",
"position": "Tester",
"company_name": "Backoff Company",
"segment": "Testing",
}
for i in range(20):
response = client.post("/api/v1/generate_email", json=payload)
if response.status_code == 429:
data = response.json()
if "strategy" in data:
assert data["strategy"] == "exponential_backoff"
return
assert False, "No backoff strategy detected"