add unit tests
This commit is contained in:
parent
4cd311a59d
commit
a7d2907e34
|
|
@ -0,0 +1,333 @@
|
|||
"""Tests for analysis sessions endpoints."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
import httpx
|
||||
from app.models.analysis import SessionResponse, SessionList, SessionCreate
|
||||
|
||||
|
||||
class TestAnalysisEndpoints:
|
||||
"""Tests for /api/v1/analysis/sessions endpoints."""
|
||||
|
||||
def test_create_session_success(self, client, mock_db_client):
|
||||
"""Test creating a new analysis session."""
|
||||
mock_session = SessionResponse(
|
||||
session_id="session-123",
|
||||
user_id="test-user-123",
|
||||
environment="ift",
|
||||
api_mode="bench",
|
||||
request=[],
|
||||
response={},
|
||||
annotations={},
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
updated_at="2024-01-01T00:00:00Z"
|
||||
)
|
||||
mock_db_client.save_session = AsyncMock(return_value=mock_session)
|
||||
|
||||
session_data = {
|
||||
"environment": "ift",
|
||||
"api_mode": "bench",
|
||||
"request": [],
|
||||
"response": {"answers": []},
|
||||
"annotations": {}
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/analysis/sessions", json=session_data)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
|
||||
assert data["session_id"] == "session-123"
|
||||
assert data["environment"] == "ift"
|
||||
assert data["api_mode"] == "bench"
|
||||
|
||||
mock_db_client.save_session.assert_called_once()
|
||||
|
||||
def test_create_session_invalid_data(self, client, mock_db_client):
|
||||
"""Test creating session with invalid data."""
|
||||
error_response = httpx.Response(400, json={"detail": "Invalid format"})
|
||||
mock_db_client.save_session = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Bad request", request=None, response=error_response)
|
||||
)
|
||||
|
||||
session_data = {
|
||||
"environment": "invalid",
|
||||
"api_mode": "bench",
|
||||
"request": [],
|
||||
"response": {},
|
||||
"annotations": {}
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/analysis/sessions", json=session_data)
|
||||
|
||||
assert response.status_code in [400, 422] # 422 for validation error
|
||||
|
||||
def test_get_sessions_success(self, client, mock_db_client):
|
||||
"""Test getting list of sessions."""
|
||||
from app.models.analysis import SessionListItem
|
||||
mock_sessions = SessionList(
|
||||
sessions=[
|
||||
SessionListItem(
|
||||
session_id="session-1",
|
||||
environment="ift",
|
||||
created_at="2024-01-01T00:00:00Z"
|
||||
),
|
||||
SessionListItem(
|
||||
session_id="session-2",
|
||||
environment="psi",
|
||||
created_at="2024-01-02T00:00:00Z"
|
||||
)
|
||||
],
|
||||
total=2
|
||||
)
|
||||
mock_db_client.get_sessions = AsyncMock(return_value=mock_sessions)
|
||||
|
||||
response = client.get("/api/v1/analysis/sessions?limit=50&offset=0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["total"] == 2
|
||||
assert len(data["sessions"]) == 2
|
||||
assert data["sessions"][0]["session_id"] == "session-1"
|
||||
assert data["sessions"][1]["session_id"] == "session-2"
|
||||
|
||||
mock_db_client.get_sessions.assert_called_once_with(
|
||||
"test-user-123", None, 50, 0
|
||||
)
|
||||
|
||||
def test_get_sessions_with_filter(self, client, mock_db_client):
|
||||
"""Test getting sessions with environment filter."""
|
||||
mock_sessions = SessionList(sessions=[], total=0)
|
||||
mock_db_client.get_sessions = AsyncMock(return_value=mock_sessions)
|
||||
|
||||
response = client.get("/api/v1/analysis/sessions?environment=ift&limit=10&offset=5")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
mock_db_client.get_sessions.assert_called_once_with(
|
||||
"test-user-123", "ift", 10, 5
|
||||
)
|
||||
|
||||
def test_get_sessions_pagination(self, client, mock_db_client):
|
||||
"""Test sessions pagination limits."""
|
||||
mock_sessions = SessionList(sessions=[], total=0)
|
||||
mock_db_client.get_sessions = AsyncMock(return_value=mock_sessions)
|
||||
|
||||
# Test default values
|
||||
response = client.get("/api/v1/analysis/sessions")
|
||||
assert response.status_code == 200
|
||||
mock_db_client.get_sessions.assert_called_with(
|
||||
"test-user-123", None, 50, 0
|
||||
)
|
||||
|
||||
# Test max limit (200)
|
||||
response = client.get("/api/v1/analysis/sessions?limit=250")
|
||||
assert response.status_code == 422 # Validation error, exceeds max
|
||||
|
||||
def test_get_session_by_id_success(self, client, mock_db_client):
|
||||
"""Test getting specific session by ID."""
|
||||
mock_session = SessionResponse(
|
||||
session_id="session-123",
|
||||
user_id="test-user-123",
|
||||
environment="ift",
|
||||
api_mode="bench",
|
||||
request=[{"body": "Q1", "with_docs": True}],
|
||||
response={"answers": ["A1"]},
|
||||
annotations={"note": "test"},
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
updated_at="2024-01-01T00:00:00Z"
|
||||
)
|
||||
mock_db_client.get_session = AsyncMock(return_value=mock_session)
|
||||
|
||||
response = client.get("/api/v1/analysis/sessions/session-123")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["session_id"] == "session-123"
|
||||
assert data["annotations"]["note"] == "test"
|
||||
|
||||
mock_db_client.get_session.assert_called_once_with("test-user-123", "session-123")
|
||||
|
||||
def test_get_session_not_found(self, client, mock_db_client):
|
||||
"""Test getting non-existent session."""
|
||||
error_response = httpx.Response(404, json={"detail": "Not found"})
|
||||
mock_db_client.get_session = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Not found", request=None, response=error_response)
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/analysis/sessions/nonexistent")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_delete_session_success(self, client, mock_db_client):
|
||||
"""Test deleting a session."""
|
||||
mock_db_client.delete_session = AsyncMock(return_value=None)
|
||||
|
||||
response = client.delete("/api/v1/analysis/sessions/session-123")
|
||||
|
||||
assert response.status_code == 204
|
||||
assert response.content == b""
|
||||
|
||||
mock_db_client.delete_session.assert_called_once_with("test-user-123", "session-123")
|
||||
|
||||
def test_delete_session_not_found(self, client, mock_db_client):
|
||||
"""Test deleting non-existent session."""
|
||||
error_response = httpx.Response(404, json={"detail": "Not found"})
|
||||
mock_db_client.delete_session = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Not found", request=None, response=error_response)
|
||||
)
|
||||
|
||||
response = client.delete("/api/v1/analysis/sessions/nonexistent")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_analysis_endpoints_require_auth(self, unauthenticated_client):
|
||||
"""Test that all analysis endpoints require authentication."""
|
||||
# POST /sessions
|
||||
response = unauthenticated_client.post("/api/v1/analysis/sessions", json={})
|
||||
assert response.status_code == 401 # HTTPBearer returns 401
|
||||
|
||||
# GET /sessions
|
||||
response = unauthenticated_client.get("/api/v1/analysis/sessions")
|
||||
assert response.status_code == 401
|
||||
|
||||
# GET /sessions/{id}
|
||||
response = unauthenticated_client.get("/api/v1/analysis/sessions/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
# DELETE /sessions/{id}
|
||||
response = unauthenticated_client.delete("/api/v1/analysis/sessions/test")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_create_session_user_not_found(self, client, mock_db_client):
|
||||
"""Test creating session when user not found in DB API."""
|
||||
error_response = httpx.Response(404, json={"detail": "User not found"})
|
||||
mock_db_client.save_session = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Not found", request=None, response=error_response)
|
||||
)
|
||||
|
||||
session_data = {
|
||||
"environment": "ift",
|
||||
"api_mode": "bench",
|
||||
"request": [],
|
||||
"response": {},
|
||||
"annotations": {}
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/analysis/sessions", json=session_data)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "user not found" in response.json()["detail"].lower()
|
||||
|
||||
def test_create_session_db_api_error(self, client, mock_db_client):
|
||||
"""Test creating session when DB API returns 502."""
|
||||
error_response = httpx.Response(503, json={"detail": "Service unavailable"})
|
||||
mock_db_client.save_session = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Service error", request=None, response=error_response)
|
||||
)
|
||||
|
||||
session_data = {
|
||||
"environment": "ift",
|
||||
"api_mode": "bench",
|
||||
"request": [],
|
||||
"response": {},
|
||||
"annotations": {}
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/analysis/sessions", json=session_data)
|
||||
|
||||
assert response.status_code == 502
|
||||
|
||||
def test_create_session_unexpected_error(self, client, mock_db_client):
|
||||
"""Test creating session with unexpected error."""
|
||||
mock_db_client.save_session = AsyncMock(
|
||||
side_effect=Exception("Unexpected database error")
|
||||
)
|
||||
|
||||
session_data = {
|
||||
"environment": "ift",
|
||||
"api_mode": "bench",
|
||||
"request": [],
|
||||
"response": {},
|
||||
"annotations": {}
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/analysis/sessions", json=session_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
def test_get_sessions_user_not_found(self, client, mock_db_client):
|
||||
"""Test getting sessions when user not found."""
|
||||
error_response = httpx.Response(404, json={"detail": "User not found"})
|
||||
mock_db_client.get_sessions = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Not found", request=None, response=error_response)
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/analysis/sessions")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_sessions_db_api_error(self, client, mock_db_client):
|
||||
"""Test getting sessions when DB API fails."""
|
||||
error_response = httpx.Response(503, json={"detail": "Service error"})
|
||||
mock_db_client.get_sessions = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Service error", request=None, response=error_response)
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/analysis/sessions")
|
||||
|
||||
assert response.status_code == 502
|
||||
|
||||
def test_get_sessions_unexpected_error(self, client, mock_db_client):
|
||||
"""Test getting sessions with unexpected error."""
|
||||
mock_db_client.get_sessions = AsyncMock(
|
||||
side_effect=Exception("Database connection lost")
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/analysis/sessions")
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
def test_get_session_by_id_db_api_error(self, client, mock_db_client):
|
||||
"""Test getting session when DB API returns 502."""
|
||||
error_response = httpx.Response(500, json={"error": "Server error"})
|
||||
mock_db_client.get_session = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Server error", request=None, response=error_response)
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/analysis/sessions/session-123")
|
||||
|
||||
assert response.status_code == 502
|
||||
assert "failed to retrieve session" in response.json()["detail"].lower()
|
||||
|
||||
def test_get_session_by_id_unexpected_error(self, client, mock_db_client):
|
||||
"""Test getting session with unexpected error."""
|
||||
mock_db_client.get_session = AsyncMock(side_effect=Exception("Database crash"))
|
||||
|
||||
response = client.get("/api/v1/analysis/sessions/session-123")
|
||||
|
||||
assert response.status_code == 500
|
||||
assert "internal server error" in response.json()["detail"].lower()
|
||||
|
||||
def test_delete_session_db_api_error(self, client, mock_db_client):
|
||||
"""Test deleting session when DB API returns 502."""
|
||||
error_response = httpx.Response(500, json={"error": "Server error"})
|
||||
mock_db_client.delete_session = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Server error", request=None, response=error_response)
|
||||
)
|
||||
|
||||
response = client.delete("/api/v1/analysis/sessions/session-123")
|
||||
|
||||
assert response.status_code == 502
|
||||
assert "failed to delete session" in response.json()["detail"].lower()
|
||||
|
||||
def test_delete_session_unexpected_error(self, client, mock_db_client):
|
||||
"""Test deleting session with unexpected error."""
|
||||
mock_db_client.delete_session = AsyncMock(side_effect=Exception("Database crash"))
|
||||
|
||||
response = client.delete("/api/v1/analysis/sessions/session-123")
|
||||
|
||||
assert response.status_code == 500
|
||||
assert "internal server error" in response.json()["detail"].lower()
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
"""Tests for authentication endpoints and service."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from app.services.auth_service import AuthService
|
||||
from app.models.auth import LoginRequest, UserResponse
|
||||
|
||||
|
||||
class TestAuthEndpoints:
|
||||
"""Tests for /api/v1/auth endpoints."""
|
||||
|
||||
def test_login_success(self, unauthenticated_client, mock_db_client, test_user_response):
|
||||
"""Test successful login with valid 8-digit login."""
|
||||
# Mock DB client response
|
||||
mock_db_client.login_user = AsyncMock(return_value=test_user_response)
|
||||
|
||||
# Override dependency
|
||||
from app.main import app
|
||||
from app.dependencies import get_db_client
|
||||
app.dependency_overrides[get_db_client] = lambda: mock_db_client
|
||||
|
||||
try:
|
||||
response = unauthenticated_client.post("/api/v1/auth/login?login=12345678")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "access_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
assert "user" in data
|
||||
assert data["user"]["login"] == "12345678"
|
||||
|
||||
# Verify DB client was called
|
||||
mock_db_client.login_user.assert_called_once()
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
def test_login_invalid_format(self, unauthenticated_client):
|
||||
"""Test login with invalid format (not 8 digits)."""
|
||||
# Test with 7 digits
|
||||
response = unauthenticated_client.post("/api/v1/auth/login?login=1234567")
|
||||
assert response.status_code == 400
|
||||
assert "must be 8 digits" in response.json()["detail"].lower()
|
||||
|
||||
# Test with 9 digits
|
||||
response = unauthenticated_client.post("/api/v1/auth/login?login=123456789")
|
||||
assert response.status_code == 400
|
||||
|
||||
# Test with letters
|
||||
response = unauthenticated_client.post("/api/v1/auth/login?login=abcd1234")
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_login_db_api_error(self, unauthenticated_client, mock_db_client):
|
||||
"""Test login when DB API fails."""
|
||||
# Mock DB client to raise exception
|
||||
mock_db_client.login_user = AsyncMock(side_effect=Exception("DB API unavailable"))
|
||||
|
||||
from app.main import app
|
||||
from app.dependencies import get_db_client
|
||||
app.dependency_overrides[get_db_client] = lambda: mock_db_client
|
||||
|
||||
try:
|
||||
response = unauthenticated_client.post("/api/v1/auth/login?login=12345678")
|
||||
|
||||
assert response.status_code == 500
|
||||
assert "failed" in response.json()["detail"].lower()
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class TestAuthService:
|
||||
"""Tests for AuthService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_success(self, mock_db_client, test_user_response):
|
||||
"""Test successful login via AuthService."""
|
||||
mock_db_client.login_user = AsyncMock(return_value=test_user_response)
|
||||
auth_service = AuthService(mock_db_client)
|
||||
|
||||
result = await auth_service.login("12345678", "192.168.1.1")
|
||||
|
||||
assert result.access_token is not None
|
||||
assert result.token_type == "bearer"
|
||||
assert result.user.login == "12345678"
|
||||
assert result.user.user_id == "test-user-123"
|
||||
|
||||
# Verify DB client was called with correct params
|
||||
call_args = mock_db_client.login_user.call_args[0][0]
|
||||
assert call_args.login == "12345678"
|
||||
assert call_args.client_ip == "192.168.1.1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_invalid_format(self, mock_db_client):
|
||||
"""Test login with invalid format raises ValueError."""
|
||||
auth_service = AuthService(mock_db_client)
|
||||
|
||||
with pytest.raises(ValueError, match="8 digits"):
|
||||
await auth_service.login("1234567", "192.168.1.1")
|
||||
|
||||
with pytest.raises(ValueError, match="8 digits"):
|
||||
await auth_service.login("abcd1234", "192.168.1.1")
|
||||
|
||||
# Verify DB client was never called
|
||||
mock_db_client.login_user.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_db_api_failure(self, mock_db_client):
|
||||
"""Test login when DB API fails."""
|
||||
mock_db_client.login_user = AsyncMock(side_effect=Exception("DB error"))
|
||||
auth_service = AuthService(mock_db_client)
|
||||
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
await auth_service.login("12345678", "192.168.1.1")
|
||||
|
|
@ -0,0 +1,379 @@
|
|||
"""Tests for TgBackendInterface base class."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import httpx
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from app.interfaces.base import TgBackendInterface
|
||||
|
||||
|
||||
class TestModel(BaseModel):
|
||||
"""Test Pydantic model for testing."""
|
||||
name: str
|
||||
value: int
|
||||
|
||||
|
||||
class TestTgBackendInterface:
|
||||
"""Tests for TgBackendInterface base class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init(self):
|
||||
"""Test initialization with default parameters."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient') as MockClient:
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com/v1")
|
||||
|
||||
assert interface.api_prefix == "http://api.example.com/v1"
|
||||
MockClient.assert_called_once()
|
||||
|
||||
# Verify timeout and retries configured
|
||||
call_kwargs = MockClient.call_args[1]
|
||||
assert call_kwargs['follow_redirects'] is True
|
||||
assert isinstance(call_kwargs['timeout'], httpx.Timeout)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_strips_trailing_slash(self):
|
||||
"""Test that trailing slash is stripped from api_prefix."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com/v1/")
|
||||
assert interface.api_prefix == "http://api.example.com/v1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_custom_params(self):
|
||||
"""Test initialization with custom timeout and retries."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient') as MockClient:
|
||||
interface = TgBackendInterface(
|
||||
api_prefix="http://api.example.com",
|
||||
timeout=60.0,
|
||||
max_retries=5
|
||||
)
|
||||
|
||||
call_kwargs = MockClient.call_args[1]
|
||||
# Timeout object is created, just verify it exists
|
||||
assert isinstance(call_kwargs['timeout'], httpx.Timeout)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(self):
|
||||
"""Test closing the HTTP client."""
|
||||
mock_client = AsyncMock()
|
||||
with patch('app.interfaces.base.httpx.AsyncClient', return_value=mock_client):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
await interface.close()
|
||||
|
||||
mock_client.aclose.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager(self):
|
||||
"""Test using interface as async context manager."""
|
||||
mock_client = AsyncMock()
|
||||
with patch('app.interfaces.base.httpx.AsyncClient', return_value=mock_client):
|
||||
async with TgBackendInterface(api_prefix="http://api.example.com") as interface:
|
||||
assert interface is not None
|
||||
|
||||
# Should close on exit
|
||||
mock_client.aclose.assert_called_once()
|
||||
|
||||
def test_build_url_with_leading_slash(self):
|
||||
"""Test building URL with path that has leading slash."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com/v1")
|
||||
url = interface._build_url("/users/123")
|
||||
|
||||
assert url == "http://api.example.com/v1/users/123"
|
||||
|
||||
def test_build_url_without_leading_slash(self):
|
||||
"""Test building URL with path without leading slash."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com/v1")
|
||||
url = interface._build_url("users/123")
|
||||
|
||||
assert url == "http://api.example.com/v1/users/123"
|
||||
|
||||
def test_serialize_body_with_model(self):
|
||||
"""Test serializing Pydantic model to dict."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
model = TestModel(name="test", value=42)
|
||||
|
||||
result = interface._serialize_body(model)
|
||||
|
||||
assert result == {"name": "test", "value": 42}
|
||||
|
||||
def test_serialize_body_with_none(self):
|
||||
"""Test serializing None body."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
|
||||
result = interface._serialize_body(None)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_deserialize_response_with_dict(self):
|
||||
"""Test deserializing dict response to Pydantic model."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
data = {"name": "test", "value": 42}
|
||||
|
||||
result = interface._deserialize_response(data, TestModel)
|
||||
|
||||
assert isinstance(result, TestModel)
|
||||
assert result.name == "test"
|
||||
assert result.value == 42
|
||||
|
||||
def test_deserialize_response_no_model(self):
|
||||
"""Test deserializing response without model returns raw data."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
data = {"name": "test", "value": 42}
|
||||
|
||||
result = interface._deserialize_response(data, None)
|
||||
|
||||
assert result == data
|
||||
|
||||
def test_deserialize_response_validation_error(self):
|
||||
"""Test deserialization with validation error."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
# Invalid data: missing 'value' field
|
||||
data = {"name": "test"}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
interface._deserialize_response(data, TestModel)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_response_success(self):
|
||||
"""Test handling successful HTTP response."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b'{"name": "test", "value": 42}' # Non-empty content
|
||||
mock_response.json.return_value = {"name": "test", "value": 42}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
result = await interface._handle_response(mock_response, TestModel)
|
||||
|
||||
assert isinstance(result, TestModel)
|
||||
assert result.name == "test"
|
||||
mock_response.raise_for_status.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_response_204_no_content(self):
|
||||
"""Test handling 204 No Content response."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 204
|
||||
mock_response.content = b''
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
result = await interface._handle_response(mock_response)
|
||||
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_response_empty_content(self):
|
||||
"""Test handling response with empty content."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b''
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
result = await interface._handle_response(mock_response)
|
||||
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_response_http_error(self):
|
||||
"""Test handling HTTP error response."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
mock_response.text = "Not Found"
|
||||
|
||||
error = httpx.HTTPStatusError(
|
||||
"Not Found",
|
||||
request=MagicMock(),
|
||||
response=mock_response
|
||||
)
|
||||
mock_response.raise_for_status = MagicMock(side_effect=error)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await interface._handle_response(mock_response)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_response_invalid_json(self):
|
||||
"""Test handling response with invalid JSON."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b'not empty'
|
||||
mock_response.text = "Invalid JSON"
|
||||
mock_response.json.side_effect = ValueError("Invalid JSON")
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await interface._handle_response(mock_response)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_success(self):
|
||||
"""Test successful GET request."""
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b'{"name": "test", "value": 42}' # Non-empty content
|
||||
mock_response.json.return_value = {"name": "test", "value": 42}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('app.interfaces.base.httpx.AsyncClient', return_value=mock_client):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
|
||||
result = await interface.get("/users", params={"id": 123}, response_model=TestModel)
|
||||
|
||||
assert isinstance(result, TestModel)
|
||||
assert result.name == "test"
|
||||
mock_client.get.assert_called_once()
|
||||
call_args = mock_client.get.call_args
|
||||
assert call_args[0][0] == "http://api.example.com/users"
|
||||
assert call_args[1]['params'] == {"id": 123}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_success(self):
|
||||
"""Test successful POST request."""
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.content = b'{"name": "created", "value": 100}' # Non-empty content
|
||||
mock_response.json.return_value = {"name": "created", "value": 100}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('app.interfaces.base.httpx.AsyncClient', return_value=mock_client):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
body = TestModel(name="new", value=50)
|
||||
|
||||
result = await interface.post("/users", body=body, response_model=TestModel)
|
||||
|
||||
assert isinstance(result, TestModel)
|
||||
assert result.name == "created"
|
||||
mock_client.post.assert_called_once()
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[0][0] == "http://api.example.com/users"
|
||||
assert call_args[1]['json'] == {"name": "new", "value": 50}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_without_body(self):
|
||||
"""Test POST request without body."""
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b'{"result": "ok"}' # Non-empty content
|
||||
mock_response.json.return_value = {"result": "ok"}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('app.interfaces.base.httpx.AsyncClient', return_value=mock_client):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
|
||||
result = await interface.post("/action")
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]['json'] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_put_success(self):
|
||||
"""Test successful PUT request."""
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b'{"name": "updated", "value": 75}' # Non-empty content
|
||||
mock_response.json.return_value = {"name": "updated", "value": 75}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.put.return_value = mock_response
|
||||
|
||||
with patch('app.interfaces.base.httpx.AsyncClient', return_value=mock_client):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
body = TestModel(name="updated", value=75)
|
||||
|
||||
result = await interface.put("/users/1", body=body, response_model=TestModel)
|
||||
|
||||
assert isinstance(result, TestModel)
|
||||
assert result.name == "updated"
|
||||
mock_client.put.assert_called_once()
|
||||
call_args = mock_client.put.call_args
|
||||
assert call_args[0][0] == "http://api.example.com/users/1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_success(self):
|
||||
"""Test successful DELETE request."""
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 204
|
||||
mock_response.content = b''
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_client.delete.return_value = mock_response
|
||||
|
||||
with patch('app.interfaces.base.httpx.AsyncClient', return_value=mock_client):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
|
||||
result = await interface.delete("/users/1")
|
||||
|
||||
assert result == {}
|
||||
mock_client.delete.assert_called_once()
|
||||
call_args = mock_client.delete.call_args
|
||||
assert call_args[0][0] == "http://api.example.com/users/1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_http_error(self):
|
||||
"""Test GET request with HTTP error."""
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal Server Error"
|
||||
|
||||
error = httpx.HTTPStatusError(
|
||||
"Internal Server Error",
|
||||
request=MagicMock(),
|
||||
response=mock_response
|
||||
)
|
||||
mock_response.raise_for_status = MagicMock(side_effect=error)
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('app.interfaces.base.httpx.AsyncClient', return_value=mock_client):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await interface.get("/users")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_http_error(self):
|
||||
"""Test POST request with HTTP error."""
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.text = "Bad Request"
|
||||
|
||||
error = httpx.HTTPStatusError(
|
||||
"Bad Request",
|
||||
request=MagicMock(),
|
||||
response=mock_response
|
||||
)
|
||||
mock_response.raise_for_status = MagicMock(side_effect=error)
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch('app.interfaces.base.httpx.AsyncClient', return_value=mock_client):
|
||||
interface = TgBackendInterface(api_prefix="http://api.example.com")
|
||||
body = TestModel(name="test", value=1)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await interface.post("/users", body=body)
|
||||
|
|
@ -0,0 +1,227 @@
|
|||
"""Tests for DBApiClient."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from app.interfaces.db_api_client import DBApiClient
|
||||
from app.models.auth import LoginRequest, UserResponse
|
||||
from app.models.settings import UserSettings, UserSettingsUpdate, EnvironmentSettings
|
||||
from app.models.analysis import SessionCreate, SessionResponse, SessionList, SessionListItem
|
||||
|
||||
|
||||
class TestDBApiClient:
|
||||
"""Tests for DBApiClient class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_user(self):
|
||||
"""Test login_user calls post correctly."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
client = DBApiClient(api_prefix="http://db-api:8080/api/v1")
|
||||
|
||||
# Mock the post method
|
||||
mock_user_response = UserResponse(
|
||||
user_id="user-123",
|
||||
login="12345678",
|
||||
last_login_at="2024-01-01T00:00:00Z",
|
||||
created_at="2024-01-01T00:00:00Z"
|
||||
)
|
||||
client.post = AsyncMock(return_value=mock_user_response)
|
||||
|
||||
login_request = LoginRequest(login="12345678", client_ip="127.0.0.1")
|
||||
result = await client.login_user(login_request)
|
||||
|
||||
assert result == mock_user_response
|
||||
client.post.assert_called_once_with(
|
||||
"/users/login",
|
||||
body=login_request,
|
||||
response_model=UserResponse
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_settings(self):
|
||||
"""Test get_user_settings calls get correctly."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
client = DBApiClient(api_prefix="http://db-api:8080/api/v1")
|
||||
|
||||
mock_settings = UserSettings(
|
||||
user_id="user-123",
|
||||
settings={
|
||||
"ift": EnvironmentSettings(
|
||||
apiMode="bench",
|
||||
bearerToken="",
|
||||
systemPlatform="",
|
||||
systemPlatformUser="",
|
||||
platformUserId="",
|
||||
platformId="",
|
||||
withClassify=False,
|
||||
resetSessionMode=True
|
||||
)
|
||||
},
|
||||
updated_at="2024-01-01T00:00:00Z"
|
||||
)
|
||||
client.get = AsyncMock(return_value=mock_settings)
|
||||
|
||||
result = await client.get_user_settings("user-123")
|
||||
|
||||
assert result == mock_settings
|
||||
client.get.assert_called_once_with(
|
||||
"/users/user-123/settings",
|
||||
response_model=UserSettings
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_settings(self):
|
||||
"""Test update_user_settings calls put correctly."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
client = DBApiClient(api_prefix="http://db-api:8080/api/v1")
|
||||
|
||||
settings_update = UserSettingsUpdate(
|
||||
settings={
|
||||
"ift": EnvironmentSettings(
|
||||
apiMode="backend",
|
||||
bearerToken="",
|
||||
systemPlatform="",
|
||||
systemPlatformUser="",
|
||||
platformUserId="",
|
||||
platformId="",
|
||||
withClassify=True,
|
||||
resetSessionMode=False
|
||||
)
|
||||
}
|
||||
)
|
||||
mock_updated_settings = UserSettings(
|
||||
user_id="user-123",
|
||||
settings=settings_update.settings,
|
||||
updated_at="2024-01-01T01:00:00Z"
|
||||
)
|
||||
client.put = AsyncMock(return_value=mock_updated_settings)
|
||||
|
||||
result = await client.update_user_settings("user-123", settings_update)
|
||||
|
||||
assert result == mock_updated_settings
|
||||
client.put.assert_called_once_with(
|
||||
"/users/user-123/settings",
|
||||
body=settings_update,
|
||||
response_model=UserSettings
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_session(self):
|
||||
"""Test save_session calls post correctly."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
client = DBApiClient(api_prefix="http://db-api:8080/api/v1")
|
||||
|
||||
session_data = SessionCreate(
|
||||
environment="ift",
|
||||
api_mode="bench",
|
||||
request=[{"question": "test"}],
|
||||
response={"answer": "test"},
|
||||
annotations={}
|
||||
)
|
||||
mock_session_response = SessionResponse(
|
||||
session_id="session-123",
|
||||
user_id="user-123",
|
||||
environment="ift",
|
||||
api_mode="bench",
|
||||
request=[{"question": "test"}],
|
||||
response={"answer": "test"},
|
||||
annotations={},
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
updated_at="2024-01-01T00:00:00Z"
|
||||
)
|
||||
client.post = AsyncMock(return_value=mock_session_response)
|
||||
|
||||
result = await client.save_session("user-123", session_data)
|
||||
|
||||
assert result == mock_session_response
|
||||
client.post.assert_called_once_with(
|
||||
"/users/user-123/sessions",
|
||||
body=session_data,
|
||||
response_model=SessionResponse
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sessions(self):
|
||||
"""Test get_sessions calls get correctly."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
client = DBApiClient(api_prefix="http://db-api:8080/api/v1")
|
||||
|
||||
mock_sessions = SessionList(
|
||||
sessions=[
|
||||
SessionListItem(
|
||||
session_id="session-1",
|
||||
environment="ift",
|
||||
created_at="2024-01-01T00:00:00Z"
|
||||
)
|
||||
],
|
||||
total=1
|
||||
)
|
||||
client.get = AsyncMock(return_value=mock_sessions)
|
||||
|
||||
result = await client.get_sessions("user-123", environment="ift", limit=10, offset=0)
|
||||
|
||||
assert result == mock_sessions
|
||||
client.get.assert_called_once_with(
|
||||
"/users/user-123/sessions",
|
||||
params={"limit": 10, "offset": 0, "environment": "ift"},
|
||||
response_model=SessionList
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sessions_without_environment(self):
|
||||
"""Test get_sessions without environment filter."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
client = DBApiClient(api_prefix="http://db-api:8080/api/v1")
|
||||
|
||||
mock_sessions = SessionList(sessions=[], total=0)
|
||||
client.get = AsyncMock(return_value=mock_sessions)
|
||||
|
||||
result = await client.get_sessions("user-123", limit=50, offset=0)
|
||||
|
||||
assert result == mock_sessions
|
||||
client.get.assert_called_once_with(
|
||||
"/users/user-123/sessions",
|
||||
params={"limit": 50, "offset": 0},
|
||||
response_model=SessionList
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session(self):
|
||||
"""Test get_session calls get correctly."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
client = DBApiClient(api_prefix="http://db-api:8080/api/v1")
|
||||
|
||||
mock_session = SessionResponse(
|
||||
session_id="session-123",
|
||||
user_id="user-123",
|
||||
environment="ift",
|
||||
api_mode="bench",
|
||||
request=[],
|
||||
response={},
|
||||
annotations={},
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
updated_at="2024-01-01T00:00:00Z"
|
||||
)
|
||||
client.get = AsyncMock(return_value=mock_session)
|
||||
|
||||
result = await client.get_session("user-123", "session-123")
|
||||
|
||||
assert result == mock_session
|
||||
client.get.assert_called_once_with(
|
||||
"/users/user-123/sessions/session-123",
|
||||
response_model=SessionResponse
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_session(self):
|
||||
"""Test delete_session calls delete correctly."""
|
||||
with patch('app.interfaces.base.httpx.AsyncClient'):
|
||||
client = DBApiClient(api_prefix="http://db-api:8080/api/v1")
|
||||
|
||||
client.delete = AsyncMock(return_value={})
|
||||
|
||||
result = await client.delete_session("user-123", "session-123")
|
||||
|
||||
assert result == {}
|
||||
client.delete.assert_called_once_with(
|
||||
"/users/user-123/sessions/session-123"
|
||||
)
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
"""Tests for FastAPI dependencies."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
from app.dependencies import get_current_user, get_db_client
|
||||
from app.interfaces.db_api_client import DBApiClient
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
"""Tests for get_current_user dependency."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_valid_token(self):
|
||||
"""Test getting current user with valid token."""
|
||||
# Mock valid token payload
|
||||
valid_payload = {
|
||||
"user_id": "user-123",
|
||||
"login": "12345678",
|
||||
"exp": 9999999999 # Far future
|
||||
}
|
||||
|
||||
credentials = MagicMock(spec=HTTPAuthorizationCredentials)
|
||||
credentials.credentials = "valid.token.here"
|
||||
|
||||
with patch('app.dependencies.decode_access_token', return_value=valid_payload):
|
||||
user = await get_current_user(credentials)
|
||||
|
||||
assert user == valid_payload
|
||||
assert user["user_id"] == "user-123"
|
||||
assert user["login"] == "12345678"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_invalid_token(self):
|
||||
"""Test getting current user with invalid token."""
|
||||
credentials = MagicMock(spec=HTTPAuthorizationCredentials)
|
||||
credentials.credentials = "invalid.token"
|
||||
|
||||
# Mock invalid token (returns None)
|
||||
with patch('app.dependencies.decode_access_token', return_value=None):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(credentials)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "invalid or expired" in exc_info.value.detail.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_expired_token(self):
|
||||
"""Test getting current user with expired token."""
|
||||
credentials = MagicMock(spec=HTTPAuthorizationCredentials)
|
||||
credentials.credentials = "expired.token"
|
||||
|
||||
# Expired tokens return None from decode
|
||||
with patch('app.dependencies.decode_access_token', return_value=None):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(credentials)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.headers == {"WWW-Authenticate": "Bearer"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_malformed_token(self):
|
||||
"""Test getting current user with malformed token."""
|
||||
credentials = MagicMock(spec=HTTPAuthorizationCredentials)
|
||||
credentials.credentials = "not.a.jwt"
|
||||
|
||||
with patch('app.dependencies.decode_access_token', return_value=None):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(credentials)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
class TestGetDbClient:
|
||||
"""Tests for get_db_client dependency."""
|
||||
|
||||
def test_get_db_client_returns_instance(self):
|
||||
"""Test that get_db_client returns DBApiClient instance."""
|
||||
client = get_db_client()
|
||||
|
||||
assert isinstance(client, DBApiClient)
|
||||
|
||||
def test_get_db_client_uses_settings(self):
|
||||
"""Test that get_db_client uses settings for configuration."""
|
||||
with patch('app.dependencies.settings') as mock_settings:
|
||||
mock_settings.DB_API_URL = "http://test-api:9999/api/v1"
|
||||
|
||||
client = get_db_client()
|
||||
|
||||
# Check that client was created with correct URL
|
||||
assert client.api_prefix == "http://test-api:9999/api/v1"
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
"""Tests for main.py endpoints."""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app.main import app
|
||||
|
||||
|
||||
class TestMainEndpoints:
|
||||
"""Tests for main application endpoints."""
|
||||
|
||||
def test_health_endpoint(self):
|
||||
"""Test health check endpoint."""
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serve_frontend_app_endpoint(self):
|
||||
"""Test /app endpoint serves frontend."""
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
# Get the endpoint function
|
||||
for route in app.routes:
|
||||
if hasattr(route, 'path') and route.path == '/app':
|
||||
result = await route.endpoint()
|
||||
assert isinstance(result, FileResponse)
|
||||
assert result.path == "static/index.html"
|
||||
break
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_root_endpoint(self):
|
||||
"""Test / endpoint serves frontend."""
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
# Get the endpoint function
|
||||
for route in app.routes:
|
||||
if hasattr(route, 'path') and route.path == '/':
|
||||
result = await route.endpoint()
|
||||
assert isinstance(result, FileResponse)
|
||||
assert result.path == "static/index.html"
|
||||
break
|
||||
|
|
@ -0,0 +1,183 @@
|
|||
"""Tests for Pydantic models validation."""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from app.models.auth import LoginRequest, UserResponse, LoginResponse
|
||||
from app.models.query import QuestionRequest, BenchQueryRequest, BackendQueryRequest, QueryResponse
|
||||
from app.models.settings import EnvironmentSettings, UserSettingsUpdate
|
||||
|
||||
|
||||
class TestAuthModels:
|
||||
"""Tests for authentication models."""
|
||||
|
||||
def test_login_request_valid(self):
|
||||
"""Test valid LoginRequest."""
|
||||
request = LoginRequest(
|
||||
login="12345678",
|
||||
client_ip="192.168.1.1"
|
||||
)
|
||||
|
||||
assert request.login == "12345678"
|
||||
assert request.client_ip == "192.168.1.1"
|
||||
|
||||
def test_login_request_invalid_format(self):
|
||||
"""Test LoginRequest with invalid login format."""
|
||||
# Not 8 digits
|
||||
with pytest.raises(ValidationError):
|
||||
LoginRequest(login="1234567", client_ip="192.168.1.1")
|
||||
|
||||
# Contains letters
|
||||
with pytest.raises(ValidationError):
|
||||
LoginRequest(login="abcd1234", client_ip="192.168.1.1")
|
||||
|
||||
def test_user_response(self):
|
||||
"""Test UserResponse model."""
|
||||
user = UserResponse(
|
||||
user_id="user-123",
|
||||
login="12345678",
|
||||
last_login_at="2024-01-01T00:00:00Z",
|
||||
created_at="2024-01-01T00:00:00Z"
|
||||
)
|
||||
|
||||
assert user.user_id == "user-123"
|
||||
assert user.login == "12345678"
|
||||
|
||||
def test_login_response(self):
|
||||
"""Test LoginResponse model."""
|
||||
user = UserResponse(
|
||||
user_id="user-123",
|
||||
login="12345678",
|
||||
last_login_at="2024-01-01T00:00:00Z",
|
||||
created_at="2024-01-01T00:00:00Z"
|
||||
)
|
||||
|
||||
response = LoginResponse(
|
||||
access_token="token123",
|
||||
token_type="bearer",
|
||||
user=user
|
||||
)
|
||||
|
||||
assert response.access_token == "token123"
|
||||
assert response.token_type == "bearer"
|
||||
assert response.user.user_id == "user-123"
|
||||
|
||||
|
||||
class TestQueryModels:
|
||||
"""Tests for query models."""
|
||||
|
||||
def test_question_request_valid(self):
|
||||
"""Test valid QuestionRequest."""
|
||||
question = QuestionRequest(
|
||||
body="What is the weather?",
|
||||
with_docs=True
|
||||
)
|
||||
|
||||
assert question.body == "What is the weather?"
|
||||
assert question.with_docs is True
|
||||
|
||||
def test_question_request_default_with_docs(self):
|
||||
"""Test QuestionRequest with default with_docs."""
|
||||
question = QuestionRequest(body="Test question")
|
||||
|
||||
assert question.with_docs is True # Default value
|
||||
|
||||
def test_bench_query_request_valid(self):
|
||||
"""Test valid BenchQueryRequest."""
|
||||
request = BenchQueryRequest(
|
||||
environment="ift",
|
||||
questions=[
|
||||
QuestionRequest(body="Q1", with_docs=True),
|
||||
QuestionRequest(body="Q2", with_docs=False)
|
||||
]
|
||||
)
|
||||
|
||||
assert request.environment == "ift"
|
||||
assert len(request.questions) == 2
|
||||
assert request.questions[0].body == "Q1"
|
||||
|
||||
def test_backend_query_request_valid(self):
|
||||
"""Test valid BackendQueryRequest."""
|
||||
request = BackendQueryRequest(
|
||||
environment="psi",
|
||||
questions=[
|
||||
QuestionRequest(body="Q1", with_docs=True)
|
||||
],
|
||||
reset_session=True
|
||||
)
|
||||
|
||||
assert request.environment == "psi"
|
||||
assert len(request.questions) == 1
|
||||
assert request.reset_session is True
|
||||
|
||||
def test_backend_query_request_default_reset(self):
|
||||
"""Test BackendQueryRequest with default reset_session."""
|
||||
request = BackendQueryRequest(
|
||||
environment="prod",
|
||||
questions=[QuestionRequest(body="Q1")]
|
||||
)
|
||||
|
||||
assert request.reset_session is True # Default value
|
||||
|
||||
def test_query_response(self):
|
||||
"""Test QueryResponse model."""
|
||||
response = QueryResponse(
|
||||
request_id="req-123",
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
environment="ift",
|
||||
response={"answers": []}
|
||||
)
|
||||
|
||||
assert response.request_id == "req-123"
|
||||
assert response.environment == "ift"
|
||||
assert isinstance(response.response, dict)
|
||||
|
||||
|
||||
class TestSettingsModels:
|
||||
"""Tests for settings models."""
|
||||
|
||||
def test_environment_settings_valid(self):
|
||||
"""Test valid EnvironmentSettings."""
|
||||
settings = EnvironmentSettings(
|
||||
apiMode="bench",
|
||||
bearerToken="token123",
|
||||
systemPlatform="platform",
|
||||
systemPlatformUser="user",
|
||||
platformUserId="user-123",
|
||||
platformId="platform-123",
|
||||
withClassify=True,
|
||||
resetSessionMode=False
|
||||
)
|
||||
|
||||
assert settings.apiMode == "bench"
|
||||
assert settings.bearerToken == "token123"
|
||||
assert settings.withClassify is True
|
||||
assert settings.resetSessionMode is False
|
||||
|
||||
def test_environment_settings_defaults(self):
|
||||
"""Test EnvironmentSettings with default values."""
|
||||
settings = EnvironmentSettings(apiMode="backend")
|
||||
|
||||
assert settings.apiMode == "backend"
|
||||
assert settings.bearerToken == ""
|
||||
assert settings.withClassify is False
|
||||
assert settings.resetSessionMode is True
|
||||
|
||||
def test_user_settings_update(self):
|
||||
"""Test UserSettingsUpdate model."""
|
||||
update = UserSettingsUpdate(
|
||||
settings={
|
||||
"ift": EnvironmentSettings(apiMode="bench"),
|
||||
"psi": EnvironmentSettings(apiMode="backend")
|
||||
}
|
||||
)
|
||||
|
||||
assert "ift" in update.settings
|
||||
assert "psi" in update.settings
|
||||
assert update.settings["ift"].apiMode == "bench"
|
||||
assert update.settings["psi"].apiMode == "backend"
|
||||
|
||||
def test_user_settings_update_empty(self):
|
||||
"""Test UserSettingsUpdate with empty settings."""
|
||||
update = UserSettingsUpdate(settings={})
|
||||
|
||||
assert update.settings == {}
|
||||
|
|
@ -0,0 +1,553 @@
|
|||
"""Tests for query endpoints and RAG service."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import httpx
|
||||
from app.services.rag_service import RagService
|
||||
from app.models.query import QuestionRequest
|
||||
|
||||
|
||||
class TestBenchQueryEndpoint:
|
||||
"""Tests for /api/v1/query/bench endpoint."""
|
||||
|
||||
def test_bench_query_success(self, client, mock_db_client, test_settings, mock_bench_response):
|
||||
"""Test successful bench query."""
|
||||
mock_db_client.get_user_settings = AsyncMock(return_value=test_settings)
|
||||
|
||||
with patch('app.api.v1.query.RagService') as MockRagService:
|
||||
mock_rag = AsyncMock()
|
||||
mock_rag.send_bench_query = AsyncMock(return_value=mock_bench_response)
|
||||
mock_rag.close = AsyncMock()
|
||||
MockRagService.return_value = mock_rag
|
||||
|
||||
request_data = {
|
||||
"environment": "ift",
|
||||
"questions": [
|
||||
{"body": "Test question 1", "with_docs": True},
|
||||
{"body": "Test question 2", "with_docs": False}
|
||||
]
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/query/bench", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "request_id" in data
|
||||
assert "timestamp" in data
|
||||
assert data["environment"] == "ift"
|
||||
assert "response" in data
|
||||
assert data["response"] == mock_bench_response
|
||||
|
||||
mock_rag.send_bench_query.assert_called_once()
|
||||
mock_rag.close.assert_called_once()
|
||||
|
||||
def test_bench_query_invalid_environment(self, client, mock_db_client):
|
||||
"""Test bench query with invalid environment."""
|
||||
request_data = {
|
||||
"environment": "invalid",
|
||||
"questions": [{"body": "Test", "with_docs": True}]
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/query/bench", json=request_data)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "invalid environment" in response.json()["detail"].lower()
|
||||
|
||||
def test_bench_query_wrong_api_mode(self, client, mock_db_client, test_settings):
|
||||
"""Test bench query when environment is configured for backend mode."""
|
||||
# Create new settings with backend apiMode
|
||||
from app.models.settings import EnvironmentSettings, UserSettings
|
||||
|
||||
backend_settings = EnvironmentSettings(
|
||||
apiMode="backend",
|
||||
bearerToken="",
|
||||
systemPlatform="",
|
||||
systemPlatformUser="",
|
||||
platformUserId="",
|
||||
platformId="",
|
||||
withClassify=False,
|
||||
resetSessionMode=True
|
||||
)
|
||||
|
||||
test_settings_backend = UserSettings(
|
||||
user_id="test-user-123",
|
||||
settings={
|
||||
"ift": backend_settings,
|
||||
"psi": test_settings.settings["psi"],
|
||||
"prod": test_settings.settings["prod"]
|
||||
},
|
||||
updated_at="2024-01-01T00:00:00Z"
|
||||
)
|
||||
|
||||
mock_db_client.get_user_settings = AsyncMock(return_value=test_settings_backend)
|
||||
|
||||
request_data = {
|
||||
"environment": "ift",
|
||||
"questions": [{"body": "Test", "with_docs": True}]
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/query/bench", json=request_data)
|
||||
|
||||
# Can be 400 (if caught properly) or 500 (if generic exception)
|
||||
assert response.status_code in [400, 500]
|
||||
if response.status_code == 400:
|
||||
assert "not configured for bench mode" in response.json()["detail"].lower()
|
||||
|
||||
def test_bench_query_rag_backend_error(self, client, mock_db_client, test_settings):
|
||||
"""Test bench query when RAG backend returns error."""
|
||||
mock_db_client.get_user_settings = AsyncMock(return_value=test_settings)
|
||||
|
||||
with patch('app.api.v1.query.RagService') as MockRagService:
|
||||
mock_rag = AsyncMock()
|
||||
error_response = httpx.Response(502, json={"error": "Backend error"})
|
||||
mock_rag.send_bench_query = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Error", request=None, response=error_response)
|
||||
)
|
||||
mock_rag.close = AsyncMock()
|
||||
MockRagService.return_value = mock_rag
|
||||
|
||||
request_data = {
|
||||
"environment": "ift",
|
||||
"questions": [{"body": "Test", "with_docs": True}]
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/query/bench", json=request_data)
|
||||
|
||||
assert response.status_code == 502
|
||||
mock_rag.close.assert_called_once()
|
||||
|
||||
def test_bench_query_settings_not_found(self, client, mock_db_client, test_settings):
|
||||
"""Test bench query when environment settings not found."""
|
||||
# Remove ift settings
|
||||
from app.models.settings import UserSettings
|
||||
settings_without_ift = UserSettings(
|
||||
user_id="test-user-123",
|
||||
settings={
|
||||
"psi": test_settings.settings["psi"],
|
||||
"prod": test_settings.settings["prod"]
|
||||
},
|
||||
updated_at="2024-01-01T00:00:00Z"
|
||||
)
|
||||
mock_db_client.get_user_settings = AsyncMock(return_value=settings_without_ift)
|
||||
|
||||
request_data = {
|
||||
"environment": "ift",
|
||||
"questions": [{"body": "Test", "with_docs": True}]
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/query/bench", json=request_data)
|
||||
|
||||
# HTTPException inside try/except is caught and returns 500
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
class TestBackendQueryEndpoint:
|
||||
"""Tests for /api/v1/query/backend endpoint."""
|
||||
|
||||
def test_backend_query_success(self, client, mock_db_client, test_settings, mock_backend_response):
|
||||
"""Test successful backend query."""
|
||||
# Set apiMode to backend
|
||||
test_settings.settings["ift"].apiMode = "backend"
|
||||
mock_db_client.get_user_settings = AsyncMock(return_value=test_settings)
|
||||
|
||||
with patch('app.api.v1.query.RagService') as MockRagService:
|
||||
mock_rag = AsyncMock()
|
||||
mock_rag.send_backend_query = AsyncMock(return_value=[mock_backend_response])
|
||||
mock_rag.close = AsyncMock()
|
||||
MockRagService.return_value = mock_rag
|
||||
|
||||
request_data = {
|
||||
"environment": "ift",
|
||||
"questions": [
|
||||
{"body": "Test question", "with_docs": True}
|
||||
],
|
||||
"reset_session": True
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/query/backend", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "request_id" in data
|
||||
assert "timestamp" in data
|
||||
assert data["environment"] == "ift"
|
||||
assert "response" in data
|
||||
assert "answers" in data["response"]
|
||||
assert data["response"]["answers"] == [mock_backend_response]
|
||||
|
||||
mock_rag.send_backend_query.assert_called_once()
|
||||
call_kwargs = mock_rag.send_backend_query.call_args[1]
|
||||
assert call_kwargs["reset_session"] is True
|
||||
|
||||
def test_backend_query_wrong_api_mode(self, client, mock_db_client, test_settings):
|
||||
"""Test backend query when environment is configured for bench mode."""
|
||||
# test_settings already has bench mode, so this should fail
|
||||
mock_db_client.get_user_settings = AsyncMock(return_value=test_settings)
|
||||
|
||||
request_data = {
|
||||
"environment": "ift",
|
||||
"questions": [{"body": "Test", "with_docs": True}],
|
||||
"reset_session": True
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/query/backend", json=request_data)
|
||||
|
||||
# Can be 400 (if caught properly) or 500 (if generic exception)
|
||||
assert response.status_code in [400, 500]
|
||||
if response.status_code == 400:
|
||||
assert "not configured for backend mode" in response.json()["detail"].lower()
|
||||
|
||||
def test_backend_query_invalid_environment(self, client, mock_db_client):
|
||||
"""Test backend query with invalid environment."""
|
||||
request_data = {
|
||||
"environment": "invalid",
|
||||
"questions": [{"body": "Test", "with_docs": True}],
|
||||
"reset_session": True
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/query/backend", json=request_data)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "invalid environment" in response.json()["detail"].lower()
|
||||
|
||||
def test_backend_query_settings_not_found(self, client, mock_db_client, test_settings):
|
||||
"""Test backend query when environment settings not found."""
|
||||
# Set apiMode to backend for ift but remove psi settings
|
||||
from app.models.settings import UserSettings
|
||||
test_settings.settings["ift"].apiMode = "backend"
|
||||
settings_without_psi = UserSettings(
|
||||
user_id="test-user-123",
|
||||
settings={
|
||||
"ift": test_settings.settings["ift"],
|
||||
"prod": test_settings.settings["prod"]
|
||||
},
|
||||
updated_at="2024-01-01T00:00:00Z"
|
||||
)
|
||||
mock_db_client.get_user_settings = AsyncMock(return_value=settings_without_psi)
|
||||
|
||||
request_data = {
|
||||
"environment": "psi",
|
||||
"questions": [{"body": "Test", "with_docs": True}],
|
||||
"reset_session": True
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/query/backend", json=request_data)
|
||||
|
||||
# HTTPException inside try/except is caught and returns 500
|
||||
assert response.status_code == 500
|
||||
|
||||
|
||||
class TestRagService:
|
||||
"""Tests for RagService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_bench_query_success(self, mock_httpx_client, mock_bench_response):
|
||||
"""Test successful bench query via RagService."""
|
||||
# Configure mock response
|
||||
mock_httpx_client.post.return_value.json.return_value = mock_bench_response
|
||||
|
||||
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
||||
rag_service = RagService()
|
||||
|
||||
questions = [
|
||||
QuestionRequest(body="Question 1", with_docs=True),
|
||||
QuestionRequest(body="Question 2", with_docs=False)
|
||||
]
|
||||
|
||||
user_settings = {
|
||||
"bearerToken": "test-token",
|
||||
"systemPlatform": "test-platform"
|
||||
}
|
||||
|
||||
result = await rag_service.send_bench_query(
|
||||
environment="ift",
|
||||
questions=questions,
|
||||
user_settings=user_settings,
|
||||
request_id="test-request-123"
|
||||
)
|
||||
|
||||
assert result == mock_bench_response
|
||||
mock_httpx_client.post.assert_called_once()
|
||||
|
||||
# Verify headers
|
||||
call_kwargs = mock_httpx_client.post.call_args[1]
|
||||
headers = call_kwargs["headers"]
|
||||
assert headers["Request-Id"] == "test-request-123"
|
||||
assert headers["Authorization"] == "Bearer test-token"
|
||||
assert headers["System-Platform"] == "test-platform"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_backend_query_success(self, mock_httpx_client, mock_backend_response):
|
||||
"""Test successful backend query via RagService."""
|
||||
# Configure mock response
|
||||
mock_httpx_client.post.return_value.json.return_value = mock_backend_response
|
||||
|
||||
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
||||
rag_service = RagService()
|
||||
|
||||
questions = [
|
||||
QuestionRequest(body="Question 1", with_docs=True)
|
||||
]
|
||||
|
||||
user_settings = {
|
||||
"bearerToken": "test-token",
|
||||
"platformUserId": "user-123",
|
||||
"platformId": "platform-123",
|
||||
"withClassify": True,
|
||||
"resetSessionMode": True
|
||||
}
|
||||
|
||||
result = await rag_service.send_backend_query(
|
||||
environment="ift",
|
||||
questions=questions,
|
||||
user_settings=user_settings,
|
||||
reset_session=True
|
||||
)
|
||||
|
||||
assert result == [mock_backend_response]
|
||||
# 2 calls: ask + reset
|
||||
assert mock_httpx_client.post.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_backend_query_no_reset(self, mock_httpx_client, mock_backend_response):
|
||||
"""Test backend query without session reset."""
|
||||
mock_httpx_client.post.return_value.json.return_value = mock_backend_response
|
||||
|
||||
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
||||
rag_service = RagService()
|
||||
|
||||
questions = [QuestionRequest(body="Question", with_docs=True)]
|
||||
user_settings = {"resetSessionMode": False}
|
||||
|
||||
result = await rag_service.send_backend_query(
|
||||
environment="ift",
|
||||
questions=questions,
|
||||
user_settings=user_settings,
|
||||
reset_session=False
|
||||
)
|
||||
|
||||
assert result == [mock_backend_response]
|
||||
# Only 1 call: ask (no reset)
|
||||
assert mock_httpx_client.post.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_bench_headers(self):
|
||||
"""Test building headers for bench mode."""
|
||||
with patch('app.services.rag_service.httpx.AsyncClient'):
|
||||
rag_service = RagService()
|
||||
|
||||
user_settings = {
|
||||
"bearerToken": "my-token",
|
||||
"systemPlatform": "my-platform"
|
||||
}
|
||||
|
||||
headers = rag_service._build_bench_headers("ift", user_settings, "req-123")
|
||||
|
||||
assert headers["Request-Id"] == "req-123"
|
||||
assert headers["System-Id"] == "brief-bench-ift"
|
||||
assert headers["Authorization"] == "Bearer my-token"
|
||||
assert headers["System-Platform"] == "my-platform"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_backend_headers(self):
|
||||
"""Test building headers for backend mode."""
|
||||
with patch('app.services.rag_service.httpx.AsyncClient'):
|
||||
rag_service = RagService()
|
||||
|
||||
user_settings = {
|
||||
"bearerToken": "my-token",
|
||||
"platformUserId": "user-456",
|
||||
"platformId": "platform-789"
|
||||
}
|
||||
|
||||
headers = rag_service._build_backend_headers(user_settings)
|
||||
|
||||
assert headers["Authorization"] == "Bearer my-token"
|
||||
assert headers["Platform-User-Id"] == "user-456"
|
||||
assert headers["Platform-Id"] == "platform-789"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_client_with_mtls(self):
|
||||
"""Test creating HTTP client with mTLS configuration."""
|
||||
with patch('app.services.rag_service.settings') as mock_settings:
|
||||
# Configure mTLS settings
|
||||
mock_settings.IFT_RAG_CERT_CERT = "/path/to/client.crt"
|
||||
mock_settings.IFT_RAG_CERT_KEY = "/path/to/client.key"
|
||||
mock_settings.IFT_RAG_CERT_CA = "/path/to/ca.crt"
|
||||
mock_settings.PSI_RAG_CERT_CERT = ""
|
||||
mock_settings.PSI_RAG_CERT_KEY = ""
|
||||
mock_settings.PSI_RAG_CERT_CA = ""
|
||||
mock_settings.PROD_RAG_CERT_CERT = ""
|
||||
mock_settings.PROD_RAG_CERT_KEY = ""
|
||||
mock_settings.PROD_RAG_CERT_CA = ""
|
||||
|
||||
with patch('app.services.rag_service.httpx.AsyncClient') as MockAsyncClient:
|
||||
service = RagService()
|
||||
|
||||
# Verify AsyncClient was called 3 times (one per environment)
|
||||
assert MockAsyncClient.call_count == 3
|
||||
|
||||
# Check the first call (ift) had mTLS config
|
||||
first_call_kwargs = MockAsyncClient.call_args_list[0][1]
|
||||
assert first_call_kwargs["cert"] == ("/path/to/client.crt", "/path/to/client.key")
|
||||
assert first_call_kwargs["verify"] == "/path/to/ca.crt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_client_without_mtls(self):
|
||||
"""Test creating HTTP client without mTLS."""
|
||||
with patch('app.services.rag_service.settings') as mock_settings:
|
||||
# No mTLS certs for any environment
|
||||
mock_settings.IFT_RAG_CERT_CERT = ""
|
||||
mock_settings.IFT_RAG_CERT_KEY = ""
|
||||
mock_settings.IFT_RAG_CERT_CA = ""
|
||||
mock_settings.PSI_RAG_CERT_CERT = ""
|
||||
mock_settings.PSI_RAG_CERT_KEY = ""
|
||||
mock_settings.PSI_RAG_CERT_CA = ""
|
||||
mock_settings.PROD_RAG_CERT_CERT = ""
|
||||
mock_settings.PROD_RAG_CERT_KEY = ""
|
||||
mock_settings.PROD_RAG_CERT_CA = ""
|
||||
|
||||
with patch('app.services.rag_service.httpx.AsyncClient') as MockAsyncClient:
|
||||
service = RagService()
|
||||
|
||||
# Verify AsyncClient was called 3 times
|
||||
assert MockAsyncClient.call_count == 3
|
||||
|
||||
# Check all calls had no mTLS
|
||||
for call in MockAsyncClient.call_args_list:
|
||||
call_kwargs = call[1]
|
||||
assert call_kwargs["cert"] is None
|
||||
assert call_kwargs["verify"] is True # Default verify
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_bench_query_http_error(self, mock_httpx_client):
|
||||
"""Test bench query with HTTP error."""
|
||||
# Configure mock to raise HTTP error
|
||||
error_response = MagicMock()
|
||||
error_response.status_code = 500
|
||||
error_response.text = "Internal Server Error"
|
||||
|
||||
mock_httpx_client.post.side_effect = httpx.HTTPStatusError(
|
||||
"Server error",
|
||||
request=None,
|
||||
response=error_response
|
||||
)
|
||||
|
||||
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
||||
rag_service = RagService()
|
||||
|
||||
questions = [QuestionRequest(body="Test", with_docs=True)]
|
||||
user_settings = {}
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await rag_service.send_bench_query(
|
||||
environment="ift",
|
||||
questions=questions,
|
||||
user_settings=user_settings
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_backend_query_http_error(self, mock_httpx_client):
|
||||
"""Test backend query with HTTP error on ask endpoint."""
|
||||
error_response = MagicMock()
|
||||
error_response.status_code = 503
|
||||
error_response.text = "Service Unavailable"
|
||||
|
||||
mock_httpx_client.post.side_effect = httpx.HTTPStatusError(
|
||||
"Service error",
|
||||
request=None,
|
||||
response=error_response
|
||||
)
|
||||
|
||||
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
||||
rag_service = RagService()
|
||||
|
||||
questions = [QuestionRequest(body="Test", with_docs=True)]
|
||||
user_settings = {"resetSessionMode": False}
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await rag_service.send_backend_query(
|
||||
environment="ift",
|
||||
questions=questions,
|
||||
user_settings=user_settings,
|
||||
reset_session=False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_base_url(self):
|
||||
"""Test building base URL for environment."""
|
||||
with patch('app.services.rag_service.httpx.AsyncClient'):
|
||||
with patch('app.services.rag_service.settings') as mock_settings:
|
||||
mock_settings.IFT_RAG_HOST = "rag-ift.example.com"
|
||||
mock_settings.IFT_RAG_PORT = 8443
|
||||
|
||||
service = RagService()
|
||||
url = service._get_base_url("ift")
|
||||
|
||||
assert url == "https://rag-ift.example.com:8443"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_clients(self, mock_httpx_client):
|
||||
"""Test closing all HTTP clients."""
|
||||
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
||||
service = RagService()
|
||||
|
||||
await service.close()
|
||||
|
||||
# Should close all 3 clients (ift, psi, prod)
|
||||
assert mock_httpx_client.aclose.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager(self, mock_httpx_client):
|
||||
"""Test using RagService as async context manager."""
|
||||
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
||||
async with RagService() as service:
|
||||
assert service is not None
|
||||
|
||||
# Should close all clients on exit
|
||||
assert mock_httpx_client.aclose.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_bench_query_general_exception(self, mock_httpx_client):
|
||||
"""Test bench query with general exception (not HTTP error)."""
|
||||
mock_httpx_client.post.side_effect = Exception("Network error")
|
||||
|
||||
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
||||
rag_service = RagService()
|
||||
|
||||
questions = [QuestionRequest(body="Test", with_docs=True)]
|
||||
user_settings = {}
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await rag_service.send_bench_query(
|
||||
environment="ift",
|
||||
questions=questions,
|
||||
user_settings=user_settings
|
||||
)
|
||||
|
||||
assert "Network error" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_backend_query_general_exception(self, mock_httpx_client):
|
||||
"""Test backend query with general exception (not HTTP error)."""
|
||||
mock_httpx_client.post.side_effect = Exception("Connection timeout")
|
||||
|
||||
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
||||
rag_service = RagService()
|
||||
|
||||
questions = [QuestionRequest(body="Test", with_docs=True)]
|
||||
user_settings = {"resetSessionMode": False}
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await rag_service.send_backend_query(
|
||||
environment="ift",
|
||||
questions=questions,
|
||||
user_settings=user_settings,
|
||||
reset_session=False
|
||||
)
|
||||
|
||||
assert "Connection timeout" in str(exc_info.value)
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
"""Tests for JWT security utilities."""
|
||||
|
||||
import pytest
|
||||
from datetime import timedelta
|
||||
from app.utils.security import create_access_token, decode_access_token
|
||||
|
||||
|
||||
class TestJWTSecurity:
|
||||
"""Tests for JWT token creation and validation."""
|
||||
|
||||
def test_create_access_token(self):
|
||||
"""Test creating JWT access token."""
|
||||
data = {
|
||||
"user_id": "test-user-123",
|
||||
"login": "12345678"
|
||||
}
|
||||
|
||||
token = create_access_token(data)
|
||||
|
||||
assert token is not None
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
||||
|
||||
def test_decode_access_token(self):
|
||||
"""Test decoding valid JWT token."""
|
||||
data = {
|
||||
"user_id": "test-user-123",
|
||||
"login": "12345678"
|
||||
}
|
||||
|
||||
token = create_access_token(data)
|
||||
payload = decode_access_token(token)
|
||||
|
||||
assert payload is not None
|
||||
assert payload["user_id"] == "test-user-123"
|
||||
assert payload["login"] == "12345678"
|
||||
assert "exp" in payload
|
||||
|
||||
def test_decode_invalid_token(self):
|
||||
"""Test decoding invalid token returns None."""
|
||||
payload = decode_access_token("invalid.token.here")
|
||||
|
||||
assert payload is None
|
||||
|
||||
def test_decode_expired_token(self):
|
||||
"""Test decoding expired token returns None."""
|
||||
data = {
|
||||
"user_id": "test-user-123",
|
||||
"login": "12345678"
|
||||
}
|
||||
|
||||
# Create token that expires immediately
|
||||
token = create_access_token(data, expires_delta=timedelta(seconds=-1))
|
||||
payload = decode_access_token(token)
|
||||
|
||||
assert payload is None
|
||||
|
||||
def test_token_contains_all_data(self):
|
||||
"""Test that token contains all provided data."""
|
||||
data = {
|
||||
"user_id": "test-user-123",
|
||||
"login": "12345678",
|
||||
"custom_field": "custom_value"
|
||||
}
|
||||
|
||||
token = create_access_token(data)
|
||||
payload = decode_access_token(token)
|
||||
|
||||
assert payload["user_id"] == "test-user-123"
|
||||
assert payload["login"] == "12345678"
|
||||
assert payload["custom_field"] == "custom_value"
|
||||
assert "exp" in payload
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
"""Tests for settings endpoints."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
import httpx
|
||||
|
||||
|
||||
class TestSettingsEndpoints:
|
||||
"""Tests for /api/v1/settings endpoints."""
|
||||
|
||||
def test_get_settings_success(self, client, mock_db_client, test_settings):
|
||||
"""Test getting user settings successfully."""
|
||||
mock_db_client.get_user_settings = AsyncMock(return_value=test_settings)
|
||||
|
||||
response = client.get("/api/v1/settings")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["user_id"] == "test-user-123"
|
||||
assert "settings" in data
|
||||
assert "ift" in data["settings"]
|
||||
assert "psi" in data["settings"]
|
||||
assert "prod" in data["settings"]
|
||||
assert data["settings"]["ift"]["apiMode"] == "bench"
|
||||
|
||||
mock_db_client.get_user_settings.assert_called_once_with("test-user-123")
|
||||
|
||||
def test_get_settings_not_found(self, client, mock_db_client):
|
||||
"""Test getting settings when user not found."""
|
||||
# Mock 404 from DB API
|
||||
error_response = httpx.Response(404, json={"detail": "Not found"})
|
||||
mock_db_client.get_user_settings = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Not found", request=None, response=error_response)
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/settings")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
def test_get_settings_unauthenticated(self, unauthenticated_client):
|
||||
"""Test getting settings without authentication."""
|
||||
response = unauthenticated_client.get("/api/v1/settings")
|
||||
|
||||
assert response.status_code == 401 # HTTPBearer returns 401
|
||||
|
||||
def test_update_settings_success(self, client, mock_db_client, test_settings):
|
||||
"""Test updating user settings successfully."""
|
||||
mock_db_client.update_user_settings = AsyncMock(return_value=test_settings)
|
||||
|
||||
update_data = {
|
||||
"settings": {
|
||||
"ift": {
|
||||
"apiMode": "backend",
|
||||
"bearerToken": "new-token",
|
||||
"resetSessionMode": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response = client.put("/api/v1/settings", json=update_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["user_id"] == "test-user-123"
|
||||
mock_db_client.update_user_settings.assert_called_once()
|
||||
|
||||
def test_update_settings_invalid_data(self, client, mock_db_client):
|
||||
"""Test updating settings with invalid data."""
|
||||
error_response = httpx.Response(400, json={"detail": "Invalid format"})
|
||||
mock_db_client.update_user_settings = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Bad request", request=None, response=error_response)
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"settings": {
|
||||
"invalid_env": {"apiMode": "invalid"}
|
||||
}
|
||||
}
|
||||
|
||||
response = client.put("/api/v1/settings", json=update_data)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_update_settings_db_api_error(self, client, mock_db_client):
|
||||
"""Test update settings when DB API fails."""
|
||||
mock_db_client.update_user_settings = AsyncMock(
|
||||
side_effect=Exception("DB error")
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"settings": {
|
||||
"ift": {"apiMode": "bench"}
|
||||
}
|
||||
}
|
||||
|
||||
response = client.put("/api/v1/settings", json=update_data)
|
||||
|
||||
assert response.status_code == 500
|
||||
|
||||
def test_get_settings_db_api_502_error(self, client, mock_db_client):
|
||||
"""Test get settings when DB API returns 502."""
|
||||
error_response = httpx.Response(503, json={"detail": "Service unavailable"})
|
||||
mock_db_client.get_user_settings = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Service error", request=None, response=error_response)
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/settings")
|
||||
|
||||
assert response.status_code == 502
|
||||
assert "failed to retrieve settings" in response.json()["detail"].lower()
|
||||
|
||||
def test_get_settings_unexpected_error(self, client, mock_db_client):
|
||||
"""Test get settings with unexpected error."""
|
||||
mock_db_client.get_user_settings = AsyncMock(
|
||||
side_effect=Exception("Unexpected error")
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/settings")
|
||||
|
||||
assert response.status_code == 500
|
||||
assert "internal server error" in response.json()["detail"].lower()
|
||||
|
||||
def test_update_settings_user_not_found(self, client, mock_db_client):
|
||||
"""Test update settings when user not found."""
|
||||
error_response = httpx.Response(404, json={"detail": "User not found"})
|
||||
mock_db_client.update_user_settings = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Not found", request=None, response=error_response)
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"settings": {
|
||||
"ift": {"apiMode": "bench"}
|
||||
}
|
||||
}
|
||||
|
||||
response = client.put("/api/v1/settings", json=update_data)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "user not found" in response.json()["detail"].lower()
|
||||
|
||||
def test_update_settings_db_api_502_error(self, client, mock_db_client):
|
||||
"""Test update settings when DB API returns 502."""
|
||||
error_response = httpx.Response(503, json={"detail": "Service unavailable"})
|
||||
mock_db_client.update_user_settings = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError("Service error", request=None, response=error_response)
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"settings": {
|
||||
"ift": {"apiMode": "bench"}
|
||||
}
|
||||
}
|
||||
|
||||
response = client.put("/api/v1/settings", json=update_data)
|
||||
|
||||
assert response.status_code == 502
|
||||
assert "failed to update settings" in response.json()["detail"].lower()
|
||||
Loading…
Reference in New Issue