FastAPI application with: - CORS middleware (permissive for dev) - /health endpoint checking DB and Redis connectivity - /ws WebSocket endpoint with ConnectionManager for real-time updates - Async lifespan hooks for DB engine and Redis init/teardown - get_db dependency for session management - Dynamic router mounting that silently skips missing router modules - 10 tests covering all endpoints and utilities
129 lines
3.8 KiB
Python
129 lines
3.8 KiB
Python
"""Tests for backend/main.py — FastAPI application."""
|
|
|
|
import os
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _isolate_settings(tmp_path):
|
|
"""Ensure tests use a temp SQLite DB and no Redis."""
|
|
env = {
|
|
"DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}",
|
|
"REDIS_URL": "",
|
|
"DATA_DIR": str(tmp_path),
|
|
}
|
|
with patch.dict(os.environ, env, clear=False):
|
|
# Reload settings so it picks up test env
|
|
import config
|
|
new_settings = config.Settings(_env_file=None)
|
|
config.settings = new_settings
|
|
|
|
# Patch main's reference too
|
|
import main
|
|
main.settings = new_settings
|
|
main._init_db()
|
|
main._init_redis()
|
|
|
|
# Create tables
|
|
from models import Base
|
|
Base.metadata.create_all(bind=main.engine)
|
|
|
|
yield
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
from main import app
|
|
return TestClient(app)
|
|
|
|
|
|
class TestHealthEndpoint:
|
|
def test_health_returns_ok(self, client):
|
|
resp = client.get("/health")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["status"] == "ok"
|
|
assert data["database"] is True
|
|
assert data["redis"] is True # in-process mode counts as ok
|
|
|
|
def test_health_response_schema(self, client):
|
|
resp = client.get("/health")
|
|
data = resp.json()
|
|
assert set(data.keys()) == {"status", "database", "redis"}
|
|
|
|
|
|
class TestCORSMiddleware:
|
|
def test_cors_headers_present(self, client):
|
|
resp = client.options(
|
|
"/health",
|
|
headers={
|
|
"Origin": "http://localhost:3000",
|
|
"Access-Control-Request-Method": "GET",
|
|
},
|
|
)
|
|
assert "access-control-allow-origin" in resp.headers
|
|
|
|
|
|
class TestWebSocket:
|
|
def test_websocket_connect_and_echo(self, client):
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_json({"type": "ping"})
|
|
data = ws.receive_json()
|
|
assert data["type"] == "ack"
|
|
assert data["data"]["type"] == "ping"
|
|
|
|
def test_websocket_disconnect_cleanup(self, client):
|
|
from main import ws_manager
|
|
initial_count = len(ws_manager.active_connections)
|
|
with client.websocket_connect("/ws") as ws:
|
|
assert len(ws_manager.active_connections) == initial_count + 1
|
|
# After disconnect, connection should be removed
|
|
assert len(ws_manager.active_connections) == initial_count
|
|
|
|
|
|
class TestRouterMounting:
|
|
def test_openapi_schema_loads(self, client):
|
|
resp = client.get("/openapi.json")
|
|
assert resp.status_code == 200
|
|
schema = resp.json()
|
|
assert schema["info"]["title"] == "PromptLooper"
|
|
|
|
def test_unknown_route_returns_404(self, client):
|
|
resp = client.get("/api/nonexistent")
|
|
assert resp.status_code == 404
|
|
|
|
|
|
class TestConnectionManager:
|
|
def test_broadcast_removes_dead_connections(self):
|
|
"""ConnectionManager.broadcast skips and removes broken connections."""
|
|
from main import ConnectionManager
|
|
manager = ConnectionManager()
|
|
# No connections — broadcast should not raise
|
|
import asyncio
|
|
asyncio.get_event_loop().run_until_complete(
|
|
manager.broadcast({"test": True})
|
|
)
|
|
assert len(manager.active_connections) == 0
|
|
|
|
|
|
class TestGetDb:
|
|
def test_get_db_yields_session(self):
|
|
from main import get_db
|
|
gen = get_db()
|
|
session = next(gen)
|
|
assert session is not None
|
|
# Clean up
|
|
try:
|
|
next(gen)
|
|
except StopIteration:
|
|
pass
|
|
|
|
|
|
class TestGetRedis:
|
|
def test_get_redis_returns_none_in_process_mode(self):
|
|
from main import get_redis
|
|
# In test setup, Redis is not configured
|
|
assert get_redis() is None
|