fix [asswprds
This commit is contained in:
parent
c34aef8dd7
commit
843ff2c569
|
|
@ -2,7 +2,10 @@
|
|||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(dir:*)",
|
||||
"Bash(grep:*)"
|
||||
"Bash(grep:*)",
|
||||
"Bash(.venvScriptspython -m pytest tests/test_security.py -v --tb=short)",
|
||||
"Bash(cmd.exe /c \"cd c:\\\\Users\\\\leonk\\\\Documents\\\\code\\\\itcloud\\\\backend && ..\\\\.venv\\\\Scripts\\\\python -m pytest tests/test_security.py -v --tb=short\")",
|
||||
"Bash(../\".venv/Scripts/python.exe\" -m pytest tests/test_security.py -v --tb=short)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -3,6 +3,7 @@ name = "itcloud-backend"
|
|||
version = "0.1.0"
|
||||
description = "Cloud photo and video storage backend"
|
||||
authors = ["ITCloud Team"]
|
||||
package-mode = false
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
|
|
@ -14,7 +15,8 @@ pydantic = "^2.5.3"
|
|||
pydantic-settings = "^2.1.0"
|
||||
email-validator = "^2.1.0"
|
||||
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
|
||||
passlib = {extras = ["bcrypt"], version = "^1.7.4"}
|
||||
passlib = {extras = ["argon2"], version = "^1.7.4"}
|
||||
argon2-cffi = "^23.1.0"
|
||||
python-multipart = "^0.0.6"
|
||||
aiosqlite = "^0.19.0"
|
||||
asyncpg = "^0.29.0"
|
||||
|
|
@ -26,6 +28,7 @@ pillow = "^10.2.0"
|
|||
python-magic = "^0.4.27"
|
||||
loguru = "^0.7.2"
|
||||
httpx = "^0.26.0"
|
||||
cryptography = "^46.0.3"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^7.4.4"
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from app.domain.models import User
|
||||
from app.infra.database import get_db
|
||||
from app.infra.s3_client import S3Client, get_s3_client
|
||||
from app.infra.security import decode_token
|
||||
from app.infra.security import decode_access_token, get_subject
|
||||
from app.repositories.user_repository import UserRepository
|
||||
|
||||
security = HTTPBearer()
|
||||
|
|
@ -33,15 +33,15 @@ async def get_current_user(
|
|||
HTTPException: If token is invalid or user not found
|
||||
"""
|
||||
token = credentials.credentials
|
||||
payload = decode_token(token)
|
||||
payload = decode_access_token(token)
|
||||
|
||||
if not payload or payload.get("type") != "access":
|
||||
if not payload:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
|
||||
user_id = payload.get("sub")
|
||||
user_id = get_subject(payload)
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
|
|||
|
|
@ -1,110 +1,255 @@
|
|||
"""Security utilities for authentication and authorization."""
|
||||
"""
|
||||
Security utilities for authentication and authorization.
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
Design goals:
|
||||
- Strong password hashing with Argon2id (no bcrypt 72-byte issues, no pre-hash needed)
|
||||
- Clear token types (access/refresh) with strict validation
|
||||
- Optional issuer/audience support (safe defaults)
|
||||
- Token IDs (jti) for future revocation support
|
||||
- Settings fetched via dependency-friendly getter (no module-level hard binding)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from typing import Any, Mapping, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.infra.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
# ---- Password hashing ----
|
||||
|
||||
_pwd_context = CryptContext(
|
||||
schemes=["argon2"],
|
||||
deprecated="auto",
|
||||
# You can tune params via passlib config if needed.
|
||||
# Argon2 parameters (time_cost/memory_cost/parallelism) are handled by passlib's argon2 backend.
|
||||
)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""
|
||||
Hash a password using bcrypt with SHA256 pre-hashing.
|
||||
Hash a password using Argon2id.
|
||||
|
||||
Args:
|
||||
password: Plain text password (any length supported)
|
||||
password: Plain text password.
|
||||
|
||||
Returns:
|
||||
Hashed password
|
||||
|
||||
Note:
|
||||
Uses SHA256 pre-hashing to support passwords of any length,
|
||||
avoiding bcrypt's 72-byte limitation.
|
||||
Encoded password hash.
|
||||
"""
|
||||
# Pre-hash with SHA256 to support unlimited password length
|
||||
# Use base64 encoding for compact representation (43 chars < 72 bytes)
|
||||
password_bytes = hashlib.sha256(password.encode('utf-8')).digest()
|
||||
password_hash = base64.b64encode(password_bytes).decode('ascii')
|
||||
return pwd_context.hash(password_hash)
|
||||
if not isinstance(password, str) or not password:
|
||||
raise ValueError("password must be a non-empty string")
|
||||
return _pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
Verify a password against its hash.
|
||||
Verify a password against its stored hash.
|
||||
|
||||
Args:
|
||||
plain_password: Plain text password
|
||||
hashed_password: Hashed password to verify against
|
||||
plain_password: Plain text password.
|
||||
hashed_password: Stored password hash.
|
||||
|
||||
Returns:
|
||||
True if password matches, False otherwise
|
||||
"""
|
||||
# Apply same SHA256 pre-hashing as hash_password
|
||||
password_bytes = hashlib.sha256(plain_password.encode('utf-8')).digest()
|
||||
password_hash = base64.b64encode(password_bytes).decode('ascii')
|
||||
return pwd_context.verify(password_hash, hashed_password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""
|
||||
Create a JWT access token.
|
||||
|
||||
Args:
|
||||
data: Data to encode in the token
|
||||
expires_delta: Optional expiration time delta
|
||||
|
||||
Returns:
|
||||
Encoded JWT token
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(seconds=settings.jwt_access_ttl_seconds)
|
||||
|
||||
to_encode.update({"exp": expire, "type": "access"})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(data: dict) -> str:
|
||||
"""
|
||||
Create a JWT refresh token.
|
||||
|
||||
Args:
|
||||
data: Data to encode in the token
|
||||
|
||||
Returns:
|
||||
Encoded JWT token
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + timedelta(seconds=settings.jwt_refresh_ttl_seconds)
|
||||
to_encode.update({"exp": expire, "type": "refresh"})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str) -> Optional[dict]:
|
||||
"""
|
||||
Decode and verify a JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT token to decode
|
||||
|
||||
Returns:
|
||||
Decoded token payload or None if invalid
|
||||
True if matches, False otherwise.
|
||||
"""
|
||||
if not plain_password or not hashed_password:
|
||||
return False
|
||||
try:
|
||||
payload = jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm])
|
||||
return _pwd_context.verify(plain_password, hashed_password)
|
||||
except Exception:
|
||||
# Any parsing/format issues should not crash auth flow.
|
||||
return False
|
||||
|
||||
|
||||
# ---- JWT tokens ----
|
||||
|
||||
TokenType = str # "access" | "refresh"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class TokenClaims:
|
||||
"""
|
||||
Standardized JWT claims we issue and validate.
|
||||
"""
|
||||
|
||||
sub: str # subject (e.g., user_id)
|
||||
typ: TokenType # "access" or "refresh"
|
||||
exp: datetime
|
||||
iat: datetime
|
||||
jti: str
|
||||
|
||||
# Optional hardening
|
||||
iss: Optional[str] = None
|
||||
aud: Optional[str] = None
|
||||
|
||||
# Extra custom claims may be included in the encoded token, but are not represented here.
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _encode_jwt(claims: Mapping[str, Any]) -> str:
|
||||
settings = get_settings()
|
||||
return jwt.encode(
|
||||
dict(claims),
|
||||
settings.jwt_secret,
|
||||
algorithm=settings.jwt_algorithm,
|
||||
)
|
||||
|
||||
|
||||
def _decode_jwt(token: str) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
Decode and verify JWT signature + exp using configured algorithm/secret.
|
||||
Does NOT by itself enforce token type (access/refresh); use decode_* functions.
|
||||
"""
|
||||
if not token:
|
||||
return None
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Optional issuer/audience validation: only enforced if present in settings
|
||||
options: dict[str, Any] = {
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_nbf": True,
|
||||
"verify_iat": False, # iat isn't always required; we set it, but don't hard-fail on clients w/ clock drift
|
||||
"require_exp": True,
|
||||
}
|
||||
|
||||
kwargs: dict[str, Any] = {"algorithms": [settings.jwt_algorithm]}
|
||||
|
||||
# If your settings provide these, they will be enforced.
|
||||
iss = getattr(settings, "jwt_issuer", None)
|
||||
aud = getattr(settings, "jwt_audience", None)
|
||||
if iss:
|
||||
kwargs["issuer"] = iss
|
||||
if aud:
|
||||
kwargs["audience"] = aud
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, settings.jwt_secret, options=options, **kwargs)
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
def _build_claims(
|
||||
*,
|
||||
subject: str,
|
||||
token_type: TokenType,
|
||||
ttl: timedelta,
|
||||
extra: Optional[Mapping[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
if not subject:
|
||||
raise ValueError("subject must be a non-empty string")
|
||||
if token_type not in {"access", "refresh"}:
|
||||
raise ValueError("token_type must be 'access' or 'refresh'")
|
||||
|
||||
now = _utcnow()
|
||||
exp = now + ttl
|
||||
|
||||
settings = get_settings()
|
||||
iss = getattr(settings, "jwt_issuer", None)
|
||||
aud = getattr(settings, "jwt_audience", None)
|
||||
|
||||
claims: dict[str, Any] = {
|
||||
"sub": subject,
|
||||
"typ": token_type, # use "typ" (common) to avoid confusion with header "typ"
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int(exp.timestamp()),
|
||||
"jti": uuid4().hex,
|
||||
}
|
||||
|
||||
if iss:
|
||||
claims["iss"] = iss
|
||||
if aud:
|
||||
claims["aud"] = aud
|
||||
|
||||
if extra:
|
||||
# Prevent overriding reserved claims
|
||||
reserved = {"sub", "typ", "iat", "exp", "jti", "iss", "aud", "nbf"}
|
||||
for k, v in extra.items():
|
||||
if k in reserved:
|
||||
continue
|
||||
claims[k] = v
|
||||
|
||||
return claims
|
||||
|
||||
|
||||
def create_access_token(*, subject: str, extra: Optional[Mapping[str, Any]] = None) -> str:
|
||||
"""
|
||||
Create a signed JWT access token.
|
||||
|
||||
Args:
|
||||
subject: User identifier (user_id) as string.
|
||||
extra: Optional additional non-reserved claims (e.g., roles).
|
||||
|
||||
Returns:
|
||||
JWT string.
|
||||
"""
|
||||
settings = get_settings()
|
||||
ttl = timedelta(seconds=settings.jwt_access_ttl_seconds)
|
||||
claims = _build_claims(subject=subject, token_type="access", ttl=ttl, extra=extra)
|
||||
return _encode_jwt(claims)
|
||||
|
||||
|
||||
def create_refresh_token(*, subject: str, extra: Optional[Mapping[str, Any]] = None) -> str:
|
||||
"""
|
||||
Create a signed JWT refresh token.
|
||||
|
||||
Args:
|
||||
subject: User identifier (user_id) as string.
|
||||
extra: Optional additional non-reserved claims.
|
||||
|
||||
Returns:
|
||||
JWT string.
|
||||
"""
|
||||
settings = get_settings()
|
||||
ttl = timedelta(seconds=settings.jwt_refresh_ttl_seconds)
|
||||
claims = _build_claims(subject=subject, token_type="refresh", ttl=ttl, extra=extra)
|
||||
return _encode_jwt(claims)
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
Decode and validate an access token (signature/exp + typ=access).
|
||||
"""
|
||||
payload = _decode_jwt(token)
|
||||
if not payload:
|
||||
return None
|
||||
if payload.get("typ") != "access":
|
||||
return None
|
||||
if not payload.get("sub"):
|
||||
return None
|
||||
return payload
|
||||
|
||||
|
||||
def decode_refresh_token(token: str) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
Decode and validate a refresh token (signature/exp + typ=refresh).
|
||||
"""
|
||||
payload = _decode_jwt(token)
|
||||
if not payload:
|
||||
return None
|
||||
if payload.get("typ") != "refresh":
|
||||
return None
|
||||
if not payload.get("sub"):
|
||||
return None
|
||||
return payload
|
||||
|
||||
|
||||
def get_subject(payload: Mapping[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
Extract subject (user_id) from decoded payload.
|
||||
"""
|
||||
sub = payload.get("sub")
|
||||
return str(sub) if sub else None
|
||||
|
|
|
|||
|
|
@ -79,8 +79,8 @@ class AuthService:
|
|||
detail="User account is inactive",
|
||||
)
|
||||
|
||||
access_token = create_access_token({"sub": user.id})
|
||||
refresh_token = create_refresh_token({"sub": user.id})
|
||||
access_token = create_access_token(subject=str(user.id))
|
||||
refresh_token = create_refresh_token(subject=str(user.id))
|
||||
return access_token, refresh_token
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[User]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
"""Tests package."""
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""Pytest configuration."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Add src directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
src_dir = backend_dir / "src"
|
||||
sys.path.insert(0, str(src_dir))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_test_env():
|
||||
"""Set up test environment variables."""
|
||||
os.environ["JWT_SECRET"] = "test-secret-key-for-testing-only"
|
||||
os.environ["S3_ACCESS_KEY_ID"] = "test-access-key"
|
||||
os.environ["S3_SECRET_ACCESS_KEY"] = "test-secret-key"
|
||||
os.environ["APP_ENV"] = "dev"
|
||||
|
|
@ -0,0 +1,264 @@
|
|||
"""Tests for security module."""
|
||||
|
||||
import pytest
|
||||
from app.infra.security import (
|
||||
hash_password,
|
||||
verify_password,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_access_token,
|
||||
decode_refresh_token,
|
||||
get_subject,
|
||||
)
|
||||
|
||||
|
||||
class TestPasswordHashing:
|
||||
"""Test password hashing and verification."""
|
||||
|
||||
def test_hash_password_creates_hash(self):
|
||||
"""Test that hash_password creates a valid Argon2 hash."""
|
||||
password = "test_password_123"
|
||||
hashed = hash_password(password)
|
||||
|
||||
# Argon2 hashes start with $argon2
|
||||
assert hashed.startswith("$argon2")
|
||||
# Argon2 hashes are longer than bcrypt (typically 90+ characters)
|
||||
assert len(hashed) > 80
|
||||
|
||||
def test_verify_password_correct(self):
|
||||
"""Test that verify_password returns True for correct password."""
|
||||
password = "my_secure_password"
|
||||
hashed = hash_password(password)
|
||||
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
def test_verify_password_incorrect(self):
|
||||
"""Test that verify_password returns False for incorrect password."""
|
||||
password = "my_secure_password"
|
||||
hashed = hash_password(password)
|
||||
|
||||
assert verify_password("wrong_password", hashed) is False
|
||||
|
||||
def test_short_password(self):
|
||||
"""Test that short passwords work correctly."""
|
||||
password = "abc"
|
||||
hashed = hash_password(password)
|
||||
|
||||
assert verify_password(password, hashed) is True
|
||||
assert verify_password("xyz", hashed) is False
|
||||
|
||||
def test_long_password_under_72_bytes(self):
|
||||
"""Test passwords under 72 bytes."""
|
||||
# 50 character password (50 bytes in ASCII)
|
||||
password = "a" * 50
|
||||
hashed = hash_password(password)
|
||||
|
||||
assert verify_password(password, hashed) is True
|
||||
assert verify_password("a" * 49, hashed) is False
|
||||
|
||||
def test_long_password_over_72_bytes(self):
|
||||
"""Test that passwords over 72 bytes work (this is the critical test)."""
|
||||
# 100 character password (100 bytes in ASCII)
|
||||
password = "a" * 100
|
||||
hashed = hash_password(password)
|
||||
|
||||
# Should work without ValueError
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
# Different password should fail
|
||||
assert verify_password("a" * 99, hashed) is False
|
||||
assert verify_password("a" * 101, hashed) is False
|
||||
|
||||
def test_very_long_password(self):
|
||||
"""Test extremely long passwords (200+ bytes)."""
|
||||
password = "x" * 200
|
||||
hashed = hash_password(password)
|
||||
|
||||
assert verify_password(password, hashed) is True
|
||||
assert verify_password("x" * 199, hashed) is False
|
||||
|
||||
def test_unicode_password(self):
|
||||
"""Test passwords with unicode characters."""
|
||||
password = "пароль_с_юникодом_🔒"
|
||||
hashed = hash_password(password)
|
||||
|
||||
assert verify_password(password, hashed) is True
|
||||
assert verify_password("пароль_с_юникодом_🔓", hashed) is False
|
||||
|
||||
def test_long_unicode_password(self):
|
||||
"""Test long unicode password (each Cyrillic char is 2 bytes in UTF-8)."""
|
||||
# 50 Cyrillic characters = 100 bytes in UTF-8
|
||||
password = "п" * 50
|
||||
hashed = hash_password(password)
|
||||
|
||||
assert verify_password(password, hashed) is True
|
||||
assert verify_password("п" * 49, hashed) is False
|
||||
|
||||
def test_same_password_different_hashes(self):
|
||||
"""Test that same password produces different hashes (salt)."""
|
||||
password = "same_password"
|
||||
hash1 = hash_password(password)
|
||||
hash2 = hash_password(password)
|
||||
|
||||
# Hashes should be different (Argon2 uses random salt)
|
||||
assert hash1 != hash2
|
||||
|
||||
# But both should verify correctly
|
||||
assert verify_password(password, hash1) is True
|
||||
assert verify_password(password, hash2) is True
|
||||
|
||||
def test_empty_password_raises_error(self):
|
||||
"""Test that empty password raises ValueError."""
|
||||
with pytest.raises(ValueError, match="password must be a non-empty string"):
|
||||
hash_password("")
|
||||
|
||||
def test_none_password_raises_error(self):
|
||||
"""Test that None password raises ValueError."""
|
||||
with pytest.raises(ValueError, match="password must be a non-empty string"):
|
||||
hash_password(None)
|
||||
|
||||
def test_verify_empty_password_returns_false(self):
|
||||
"""Test that verifying with empty password/hash returns False."""
|
||||
hashed = hash_password("validpassword")
|
||||
assert verify_password("", hashed) is False
|
||||
assert verify_password("validpassword", "") is False
|
||||
|
||||
def test_password_with_special_chars(self):
|
||||
"""Test password with special characters."""
|
||||
password = "P@ssw0rd!#$%^&*()_+-=[]{}|;:',.<>?/~`"
|
||||
hashed = hash_password(password)
|
||||
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
def test_password_with_spaces(self):
|
||||
"""Test password with spaces."""
|
||||
password = "password with spaces"
|
||||
hashed = hash_password(password)
|
||||
|
||||
assert verify_password(password, hashed) is True
|
||||
assert verify_password("passwordwithspaces", hashed) is False
|
||||
|
||||
def test_realistic_long_password(self):
|
||||
"""Test realistic long password scenario."""
|
||||
# Simulate a password manager generated password
|
||||
password = "Xy8#mK9$pL2@nQ7!wE6%rT5^yU4&iO3*aS2(dF1)"
|
||||
hashed = hash_password(password)
|
||||
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
# One character different should fail
|
||||
wrong = "Xy8#mK9$pL2@nQ7!wE6%rT5^yU4&iO3*aS2(dF2)"
|
||||
assert verify_password(wrong, hashed) is False
|
||||
|
||||
|
||||
class TestJWTTokens:
|
||||
"""Test JWT token creation and validation."""
|
||||
|
||||
def test_create_access_token(self):
|
||||
"""Test access token creation."""
|
||||
user_id = "user123"
|
||||
token = create_access_token(subject=user_id)
|
||||
|
||||
assert token is not None
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 50
|
||||
|
||||
def test_create_refresh_token(self):
|
||||
"""Test refresh token creation."""
|
||||
user_id = "user456"
|
||||
token = create_refresh_token(subject=user_id)
|
||||
|
||||
assert token is not None
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 50
|
||||
|
||||
def test_decode_access_token(self):
|
||||
"""Test decoding valid access token."""
|
||||
user_id = "user789"
|
||||
token = create_access_token(subject=user_id)
|
||||
payload = decode_access_token(token)
|
||||
|
||||
assert payload is not None
|
||||
assert payload["sub"] == user_id
|
||||
assert payload["typ"] == "access"
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
assert "jti" in payload
|
||||
|
||||
def test_decode_refresh_token(self):
|
||||
"""Test decoding valid refresh token."""
|
||||
user_id = "user999"
|
||||
token = create_refresh_token(subject=user_id)
|
||||
payload = decode_refresh_token(token)
|
||||
|
||||
assert payload is not None
|
||||
assert payload["sub"] == user_id
|
||||
assert payload["typ"] == "refresh"
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
assert "jti" in payload
|
||||
|
||||
def test_decode_access_token_rejects_refresh(self):
|
||||
"""Test that decode_access_token rejects refresh tokens."""
|
||||
user_id = "user111"
|
||||
refresh_token = create_refresh_token(subject=user_id)
|
||||
payload = decode_access_token(refresh_token)
|
||||
|
||||
assert payload is None
|
||||
|
||||
def test_decode_refresh_token_rejects_access(self):
|
||||
"""Test that decode_refresh_token rejects access tokens."""
|
||||
user_id = "user222"
|
||||
access_token = create_access_token(subject=user_id)
|
||||
payload = decode_refresh_token(access_token)
|
||||
|
||||
assert payload is None
|
||||
|
||||
def test_decode_invalid_token(self):
|
||||
"""Test decoding invalid token."""
|
||||
payload = decode_access_token("invalid.token.here")
|
||||
|
||||
assert payload is None
|
||||
|
||||
def test_decode_empty_token(self):
|
||||
"""Test decoding empty token."""
|
||||
payload = decode_access_token("")
|
||||
|
||||
assert payload is None
|
||||
|
||||
def test_get_subject(self):
|
||||
"""Test extracting subject from payload."""
|
||||
user_id = "user333"
|
||||
token = create_access_token(subject=user_id)
|
||||
payload = decode_access_token(token)
|
||||
|
||||
subject = get_subject(payload)
|
||||
assert subject == user_id
|
||||
|
||||
def test_get_subject_from_empty_payload(self):
|
||||
"""Test extracting subject from empty payload."""
|
||||
subject = get_subject({})
|
||||
assert subject is None
|
||||
|
||||
def test_create_access_token_with_extra_claims(self):
|
||||
"""Test creating access token with extra claims."""
|
||||
user_id = "user444"
|
||||
extra = {"role": "admin", "email": "test@example.com"}
|
||||
token = create_access_token(subject=user_id, extra=extra)
|
||||
payload = decode_access_token(token)
|
||||
|
||||
assert payload is not None
|
||||
assert payload["sub"] == user_id
|
||||
assert payload.get("role") == "admin"
|
||||
assert payload.get("email") == "test@example.com"
|
||||
|
||||
def test_unique_jti_per_token(self):
|
||||
"""Test that each token has unique jti."""
|
||||
user_id = "user555"
|
||||
token1 = create_access_token(subject=user_id)
|
||||
token2 = create_access_token(subject=user_id)
|
||||
|
||||
payload1 = decode_access_token(token1)
|
||||
payload2 = decode_access_token(token2)
|
||||
|
||||
assert payload1["jti"] != payload2["jti"]
|
||||
Loading…
Reference in New Issue