itcloud/backend/src/app/infra/security.py

256 lines
6.8 KiB
Python

"""
Security utilities for authentication and authorization.
Design goals:
- Strong password hashing with Argon2id (no bcrypt 72-byte issues, no pre-hash needed)
- Clear token types (access/refresh) with strict validation
- Optional issuer/audience support (safe defaults)
- Token IDs (jti) for future revocation support
- Settings fetched via dependency-friendly getter (no module-level hard binding)
"""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any, Mapping, Optional
from uuid import uuid4
from jose import JWTError, jwt
from passlib.context import CryptContext
from app.infra.config import get_settings
# ---- Password hashing ----
_pwd_context = CryptContext(
schemes=["argon2"],
deprecated="auto",
# You can tune params via passlib config if needed.
# Argon2 parameters (time_cost/memory_cost/parallelism) are handled by passlib's argon2 backend.
)
def hash_password(password: str) -> str:
"""
Hash a password using Argon2id.
Args:
password: Plain text password.
Returns:
Encoded password hash.
"""
if not isinstance(password, str) or not password:
raise ValueError("password must be a non-empty string")
return _pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verify a password against its stored hash.
Args:
plain_password: Plain text password.
hashed_password: Stored password hash.
Returns:
True if matches, False otherwise.
"""
if not plain_password or not hashed_password:
return False
try:
return _pwd_context.verify(plain_password, hashed_password)
except Exception:
# Any parsing/format issues should not crash auth flow.
return False
# ---- JWT tokens ----
TokenType = str # "access" | "refresh"
@dataclass(frozen=True, slots=True)
class TokenClaims:
"""
Standardized JWT claims we issue and validate.
"""
sub: str # subject (e.g., user_id)
typ: TokenType # "access" or "refresh"
exp: datetime
iat: datetime
jti: str
# Optional hardening
iss: Optional[str] = None
aud: Optional[str] = None
# Extra custom claims may be included in the encoded token, but are not represented here.
def _utcnow() -> datetime:
return datetime.now(timezone.utc)
def _encode_jwt(claims: Mapping[str, Any]) -> str:
settings = get_settings()
return jwt.encode(
dict(claims),
settings.jwt_secret,
algorithm=settings.jwt_algorithm,
)
def _decode_jwt(token: str) -> Optional[dict[str, Any]]:
"""
Decode and verify JWT signature + exp using configured algorithm/secret.
Does NOT by itself enforce token type (access/refresh); use decode_* functions.
"""
if not token:
return None
settings = get_settings()
# Optional issuer/audience validation: only enforced if present in settings
options: dict[str, Any] = {
"verify_signature": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iat": False, # iat isn't always required; we set it, but don't hard-fail on clients w/ clock drift
"require_exp": True,
}
kwargs: dict[str, Any] = {"algorithms": [settings.jwt_algorithm]}
# If your settings provide these, they will be enforced.
iss = getattr(settings, "jwt_issuer", None)
aud = getattr(settings, "jwt_audience", None)
if iss:
kwargs["issuer"] = iss
if aud:
kwargs["audience"] = aud
try:
payload = jwt.decode(token, settings.jwt_secret, options=options, **kwargs)
if not isinstance(payload, dict):
return None
return payload
except JWTError:
return None
def _build_claims(
*,
subject: str,
token_type: TokenType,
ttl: timedelta,
extra: Optional[Mapping[str, Any]] = None,
) -> dict[str, Any]:
if not subject:
raise ValueError("subject must be a non-empty string")
if token_type not in {"access", "refresh"}:
raise ValueError("token_type must be 'access' or 'refresh'")
now = _utcnow()
exp = now + ttl
settings = get_settings()
iss = getattr(settings, "jwt_issuer", None)
aud = getattr(settings, "jwt_audience", None)
claims: dict[str, Any] = {
"sub": subject,
"typ": token_type, # use "typ" (common) to avoid confusion with header "typ"
"iat": int(now.timestamp()),
"exp": int(exp.timestamp()),
"jti": uuid4().hex,
}
if iss:
claims["iss"] = iss
if aud:
claims["aud"] = aud
if extra:
# Prevent overriding reserved claims
reserved = {"sub", "typ", "iat", "exp", "jti", "iss", "aud", "nbf"}
for k, v in extra.items():
if k in reserved:
continue
claims[k] = v
return claims
def create_access_token(*, subject: str, extra: Optional[Mapping[str, Any]] = None) -> str:
"""
Create a signed JWT access token.
Args:
subject: User identifier (user_id) as string.
extra: Optional additional non-reserved claims (e.g., roles).
Returns:
JWT string.
"""
settings = get_settings()
ttl = timedelta(seconds=settings.jwt_access_ttl_seconds)
claims = _build_claims(subject=subject, token_type="access", ttl=ttl, extra=extra)
return _encode_jwt(claims)
def create_refresh_token(*, subject: str, extra: Optional[Mapping[str, Any]] = None) -> str:
"""
Create a signed JWT refresh token.
Args:
subject: User identifier (user_id) as string.
extra: Optional additional non-reserved claims.
Returns:
JWT string.
"""
settings = get_settings()
ttl = timedelta(seconds=settings.jwt_refresh_ttl_seconds)
claims = _build_claims(subject=subject, token_type="refresh", ttl=ttl, extra=extra)
return _encode_jwt(claims)
def decode_access_token(token: str) -> Optional[dict[str, Any]]:
"""
Decode and validate an access token (signature/exp + typ=access).
"""
payload = _decode_jwt(token)
if not payload:
return None
if payload.get("typ") != "access":
return None
if not payload.get("sub"):
return None
return payload
def decode_refresh_token(token: str) -> Optional[dict[str, Any]]:
"""
Decode and validate a refresh token (signature/exp + typ=refresh).
"""
payload = _decode_jwt(token)
if not payload:
return None
if payload.get("typ") != "refresh":
return None
if not payload.get("sub"):
return None
return payload
def get_subject(payload: Mapping[str, Any]) -> Optional[str]:
"""
Extract subject (user_id) from decoded payload.
"""
sub = payload.get("sub")
return str(sub) if sub else None