fix: whitelist

This commit is contained in:
itqop 2025-10-18 13:42:07 +03:00
parent 3746303214
commit a16e8d263b
7 changed files with 128 additions and 69 deletions

View File

@ -36,7 +36,7 @@ async def verify_api_key(
) -> str:
"""Verify API key."""
if x_api_key != context.settings.security.api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
raise HTTPException(status_code=401, detail=f"Invalid API key, {x_api_key=}, {context.settings.security.api_key=}")
return x_api_key

View File

@ -1,13 +1,13 @@
"""Whitelist endpoints."""
from fastapi import APIRouter, Depends, HTTPException
from typing import Annotated
from fastapi import APIRouter, Depends
from typing import Annotated, Optional
from hubgw.api.deps import get_whitelist_service, verify_api_key
from hubgw.services.whitelist_service import WhitelistService
from hubgw.schemas.whitelist import (
WhitelistAddRequest, WhitelistRemoveRequest, WhitelistCheckRequest,
WhitelistEntry, WhitelistCheckResponse, WhitelistListResponse
WhitelistEntry, WhitelistCheckResponse, WhitelistListResponse, WhitelistQuery
)
from hubgw.core.errors import AppError, create_http_exception
@ -58,8 +58,22 @@ async def list_players(
service: Annotated[WhitelistService, Depends(get_whitelist_service)],
_: Annotated[str, Depends(verify_api_key)]
):
"""List all whitelisted players."""
"""List all whitelisted players with optional filters and pagination."""
try:
return await service.list_players()
except AppError as e:
raise create_http_exception(e)
@router.get("/count")
async def get_count(
service: Annotated[WhitelistService, Depends(get_whitelist_service)],
_: Annotated[str, Depends(verify_api_key)]
):
"""Get total count of whitelisted players."""
try:
count = await service.repo.count()
return {"total": count}
except AppError as e:
raise create_http_exception(e)

View File

@ -1,49 +1,60 @@
"""Application configuration using Pydantic Settings."""
from pydantic import BaseModel, Field, computed_field
from pydantic_settings import BaseSettings, SettingsConfigDict
from dotenv import load_dotenv
from pydantic import Field, computed_field
from pydantic_settings import BaseSettings
from typing import Optional
load_dotenv()
class DatabaseSettings(BaseModel):
class DatabaseSettings(BaseSettings):
"""Database configuration settings."""
host: str = Field(
default="localhost",
validation_alias="DATABASE__HOST",
description="Database host"
)
port: int = Field(
default=5432,
validation_alias="DATABASE__PORT",
ge=1,
le=65535,
description="Database port"
)
user: str = Field(
default="user",
validation_alias="DATABASE__USER",
description="Database user"
)
password: str = Field(
default="pass",
validation_alias="DATABASE__PASSWORD",
description="Database password"
)
database: str = Field(
default="hubgw",
validation_alias="DATABASE__DATABASE",
description="Database name"
)
pool_size: int = Field(
default=10,
validation_alias="DATABASE__POOL_SIZE",
ge=1,
le=100,
description="Database connection pool size"
)
max_overflow: int = Field(
default=10,
validation_alias="DATABASE__MAX_OVERFLOW",
ge=0,
le=100,
description="Maximum number of overflow connections"
)
echo: bool = Field(
default=False,
validation_alias="DATABASE__ECHO",
description="Enable SQLAlchemy query logging"
)
@ -57,67 +68,57 @@ class DatabaseSettings(BaseModel):
)
class SecuritySettings(BaseModel):
class SecuritySettings(BaseSettings):
"""Security configuration settings."""
api_key: str = Field(
default="your-api-key",
validation_alias="SECURITY__API_KEY",
min_length=8,
description="API key for authentication"
)
rate_limit_per_min: Optional[int] = Field(
default=None,
validation_alias="SECURITY__RATE_LIMIT_PER_MIN",
ge=1,
description="Rate limit per minute (None = disabled)"
)
class AppSettings(BaseModel):
class AppSettings(BaseSettings):
"""Application settings."""
env: str = Field(
default="dev",
validation_alias="APP__ENV",
description="Application environment (dev/prod/test)"
)
host: str = Field(
default="0.0.0.0",
validation_alias="APP__HOST",
description="Application host"
)
port: int = Field(
default=8080,
validation_alias="APP__PORT",
ge=1,
le=65535,
description="Application port"
)
log_level: str = Field(
default="INFO",
validation_alias="APP__LOG_LEVEL",
description="Logging level"
)
class Secrets(BaseSettings):
class Secrets():
"""Main configuration container with all settings."""
app: AppSettings = Field(
default_factory=AppSettings,
description="Application settings"
)
database: DatabaseSettings = Field(
default_factory=DatabaseSettings,
description="Database settings"
)
security: SecuritySettings = Field(
default_factory=SecuritySettings,
description="Security settings"
)
app: AppSettings = AppSettings()
database: DatabaseSettings = DatabaseSettings()
security: SecuritySettings = SecuritySettings()
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
env_nested_delimiter="__",
case_sensitive=True,
extra="ignore"
)
APP_CONFIG = Secrets()

View File

@ -1,13 +1,13 @@
"""Whitelist repository."""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, insert, delete, func, update
from sqlalchemy import select, delete, func, and_
from typing import List, Optional
from uuid import UUID
from datetime import datetime
from hubgw.models.whitelist import WhitelistEntry
from hubgw.schemas.whitelist import WhitelistAddRequest, WhitelistQuery
from hubgw.schemas.whitelist import WhitelistAddRequest, WhitelistCheckRequest, WhitelistQuery, WhitelistRemoveRequest
class WhitelistRepository:
@ -32,7 +32,7 @@ class WhitelistRepository:
await self.session.refresh(entry)
return entry
async def get_by_id(self, entry_id: int) -> Optional[WhitelistEntry]:
async def get_by_id(self, entry_id: UUID) -> Optional[WhitelistEntry]:
"""Get whitelist entry by id."""
stmt = select(WhitelistEntry).where(WhitelistEntry.id == entry_id)
result = await self.session.execute(stmt)
@ -50,7 +50,7 @@ class WhitelistRepository:
await self.session.refresh(entry)
return entry
async def delete_by_id(self, entry_id: int) -> bool:
async def delete_by_id(self, entry_id: UUID) -> bool:
"""Delete whitelist entry by id."""
stmt = delete(WhitelistEntry).where(WhitelistEntry.id == entry_id)
result = await self.session.execute(stmt)
@ -62,7 +62,6 @@ class WhitelistRepository:
stmt = select(WhitelistEntry)
count_stmt = select(func.count(WhitelistEntry.id))
# Apply filters
if query.player_name:
stmt = stmt.where(WhitelistEntry.player_name.ilike(f"%{query.player_name}%"))
count_stmt = count_stmt.where(WhitelistEntry.player_name.ilike(f"%{query.player_name}%"))
@ -76,15 +75,51 @@ class WhitelistRepository:
stmt = stmt.where(WhitelistEntry.is_active == query.is_active)
count_stmt = count_stmt.where(WhitelistEntry.is_active == query.is_active)
# Get total count
count_result = await self.session.execute(count_stmt)
total = count_result.scalar()
total = count_result.scalar_one()
# Apply pagination
offset = (query.page - 1) * query.size
stmt = stmt.offset(offset).limit(query.size).order_by(WhitelistEntry.player_name)
stmt = stmt.offset(offset).limit(query.size).order_by(WhitelistEntry.added_at.desc())
result = await self.session.execute(stmt)
entries = list(result.scalars().all())
return entries, total
async def check(self, request: WhitelistCheckRequest) -> Optional[WhitelistEntry]:
"""Check if player is whitelisted."""
stmt = select(WhitelistEntry).where(
and_(
WhitelistEntry.player_name == request.player_name,
WhitelistEntry.is_active == True
)
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def remove(self, request: WhitelistRemoveRequest) -> bool:
"""Remove player from whitelist."""
stmt = delete(WhitelistEntry).where(WhitelistEntry.player_name == request.player_name)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
async def list_all(self) -> List[WhitelistEntry]:
"""List all whitelisted players."""
stmt = select(WhitelistEntry).order_by(WhitelistEntry.added_at.desc())
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def count(self) -> int:
"""Count all whitelisted players."""
stmt = select(func.count(WhitelistEntry.id))
result = await self.session.execute(stmt)
return result.scalar_one()
async def delete_by_player_name(self, player_name: str) -> bool:
"""Delete whitelist entry by player name."""
stmt = delete(WhitelistEntry).where(WhitelistEntry.player_name == player_name)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0

View File

@ -1,10 +1,9 @@
"""Whitelist schemas."""
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from typing import Optional
from datetime import datetime
from uuid import UUID
from hubgw.schemas.common import BaseSchema, PaginationParams
class WhitelistAddRequest(BaseModel):
@ -38,9 +37,10 @@ class WhitelistCheckResponse(BaseModel):
player_uuid: Optional[str] = None
class WhitelistEntry(BaseSchema):
class WhitelistEntry(BaseModel):
"""Whitelist entry schema."""
model_config = ConfigDict(from_attributes=True)
id: UUID
player_name: str
player_uuid: Optional[str] = None
@ -58,8 +58,10 @@ class WhitelistListResponse(BaseModel):
total: int
class WhitelistQuery(PaginationParams):
"""Whitelist query schema."""
class WhitelistQuery(BaseModel):
model_config = ConfigDict(from_attributes=True)
page: int = 1
size: int = 10
player_name: Optional[str] = None
player_uuid: Optional[str] = None

View File

@ -6,7 +6,7 @@ from typing import List
from hubgw.repositories.whitelist_repo import WhitelistRepository
from hubgw.schemas.whitelist import (
WhitelistAddRequest, WhitelistRemoveRequest, WhitelistCheckRequest,
WhitelistEntry, WhitelistCheckResponse, WhitelistListResponse
WhitelistEntry as SchemaWhitelistEntry, WhitelistCheckResponse, WhitelistListResponse, WhitelistQuery
)
from hubgw.core.errors import AlreadyExistsError, NotFoundError
@ -17,23 +17,30 @@ class WhitelistService:
def __init__(self, session: AsyncSession):
self.repo = WhitelistRepository(session)
async def add_player(self, request: WhitelistAddRequest) -> WhitelistEntry:
"""Add player to whitelist with business logic."""
# Check if player is already whitelisted
existing = await self.repo.check(WhitelistCheckRequest(player_name=request.player_name))
async def add_player(self, request: WhitelistAddRequest) -> SchemaWhitelistEntry:
existing = await self.repo.get_by_player_name(request.player_name)
if existing:
raise AlreadyExistsError(f"Player '{request.player_name}' is already whitelisted")
return await self.repo.add(request)
if existing.is_active:
raise AlreadyExistsError(f"Player '{request.player_name}' is already whitelisted")
else:
existing.player_uuid = request.player_uuid
existing.added_by = request.added_by
existing.added_at = request.added_at
existing.expires_at = request.expires_at
existing.is_active = request.is_active
existing.reason = request.reason
updated_entry = await self.repo.update(existing)
return SchemaWhitelistEntry.model_validate(updated_entry)
created_entry = await self.repo.create(request)
return SchemaWhitelistEntry.model_validate(created_entry)
async def remove_player(self, request: WhitelistRemoveRequest) -> None:
"""Remove player from whitelist with business logic."""
success = await self.repo.remove(request)
success = await self.repo.delete_by_player_name(request.player_name)
if not success:
raise NotFoundError(f"Player '{request.player_name}' not found in whitelist")
async def check_player(self, request: WhitelistCheckRequest) -> WhitelistCheckResponse:
"""Check if player is whitelisted."""
entry = await self.repo.check(request)
return WhitelistCheckResponse(
@ -42,22 +49,22 @@ class WhitelistService:
)
async def list_players(self) -> WhitelistListResponse:
"""List all whitelisted players."""
entries = await self.repo.list_all()
total = await self.repo.count()
total = len(entries)
entry_list = [
WhitelistEntry(
id=entry.id,
player_name=entry.player_name,
player_uuid=entry.player_uuid,
added_by=entry.added_by,
added_at=entry.added_at,
reason=entry.reason,
created_at=entry.created_at,
updated_at=entry.updated_at
)
SchemaWhitelistEntry.model_validate(entry)
for entry in entries
]
return WhitelistListResponse(entries=entry_list, total=total)
async def query_players(self, query: WhitelistQuery) -> WhitelistListResponse:
entries, total = await self.repo.query(query)
entry_list = [
SchemaWhitelistEntry.model_validate(entry)
for entry in entries
]
return WhitelistListResponse(entries=entry_list, total=total)

0
test.py Normal file
View File