180 lines
6.8 KiB
Python
180 lines
6.8 KiB
Python
# tests/conftest.py
|
|
import asyncio
|
|
import sys
|
|
from typing import AsyncGenerator
|
|
|
|
import pytest
|
|
from dotenv import load_dotenv
|
|
from httpx import AsyncClient, ASGITransport
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
async_sessionmaker,
|
|
create_async_engine,
|
|
)
|
|
from sqlalchemy.exc import ProgrammingError
|
|
from asyncpg.exceptions import InvalidCatalogNameError
|
|
|
|
# Load .env before other imports to ensure config is available
|
|
load_dotenv()
|
|
|
|
from dataloader.api import app_main
|
|
from dataloader.api.v1.router import get_service
|
|
from dataloader.api.v1.service import JobsService
|
|
from dataloader.context import APP_CTX
|
|
from dataloader.storage.db import Base
|
|
from dataloader.config import APP_CONFIG
|
|
|
|
# For Windows: use SelectorEventLoop which is more stable
|
|
if sys.platform == "win32":
|
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
async def db_engine() -> AsyncGenerator[AsyncEngine, None]:
|
|
"""
|
|
Creates a temporary, isolated test database for the entire test session.
|
|
- Connects to the default 'postgres' database to create a new test DB.
|
|
- Yields an engine connected to the new test DB.
|
|
- Drops the test DB after the session is complete.
|
|
"""
|
|
pg_settings = APP_CONFIG.pg
|
|
test_db_name = f"{pg_settings.database}_test"
|
|
|
|
# DSN for connecting to 'postgres' DB to manage the test DB
|
|
system_dsn = pg_settings.url.replace(f"/{pg_settings.database}", "/postgres")
|
|
system_engine = create_async_engine(system_dsn, isolation_level="AUTOCOMMIT")
|
|
|
|
# DSN for the new test database
|
|
test_dsn = pg_settings.url.replace(f"/{pg_settings.database}", f"/{test_db_name}")
|
|
|
|
async with system_engine.connect() as conn:
|
|
# Drop the test DB if it exists from a previous failed run
|
|
await conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name} WITH (FORCE)"))
|
|
await conn.execute(text(f"CREATE DATABASE {test_db_name}"))
|
|
|
|
# Create an engine connected to the new test database
|
|
engine = create_async_engine(test_dsn)
|
|
|
|
# Create all tables and DDL objects
|
|
async with engine.begin() as conn:
|
|
# Execute DDL statements one by one
|
|
await conn.execute(text("CREATE TYPE dl_status AS ENUM ('queued','running','succeeded','failed','canceled','lost')"))
|
|
|
|
# Create tables
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
# Create indexes
|
|
await conn.execute(text("CREATE INDEX ix_dl_jobs_claim ON dl_jobs(queue, available_at, priority, created_at) WHERE status = 'queued'"))
|
|
await conn.execute(text("CREATE INDEX ix_dl_jobs_running_lease ON dl_jobs(lease_expires_at) WHERE status = 'running'"))
|
|
await conn.execute(text("CREATE INDEX ix_dl_jobs_status_queue ON dl_jobs(status, queue)"))
|
|
|
|
# Create function and triggers
|
|
await conn.execute(
|
|
text(
|
|
"""
|
|
CREATE OR REPLACE FUNCTION notify_job_ready() RETURNS trigger AS $$
|
|
BEGIN
|
|
IF (TG_OP = 'INSERT') THEN
|
|
PERFORM pg_notify('dl_jobs', NEW.queue);
|
|
RETURN NEW;
|
|
ELSIF (TG_OP = 'UPDATE') THEN
|
|
IF NEW.status = 'queued' AND NEW.available_at <= now()
|
|
AND (OLD.status IS DISTINCT FROM NEW.status OR OLD.available_at IS DISTINCT FROM NEW.available_at) THEN
|
|
PERFORM pg_notify('dl_jobs', NEW.queue);
|
|
END IF;
|
|
RETURN NEW;
|
|
END IF;
|
|
RETURN NEW;
|
|
END $$ LANGUAGE plpgsql;
|
|
"""
|
|
)
|
|
)
|
|
await conn.execute(text("CREATE TRIGGER dl_jobs_notify_ins AFTER INSERT ON dl_jobs FOR EACH ROW EXECUTE FUNCTION notify_job_ready()"))
|
|
await conn.execute(text("CREATE TRIGGER dl_jobs_notify_upd AFTER UPDATE OF status, available_at ON dl_jobs FOR EACH ROW EXECUTE FUNCTION notify_job_ready()"))
|
|
|
|
yield engine
|
|
|
|
# --- Teardown ---
|
|
await engine.dispose()
|
|
|
|
# Drop the test database
|
|
async with system_engine.connect() as conn:
|
|
await conn.execute(text(f"DROP DATABASE {test_db_name} WITH (FORCE)"))
|
|
await system_engine.dispose()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
async def app_context(db_engine: AsyncEngine) -> AsyncGenerator[None, None]:
|
|
"""
|
|
Overrides the global APP_CTX with the test database engine and sessionmaker.
|
|
This ensures that the app, when tested, uses the isolated test DB.
|
|
"""
|
|
original_engine = APP_CTX._engine
|
|
original_sessionmaker = APP_CTX._sessionmaker
|
|
|
|
APP_CTX._engine = db_engine
|
|
APP_CTX._sessionmaker = async_sessionmaker(db_engine, expire_on_commit=False, class_=AsyncSession)
|
|
|
|
yield
|
|
|
|
# Restore original context
|
|
APP_CTX._engine = original_engine
|
|
APP_CTX._sessionmaker = original_sessionmaker
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
|
|
"""
|
|
Provides a transactional session for tests.
|
|
A single connection is acquired from the engine, a transaction is started,
|
|
and a new session is created bound to that connection. At the end of the test,
|
|
the transaction is rolled back and the connection is closed, ensuring
|
|
perfect test isolation.
|
|
"""
|
|
connection = await db_engine.connect()
|
|
trans = await connection.begin()
|
|
|
|
# Create a sessionmaker bound to the single connection
|
|
TestSession = async_sessionmaker(
|
|
bind=connection,
|
|
expire_on_commit=False,
|
|
class_=AsyncSession
|
|
)
|
|
session = TestSession()
|
|
|
|
try:
|
|
yield session
|
|
finally:
|
|
# Clean up the session, rollback the transaction, and close the connection
|
|
await session.close()
|
|
await trans.rollback()
|
|
await connection.close()
|
|
|
|
|
|
@pytest.fixture
|
|
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
|
"""
|
|
Provides an HTTP client for the FastAPI application.
|
|
It overrides the 'get_session' dependency to use the test-scoped,
|
|
transactional session provided by the 'db_session' fixture. This ensures
|
|
that API calls operate within the same transaction as the test function,
|
|
allowing for consistent state checks.
|
|
"""
|
|
async def override_get_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""This override simply yields the session created by the db_session fixture."""
|
|
yield db_session
|
|
|
|
# The service depends on the session, so we override the session provider
|
|
# that is used by the service via Depends(get_session)
|
|
from dataloader.context import get_session
|
|
app_main.dependency_overrides[get_session] = override_get_session
|
|
|
|
transport = ASGITransport(app=app_main)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
|
yield c
|
|
|
|
# Clean up the dependency override
|
|
app_main.dependency_overrides.clear()
|