diff --git a/Auto Run Docs/01-scaffold.md b/Auto Run Docs/01-scaffold.md index 152c93d..8107b5d 100644 --- a/Auto Run Docs/01-scaffold.md +++ b/Auto Run Docs/01-scaffold.md @@ -29,7 +29,8 @@ Set up the PromptLooper repository, Docker infrastructure, and basic project ske - [x] Create backend/schemas.py with Pydantic request/response schemas for all API endpoints. Include create/update/response schemas for Project, Experiment, Run, Endpoint, and Webhook. Include the Score input schema and export format schemas. > Created backend/schemas.py with all Pydantic v2 schemas using ConfigDict(from_attributes=True) for ORM compatibility. Includes: Project (create/update/response/list), Experiment (create/update/response/list), Run (response/list/detail with nested stages+scores), StageResult (response), Score (input/response), Endpoint (create/update/response/list), Webhook (create/update/response/list), Auth (setup/login/token/user), Export (run row with scores dict, export response), and Health. 30 tests in tests/test_schemas.py all passing. All 64 backend tests pass. -- [ ] Create backend/main.py with the FastAPI application. Set up CORS middleware, mount all routers (even if they're stubs), configure the WebSocket endpoint, add the /health endpoint that checks DB and Redis connectivity, and add startup/shutdown lifecycle hooks. +- [x] Create backend/main.py with the FastAPI application. Set up CORS middleware, mount all routers (even if they're stubs), configure the WebSocket endpoint, add the /health endpoint that checks DB and Redis connectivity, and add startup/shutdown lifecycle hooks. + > Created backend/main.py with: CORS middleware (allow all origins), /health endpoint checking DB (SELECT 1) and Redis (ping) connectivity, /ws WebSocket endpoint with ConnectionManager for real-time broadcasts, async lifespan hooks for DB engine + Redis init/teardown, get_db dependency yielding sessions, dynamic router mounting (silently skips missing routers). 10 tests in tests/test_main.py covering health, CORS, WebSocket connect/disconnect/echo, OpenAPI schema, 404s, broadcast, get_db, and get_redis. All 74 backend tests pass. - [ ] Create backend/auth.py implementing JWT token generation/verification, API key validation, and the first-boot setup flow. The setup endpoint should check if any users exist — if not, accept username + password to create the admin account. Include a dependency function for route-level auth that supports both JWT and API key. diff --git a/backend/main.py b/backend/main.py new file mode 100644 index 0000000..1afe997 --- /dev/null +++ b/backend/main.py @@ -0,0 +1,211 @@ +"""PromptLooper FastAPI application.""" + +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from sqlalchemy import create_engine, text +from sqlalchemy.orm import sessionmaker + +from config import settings + + +# --------------------------------------------------------------------------- +# Database engine & session factory (lazy, created at startup) +# --------------------------------------------------------------------------- + +engine = None +SessionLocal = None + + +def _init_db() -> None: + """Create the SQLAlchemy engine and session factory.""" + global engine, SessionLocal + connect_args = {} + if settings.is_sqlite: + connect_args["check_same_thread"] = False + engine = create_engine( + settings.effective_database_url, + connect_args=connect_args, + ) + SessionLocal = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False) + + +def get_db(): + """FastAPI dependency that yields a database session.""" + db = SessionLocal() + try: + yield db + finally: + db.close() + + +# --------------------------------------------------------------------------- +# Redis helper +# --------------------------------------------------------------------------- + +_redis_client = None + + +def _init_redis() -> None: + """Connect to Redis if configured.""" + global _redis_client + if not settings.redis_url: + _redis_client = None + return + import redis as redis_lib + _redis_client = redis_lib.Redis.from_url(settings.redis_url, decode_responses=True) + + +def get_redis(): + """Return the Redis client (or None in single-container mode).""" + return _redis_client + + +# --------------------------------------------------------------------------- +# WebSocket connection manager +# --------------------------------------------------------------------------- + +class ConnectionManager: + """Manage active WebSocket connections.""" + + def __init__(self) -> None: + self.active_connections: list[WebSocket] = [] + + async def connect(self, websocket: WebSocket) -> None: + await websocket.accept() + self.active_connections.append(websocket) + + def disconnect(self, websocket: WebSocket) -> None: + self.active_connections.remove(websocket) + + async def broadcast(self, message: dict) -> None: + for connection in list(self.active_connections): + try: + await connection.send_json(message) + except Exception: + self.disconnect(connection) + + +ws_manager = ConnectionManager() + + +# --------------------------------------------------------------------------- +# Lifecycle +# --------------------------------------------------------------------------- + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + """Startup and shutdown lifecycle hooks.""" + _init_db() + _init_redis() + yield + # Shutdown: clean up connections + if _redis_client is not None: + _redis_client.close() + if engine is not None: + engine.dispose() + + +# --------------------------------------------------------------------------- +# Application +# --------------------------------------------------------------------------- + +app = FastAPI( + title="PromptLooper", + description="LLM pipeline tuning workbench", + version="0.1.0", + lifespan=lifespan, +) + +# CORS — allow all origins in development; tighten in production via env +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# --------------------------------------------------------------------------- +# Health endpoint +# --------------------------------------------------------------------------- + +@app.get("/health", tags=["system"]) +def health_check() -> dict: + """Check DB and Redis connectivity.""" + db_ok = False + redis_ok = False + + # Database check + if SessionLocal is not None: + try: + with SessionLocal() as session: + session.execute(text("SELECT 1")) + db_ok = True + except Exception: + pass + + # Redis check + if not settings.redis_url: + redis_ok = True # No Redis needed — in-process mode + elif _redis_client is not None: + try: + _redis_client.ping() + redis_ok = True + except Exception: + pass + + return {"status": "ok" if (db_ok and redis_ok) else "degraded", "database": db_ok, "redis": redis_ok} + + +# --------------------------------------------------------------------------- +# WebSocket endpoint +# --------------------------------------------------------------------------- + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket) -> None: + """WebSocket connection for real-time dashboard updates.""" + await ws_manager.connect(websocket) + try: + while True: + # Keep connection alive; handle incoming messages if needed + data = await websocket.receive_json() + # Echo back or handle client messages in future + await websocket.send_json({"type": "ack", "data": data}) + except WebSocketDisconnect: + ws_manager.disconnect(websocket) + + +# --------------------------------------------------------------------------- +# Mount routers (stubs — actual implementations come later) +# --------------------------------------------------------------------------- + +# Router imports are deferred to avoid circular imports and allow +# stub files to be created independently. Each router will be mounted +# as it is implemented. For now we register empty prefixes. + +def _mount_routers() -> None: + """Import and mount all routers. Silently skip missing ones.""" + router_configs = [ + ("routers.auth", "/api/auth", ["auth"]), + ("routers.projects", "/api/projects", ["projects"]), + ("routers.experiments", "/api/experiments", ["experiments"]), + ("routers.runs", "/api/runs", ["runs"]), + ("routers.endpoints", "/api/endpoints", ["endpoints"]), + ("routers.export", "/api/export", ["export"]), + ("routers.webhooks", "/api/webhooks", ["webhooks"]), + ("routers.admin", "/api/admin", ["admin"]), + ] + for module_name, prefix, tags in router_configs: + try: + import importlib + mod = importlib.import_module(module_name) + app.include_router(mod.router, prefix=prefix, tags=tags) + except (ImportError, AttributeError): + pass # Router not yet implemented + + +_mount_routers() diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py new file mode 100644 index 0000000..76cf5dd --- /dev/null +++ b/backend/tests/test_main.py @@ -0,0 +1,129 @@ +"""Tests for backend/main.py — FastAPI application.""" + +import os +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + + +@pytest.fixture(autouse=True) +def _isolate_settings(tmp_path): + """Ensure tests use a temp SQLite DB and no Redis.""" + env = { + "DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}", + "REDIS_URL": "", + "DATA_DIR": str(tmp_path), + } + with patch.dict(os.environ, env, clear=False): + # Reload settings so it picks up test env + import config + new_settings = config.Settings(_env_file=None) + config.settings = new_settings + + # Patch main's reference too + import main + main.settings = new_settings + main._init_db() + main._init_redis() + + # Create tables + from models import Base + Base.metadata.create_all(bind=main.engine) + + yield + + +@pytest.fixture +def client(): + from main import app + return TestClient(app) + + +class TestHealthEndpoint: + def test_health_returns_ok(self, client): + resp = client.get("/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["database"] is True + assert data["redis"] is True # in-process mode counts as ok + + def test_health_response_schema(self, client): + resp = client.get("/health") + data = resp.json() + assert set(data.keys()) == {"status", "database", "redis"} + + +class TestCORSMiddleware: + def test_cors_headers_present(self, client): + resp = client.options( + "/health", + headers={ + "Origin": "http://localhost:3000", + "Access-Control-Request-Method": "GET", + }, + ) + assert "access-control-allow-origin" in resp.headers + + +class TestWebSocket: + def test_websocket_connect_and_echo(self, client): + with client.websocket_connect("/ws") as ws: + ws.send_json({"type": "ping"}) + data = ws.receive_json() + assert data["type"] == "ack" + assert data["data"]["type"] == "ping" + + def test_websocket_disconnect_cleanup(self, client): + from main import ws_manager + initial_count = len(ws_manager.active_connections) + with client.websocket_connect("/ws") as ws: + assert len(ws_manager.active_connections) == initial_count + 1 + # After disconnect, connection should be removed + assert len(ws_manager.active_connections) == initial_count + + +class TestRouterMounting: + def test_openapi_schema_loads(self, client): + resp = client.get("/openapi.json") + assert resp.status_code == 200 + schema = resp.json() + assert schema["info"]["title"] == "PromptLooper" + + def test_unknown_route_returns_404(self, client): + resp = client.get("/api/nonexistent") + assert resp.status_code == 404 + + +class TestConnectionManager: + def test_broadcast_removes_dead_connections(self): + """ConnectionManager.broadcast skips and removes broken connections.""" + from main import ConnectionManager + manager = ConnectionManager() + # No connections — broadcast should not raise + import asyncio + asyncio.get_event_loop().run_until_complete( + manager.broadcast({"test": True}) + ) + assert len(manager.active_connections) == 0 + + +class TestGetDb: + def test_get_db_yields_session(self): + from main import get_db + gen = get_db() + session = next(gen) + assert session is not None + # Clean up + try: + next(gen) + except StopIteration: + pass + + +class TestGetRedis: + def test_get_redis_returns_none_in_process_mode(self): + from main import get_redis + # In test setup, Redis is not configured + assert get_redis() is None