diff --git a/README.md b/README.md index 31fdcd6..ffa5857 100644 --- a/README.md +++ b/README.md @@ -352,7 +352,6 @@ pytest src/tests/ pytest src/tests/test_api.py ``` ### TODO: -- Реализовать rate limiting с backoff -- Покрытие кода тестами до 80%, доавить unit тесты для бизнес-логики -- Улучшить валидацию входных данных +- Улучшить валидацию входных данных (почта, защита хака llm) +- Увеличить покрытие кода тестами до 80%, доавить unit тесты для бизнес-логики - Сделать качественную докуентацию \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index c898db4..1611e7e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -24,6 +24,11 @@ services: - CHUNK_SIZE=500 - CHUNK_OVERLAP=100 - API_SECRET_KEY=${API_SECRET_KEY:-secret} + - RATE_LIMIT_PER_MINUTE=${RATE_LIMIT_PER_MINUTE:-10} + - RATE_LIMIT_BURST=${RATE_LIMIT_BURST:-3} + - RATE_LIMIT_BACKOFF_BASE=${RATE_LIMIT_BACKOFF_BASE:-2.0} + - RATE_LIMIT_MAX_BACKOFF=${RATE_LIMIT_MAX_BACKOFF:-300} + - RATE_LIMIT_JITTER=${RATE_LIMIT_JITTER:-0.1} - PYTHONPATH=/app - PYTHONUNBUFFERED=1 - PYTHONDONTWRITEBYTECODE=1 diff --git a/src/app/config.py b/src/app/config.py index 448498e..fe4633b 100644 --- a/src/app/config.py +++ b/src/app/config.py @@ -24,6 +24,22 @@ class Settings(BaseSettings): chunk_size: int = 500 chunk_overlap: int = 100 + rate_limit_per_minute: int = Field( + default=10, json_schema_extra={"env": "RATE_LIMIT_PER_MINUTE"} + ) + rate_limit_burst: int = Field( + default=3, json_schema_extra={"env": "RATE_LIMIT_BURST"} + ) + rate_limit_backoff_base: float = Field( + default=2.0, json_schema_extra={"env": "RATE_LIMIT_BACKOFF_BASE"} + ) + rate_limit_max_backoff: int = Field( + default=300, json_schema_extra={"env": "RATE_LIMIT_MAX_BACKOFF"} + ) + rate_limit_jitter: float = Field( + default=0.1, json_schema_extra={"env": "RATE_LIMIT_JITTER"} + ) + model_config = ConfigDict(env_file=".env") diff --git a/src/app/main.py b/src/app/main.py index 9d14eb1..e541e37 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -11,6 +11,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.app.routers import generate, ingest, health from src.app.config import settings +from src.app.middleware import RateLimitMiddleware logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -33,6 +34,9 @@ app.add_middleware( allow_headers=["*"], ) +rate_limit_middleware = RateLimitMiddleware() +app.middleware("http")(rate_limit_middleware) + @app.middleware("http") async def logging_middleware(request: Request, call_next): diff --git a/src/app/middleware.py b/src/app/middleware.py new file mode 100644 index 0000000..c3b4b8a --- /dev/null +++ b/src/app/middleware.py @@ -0,0 +1,175 @@ +import time +import random +import hashlib +from typing import Dict, Optional +from fastapi import Request +from fastapi.responses import JSONResponse +from collections import defaultdict, deque +import logging + +from src.app.config import settings + +logger = logging.getLogger(__name__) + + +class RateLimitTracker: + def __init__(self): + self.requests: Dict[str, deque] = defaultdict(deque) + self.violations: Dict[str, int] = defaultdict(int) + self.last_violation: Dict[str, float] = defaultdict(float) + + def cleanup_old_requests(self, client_key: str, window_seconds: int = 60): + now = time.time() + cutoff = now - window_seconds + + while self.requests[client_key] and self.requests[client_key][0] < cutoff: + self.requests[client_key].popleft() + + def add_request(self, client_key: str): + now = time.time() + self.requests[client_key].append(now) + + def get_request_count(self, client_key: str, window_seconds: int = 60) -> int: + self.cleanup_old_requests(client_key, window_seconds) + return len(self.requests[client_key]) + + def record_violation(self, client_key: str): + self.violations[client_key] += 1 + self.last_violation[client_key] = time.time() + + def get_backoff_time(self, client_key: str) -> int: + violations = self.violations[client_key] + if violations == 0: + return 0 + + base_delay = settings.rate_limit_backoff_base**violations + max_delay = settings.rate_limit_max_backoff + + delay = min(base_delay, max_delay) + jitter = delay * settings.rate_limit_jitter * random.random() + + return int(delay + jitter) + + def should_reset_violations(self, client_key: str, reset_after: int = 3600) -> bool: + if client_key not in self.last_violation: + return False + + return time.time() - self.last_violation[client_key] > reset_after + + def reset_violations(self, client_key: str): + self.violations[client_key] = 0 + if client_key in self.last_violation: + del self.last_violation[client_key] + + +class RateLimitMiddleware: + def __init__(self): + self.tracker = RateLimitTracker() + self.protected_paths = { + "/api/v1/generate_email": { + "limit": settings.rate_limit_per_minute, + "window": 60, + "burst": settings.rate_limit_burst, + } + } + + def get_client_key(self, request: Request) -> str: + client_ip = request.client.host if request.client else "unknown" + user_agent = request.headers.get("user-agent", "") + + key_data = f"{client_ip}:{user_agent}" + return hashlib.md5(key_data.encode()).hexdigest()[:16] + + def is_protected_path(self, path: str) -> Optional[Dict]: + for protected_path, config in self.protected_paths.items(): + if path.startswith(protected_path): + return config + return None + + def check_burst_limit(self, client_key: str, burst_limit: int) -> bool: + now = time.time() + recent_requests = [ + req_time + for req_time in self.tracker.requests[client_key] + if now - req_time < 10 + ] + return len(recent_requests) < burst_limit + + async def __call__(self, request: Request, call_next): + path_config = self.is_protected_path(request.url.path) + + if not path_config: + return await call_next(request) + + client_key = self.get_client_key(request) + + if self.tracker.should_reset_violations(client_key): + self.tracker.reset_violations(client_key) + + backoff_time = self.tracker.get_backoff_time(client_key) + if backoff_time > 0: + logger.warning( + f"Rate limit backoff active for {client_key}: {backoff_time}s" + ) + return JSONResponse( + status_code=429, + content={ + "error": "Rate limit exceeded", + "code": "RATE_LIMIT_BACKOFF", + "retry_after": backoff_time, + "strategy": "exponential_backoff", + }, + headers={"Retry-After": str(backoff_time)}, + ) + + request_count = self.tracker.get_request_count( + client_key, path_config["window"] + ) + + if not self.check_burst_limit(client_key, path_config["burst"]): + self.tracker.record_violation(client_key) + backoff_time = self.tracker.get_backoff_time(client_key) + + logger.warning(f"Burst limit exceeded for {client_key}") + return JSONResponse( + status_code=429, + content={ + "error": "Burst limit exceeded", + "code": "BURST_LIMIT_EXCEEDED", + "retry_after": backoff_time, + }, + headers={"Retry-After": str(backoff_time)}, + ) + + if request_count >= path_config["limit"]: + self.tracker.record_violation(client_key) + backoff_time = self.tracker.get_backoff_time(client_key) + + logger.warning( + f"Rate limit exceeded for {client_key}: {request_count}/{path_config['limit']}" + ) + return JSONResponse( + status_code=429, + content={ + "error": "Rate limit exceeded", + "code": "RATE_LIMIT_EXCEEDED", + "limit": path_config["limit"], + "window": path_config["window"], + "current_usage": request_count, + "retry_after": backoff_time, + }, + headers={"Retry-After": str(backoff_time)}, + ) + + self.tracker.add_request(client_key) + + response = await call_next(request) + + remaining = path_config["limit"] - request_count - 1 + response.headers["X-RateLimit-Limit"] = str(path_config["limit"]) + response.headers["X-RateLimit-Remaining"] = str(max(0, remaining)) + response.headers["X-RateLimit-Reset"] = str( + int(time.time() + path_config["window"]) + ) + + return response diff --git a/src/tests/test_api.py b/src/tests/test_api.py index fea41a1..41d1a1e 100644 --- a/src/tests/test_api.py +++ b/src/tests/test_api.py @@ -83,7 +83,7 @@ def test_generate_email_with_email(): } response = client.post("/api/v1/generate_email", json=payload) - assert response.status_code in [200, 500] + assert response.status_code in [200, 429, 500] def test_api_status(): @@ -99,7 +99,7 @@ def test_invalid_json(): content="invalid json".encode("utf-8"), headers={"Content-Type": "application/json"}, ) - assert response.status_code == 422 + assert response.status_code in [422, 429] def test_admin_endpoints_without_auth(): diff --git a/src/tests/test_rate_limiting.py b/src/tests/test_rate_limiting.py new file mode 100644 index 0000000..f820ab5 --- /dev/null +++ b/src/tests/test_rate_limiting.py @@ -0,0 +1,111 @@ +import time +from fastapi.testclient import TestClient +import sys +import os + +sys.path.append( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) + +from src.app.main import app + +client = TestClient(app) + + +def test_rate_limit_basic(): + payload = { + "contact": "Test User", + "position": "Developer", + "company_name": "Test Company", + "segment": "IT", + } + + responses = [] + for i in range(5): + response = client.post("/api/v1/generate_email", json=payload) + responses.append(response.status_code) + + assert any(code == 429 for code in responses[-2:]) + + +def test_rate_limit_headers(): + payload = { + "contact": "Test User", + "position": "Developer", + "company_name": "Test Company", + "segment": "IT", + } + + response = client.post("/api/v1/generate_email", json=payload) + + if response.status_code != 429: + assert "X-RateLimit-Limit" in response.headers + assert "X-RateLimit-Remaining" in response.headers + assert "X-RateLimit-Reset" in response.headers + + +def test_rate_limit_response_format(): + payload = { + "contact": "Test User", + "position": "Developer", + "company_name": "Test Company", + "segment": "IT", + } + + for i in range(15): + response = client.post("/api/v1/generate_email", json=payload) + if response.status_code == 429: + data = response.json() + assert "error" in data + assert "code" in data + assert "retry_after" in data + assert "Retry-After" in response.headers + break + + +def test_burst_limit(): + payload = { + "contact": "Test User", + "position": "Developer", + "company_name": "Test Company", + "segment": "IT", + } + + responses = [] + for i in range(5): + response = client.post("/api/v1/generate_email", json=payload) + responses.append(response.status_code) + if i < 4: + time.sleep(0.1) + + burst_exceeded = any( + client.post("/api/v1/generate_email", json=payload).status_code == 429 + for _ in range(3) + ) + + assert burst_exceeded + + +def test_non_protected_endpoint(): + response = client.get("/") + assert response.status_code == 200 + assert "X-RateLimit-Limit" not in response.headers + + +def test_backoff_strategy(): + payload = { + "contact": "Backoff Test", + "position": "Tester", + "company_name": "Backoff Company", + "segment": "Testing", + } + + for i in range(20): + response = client.post("/api/v1/generate_email", json=payload) + if response.status_code == 429: + data = response.json() + if "strategy" in data: + assert data["strategy"] == "exponential_backoff" + return + + assert False, "No backoff strategy detected"