""" Security utilities for authentication and authorization. Design goals: - Strong password hashing with Argon2id (no bcrypt 72-byte issues, no pre-hash needed) - Clear token types (access/refresh) with strict validation - Optional issuer/audience support (safe defaults) - Token IDs (jti) for future revocation support - Settings fetched via dependency-friendly getter (no module-level hard binding) """ from __future__ import annotations from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Any, Mapping, Optional from uuid import uuid4 from jose import JWTError, jwt from passlib.context import CryptContext from app.infra.config import get_settings # ---- Password hashing ---- _pwd_context = CryptContext( schemes=["argon2"], deprecated="auto", # You can tune params via passlib config if needed. # Argon2 parameters (time_cost/memory_cost/parallelism) are handled by passlib's argon2 backend. ) def hash_password(password: str) -> str: """ Hash a password using Argon2id. Args: password: Plain text password. Returns: Encoded password hash. """ if not isinstance(password, str) or not password: raise ValueError("password must be a non-empty string") return _pwd_context.hash(password) def verify_password(plain_password: str, hashed_password: str) -> bool: """ Verify a password against its stored hash. Args: plain_password: Plain text password. hashed_password: Stored password hash. Returns: True if matches, False otherwise. """ if not plain_password or not hashed_password: return False try: return _pwd_context.verify(plain_password, hashed_password) except Exception: # Any parsing/format issues should not crash auth flow. return False # ---- JWT tokens ---- TokenType = str # "access" | "refresh" @dataclass(frozen=True, slots=True) class TokenClaims: """ Standardized JWT claims we issue and validate. """ sub: str # subject (e.g., user_id) typ: TokenType # "access" or "refresh" exp: datetime iat: datetime jti: str # Optional hardening iss: Optional[str] = None aud: Optional[str] = None # Extra custom claims may be included in the encoded token, but are not represented here. def _utcnow() -> datetime: return datetime.now(timezone.utc) def _encode_jwt(claims: Mapping[str, Any]) -> str: settings = get_settings() return jwt.encode( dict(claims), settings.jwt_secret, algorithm=settings.jwt_algorithm, ) def _decode_jwt(token: str) -> Optional[dict[str, Any]]: """ Decode and verify JWT signature + exp using configured algorithm/secret. Does NOT by itself enforce token type (access/refresh); use decode_* functions. """ if not token: return None settings = get_settings() # Optional issuer/audience validation: only enforced if present in settings options: dict[str, Any] = { "verify_signature": True, "verify_exp": True, "verify_nbf": True, "verify_iat": False, # iat isn't always required; we set it, but don't hard-fail on clients w/ clock drift "require_exp": True, } kwargs: dict[str, Any] = {"algorithms": [settings.jwt_algorithm]} # If your settings provide these, they will be enforced. iss = getattr(settings, "jwt_issuer", None) aud = getattr(settings, "jwt_audience", None) if iss: kwargs["issuer"] = iss if aud: kwargs["audience"] = aud try: payload = jwt.decode(token, settings.jwt_secret, options=options, **kwargs) if not isinstance(payload, dict): return None return payload except JWTError: return None def _build_claims( *, subject: str, token_type: TokenType, ttl: timedelta, extra: Optional[Mapping[str, Any]] = None, ) -> dict[str, Any]: if not subject: raise ValueError("subject must be a non-empty string") if token_type not in {"access", "refresh"}: raise ValueError("token_type must be 'access' or 'refresh'") now = _utcnow() exp = now + ttl settings = get_settings() iss = getattr(settings, "jwt_issuer", None) aud = getattr(settings, "jwt_audience", None) claims: dict[str, Any] = { "sub": subject, "typ": token_type, # use "typ" (common) to avoid confusion with header "typ" "iat": int(now.timestamp()), "exp": int(exp.timestamp()), "jti": uuid4().hex, } if iss: claims["iss"] = iss if aud: claims["aud"] = aud if extra: # Prevent overriding reserved claims reserved = {"sub", "typ", "iat", "exp", "jti", "iss", "aud", "nbf"} for k, v in extra.items(): if k in reserved: continue claims[k] = v return claims def create_access_token(*, subject: str, extra: Optional[Mapping[str, Any]] = None) -> str: """ Create a signed JWT access token. Args: subject: User identifier (user_id) as string. extra: Optional additional non-reserved claims (e.g., roles). Returns: JWT string. """ settings = get_settings() ttl = timedelta(seconds=settings.jwt_access_ttl_seconds) claims = _build_claims(subject=subject, token_type="access", ttl=ttl, extra=extra) return _encode_jwt(claims) def create_refresh_token(*, subject: str, extra: Optional[Mapping[str, Any]] = None) -> str: """ Create a signed JWT refresh token. Args: subject: User identifier (user_id) as string. extra: Optional additional non-reserved claims. Returns: JWT string. """ settings = get_settings() ttl = timedelta(seconds=settings.jwt_refresh_ttl_seconds) claims = _build_claims(subject=subject, token_type="refresh", ttl=ttl, extra=extra) return _encode_jwt(claims) def decode_access_token(token: str) -> Optional[dict[str, Any]]: """ Decode and validate an access token (signature/exp + typ=access). """ payload = _decode_jwt(token) if not payload: return None if payload.get("typ") != "access": return None if not payload.get("sub"): return None return payload def decode_refresh_token(token: str) -> Optional[dict[str, Any]]: """ Decode and validate a refresh token (signature/exp + typ=refresh). """ payload = _decode_jwt(token) if not payload: return None if payload.get("typ") != "refresh": return None if not payload.get("sub"): return None return payload def get_subject(payload: Mapping[str, Any]) -> Optional[str]: """ Extract subject (user_id) from decoded payload. """ sub = payload.get("sub") return str(sub) if sub else None