554 lines
21 KiB
Python
554 lines
21 KiB
Python
"""Tests for query endpoints and RAG service."""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
import httpx
|
|
from app.services.rag_service import RagService
|
|
from app.models.query import QuestionRequest
|
|
|
|
|
|
class TestBenchQueryEndpoint:
|
|
"""Tests for /api/v1/query/bench endpoint."""
|
|
|
|
def test_bench_query_success(self, client, mock_db_client, test_settings, mock_bench_response):
|
|
"""Test successful bench query."""
|
|
mock_db_client.get_user_settings = AsyncMock(return_value=test_settings)
|
|
|
|
with patch('app.api.v1.query.RagService') as MockRagService:
|
|
mock_rag = AsyncMock()
|
|
mock_rag.send_bench_query = AsyncMock(return_value=mock_bench_response)
|
|
mock_rag.close = AsyncMock()
|
|
MockRagService.return_value = mock_rag
|
|
|
|
request_data = {
|
|
"environment": "ift",
|
|
"questions": [
|
|
{"body": "Test question 1", "with_docs": True},
|
|
{"body": "Test question 2", "with_docs": False}
|
|
]
|
|
}
|
|
|
|
response = client.post("/api/v1/query/bench", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert "request_id" in data
|
|
assert "timestamp" in data
|
|
assert data["environment"] == "ift"
|
|
assert "response" in data
|
|
assert data["response"] == mock_bench_response
|
|
|
|
mock_rag.send_bench_query.assert_called_once()
|
|
mock_rag.close.assert_called_once()
|
|
|
|
def test_bench_query_invalid_environment(self, client, mock_db_client):
|
|
"""Test bench query with invalid environment."""
|
|
request_data = {
|
|
"environment": "invalid",
|
|
"questions": [{"body": "Test", "with_docs": True}]
|
|
}
|
|
|
|
response = client.post("/api/v1/query/bench", json=request_data)
|
|
|
|
assert response.status_code == 400
|
|
assert "invalid environment" in response.json()["detail"].lower()
|
|
|
|
def test_bench_query_wrong_api_mode(self, client, mock_db_client, test_settings):
|
|
"""Test bench query when environment is configured for backend mode."""
|
|
|
|
from app.models.settings import EnvironmentSettings, UserSettings
|
|
|
|
backend_settings = EnvironmentSettings(
|
|
apiMode="backend",
|
|
bearerToken="",
|
|
systemPlatform="",
|
|
systemPlatformUser="",
|
|
platformUserId="",
|
|
platformId="",
|
|
withClassify=False,
|
|
resetSessionMode=True
|
|
)
|
|
|
|
test_settings_backend = UserSettings(
|
|
user_id="test-user-123",
|
|
settings={
|
|
"ift": backend_settings,
|
|
"psi": test_settings.settings["psi"],
|
|
"prod": test_settings.settings["prod"]
|
|
},
|
|
updated_at="2024-01-01T00:00:00Z"
|
|
)
|
|
|
|
mock_db_client.get_user_settings = AsyncMock(return_value=test_settings_backend)
|
|
|
|
request_data = {
|
|
"environment": "ift",
|
|
"questions": [{"body": "Test", "with_docs": True}]
|
|
}
|
|
|
|
response = client.post("/api/v1/query/bench", json=request_data)
|
|
|
|
|
|
assert response.status_code in [400, 500]
|
|
if response.status_code == 400:
|
|
assert "not configured for bench mode" in response.json()["detail"].lower()
|
|
|
|
def test_bench_query_rag_backend_error(self, client, mock_db_client, test_settings):
|
|
"""Test bench query when RAG backend returns error."""
|
|
mock_db_client.get_user_settings = AsyncMock(return_value=test_settings)
|
|
|
|
with patch('app.api.v1.query.RagService') as MockRagService:
|
|
mock_rag = AsyncMock()
|
|
error_response = httpx.Response(502, json={"error": "Backend error"})
|
|
mock_rag.send_bench_query = AsyncMock(
|
|
side_effect=httpx.HTTPStatusError("Error", request=None, response=error_response)
|
|
)
|
|
mock_rag.close = AsyncMock()
|
|
MockRagService.return_value = mock_rag
|
|
|
|
request_data = {
|
|
"environment": "ift",
|
|
"questions": [{"body": "Test", "with_docs": True}]
|
|
}
|
|
|
|
response = client.post("/api/v1/query/bench", json=request_data)
|
|
|
|
assert response.status_code == 502
|
|
mock_rag.close.assert_called_once()
|
|
|
|
def test_bench_query_settings_not_found(self, client, mock_db_client, test_settings):
|
|
"""Test bench query when environment settings not found."""
|
|
|
|
from app.models.settings import UserSettings
|
|
settings_without_ift = UserSettings(
|
|
user_id="test-user-123",
|
|
settings={
|
|
"psi": test_settings.settings["psi"],
|
|
"prod": test_settings.settings["prod"]
|
|
},
|
|
updated_at="2024-01-01T00:00:00Z"
|
|
)
|
|
mock_db_client.get_user_settings = AsyncMock(return_value=settings_without_ift)
|
|
|
|
request_data = {
|
|
"environment": "ift",
|
|
"questions": [{"body": "Test", "with_docs": True}]
|
|
}
|
|
|
|
response = client.post("/api/v1/query/bench", json=request_data)
|
|
|
|
|
|
assert response.status_code == 500
|
|
|
|
|
|
class TestBackendQueryEndpoint:
|
|
"""Tests for /api/v1/query/backend endpoint."""
|
|
|
|
def test_backend_query_success(self, client, mock_db_client, test_settings, mock_backend_response):
|
|
"""Test successful backend query."""
|
|
|
|
test_settings.settings["ift"].apiMode = "backend"
|
|
mock_db_client.get_user_settings = AsyncMock(return_value=test_settings)
|
|
|
|
with patch('app.api.v1.query.RagService') as MockRagService:
|
|
mock_rag = AsyncMock()
|
|
mock_rag.send_backend_query = AsyncMock(return_value=[mock_backend_response])
|
|
mock_rag.close = AsyncMock()
|
|
MockRagService.return_value = mock_rag
|
|
|
|
request_data = {
|
|
"environment": "ift",
|
|
"questions": [
|
|
{"body": "Test question", "with_docs": True}
|
|
],
|
|
"reset_session": True
|
|
}
|
|
|
|
response = client.post("/api/v1/query/backend", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert "request_id" in data
|
|
assert "timestamp" in data
|
|
assert data["environment"] == "ift"
|
|
assert "response" in data
|
|
assert "answers" in data["response"]
|
|
assert data["response"]["answers"] == [mock_backend_response]
|
|
|
|
mock_rag.send_backend_query.assert_called_once()
|
|
call_kwargs = mock_rag.send_backend_query.call_args[1]
|
|
assert call_kwargs["reset_session"] is True
|
|
|
|
def test_backend_query_wrong_api_mode(self, client, mock_db_client, test_settings):
|
|
"""Test backend query when environment is configured for bench mode."""
|
|
|
|
mock_db_client.get_user_settings = AsyncMock(return_value=test_settings)
|
|
|
|
request_data = {
|
|
"environment": "ift",
|
|
"questions": [{"body": "Test", "with_docs": True}],
|
|
"reset_session": True
|
|
}
|
|
|
|
response = client.post("/api/v1/query/backend", json=request_data)
|
|
|
|
|
|
assert response.status_code in [400, 500]
|
|
if response.status_code == 400:
|
|
assert "not configured for backend mode" in response.json()["detail"].lower()
|
|
|
|
def test_backend_query_invalid_environment(self, client, mock_db_client):
|
|
"""Test backend query with invalid environment."""
|
|
request_data = {
|
|
"environment": "invalid",
|
|
"questions": [{"body": "Test", "with_docs": True}],
|
|
"reset_session": True
|
|
}
|
|
|
|
response = client.post("/api/v1/query/backend", json=request_data)
|
|
|
|
assert response.status_code == 400
|
|
assert "invalid environment" in response.json()["detail"].lower()
|
|
|
|
def test_backend_query_settings_not_found(self, client, mock_db_client, test_settings):
|
|
"""Test backend query when environment settings not found."""
|
|
|
|
from app.models.settings import UserSettings
|
|
test_settings.settings["ift"].apiMode = "backend"
|
|
settings_without_psi = UserSettings(
|
|
user_id="test-user-123",
|
|
settings={
|
|
"ift": test_settings.settings["ift"],
|
|
"prod": test_settings.settings["prod"]
|
|
},
|
|
updated_at="2024-01-01T00:00:00Z"
|
|
)
|
|
mock_db_client.get_user_settings = AsyncMock(return_value=settings_without_psi)
|
|
|
|
request_data = {
|
|
"environment": "psi",
|
|
"questions": [{"body": "Test", "with_docs": True}],
|
|
"reset_session": True
|
|
}
|
|
|
|
response = client.post("/api/v1/query/backend", json=request_data)
|
|
|
|
|
|
assert response.status_code == 500
|
|
|
|
|
|
class TestRagService:
|
|
"""Tests for RagService."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_bench_query_success(self, mock_httpx_client, mock_bench_response):
|
|
"""Test successful bench query via RagService."""
|
|
|
|
mock_httpx_client.post.return_value.json.return_value = mock_bench_response
|
|
|
|
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
|
rag_service = RagService()
|
|
|
|
questions = [
|
|
QuestionRequest(body="Question 1", with_docs=True),
|
|
QuestionRequest(body="Question 2", with_docs=False)
|
|
]
|
|
|
|
user_settings = {
|
|
"bearerToken": "test-token",
|
|
"systemPlatform": "test-platform"
|
|
}
|
|
|
|
result = await rag_service.send_bench_query(
|
|
environment="ift",
|
|
questions=questions,
|
|
user_settings=user_settings,
|
|
request_id="test-request-123"
|
|
)
|
|
|
|
assert result == mock_bench_response
|
|
mock_httpx_client.post.assert_called_once()
|
|
|
|
|
|
call_kwargs = mock_httpx_client.post.call_args[1]
|
|
headers = call_kwargs["headers"]
|
|
assert headers["Request-Id"] == "test-request-123"
|
|
assert headers["Authorization"] == "Bearer test-token"
|
|
assert headers["System-Platform"] == "test-platform"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_backend_query_success(self, mock_httpx_client, mock_backend_response):
|
|
"""Test successful backend query via RagService."""
|
|
|
|
mock_httpx_client.post.return_value.json.return_value = mock_backend_response
|
|
|
|
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
|
rag_service = RagService()
|
|
|
|
questions = [
|
|
QuestionRequest(body="Question 1", with_docs=True)
|
|
]
|
|
|
|
user_settings = {
|
|
"bearerToken": "test-token",
|
|
"platformUserId": "user-123",
|
|
"platformId": "platform-123",
|
|
"withClassify": True,
|
|
"resetSessionMode": True
|
|
}
|
|
|
|
result = await rag_service.send_backend_query(
|
|
environment="ift",
|
|
questions=questions,
|
|
user_settings=user_settings,
|
|
reset_session=True
|
|
)
|
|
|
|
assert result == [mock_backend_response]
|
|
|
|
assert mock_httpx_client.post.call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_backend_query_no_reset(self, mock_httpx_client, mock_backend_response):
|
|
"""Test backend query without session reset."""
|
|
mock_httpx_client.post.return_value.json.return_value = mock_backend_response
|
|
|
|
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
|
rag_service = RagService()
|
|
|
|
questions = [QuestionRequest(body="Question", with_docs=True)]
|
|
user_settings = {"resetSessionMode": False}
|
|
|
|
result = await rag_service.send_backend_query(
|
|
environment="ift",
|
|
questions=questions,
|
|
user_settings=user_settings,
|
|
reset_session=False
|
|
)
|
|
|
|
assert result == [mock_backend_response]
|
|
|
|
assert mock_httpx_client.post.call_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_bench_headers(self):
|
|
"""Test building headers for bench mode."""
|
|
with patch('app.services.rag_service.httpx.AsyncClient'):
|
|
rag_service = RagService()
|
|
|
|
user_settings = {
|
|
"bearerToken": "my-token",
|
|
"systemPlatform": "my-platform"
|
|
}
|
|
|
|
headers = rag_service._build_bench_headers("ift", user_settings, "req-123")
|
|
|
|
assert headers["Request-Id"] == "req-123"
|
|
assert headers["System-Id"] == "brief-bench-ift"
|
|
assert headers["Authorization"] == "Bearer my-token"
|
|
assert headers["System-Platform"] == "my-platform"
|
|
assert headers["Content-Type"] == "application/json"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_build_backend_headers(self):
|
|
"""Test building headers for backend mode."""
|
|
with patch('app.services.rag_service.httpx.AsyncClient'):
|
|
rag_service = RagService()
|
|
|
|
user_settings = {
|
|
"bearerToken": "my-token",
|
|
"platformUserId": "user-456",
|
|
"platformId": "platform-789"
|
|
}
|
|
|
|
headers = rag_service._build_backend_headers(user_settings)
|
|
|
|
assert headers["Authorization"] == "Bearer my-token"
|
|
assert headers["Platform-User-Id"] == "user-456"
|
|
assert headers["Platform-Id"] == "platform-789"
|
|
assert headers["Content-Type"] == "application/json"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_client_with_mtls(self):
|
|
"""Test creating HTTP client with mTLS configuration."""
|
|
with patch('app.services.rag_service.settings') as mock_settings:
|
|
|
|
mock_settings.IFT_RAG_CERT_CERT = "/path/to/client.crt"
|
|
mock_settings.IFT_RAG_CERT_KEY = "/path/to/client.key"
|
|
mock_settings.IFT_RAG_CERT_CA = "/path/to/ca.crt"
|
|
mock_settings.PSI_RAG_CERT_CERT = ""
|
|
mock_settings.PSI_RAG_CERT_KEY = ""
|
|
mock_settings.PSI_RAG_CERT_CA = ""
|
|
mock_settings.PROD_RAG_CERT_CERT = ""
|
|
mock_settings.PROD_RAG_CERT_KEY = ""
|
|
mock_settings.PROD_RAG_CERT_CA = ""
|
|
|
|
with patch('app.services.rag_service.httpx.AsyncClient') as MockAsyncClient:
|
|
service = RagService()
|
|
|
|
|
|
assert MockAsyncClient.call_count == 3
|
|
|
|
|
|
first_call_kwargs = MockAsyncClient.call_args_list[0][1]
|
|
assert first_call_kwargs["cert"] == ("/path/to/client.crt", "/path/to/client.key")
|
|
assert first_call_kwargs["verify"] == "/path/to/ca.crt"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_client_without_mtls(self):
|
|
"""Test creating HTTP client without mTLS."""
|
|
with patch('app.services.rag_service.settings') as mock_settings:
|
|
|
|
mock_settings.IFT_RAG_CERT_CERT = ""
|
|
mock_settings.IFT_RAG_CERT_KEY = ""
|
|
mock_settings.IFT_RAG_CERT_CA = ""
|
|
mock_settings.PSI_RAG_CERT_CERT = ""
|
|
mock_settings.PSI_RAG_CERT_KEY = ""
|
|
mock_settings.PSI_RAG_CERT_CA = ""
|
|
mock_settings.PROD_RAG_CERT_CERT = ""
|
|
mock_settings.PROD_RAG_CERT_KEY = ""
|
|
mock_settings.PROD_RAG_CERT_CA = ""
|
|
|
|
with patch('app.services.rag_service.httpx.AsyncClient') as MockAsyncClient:
|
|
service = RagService()
|
|
|
|
|
|
assert MockAsyncClient.call_count == 3
|
|
|
|
|
|
for call in MockAsyncClient.call_args_list:
|
|
call_kwargs = call[1]
|
|
assert call_kwargs["cert"] is None
|
|
assert call_kwargs["verify"] is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_bench_query_http_error(self, mock_httpx_client):
|
|
"""Test bench query with HTTP error."""
|
|
|
|
error_response = MagicMock()
|
|
error_response.status_code = 500
|
|
error_response.text = "Internal Server Error"
|
|
|
|
mock_httpx_client.post.side_effect = httpx.HTTPStatusError(
|
|
"Server error",
|
|
request=None,
|
|
response=error_response
|
|
)
|
|
|
|
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
|
rag_service = RagService()
|
|
|
|
questions = [QuestionRequest(body="Test", with_docs=True)]
|
|
user_settings = {}
|
|
|
|
with pytest.raises(httpx.HTTPStatusError):
|
|
await rag_service.send_bench_query(
|
|
environment="ift",
|
|
questions=questions,
|
|
user_settings=user_settings
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_backend_query_http_error(self, mock_httpx_client):
|
|
"""Test backend query with HTTP error on ask endpoint."""
|
|
error_response = MagicMock()
|
|
error_response.status_code = 503
|
|
error_response.text = "Service Unavailable"
|
|
|
|
mock_httpx_client.post.side_effect = httpx.HTTPStatusError(
|
|
"Service error",
|
|
request=None,
|
|
response=error_response
|
|
)
|
|
|
|
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
|
rag_service = RagService()
|
|
|
|
questions = [QuestionRequest(body="Test", with_docs=True)]
|
|
user_settings = {"resetSessionMode": False}
|
|
|
|
with pytest.raises(httpx.HTTPStatusError):
|
|
await rag_service.send_backend_query(
|
|
environment="ift",
|
|
questions=questions,
|
|
user_settings=user_settings,
|
|
reset_session=False
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_base_url(self):
|
|
"""Test building base URL for environment."""
|
|
with patch('app.services.rag_service.httpx.AsyncClient'):
|
|
with patch('app.services.rag_service.settings') as mock_settings:
|
|
mock_settings.IFT_RAG_HOST = "rag-ift.example.com"
|
|
mock_settings.IFT_RAG_PORT = 8443
|
|
|
|
service = RagService()
|
|
url = service._get_base_url("ift")
|
|
|
|
assert url == "https://rag-ift.example.com:8443"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_clients(self, mock_httpx_client):
|
|
"""Test closing all HTTP clients."""
|
|
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
|
service = RagService()
|
|
|
|
await service.close()
|
|
|
|
|
|
assert mock_httpx_client.aclose.call_count == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_context_manager(self, mock_httpx_client):
|
|
"""Test using RagService as async context manager."""
|
|
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
|
async with RagService() as service:
|
|
assert service is not None
|
|
|
|
|
|
assert mock_httpx_client.aclose.call_count == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_bench_query_general_exception(self, mock_httpx_client):
|
|
"""Test bench query with general exception (not HTTP error)."""
|
|
mock_httpx_client.post.side_effect = Exception("Network error")
|
|
|
|
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
|
rag_service = RagService()
|
|
|
|
questions = [QuestionRequest(body="Test", with_docs=True)]
|
|
user_settings = {}
|
|
|
|
with pytest.raises(Exception) as exc_info:
|
|
await rag_service.send_bench_query(
|
|
environment="ift",
|
|
questions=questions,
|
|
user_settings=user_settings
|
|
)
|
|
|
|
assert "Network error" in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_backend_query_general_exception(self, mock_httpx_client):
|
|
"""Test backend query with general exception (not HTTP error)."""
|
|
mock_httpx_client.post.side_effect = Exception("Connection timeout")
|
|
|
|
with patch('app.services.rag_service.httpx.AsyncClient', return_value=mock_httpx_client):
|
|
rag_service = RagService()
|
|
|
|
questions = [QuestionRequest(body="Test", with_docs=True)]
|
|
user_settings = {"resetSessionMode": False}
|
|
|
|
with pytest.raises(Exception) as exc_info:
|
|
await rag_service.send_backend_query(
|
|
environment="ift",
|
|
questions=questions,
|
|
user_settings=user_settings,
|
|
reset_session=False
|
|
)
|
|
|
|
assert "Connection timeout" in str(exc_info.value)
|