brief-rags-bench/tests/test_dependencies.py

94 lines
3.5 KiB
Python

"""Tests for FastAPI dependencies."""
import pytest
from unittest.mock import MagicMock, patch
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from app.dependencies import get_current_user, get_db_client
from app.interfaces.db_api_client import DBApiClient
class TestGetCurrentUser:
"""Tests for get_current_user dependency."""
@pytest.mark.asyncio
async def test_get_current_user_valid_token(self):
"""Test getting current user with valid token."""
# Mock valid token payload
valid_payload = {
"user_id": "user-123",
"login": "12345678",
"exp": 9999999999 # Far future
}
credentials = MagicMock(spec=HTTPAuthorizationCredentials)
credentials.credentials = "valid.token.here"
with patch('app.dependencies.decode_access_token', return_value=valid_payload):
user = await get_current_user(credentials)
assert user == valid_payload
assert user["user_id"] == "user-123"
assert user["login"] == "12345678"
@pytest.mark.asyncio
async def test_get_current_user_invalid_token(self):
"""Test getting current user with invalid token."""
credentials = MagicMock(spec=HTTPAuthorizationCredentials)
credentials.credentials = "invalid.token"
# Mock invalid token (returns None)
with patch('app.dependencies.decode_access_token', return_value=None):
with pytest.raises(HTTPException) as exc_info:
await get_current_user(credentials)
assert exc_info.value.status_code == 401
assert "invalid or expired" in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_get_current_user_expired_token(self):
"""Test getting current user with expired token."""
credentials = MagicMock(spec=HTTPAuthorizationCredentials)
credentials.credentials = "expired.token"
# Expired tokens return None from decode
with patch('app.dependencies.decode_access_token', return_value=None):
with pytest.raises(HTTPException) as exc_info:
await get_current_user(credentials)
assert exc_info.value.status_code == 401
assert exc_info.value.headers == {"WWW-Authenticate": "Bearer"}
@pytest.mark.asyncio
async def test_get_current_user_malformed_token(self):
"""Test getting current user with malformed token."""
credentials = MagicMock(spec=HTTPAuthorizationCredentials)
credentials.credentials = "not.a.jwt"
with patch('app.dependencies.decode_access_token', return_value=None):
with pytest.raises(HTTPException) as exc_info:
await get_current_user(credentials)
assert exc_info.value.status_code == 401
class TestGetDbClient:
"""Tests for get_db_client dependency."""
def test_get_db_client_returns_instance(self):
"""Test that get_db_client returns DBApiClient instance."""
client = get_db_client()
assert isinstance(client, DBApiClient)
def test_get_db_client_uses_settings(self):
"""Test that get_db_client uses settings for configuration."""
with patch('app.dependencies.settings') as mock_settings:
mock_settings.DB_API_URL = "http://test-api:9999/api/v1"
client = get_db_client()
# Check that client was created with correct URL
assert client.api_prefix == "http://test-api:9999/api/v1"