Add rate limiting with backoff
CI / lint_and_test (push) Successful in 2m34s
Details
CI / lint_and_test (push) Successful in 2m34s
Details
This commit is contained in:
parent
d32beef6af
commit
0aa8964c6c
|
@ -352,7 +352,6 @@ pytest src/tests/
|
|||
pytest src/tests/test_api.py
|
||||
```
|
||||
### TODO:
|
||||
- Реализовать rate limiting с backoff
|
||||
- Покрытие кода тестами до 80%, доавить unit тесты для бизнес-логики
|
||||
- Улучшить валидацию входных данных
|
||||
- Улучшить валидацию входных данных (почта, защита хака llm)
|
||||
- Увеличить покрытие кода тестами до 80%, доавить unit тесты для бизнес-логики
|
||||
- Сделать качественную докуентацию
|
|
@ -24,6 +24,11 @@ services:
|
|||
- CHUNK_SIZE=500
|
||||
- CHUNK_OVERLAP=100
|
||||
- API_SECRET_KEY=${API_SECRET_KEY:-secret}
|
||||
- RATE_LIMIT_PER_MINUTE=${RATE_LIMIT_PER_MINUTE:-10}
|
||||
- RATE_LIMIT_BURST=${RATE_LIMIT_BURST:-3}
|
||||
- RATE_LIMIT_BACKOFF_BASE=${RATE_LIMIT_BACKOFF_BASE:-2.0}
|
||||
- RATE_LIMIT_MAX_BACKOFF=${RATE_LIMIT_MAX_BACKOFF:-300}
|
||||
- RATE_LIMIT_JITTER=${RATE_LIMIT_JITTER:-0.1}
|
||||
- PYTHONPATH=/app
|
||||
- PYTHONUNBUFFERED=1
|
||||
- PYTHONDONTWRITEBYTECODE=1
|
||||
|
|
|
@ -24,6 +24,22 @@ class Settings(BaseSettings):
|
|||
chunk_size: int = 500
|
||||
chunk_overlap: int = 100
|
||||
|
||||
rate_limit_per_minute: int = Field(
|
||||
default=10, json_schema_extra={"env": "RATE_LIMIT_PER_MINUTE"}
|
||||
)
|
||||
rate_limit_burst: int = Field(
|
||||
default=3, json_schema_extra={"env": "RATE_LIMIT_BURST"}
|
||||
)
|
||||
rate_limit_backoff_base: float = Field(
|
||||
default=2.0, json_schema_extra={"env": "RATE_LIMIT_BACKOFF_BASE"}
|
||||
)
|
||||
rate_limit_max_backoff: int = Field(
|
||||
default=300, json_schema_extra={"env": "RATE_LIMIT_MAX_BACKOFF"}
|
||||
)
|
||||
rate_limit_jitter: float = Field(
|
||||
default=0.1, json_schema_extra={"env": "RATE_LIMIT_JITTER"}
|
||||
)
|
||||
|
||||
model_config = ConfigDict(env_file=".env")
|
||||
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||
|
||||
from src.app.routers import generate, ingest, health
|
||||
from src.app.config import settings
|
||||
from src.app.middleware import RateLimitMiddleware
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
@ -33,6 +34,9 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
rate_limit_middleware = RateLimitMiddleware()
|
||||
app.middleware("http")(rate_limit_middleware)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def logging_middleware(request: Request, call_next):
|
||||
|
|
|
@ -0,0 +1,175 @@
|
|||
import time
|
||||
import random
|
||||
import hashlib
|
||||
from typing import Dict, Optional
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from collections import defaultdict, deque
|
||||
import logging
|
||||
|
||||
from src.app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitTracker:
|
||||
def __init__(self):
|
||||
self.requests: Dict[str, deque] = defaultdict(deque)
|
||||
self.violations: Dict[str, int] = defaultdict(int)
|
||||
self.last_violation: Dict[str, float] = defaultdict(float)
|
||||
|
||||
def cleanup_old_requests(self, client_key: str, window_seconds: int = 60):
|
||||
now = time.time()
|
||||
cutoff = now - window_seconds
|
||||
|
||||
while self.requests[client_key] and self.requests[client_key][0] < cutoff:
|
||||
self.requests[client_key].popleft()
|
||||
|
||||
def add_request(self, client_key: str):
|
||||
now = time.time()
|
||||
self.requests[client_key].append(now)
|
||||
|
||||
def get_request_count(self, client_key: str, window_seconds: int = 60) -> int:
|
||||
self.cleanup_old_requests(client_key, window_seconds)
|
||||
return len(self.requests[client_key])
|
||||
|
||||
def record_violation(self, client_key: str):
|
||||
self.violations[client_key] += 1
|
||||
self.last_violation[client_key] = time.time()
|
||||
|
||||
def get_backoff_time(self, client_key: str) -> int:
|
||||
violations = self.violations[client_key]
|
||||
if violations == 0:
|
||||
return 0
|
||||
|
||||
base_delay = settings.rate_limit_backoff_base**violations
|
||||
max_delay = settings.rate_limit_max_backoff
|
||||
|
||||
delay = min(base_delay, max_delay)
|
||||
jitter = delay * settings.rate_limit_jitter * random.random()
|
||||
|
||||
return int(delay + jitter)
|
||||
|
||||
def should_reset_violations(self, client_key: str, reset_after: int = 3600) -> bool:
|
||||
if client_key not in self.last_violation:
|
||||
return False
|
||||
|
||||
return time.time() - self.last_violation[client_key] > reset_after
|
||||
|
||||
def reset_violations(self, client_key: str):
|
||||
self.violations[client_key] = 0
|
||||
if client_key in self.last_violation:
|
||||
del self.last_violation[client_key]
|
||||
|
||||
|
||||
class RateLimitMiddleware:
|
||||
def __init__(self):
|
||||
self.tracker = RateLimitTracker()
|
||||
self.protected_paths = {
|
||||
"/api/v1/generate_email": {
|
||||
"limit": settings.rate_limit_per_minute,
|
||||
"window": 60,
|
||||
"burst": settings.rate_limit_burst,
|
||||
}
|
||||
}
|
||||
|
||||
def get_client_key(self, request: Request) -> str:
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
|
||||
key_data = f"{client_ip}:{user_agent}"
|
||||
return hashlib.md5(key_data.encode()).hexdigest()[:16]
|
||||
|
||||
def is_protected_path(self, path: str) -> Optional[Dict]:
|
||||
for protected_path, config in self.protected_paths.items():
|
||||
if path.startswith(protected_path):
|
||||
return config
|
||||
return None
|
||||
|
||||
def check_burst_limit(self, client_key: str, burst_limit: int) -> bool:
|
||||
now = time.time()
|
||||
recent_requests = [
|
||||
req_time
|
||||
for req_time in self.tracker.requests[client_key]
|
||||
if now - req_time < 10
|
||||
]
|
||||
return len(recent_requests) < burst_limit
|
||||
|
||||
async def __call__(self, request: Request, call_next):
|
||||
path_config = self.is_protected_path(request.url.path)
|
||||
|
||||
if not path_config:
|
||||
return await call_next(request)
|
||||
|
||||
client_key = self.get_client_key(request)
|
||||
|
||||
if self.tracker.should_reset_violations(client_key):
|
||||
self.tracker.reset_violations(client_key)
|
||||
|
||||
backoff_time = self.tracker.get_backoff_time(client_key)
|
||||
if backoff_time > 0:
|
||||
logger.warning(
|
||||
f"Rate limit backoff active for {client_key}: {backoff_time}s"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "Rate limit exceeded",
|
||||
"code": "RATE_LIMIT_BACKOFF",
|
||||
"retry_after": backoff_time,
|
||||
"strategy": "exponential_backoff",
|
||||
},
|
||||
headers={"Retry-After": str(backoff_time)},
|
||||
)
|
||||
|
||||
request_count = self.tracker.get_request_count(
|
||||
client_key, path_config["window"]
|
||||
)
|
||||
|
||||
if not self.check_burst_limit(client_key, path_config["burst"]):
|
||||
self.tracker.record_violation(client_key)
|
||||
backoff_time = self.tracker.get_backoff_time(client_key)
|
||||
|
||||
logger.warning(f"Burst limit exceeded for {client_key}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "Burst limit exceeded",
|
||||
"code": "BURST_LIMIT_EXCEEDED",
|
||||
"retry_after": backoff_time,
|
||||
},
|
||||
headers={"Retry-After": str(backoff_time)},
|
||||
)
|
||||
|
||||
if request_count >= path_config["limit"]:
|
||||
self.tracker.record_violation(client_key)
|
||||
backoff_time = self.tracker.get_backoff_time(client_key)
|
||||
|
||||
logger.warning(
|
||||
f"Rate limit exceeded for {client_key}: {request_count}/{path_config['limit']}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "Rate limit exceeded",
|
||||
"code": "RATE_LIMIT_EXCEEDED",
|
||||
"limit": path_config["limit"],
|
||||
"window": path_config["window"],
|
||||
"current_usage": request_count,
|
||||
"retry_after": backoff_time,
|
||||
},
|
||||
headers={"Retry-After": str(backoff_time)},
|
||||
)
|
||||
|
||||
self.tracker.add_request(client_key)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
remaining = path_config["limit"] - request_count - 1
|
||||
response.headers["X-RateLimit-Limit"] = str(path_config["limit"])
|
||||
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
|
||||
response.headers["X-RateLimit-Reset"] = str(
|
||||
int(time.time() + path_config["window"])
|
||||
)
|
||||
|
||||
return response
|
|
@ -83,7 +83,7 @@ def test_generate_email_with_email():
|
|||
}
|
||||
|
||||
response = client.post("/api/v1/generate_email", json=payload)
|
||||
assert response.status_code in [200, 500]
|
||||
assert response.status_code in [200, 429, 500]
|
||||
|
||||
|
||||
def test_api_status():
|
||||
|
@ -99,7 +99,7 @@ def test_invalid_json():
|
|||
content="invalid json".encode("utf-8"),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
assert response.status_code in [422, 429]
|
||||
|
||||
|
||||
def test_admin_endpoints_without_auth():
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
import time
|
||||
from fastapi.testclient import TestClient
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.append(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
|
||||
from src.app.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_rate_limit_basic():
|
||||
payload = {
|
||||
"contact": "Test User",
|
||||
"position": "Developer",
|
||||
"company_name": "Test Company",
|
||||
"segment": "IT",
|
||||
}
|
||||
|
||||
responses = []
|
||||
for i in range(5):
|
||||
response = client.post("/api/v1/generate_email", json=payload)
|
||||
responses.append(response.status_code)
|
||||
|
||||
assert any(code == 429 for code in responses[-2:])
|
||||
|
||||
|
||||
def test_rate_limit_headers():
|
||||
payload = {
|
||||
"contact": "Test User",
|
||||
"position": "Developer",
|
||||
"company_name": "Test Company",
|
||||
"segment": "IT",
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/generate_email", json=payload)
|
||||
|
||||
if response.status_code != 429:
|
||||
assert "X-RateLimit-Limit" in response.headers
|
||||
assert "X-RateLimit-Remaining" in response.headers
|
||||
assert "X-RateLimit-Reset" in response.headers
|
||||
|
||||
|
||||
def test_rate_limit_response_format():
|
||||
payload = {
|
||||
"contact": "Test User",
|
||||
"position": "Developer",
|
||||
"company_name": "Test Company",
|
||||
"segment": "IT",
|
||||
}
|
||||
|
||||
for i in range(15):
|
||||
response = client.post("/api/v1/generate_email", json=payload)
|
||||
if response.status_code == 429:
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert "code" in data
|
||||
assert "retry_after" in data
|
||||
assert "Retry-After" in response.headers
|
||||
break
|
||||
|
||||
|
||||
def test_burst_limit():
|
||||
payload = {
|
||||
"contact": "Test User",
|
||||
"position": "Developer",
|
||||
"company_name": "Test Company",
|
||||
"segment": "IT",
|
||||
}
|
||||
|
||||
responses = []
|
||||
for i in range(5):
|
||||
response = client.post("/api/v1/generate_email", json=payload)
|
||||
responses.append(response.status_code)
|
||||
if i < 4:
|
||||
time.sleep(0.1)
|
||||
|
||||
burst_exceeded = any(
|
||||
client.post("/api/v1/generate_email", json=payload).status_code == 429
|
||||
for _ in range(3)
|
||||
)
|
||||
|
||||
assert burst_exceeded
|
||||
|
||||
|
||||
def test_non_protected_endpoint():
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert "X-RateLimit-Limit" not in response.headers
|
||||
|
||||
|
||||
def test_backoff_strategy():
|
||||
payload = {
|
||||
"contact": "Backoff Test",
|
||||
"position": "Tester",
|
||||
"company_name": "Backoff Company",
|
||||
"segment": "Testing",
|
||||
}
|
||||
|
||||
for i in range(20):
|
||||
response = client.post("/api/v1/generate_email", json=payload)
|
||||
if response.status_code == 429:
|
||||
data = response.json()
|
||||
if "strategy" in data:
|
||||
assert data["strategy"] == "exponential_backoff"
|
||||
return
|
||||
|
||||
assert False, "No backoff strategy detected"
|
Loading…
Reference in New Issue