256 lines
6.8 KiB
Python
256 lines
6.8 KiB
Python
"""
|
|
Security utilities for authentication and authorization.
|
|
|
|
Design goals:
|
|
- Strong password hashing with Argon2id (no bcrypt 72-byte issues, no pre-hash needed)
|
|
- Clear token types (access/refresh) with strict validation
|
|
- Optional issuer/audience support (safe defaults)
|
|
- Token IDs (jti) for future revocation support
|
|
- Settings fetched via dependency-friendly getter (no module-level hard binding)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any, Mapping, Optional
|
|
from uuid import uuid4
|
|
|
|
from jose import JWTError, jwt
|
|
from passlib.context import CryptContext
|
|
|
|
from app.infra.config import get_settings
|
|
|
|
|
|
# ---- Password hashing ----
|
|
|
|
_pwd_context = CryptContext(
|
|
schemes=["argon2"],
|
|
deprecated="auto",
|
|
# You can tune params via passlib config if needed.
|
|
# Argon2 parameters (time_cost/memory_cost/parallelism) are handled by passlib's argon2 backend.
|
|
)
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
"""
|
|
Hash a password using Argon2id.
|
|
|
|
Args:
|
|
password: Plain text password.
|
|
|
|
Returns:
|
|
Encoded password hash.
|
|
"""
|
|
if not isinstance(password, str) or not password:
|
|
raise ValueError("password must be a non-empty string")
|
|
return _pwd_context.hash(password)
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""
|
|
Verify a password against its stored hash.
|
|
|
|
Args:
|
|
plain_password: Plain text password.
|
|
hashed_password: Stored password hash.
|
|
|
|
Returns:
|
|
True if matches, False otherwise.
|
|
"""
|
|
if not plain_password or not hashed_password:
|
|
return False
|
|
try:
|
|
return _pwd_context.verify(plain_password, hashed_password)
|
|
except Exception:
|
|
# Any parsing/format issues should not crash auth flow.
|
|
return False
|
|
|
|
|
|
# ---- JWT tokens ----
|
|
|
|
TokenType = str # "access" | "refresh"
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class TokenClaims:
|
|
"""
|
|
Standardized JWT claims we issue and validate.
|
|
"""
|
|
|
|
sub: str # subject (e.g., user_id)
|
|
typ: TokenType # "access" or "refresh"
|
|
exp: datetime
|
|
iat: datetime
|
|
jti: str
|
|
|
|
# Optional hardening
|
|
iss: Optional[str] = None
|
|
aud: Optional[str] = None
|
|
|
|
# Extra custom claims may be included in the encoded token, but are not represented here.
|
|
|
|
|
|
def _utcnow() -> datetime:
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
def _encode_jwt(claims: Mapping[str, Any]) -> str:
|
|
settings = get_settings()
|
|
return jwt.encode(
|
|
dict(claims),
|
|
settings.jwt_secret,
|
|
algorithm=settings.jwt_algorithm,
|
|
)
|
|
|
|
|
|
def _decode_jwt(token: str) -> Optional[dict[str, Any]]:
|
|
"""
|
|
Decode and verify JWT signature + exp using configured algorithm/secret.
|
|
Does NOT by itself enforce token type (access/refresh); use decode_* functions.
|
|
"""
|
|
if not token:
|
|
return None
|
|
|
|
settings = get_settings()
|
|
|
|
# Optional issuer/audience validation: only enforced if present in settings
|
|
options: dict[str, Any] = {
|
|
"verify_signature": True,
|
|
"verify_exp": True,
|
|
"verify_nbf": True,
|
|
"verify_iat": False, # iat isn't always required; we set it, but don't hard-fail on clients w/ clock drift
|
|
"require_exp": True,
|
|
}
|
|
|
|
kwargs: dict[str, Any] = {"algorithms": [settings.jwt_algorithm]}
|
|
|
|
# If your settings provide these, they will be enforced.
|
|
iss = getattr(settings, "jwt_issuer", None)
|
|
aud = getattr(settings, "jwt_audience", None)
|
|
if iss:
|
|
kwargs["issuer"] = iss
|
|
if aud:
|
|
kwargs["audience"] = aud
|
|
|
|
try:
|
|
payload = jwt.decode(token, settings.jwt_secret, options=options, **kwargs)
|
|
if not isinstance(payload, dict):
|
|
return None
|
|
return payload
|
|
except JWTError:
|
|
return None
|
|
|
|
|
|
def _build_claims(
|
|
*,
|
|
subject: str,
|
|
token_type: TokenType,
|
|
ttl: timedelta,
|
|
extra: Optional[Mapping[str, Any]] = None,
|
|
) -> dict[str, Any]:
|
|
if not subject:
|
|
raise ValueError("subject must be a non-empty string")
|
|
if token_type not in {"access", "refresh"}:
|
|
raise ValueError("token_type must be 'access' or 'refresh'")
|
|
|
|
now = _utcnow()
|
|
exp = now + ttl
|
|
|
|
settings = get_settings()
|
|
iss = getattr(settings, "jwt_issuer", None)
|
|
aud = getattr(settings, "jwt_audience", None)
|
|
|
|
claims: dict[str, Any] = {
|
|
"sub": subject,
|
|
"typ": token_type, # use "typ" (common) to avoid confusion with header "typ"
|
|
"iat": int(now.timestamp()),
|
|
"exp": int(exp.timestamp()),
|
|
"jti": uuid4().hex,
|
|
}
|
|
|
|
if iss:
|
|
claims["iss"] = iss
|
|
if aud:
|
|
claims["aud"] = aud
|
|
|
|
if extra:
|
|
# Prevent overriding reserved claims
|
|
reserved = {"sub", "typ", "iat", "exp", "jti", "iss", "aud", "nbf"}
|
|
for k, v in extra.items():
|
|
if k in reserved:
|
|
continue
|
|
claims[k] = v
|
|
|
|
return claims
|
|
|
|
|
|
def create_access_token(*, subject: str, extra: Optional[Mapping[str, Any]] = None) -> str:
|
|
"""
|
|
Create a signed JWT access token.
|
|
|
|
Args:
|
|
subject: User identifier (user_id) as string.
|
|
extra: Optional additional non-reserved claims (e.g., roles).
|
|
|
|
Returns:
|
|
JWT string.
|
|
"""
|
|
settings = get_settings()
|
|
ttl = timedelta(seconds=settings.jwt_access_ttl_seconds)
|
|
claims = _build_claims(subject=subject, token_type="access", ttl=ttl, extra=extra)
|
|
return _encode_jwt(claims)
|
|
|
|
|
|
def create_refresh_token(*, subject: str, extra: Optional[Mapping[str, Any]] = None) -> str:
|
|
"""
|
|
Create a signed JWT refresh token.
|
|
|
|
Args:
|
|
subject: User identifier (user_id) as string.
|
|
extra: Optional additional non-reserved claims.
|
|
|
|
Returns:
|
|
JWT string.
|
|
"""
|
|
settings = get_settings()
|
|
ttl = timedelta(seconds=settings.jwt_refresh_ttl_seconds)
|
|
claims = _build_claims(subject=subject, token_type="refresh", ttl=ttl, extra=extra)
|
|
return _encode_jwt(claims)
|
|
|
|
|
|
def decode_access_token(token: str) -> Optional[dict[str, Any]]:
|
|
"""
|
|
Decode and validate an access token (signature/exp + typ=access).
|
|
"""
|
|
payload = _decode_jwt(token)
|
|
if not payload:
|
|
return None
|
|
if payload.get("typ") != "access":
|
|
return None
|
|
if not payload.get("sub"):
|
|
return None
|
|
return payload
|
|
|
|
|
|
def decode_refresh_token(token: str) -> Optional[dict[str, Any]]:
|
|
"""
|
|
Decode and validate a refresh token (signature/exp + typ=refresh).
|
|
"""
|
|
payload = _decode_jwt(token)
|
|
if not payload:
|
|
return None
|
|
if payload.get("typ") != "refresh":
|
|
return None
|
|
if not payload.get("sub"):
|
|
return None
|
|
return payload
|
|
|
|
|
|
def get_subject(payload: Mapping[str, Any]) -> Optional[str]:
|
|
"""
|
|
Extract subject (user_id) from decoded payload.
|
|
"""
|
|
sub = payload.get("sub")
|
|
return str(sub) if sub else None
|