Add rate limiting with backoff
CI / lint_and_test (push) Successful in 2m34s
Details
CI / lint_and_test (push) Successful in 2m34s
Details
This commit is contained in:
parent
d32beef6af
commit
0aa8964c6c
|
@ -352,7 +352,6 @@ pytest src/tests/
|
||||||
pytest src/tests/test_api.py
|
pytest src/tests/test_api.py
|
||||||
```
|
```
|
||||||
### TODO:
|
### TODO:
|
||||||
- Реализовать rate limiting с backoff
|
- Улучшить валидацию входных данных (почта, защита хака llm)
|
||||||
- Покрытие кода тестами до 80%, доавить unit тесты для бизнес-логики
|
- Увеличить покрытие кода тестами до 80%, доавить unit тесты для бизнес-логики
|
||||||
- Улучшить валидацию входных данных
|
|
||||||
- Сделать качественную докуентацию
|
- Сделать качественную докуентацию
|
|
@ -24,6 +24,11 @@ services:
|
||||||
- CHUNK_SIZE=500
|
- CHUNK_SIZE=500
|
||||||
- CHUNK_OVERLAP=100
|
- CHUNK_OVERLAP=100
|
||||||
- API_SECRET_KEY=${API_SECRET_KEY:-secret}
|
- API_SECRET_KEY=${API_SECRET_KEY:-secret}
|
||||||
|
- RATE_LIMIT_PER_MINUTE=${RATE_LIMIT_PER_MINUTE:-10}
|
||||||
|
- RATE_LIMIT_BURST=${RATE_LIMIT_BURST:-3}
|
||||||
|
- RATE_LIMIT_BACKOFF_BASE=${RATE_LIMIT_BACKOFF_BASE:-2.0}
|
||||||
|
- RATE_LIMIT_MAX_BACKOFF=${RATE_LIMIT_MAX_BACKOFF:-300}
|
||||||
|
- RATE_LIMIT_JITTER=${RATE_LIMIT_JITTER:-0.1}
|
||||||
- PYTHONPATH=/app
|
- PYTHONPATH=/app
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
- PYTHONDONTWRITEBYTECODE=1
|
- PYTHONDONTWRITEBYTECODE=1
|
||||||
|
|
|
@ -24,6 +24,22 @@ class Settings(BaseSettings):
|
||||||
chunk_size: int = 500
|
chunk_size: int = 500
|
||||||
chunk_overlap: int = 100
|
chunk_overlap: int = 100
|
||||||
|
|
||||||
|
rate_limit_per_minute: int = Field(
|
||||||
|
default=10, json_schema_extra={"env": "RATE_LIMIT_PER_MINUTE"}
|
||||||
|
)
|
||||||
|
rate_limit_burst: int = Field(
|
||||||
|
default=3, json_schema_extra={"env": "RATE_LIMIT_BURST"}
|
||||||
|
)
|
||||||
|
rate_limit_backoff_base: float = Field(
|
||||||
|
default=2.0, json_schema_extra={"env": "RATE_LIMIT_BACKOFF_BASE"}
|
||||||
|
)
|
||||||
|
rate_limit_max_backoff: int = Field(
|
||||||
|
default=300, json_schema_extra={"env": "RATE_LIMIT_MAX_BACKOFF"}
|
||||||
|
)
|
||||||
|
rate_limit_jitter: float = Field(
|
||||||
|
default=0.1, json_schema_extra={"env": "RATE_LIMIT_JITTER"}
|
||||||
|
)
|
||||||
|
|
||||||
model_config = ConfigDict(env_file=".env")
|
model_config = ConfigDict(env_file=".env")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
from src.app.routers import generate, ingest, health
|
from src.app.routers import generate, ingest, health
|
||||||
from src.app.config import settings
|
from src.app.config import settings
|
||||||
|
from src.app.middleware import RateLimitMiddleware
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
@ -33,6 +34,9 @@ app.add_middleware(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rate_limit_middleware = RateLimitMiddleware()
|
||||||
|
app.middleware("http")(rate_limit_middleware)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def logging_middleware(request: Request, call_next):
|
async def logging_middleware(request: Request, call_next):
|
||||||
|
|
|
@ -0,0 +1,175 @@
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import hashlib
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from src.app.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitTracker:
|
||||||
|
def __init__(self):
|
||||||
|
self.requests: Dict[str, deque] = defaultdict(deque)
|
||||||
|
self.violations: Dict[str, int] = defaultdict(int)
|
||||||
|
self.last_violation: Dict[str, float] = defaultdict(float)
|
||||||
|
|
||||||
|
def cleanup_old_requests(self, client_key: str, window_seconds: int = 60):
|
||||||
|
now = time.time()
|
||||||
|
cutoff = now - window_seconds
|
||||||
|
|
||||||
|
while self.requests[client_key] and self.requests[client_key][0] < cutoff:
|
||||||
|
self.requests[client_key].popleft()
|
||||||
|
|
||||||
|
def add_request(self, client_key: str):
|
||||||
|
now = time.time()
|
||||||
|
self.requests[client_key].append(now)
|
||||||
|
|
||||||
|
def get_request_count(self, client_key: str, window_seconds: int = 60) -> int:
|
||||||
|
self.cleanup_old_requests(client_key, window_seconds)
|
||||||
|
return len(self.requests[client_key])
|
||||||
|
|
||||||
|
def record_violation(self, client_key: str):
|
||||||
|
self.violations[client_key] += 1
|
||||||
|
self.last_violation[client_key] = time.time()
|
||||||
|
|
||||||
|
def get_backoff_time(self, client_key: str) -> int:
|
||||||
|
violations = self.violations[client_key]
|
||||||
|
if violations == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
base_delay = settings.rate_limit_backoff_base**violations
|
||||||
|
max_delay = settings.rate_limit_max_backoff
|
||||||
|
|
||||||
|
delay = min(base_delay, max_delay)
|
||||||
|
jitter = delay * settings.rate_limit_jitter * random.random()
|
||||||
|
|
||||||
|
return int(delay + jitter)
|
||||||
|
|
||||||
|
def should_reset_violations(self, client_key: str, reset_after: int = 3600) -> bool:
|
||||||
|
if client_key not in self.last_violation:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return time.time() - self.last_violation[client_key] > reset_after
|
||||||
|
|
||||||
|
def reset_violations(self, client_key: str):
|
||||||
|
self.violations[client_key] = 0
|
||||||
|
if client_key in self.last_violation:
|
||||||
|
del self.last_violation[client_key]
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitMiddleware:
|
||||||
|
def __init__(self):
|
||||||
|
self.tracker = RateLimitTracker()
|
||||||
|
self.protected_paths = {
|
||||||
|
"/api/v1/generate_email": {
|
||||||
|
"limit": settings.rate_limit_per_minute,
|
||||||
|
"window": 60,
|
||||||
|
"burst": settings.rate_limit_burst,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_client_key(self, request: Request) -> str:
|
||||||
|
client_ip = request.client.host if request.client else "unknown"
|
||||||
|
user_agent = request.headers.get("user-agent", "")
|
||||||
|
|
||||||
|
key_data = f"{client_ip}:{user_agent}"
|
||||||
|
return hashlib.md5(key_data.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
def is_protected_path(self, path: str) -> Optional[Dict]:
|
||||||
|
for protected_path, config in self.protected_paths.items():
|
||||||
|
if path.startswith(protected_path):
|
||||||
|
return config
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_burst_limit(self, client_key: str, burst_limit: int) -> bool:
|
||||||
|
now = time.time()
|
||||||
|
recent_requests = [
|
||||||
|
req_time
|
||||||
|
for req_time in self.tracker.requests[client_key]
|
||||||
|
if now - req_time < 10
|
||||||
|
]
|
||||||
|
return len(recent_requests) < burst_limit
|
||||||
|
|
||||||
|
async def __call__(self, request: Request, call_next):
|
||||||
|
path_config = self.is_protected_path(request.url.path)
|
||||||
|
|
||||||
|
if not path_config:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
client_key = self.get_client_key(request)
|
||||||
|
|
||||||
|
if self.tracker.should_reset_violations(client_key):
|
||||||
|
self.tracker.reset_violations(client_key)
|
||||||
|
|
||||||
|
backoff_time = self.tracker.get_backoff_time(client_key)
|
||||||
|
if backoff_time > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Rate limit backoff active for {client_key}: {backoff_time}s"
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"error": "Rate limit exceeded",
|
||||||
|
"code": "RATE_LIMIT_BACKOFF",
|
||||||
|
"retry_after": backoff_time,
|
||||||
|
"strategy": "exponential_backoff",
|
||||||
|
},
|
||||||
|
headers={"Retry-After": str(backoff_time)},
|
||||||
|
)
|
||||||
|
|
||||||
|
request_count = self.tracker.get_request_count(
|
||||||
|
client_key, path_config["window"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.check_burst_limit(client_key, path_config["burst"]):
|
||||||
|
self.tracker.record_violation(client_key)
|
||||||
|
backoff_time = self.tracker.get_backoff_time(client_key)
|
||||||
|
|
||||||
|
logger.warning(f"Burst limit exceeded for {client_key}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"error": "Burst limit exceeded",
|
||||||
|
"code": "BURST_LIMIT_EXCEEDED",
|
||||||
|
"retry_after": backoff_time,
|
||||||
|
},
|
||||||
|
headers={"Retry-After": str(backoff_time)},
|
||||||
|
)
|
||||||
|
|
||||||
|
if request_count >= path_config["limit"]:
|
||||||
|
self.tracker.record_violation(client_key)
|
||||||
|
backoff_time = self.tracker.get_backoff_time(client_key)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"Rate limit exceeded for {client_key}: {request_count}/{path_config['limit']}"
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"error": "Rate limit exceeded",
|
||||||
|
"code": "RATE_LIMIT_EXCEEDED",
|
||||||
|
"limit": path_config["limit"],
|
||||||
|
"window": path_config["window"],
|
||||||
|
"current_usage": request_count,
|
||||||
|
"retry_after": backoff_time,
|
||||||
|
},
|
||||||
|
headers={"Retry-After": str(backoff_time)},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tracker.add_request(client_key)
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
remaining = path_config["limit"] - request_count - 1
|
||||||
|
response.headers["X-RateLimit-Limit"] = str(path_config["limit"])
|
||||||
|
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
|
||||||
|
response.headers["X-RateLimit-Reset"] = str(
|
||||||
|
int(time.time() + path_config["window"])
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
|
@ -83,7 +83,7 @@ def test_generate_email_with_email():
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/v1/generate_email", json=payload)
|
response = client.post("/api/v1/generate_email", json=payload)
|
||||||
assert response.status_code in [200, 500]
|
assert response.status_code in [200, 429, 500]
|
||||||
|
|
||||||
|
|
||||||
def test_api_status():
|
def test_api_status():
|
||||||
|
@ -99,7 +99,7 @@ def test_invalid_json():
|
||||||
content="invalid json".encode("utf-8"),
|
content="invalid json".encode("utf-8"),
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
assert response.status_code == 422
|
assert response.status_code in [422, 429]
|
||||||
|
|
||||||
|
|
||||||
def test_admin_endpoints_without_auth():
|
def test_admin_endpoints_without_auth():
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
import time
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.append(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.app.main import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_basic():
|
||||||
|
payload = {
|
||||||
|
"contact": "Test User",
|
||||||
|
"position": "Developer",
|
||||||
|
"company_name": "Test Company",
|
||||||
|
"segment": "IT",
|
||||||
|
}
|
||||||
|
|
||||||
|
responses = []
|
||||||
|
for i in range(5):
|
||||||
|
response = client.post("/api/v1/generate_email", json=payload)
|
||||||
|
responses.append(response.status_code)
|
||||||
|
|
||||||
|
assert any(code == 429 for code in responses[-2:])
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_headers():
|
||||||
|
payload = {
|
||||||
|
"contact": "Test User",
|
||||||
|
"position": "Developer",
|
||||||
|
"company_name": "Test Company",
|
||||||
|
"segment": "IT",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/api/v1/generate_email", json=payload)
|
||||||
|
|
||||||
|
if response.status_code != 429:
|
||||||
|
assert "X-RateLimit-Limit" in response.headers
|
||||||
|
assert "X-RateLimit-Remaining" in response.headers
|
||||||
|
assert "X-RateLimit-Reset" in response.headers
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_response_format():
|
||||||
|
payload = {
|
||||||
|
"contact": "Test User",
|
||||||
|
"position": "Developer",
|
||||||
|
"company_name": "Test Company",
|
||||||
|
"segment": "IT",
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(15):
|
||||||
|
response = client.post("/api/v1/generate_email", json=payload)
|
||||||
|
if response.status_code == 429:
|
||||||
|
data = response.json()
|
||||||
|
assert "error" in data
|
||||||
|
assert "code" in data
|
||||||
|
assert "retry_after" in data
|
||||||
|
assert "Retry-After" in response.headers
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def test_burst_limit():
|
||||||
|
payload = {
|
||||||
|
"contact": "Test User",
|
||||||
|
"position": "Developer",
|
||||||
|
"company_name": "Test Company",
|
||||||
|
"segment": "IT",
|
||||||
|
}
|
||||||
|
|
||||||
|
responses = []
|
||||||
|
for i in range(5):
|
||||||
|
response = client.post("/api/v1/generate_email", json=payload)
|
||||||
|
responses.append(response.status_code)
|
||||||
|
if i < 4:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
burst_exceeded = any(
|
||||||
|
client.post("/api/v1/generate_email", json=payload).status_code == 429
|
||||||
|
for _ in range(3)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert burst_exceeded
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_protected_endpoint():
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "X-RateLimit-Limit" not in response.headers
|
||||||
|
|
||||||
|
|
||||||
|
def test_backoff_strategy():
|
||||||
|
payload = {
|
||||||
|
"contact": "Backoff Test",
|
||||||
|
"position": "Tester",
|
||||||
|
"company_name": "Backoff Company",
|
||||||
|
"segment": "Testing",
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(20):
|
||||||
|
response = client.post("/api/v1/generate_email", json=payload)
|
||||||
|
if response.status_code == 429:
|
||||||
|
data = response.json()
|
||||||
|
if "strategy" in data:
|
||||||
|
assert data["strategy"] == "exponential_backoff"
|
||||||
|
return
|
||||||
|
|
||||||
|
assert False, "No backoff strategy detected"
|
Loading…
Reference in New Issue