dataloader/tests/conftest.py

180 lines
6.8 KiB
Python

# tests/conftest.py
import asyncio
import sys
from typing import AsyncGenerator
import pytest
from dotenv import load_dotenv
from httpx import AsyncClient, ASGITransport
from sqlalchemy import text
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.exc import ProgrammingError
from asyncpg.exceptions import InvalidCatalogNameError
# Load .env before other imports to ensure config is available
load_dotenv()
from dataloader.api import app_main
from dataloader.api.v1.router import get_service
from dataloader.api.v1.service import JobsService
from dataloader.context import APP_CTX
from dataloader.storage.db import Base
from dataloader.config import APP_CONFIG
# For Windows: use SelectorEventLoop which is more stable
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@pytest.fixture(scope="session")
async def db_engine() -> AsyncGenerator[AsyncEngine, None]:
"""
Creates a temporary, isolated test database for the entire test session.
- Connects to the default 'postgres' database to create a new test DB.
- Yields an engine connected to the new test DB.
- Drops the test DB after the session is complete.
"""
pg_settings = APP_CONFIG.pg
test_db_name = f"{pg_settings.database}_test"
# DSN for connecting to 'postgres' DB to manage the test DB
system_dsn = pg_settings.url.replace(f"/{pg_settings.database}", "/postgres")
system_engine = create_async_engine(system_dsn, isolation_level="AUTOCOMMIT")
# DSN for the new test database
test_dsn = pg_settings.url.replace(f"/{pg_settings.database}", f"/{test_db_name}")
async with system_engine.connect() as conn:
# Drop the test DB if it exists from a previous failed run
await conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name} WITH (FORCE)"))
await conn.execute(text(f"CREATE DATABASE {test_db_name}"))
# Create an engine connected to the new test database
engine = create_async_engine(test_dsn)
# Create all tables and DDL objects
async with engine.begin() as conn:
# Execute DDL statements one by one
await conn.execute(text("CREATE TYPE dl_status AS ENUM ('queued','running','succeeded','failed','canceled','lost')"))
# Create tables
await conn.run_sync(Base.metadata.create_all)
# Create indexes
await conn.execute(text("CREATE INDEX ix_dl_jobs_claim ON dl_jobs(queue, available_at, priority, created_at) WHERE status = 'queued'"))
await conn.execute(text("CREATE INDEX ix_dl_jobs_running_lease ON dl_jobs(lease_expires_at) WHERE status = 'running'"))
await conn.execute(text("CREATE INDEX ix_dl_jobs_status_queue ON dl_jobs(status, queue)"))
# Create function and triggers
await conn.execute(
text(
"""
CREATE OR REPLACE FUNCTION notify_job_ready() RETURNS trigger AS $$
BEGIN
IF (TG_OP = 'INSERT') THEN
PERFORM pg_notify('dl_jobs', NEW.queue);
RETURN NEW;
ELSIF (TG_OP = 'UPDATE') THEN
IF NEW.status = 'queued' AND NEW.available_at <= now()
AND (OLD.status IS DISTINCT FROM NEW.status OR OLD.available_at IS DISTINCT FROM NEW.available_at) THEN
PERFORM pg_notify('dl_jobs', NEW.queue);
END IF;
RETURN NEW;
END IF;
RETURN NEW;
END $$ LANGUAGE plpgsql;
"""
)
)
await conn.execute(text("CREATE TRIGGER dl_jobs_notify_ins AFTER INSERT ON dl_jobs FOR EACH ROW EXECUTE FUNCTION notify_job_ready()"))
await conn.execute(text("CREATE TRIGGER dl_jobs_notify_upd AFTER UPDATE OF status, available_at ON dl_jobs FOR EACH ROW EXECUTE FUNCTION notify_job_ready()"))
yield engine
# --- Teardown ---
await engine.dispose()
# Drop the test database
async with system_engine.connect() as conn:
await conn.execute(text(f"DROP DATABASE {test_db_name} WITH (FORCE)"))
await system_engine.dispose()
@pytest.fixture(scope="session")
async def app_context(db_engine: AsyncEngine) -> AsyncGenerator[None, None]:
"""
Overrides the global APP_CTX with the test database engine and sessionmaker.
This ensures that the app, when tested, uses the isolated test DB.
"""
original_engine = APP_CTX._engine
original_sessionmaker = APP_CTX._sessionmaker
APP_CTX._engine = db_engine
APP_CTX._sessionmaker = async_sessionmaker(db_engine, expire_on_commit=False, class_=AsyncSession)
yield
# Restore original context
APP_CTX._engine = original_engine
APP_CTX._sessionmaker = original_sessionmaker
@pytest.fixture
async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
"""
Provides a transactional session for tests.
A single connection is acquired from the engine, a transaction is started,
and a new session is created bound to that connection. At the end of the test,
the transaction is rolled back and the connection is closed, ensuring
perfect test isolation.
"""
connection = await db_engine.connect()
trans = await connection.begin()
# Create a sessionmaker bound to the single connection
TestSession = async_sessionmaker(
bind=connection,
expire_on_commit=False,
class_=AsyncSession
)
session = TestSession()
try:
yield session
finally:
# Clean up the session, rollback the transaction, and close the connection
await session.close()
await trans.rollback()
await connection.close()
@pytest.fixture
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
"""
Provides an HTTP client for the FastAPI application.
It overrides the 'get_session' dependency to use the test-scoped,
transactional session provided by the 'db_session' fixture. This ensures
that API calls operate within the same transaction as the test function,
allowing for consistent state checks.
"""
async def override_get_session() -> AsyncGenerator[AsyncSession, None]:
"""This override simply yields the session created by the db_session fixture."""
yield db_session
# The service depends on the session, so we override the session provider
# that is used by the service via Depends(get_session)
from dataloader.context import get_session
app_main.dependency_overrides[get_session] = override_get_session
transport = ASGITransport(app=app_main)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# Clean up the dependency override
app_main.dependency_overrides.clear()