fix: whitelist
This commit is contained in:
parent
3746303214
commit
a16e8d263b
|
|
@ -36,7 +36,7 @@ async def verify_api_key(
|
|||
) -> str:
|
||||
"""Verify API key."""
|
||||
if x_api_key != context.settings.security.api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
raise HTTPException(status_code=401, detail=f"Invalid API key, {x_api_key=}, {context.settings.security.api_key=}")
|
||||
return x_api_key
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
"""Whitelist endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from hubgw.api.deps import get_whitelist_service, verify_api_key
|
||||
from hubgw.services.whitelist_service import WhitelistService
|
||||
from hubgw.schemas.whitelist import (
|
||||
WhitelistAddRequest, WhitelistRemoveRequest, WhitelistCheckRequest,
|
||||
WhitelistEntry, WhitelistCheckResponse, WhitelistListResponse
|
||||
WhitelistEntry, WhitelistCheckResponse, WhitelistListResponse, WhitelistQuery
|
||||
)
|
||||
from hubgw.core.errors import AppError, create_http_exception
|
||||
|
||||
|
|
@ -58,8 +58,22 @@ async def list_players(
|
|||
service: Annotated[WhitelistService, Depends(get_whitelist_service)],
|
||||
_: Annotated[str, Depends(verify_api_key)]
|
||||
):
|
||||
"""List all whitelisted players."""
|
||||
"""List all whitelisted players with optional filters and pagination."""
|
||||
try:
|
||||
return await service.list_players()
|
||||
except AppError as e:
|
||||
raise create_http_exception(e)
|
||||
|
||||
|
||||
@router.get("/count")
|
||||
async def get_count(
|
||||
service: Annotated[WhitelistService, Depends(get_whitelist_service)],
|
||||
_: Annotated[str, Depends(verify_api_key)]
|
||||
):
|
||||
"""Get total count of whitelisted players."""
|
||||
try:
|
||||
count = await service.repo.count()
|
||||
return {"total": count}
|
||||
except AppError as e:
|
||||
raise create_http_exception(e)
|
||||
|
||||
|
|
@ -1,49 +1,60 @@
|
|||
"""Application configuration using Pydantic Settings."""
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
|
||||
load_dotenv()
|
||||
|
||||
class DatabaseSettings(BaseModel):
|
||||
|
||||
class DatabaseSettings(BaseSettings):
|
||||
"""Database configuration settings."""
|
||||
|
||||
host: str = Field(
|
||||
default="localhost",
|
||||
validation_alias="DATABASE__HOST",
|
||||
description="Database host"
|
||||
)
|
||||
port: int = Field(
|
||||
default=5432,
|
||||
validation_alias="DATABASE__PORT",
|
||||
ge=1,
|
||||
le=65535,
|
||||
description="Database port"
|
||||
)
|
||||
user: str = Field(
|
||||
default="user",
|
||||
validation_alias="DATABASE__USER",
|
||||
description="Database user"
|
||||
)
|
||||
password: str = Field(
|
||||
default="pass",
|
||||
validation_alias="DATABASE__PASSWORD",
|
||||
description="Database password"
|
||||
)
|
||||
database: str = Field(
|
||||
default="hubgw",
|
||||
validation_alias="DATABASE__DATABASE",
|
||||
description="Database name"
|
||||
)
|
||||
pool_size: int = Field(
|
||||
default=10,
|
||||
validation_alias="DATABASE__POOL_SIZE",
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Database connection pool size"
|
||||
)
|
||||
max_overflow: int = Field(
|
||||
default=10,
|
||||
validation_alias="DATABASE__MAX_OVERFLOW",
|
||||
ge=0,
|
||||
le=100,
|
||||
description="Maximum number of overflow connections"
|
||||
)
|
||||
echo: bool = Field(
|
||||
default=False,
|
||||
validation_alias="DATABASE__ECHO",
|
||||
description="Enable SQLAlchemy query logging"
|
||||
)
|
||||
|
||||
|
|
@ -57,67 +68,57 @@ class DatabaseSettings(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class SecuritySettings(BaseModel):
|
||||
class SecuritySettings(BaseSettings):
|
||||
"""Security configuration settings."""
|
||||
|
||||
api_key: str = Field(
|
||||
default="your-api-key",
|
||||
validation_alias="SECURITY__API_KEY",
|
||||
min_length=8,
|
||||
description="API key for authentication"
|
||||
)
|
||||
rate_limit_per_min: Optional[int] = Field(
|
||||
default=None,
|
||||
validation_alias="SECURITY__RATE_LIMIT_PER_MIN",
|
||||
ge=1,
|
||||
description="Rate limit per minute (None = disabled)"
|
||||
)
|
||||
|
||||
|
||||
class AppSettings(BaseModel):
|
||||
class AppSettings(BaseSettings):
|
||||
"""Application settings."""
|
||||
|
||||
env: str = Field(
|
||||
default="dev",
|
||||
validation_alias="APP__ENV",
|
||||
description="Application environment (dev/prod/test)"
|
||||
)
|
||||
host: str = Field(
|
||||
default="0.0.0.0",
|
||||
validation_alias="APP__HOST",
|
||||
description="Application host"
|
||||
)
|
||||
port: int = Field(
|
||||
default=8080,
|
||||
validation_alias="APP__PORT",
|
||||
ge=1,
|
||||
le=65535,
|
||||
description="Application port"
|
||||
)
|
||||
log_level: str = Field(
|
||||
default="INFO",
|
||||
validation_alias="APP__LOG_LEVEL",
|
||||
description="Logging level"
|
||||
)
|
||||
|
||||
|
||||
class Secrets(BaseSettings):
|
||||
class Secrets():
|
||||
"""Main configuration container with all settings."""
|
||||
|
||||
app: AppSettings = Field(
|
||||
default_factory=AppSettings,
|
||||
description="Application settings"
|
||||
)
|
||||
database: DatabaseSettings = Field(
|
||||
default_factory=DatabaseSettings,
|
||||
description="Database settings"
|
||||
)
|
||||
security: SecuritySettings = Field(
|
||||
default_factory=SecuritySettings,
|
||||
description="Security settings"
|
||||
)
|
||||
app: AppSettings = AppSettings()
|
||||
database: DatabaseSettings = DatabaseSettings()
|
||||
security: SecuritySettings = SecuritySettings()
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
env_nested_delimiter="__",
|
||||
case_sensitive=True,
|
||||
extra="ignore"
|
||||
)
|
||||
|
||||
|
||||
APP_CONFIG = Secrets()
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
"""Whitelist repository."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, insert, delete, func, update
|
||||
from sqlalchemy import select, delete, func, and_
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
|
||||
from hubgw.models.whitelist import WhitelistEntry
|
||||
from hubgw.schemas.whitelist import WhitelistAddRequest, WhitelistQuery
|
||||
from hubgw.schemas.whitelist import WhitelistAddRequest, WhitelistCheckRequest, WhitelistQuery, WhitelistRemoveRequest
|
||||
|
||||
|
||||
class WhitelistRepository:
|
||||
|
|
@ -32,7 +32,7 @@ class WhitelistRepository:
|
|||
await self.session.refresh(entry)
|
||||
return entry
|
||||
|
||||
async def get_by_id(self, entry_id: int) -> Optional[WhitelistEntry]:
|
||||
async def get_by_id(self, entry_id: UUID) -> Optional[WhitelistEntry]:
|
||||
"""Get whitelist entry by id."""
|
||||
stmt = select(WhitelistEntry).where(WhitelistEntry.id == entry_id)
|
||||
result = await self.session.execute(stmt)
|
||||
|
|
@ -50,7 +50,7 @@ class WhitelistRepository:
|
|||
await self.session.refresh(entry)
|
||||
return entry
|
||||
|
||||
async def delete_by_id(self, entry_id: int) -> bool:
|
||||
async def delete_by_id(self, entry_id: UUID) -> bool:
|
||||
"""Delete whitelist entry by id."""
|
||||
stmt = delete(WhitelistEntry).where(WhitelistEntry.id == entry_id)
|
||||
result = await self.session.execute(stmt)
|
||||
|
|
@ -62,7 +62,6 @@ class WhitelistRepository:
|
|||
stmt = select(WhitelistEntry)
|
||||
count_stmt = select(func.count(WhitelistEntry.id))
|
||||
|
||||
# Apply filters
|
||||
if query.player_name:
|
||||
stmt = stmt.where(WhitelistEntry.player_name.ilike(f"%{query.player_name}%"))
|
||||
count_stmt = count_stmt.where(WhitelistEntry.player_name.ilike(f"%{query.player_name}%"))
|
||||
|
|
@ -76,15 +75,51 @@ class WhitelistRepository:
|
|||
stmt = stmt.where(WhitelistEntry.is_active == query.is_active)
|
||||
count_stmt = count_stmt.where(WhitelistEntry.is_active == query.is_active)
|
||||
|
||||
# Get total count
|
||||
count_result = await self.session.execute(count_stmt)
|
||||
total = count_result.scalar()
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination
|
||||
offset = (query.page - 1) * query.size
|
||||
stmt = stmt.offset(offset).limit(query.size).order_by(WhitelistEntry.player_name)
|
||||
stmt = stmt.offset(offset).limit(query.size).order_by(WhitelistEntry.added_at.desc())
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
entries = list(result.scalars().all())
|
||||
|
||||
return entries, total
|
||||
|
||||
|
||||
async def check(self, request: WhitelistCheckRequest) -> Optional[WhitelistEntry]:
|
||||
"""Check if player is whitelisted."""
|
||||
stmt = select(WhitelistEntry).where(
|
||||
and_(
|
||||
WhitelistEntry.player_name == request.player_name,
|
||||
WhitelistEntry.is_active == True
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def remove(self, request: WhitelistRemoveRequest) -> bool:
|
||||
"""Remove player from whitelist."""
|
||||
stmt = delete(WhitelistEntry).where(WhitelistEntry.player_name == request.player_name)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def list_all(self) -> List[WhitelistEntry]:
|
||||
"""List all whitelisted players."""
|
||||
stmt = select(WhitelistEntry).order_by(WhitelistEntry.added_at.desc())
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Count all whitelisted players."""
|
||||
stmt = select(func.count(WhitelistEntry.id))
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one()
|
||||
|
||||
async def delete_by_player_name(self, player_name: str) -> bool:
|
||||
"""Delete whitelist entry by player name."""
|
||||
stmt = delete(WhitelistEntry).where(WhitelistEntry.player_name == player_name)
|
||||
result = await self.session.execute(stmt)
|
||||
await self.session.commit()
|
||||
return result.rowcount > 0
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
"""Whitelist schemas."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from hubgw.schemas.common import BaseSchema, PaginationParams
|
||||
|
||||
|
||||
class WhitelistAddRequest(BaseModel):
|
||||
|
|
@ -38,9 +37,10 @@ class WhitelistCheckResponse(BaseModel):
|
|||
player_uuid: Optional[str] = None
|
||||
|
||||
|
||||
class WhitelistEntry(BaseSchema):
|
||||
class WhitelistEntry(BaseModel):
|
||||
"""Whitelist entry schema."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
player_name: str
|
||||
player_uuid: Optional[str] = None
|
||||
|
|
@ -58,8 +58,10 @@ class WhitelistListResponse(BaseModel):
|
|||
total: int
|
||||
|
||||
|
||||
class WhitelistQuery(PaginationParams):
|
||||
"""Whitelist query schema."""
|
||||
class WhitelistQuery(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
page: int = 1
|
||||
size: int = 10
|
||||
|
||||
player_name: Optional[str] = None
|
||||
player_uuid: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import List
|
|||
from hubgw.repositories.whitelist_repo import WhitelistRepository
|
||||
from hubgw.schemas.whitelist import (
|
||||
WhitelistAddRequest, WhitelistRemoveRequest, WhitelistCheckRequest,
|
||||
WhitelistEntry, WhitelistCheckResponse, WhitelistListResponse
|
||||
WhitelistEntry as SchemaWhitelistEntry, WhitelistCheckResponse, WhitelistListResponse, WhitelistQuery
|
||||
)
|
||||
from hubgw.core.errors import AlreadyExistsError, NotFoundError
|
||||
|
||||
|
|
@ -17,23 +17,30 @@ class WhitelistService:
|
|||
def __init__(self, session: AsyncSession):
|
||||
self.repo = WhitelistRepository(session)
|
||||
|
||||
async def add_player(self, request: WhitelistAddRequest) -> WhitelistEntry:
|
||||
"""Add player to whitelist with business logic."""
|
||||
# Check if player is already whitelisted
|
||||
existing = await self.repo.check(WhitelistCheckRequest(player_name=request.player_name))
|
||||
async def add_player(self, request: WhitelistAddRequest) -> SchemaWhitelistEntry:
|
||||
existing = await self.repo.get_by_player_name(request.player_name)
|
||||
if existing:
|
||||
raise AlreadyExistsError(f"Player '{request.player_name}' is already whitelisted")
|
||||
|
||||
return await self.repo.add(request)
|
||||
if existing.is_active:
|
||||
raise AlreadyExistsError(f"Player '{request.player_name}' is already whitelisted")
|
||||
else:
|
||||
existing.player_uuid = request.player_uuid
|
||||
existing.added_by = request.added_by
|
||||
existing.added_at = request.added_at
|
||||
existing.expires_at = request.expires_at
|
||||
existing.is_active = request.is_active
|
||||
existing.reason = request.reason
|
||||
updated_entry = await self.repo.update(existing)
|
||||
return SchemaWhitelistEntry.model_validate(updated_entry)
|
||||
|
||||
created_entry = await self.repo.create(request)
|
||||
return SchemaWhitelistEntry.model_validate(created_entry)
|
||||
|
||||
async def remove_player(self, request: WhitelistRemoveRequest) -> None:
|
||||
"""Remove player from whitelist with business logic."""
|
||||
success = await self.repo.remove(request)
|
||||
success = await self.repo.delete_by_player_name(request.player_name)
|
||||
if not success:
|
||||
raise NotFoundError(f"Player '{request.player_name}' not found in whitelist")
|
||||
|
||||
async def check_player(self, request: WhitelistCheckRequest) -> WhitelistCheckResponse:
|
||||
"""Check if player is whitelisted."""
|
||||
entry = await self.repo.check(request)
|
||||
|
||||
return WhitelistCheckResponse(
|
||||
|
|
@ -42,22 +49,22 @@ class WhitelistService:
|
|||
)
|
||||
|
||||
async def list_players(self) -> WhitelistListResponse:
|
||||
"""List all whitelisted players."""
|
||||
entries = await self.repo.list_all()
|
||||
total = await self.repo.count()
|
||||
|
||||
total = len(entries)
|
||||
|
||||
entry_list = [
|
||||
WhitelistEntry(
|
||||
id=entry.id,
|
||||
player_name=entry.player_name,
|
||||
player_uuid=entry.player_uuid,
|
||||
added_by=entry.added_by,
|
||||
added_at=entry.added_at,
|
||||
reason=entry.reason,
|
||||
created_at=entry.created_at,
|
||||
updated_at=entry.updated_at
|
||||
)
|
||||
SchemaWhitelistEntry.model_validate(entry)
|
||||
for entry in entries
|
||||
]
|
||||
|
||||
return WhitelistListResponse(entries=entry_list, total=total)
|
||||
|
||||
async def query_players(self, query: WhitelistQuery) -> WhitelistListResponse:
|
||||
entries, total = await self.repo.query(query)
|
||||
|
||||
entry_list = [
|
||||
SchemaWhitelistEntry.model_validate(entry)
|
||||
for entry in entries
|
||||
]
|
||||
|
||||
return WhitelistListResponse(entries=entry_list, total=total)
|
||||
|
|
|
|||
Loading…
Reference in New Issue