"""Tests for backend/main.py — FastAPI application.""" import os from unittest.mock import patch import pytest from fastapi.testclient import TestClient @pytest.fixture(autouse=True) def _isolate_settings(tmp_path): """Ensure tests use a temp SQLite DB and no Redis.""" env = { "DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}", "REDIS_URL": "", "DATA_DIR": str(tmp_path), } with patch.dict(os.environ, env, clear=False): # Reload settings so it picks up test env import config new_settings = config.Settings(_env_file=None) config.settings = new_settings # Patch main's reference too import main main.settings = new_settings main._init_db() main._init_redis() # Create tables from models import Base Base.metadata.create_all(bind=main.engine) yield @pytest.fixture def client(): from main import app return TestClient(app) class TestHealthEndpoint: def test_health_returns_ok(self, client): resp = client.get("/health") assert resp.status_code == 200 data = resp.json() assert data["status"] == "ok" assert data["database"] is True assert data["redis"] is True # in-process mode counts as ok def test_health_response_schema(self, client): resp = client.get("/health") data = resp.json() assert set(data.keys()) == {"status", "database", "redis"} class TestCORSMiddleware: def test_cors_headers_present(self, client): resp = client.options( "/health", headers={ "Origin": "http://localhost:3000", "Access-Control-Request-Method": "GET", }, ) assert "access-control-allow-origin" in resp.headers class TestWebSocket: def test_websocket_connect_and_echo(self, client): with client.websocket_connect("/ws") as ws: ws.send_json({"type": "ping"}) data = ws.receive_json() assert data["type"] == "ack" assert data["data"]["type"] == "ping" def test_websocket_disconnect_cleanup(self, client): from main import ws_manager initial_count = len(ws_manager.active_connections) with client.websocket_connect("/ws") as ws: assert len(ws_manager.active_connections) == initial_count + 1 # After disconnect, connection should be removed assert len(ws_manager.active_connections) == initial_count class TestRouterMounting: def test_openapi_schema_loads(self, client): resp = client.get("/openapi.json") assert resp.status_code == 200 schema = resp.json() assert schema["info"]["title"] == "PromptLooper" def test_unknown_route_returns_404(self, client): resp = client.get("/api/nonexistent") assert resp.status_code == 404 class TestConnectionManager: def test_broadcast_removes_dead_connections(self): """ConnectionManager.broadcast skips and removes broken connections.""" from main import ConnectionManager manager = ConnectionManager() # No connections — broadcast should not raise import asyncio asyncio.get_event_loop().run_until_complete( manager.broadcast({"test": True}) ) assert len(manager.active_connections) == 0 class TestGetDb: def test_get_db_yields_session(self): from main import get_db gen = get_db() session = next(gen) assert session is not None # Clean up try: next(gen) except StopIteration: pass class TestGetRedis: def test_get_redis_returns_none_in_process_mode(self): from main import get_redis # In test setup, Redis is not configured assert get_redis() is None