MAESTRO: Create backend/main.py with FastAPI app, CORS, health check, WebSocket, and router mounting
FastAPI application with: - CORS middleware (permissive for dev) - /health endpoint checking DB and Redis connectivity - /ws WebSocket endpoint with ConnectionManager for real-time updates - Async lifespan hooks for DB engine and Redis init/teardown - get_db dependency for session management - Dynamic router mounting that silently skips missing router modules - 10 tests covering all endpoints and utilities
This commit is contained in:
parent
42668eeeb1
commit
15ca2c922a
3 changed files with 342 additions and 1 deletions
|
|
@ -29,7 +29,8 @@ Set up the PromptLooper repository, Docker infrastructure, and basic project ske
|
|||
- [x] Create backend/schemas.py with Pydantic request/response schemas for all API endpoints. Include create/update/response schemas for Project, Experiment, Run, Endpoint, and Webhook. Include the Score input schema and export format schemas.
|
||||
> Created backend/schemas.py with all Pydantic v2 schemas using ConfigDict(from_attributes=True) for ORM compatibility. Includes: Project (create/update/response/list), Experiment (create/update/response/list), Run (response/list/detail with nested stages+scores), StageResult (response), Score (input/response), Endpoint (create/update/response/list), Webhook (create/update/response/list), Auth (setup/login/token/user), Export (run row with scores dict, export response), and Health. 30 tests in tests/test_schemas.py all passing. All 64 backend tests pass.
|
||||
|
||||
- [ ] Create backend/main.py with the FastAPI application. Set up CORS middleware, mount all routers (even if they're stubs), configure the WebSocket endpoint, add the /health endpoint that checks DB and Redis connectivity, and add startup/shutdown lifecycle hooks.
|
||||
- [x] Create backend/main.py with the FastAPI application. Set up CORS middleware, mount all routers (even if they're stubs), configure the WebSocket endpoint, add the /health endpoint that checks DB and Redis connectivity, and add startup/shutdown lifecycle hooks.
|
||||
> Created backend/main.py with: CORS middleware (allow all origins), /health endpoint checking DB (SELECT 1) and Redis (ping) connectivity, /ws WebSocket endpoint with ConnectionManager for real-time broadcasts, async lifespan hooks for DB engine + Redis init/teardown, get_db dependency yielding sessions, dynamic router mounting (silently skips missing routers). 10 tests in tests/test_main.py covering health, CORS, WebSocket connect/disconnect/echo, OpenAPI schema, 404s, broadcast, get_db, and get_redis. All 74 backend tests pass.
|
||||
|
||||
- [ ] Create backend/auth.py implementing JWT token generation/verification, API key validation, and the first-boot setup flow. The setup endpoint should check if any users exist — if not, accept username + password to create the admin account. Include a dependency function for route-level auth that supports both JWT and API key.
|
||||
|
||||
|
|
|
|||
211
backend/main.py
Normal file
211
backend/main.py
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
"""PromptLooper FastAPI application."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from config import settings
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Database engine & session factory (lazy, created at startup)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
engine = None
|
||||
SessionLocal = None
|
||||
|
||||
|
||||
def _init_db() -> None:
|
||||
"""Create the SQLAlchemy engine and session factory."""
|
||||
global engine, SessionLocal
|
||||
connect_args = {}
|
||||
if settings.is_sqlite:
|
||||
connect_args["check_same_thread"] = False
|
||||
engine = create_engine(
|
||||
settings.effective_database_url,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
SessionLocal = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""FastAPI dependency that yields a database session."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Redis helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_redis_client = None
|
||||
|
||||
|
||||
def _init_redis() -> None:
|
||||
"""Connect to Redis if configured."""
|
||||
global _redis_client
|
||||
if not settings.redis_url:
|
||||
_redis_client = None
|
||||
return
|
||||
import redis as redis_lib
|
||||
_redis_client = redis_lib.Redis.from_url(settings.redis_url, decode_responses=True)
|
||||
|
||||
|
||||
def get_redis():
|
||||
"""Return the Redis client (or None in single-container mode)."""
|
||||
return _redis_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WebSocket connection manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manage active WebSocket connections."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: list[WebSocket] = []
|
||||
|
||||
async def connect(self, websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket) -> None:
|
||||
self.active_connections.remove(websocket)
|
||||
|
||||
async def broadcast(self, message: dict) -> None:
|
||||
for connection in list(self.active_connections):
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception:
|
||||
self.disconnect(connection)
|
||||
|
||||
|
||||
ws_manager = ConnectionManager()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Startup and shutdown lifecycle hooks."""
|
||||
_init_db()
|
||||
_init_redis()
|
||||
yield
|
||||
# Shutdown: clean up connections
|
||||
if _redis_client is not None:
|
||||
_redis_client.close()
|
||||
if engine is not None:
|
||||
engine.dispose()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Application
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
app = FastAPI(
|
||||
title="PromptLooper",
|
||||
description="LLM pipeline tuning workbench",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS — allow all origins in development; tighten in production via env
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/health", tags=["system"])
|
||||
def health_check() -> dict:
|
||||
"""Check DB and Redis connectivity."""
|
||||
db_ok = False
|
||||
redis_ok = False
|
||||
|
||||
# Database check
|
||||
if SessionLocal is not None:
|
||||
try:
|
||||
with SessionLocal() as session:
|
||||
session.execute(text("SELECT 1"))
|
||||
db_ok = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Redis check
|
||||
if not settings.redis_url:
|
||||
redis_ok = True # No Redis needed — in-process mode
|
||||
elif _redis_client is not None:
|
||||
try:
|
||||
_redis_client.ping()
|
||||
redis_ok = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"status": "ok" if (db_ok and redis_ok) else "degraded", "database": db_ok, "redis": redis_ok}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WebSocket endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket) -> None:
|
||||
"""WebSocket connection for real-time dashboard updates."""
|
||||
await ws_manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
# Keep connection alive; handle incoming messages if needed
|
||||
data = await websocket.receive_json()
|
||||
# Echo back or handle client messages in future
|
||||
await websocket.send_json({"type": "ack", "data": data})
|
||||
except WebSocketDisconnect:
|
||||
ws_manager.disconnect(websocket)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mount routers (stubs — actual implementations come later)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Router imports are deferred to avoid circular imports and allow
|
||||
# stub files to be created independently. Each router will be mounted
|
||||
# as it is implemented. For now we register empty prefixes.
|
||||
|
||||
def _mount_routers() -> None:
|
||||
"""Import and mount all routers. Silently skip missing ones."""
|
||||
router_configs = [
|
||||
("routers.auth", "/api/auth", ["auth"]),
|
||||
("routers.projects", "/api/projects", ["projects"]),
|
||||
("routers.experiments", "/api/experiments", ["experiments"]),
|
||||
("routers.runs", "/api/runs", ["runs"]),
|
||||
("routers.endpoints", "/api/endpoints", ["endpoints"]),
|
||||
("routers.export", "/api/export", ["export"]),
|
||||
("routers.webhooks", "/api/webhooks", ["webhooks"]),
|
||||
("routers.admin", "/api/admin", ["admin"]),
|
||||
]
|
||||
for module_name, prefix, tags in router_configs:
|
||||
try:
|
||||
import importlib
|
||||
mod = importlib.import_module(module_name)
|
||||
app.include_router(mod.router, prefix=prefix, tags=tags)
|
||||
except (ImportError, AttributeError):
|
||||
pass # Router not yet implemented
|
||||
|
||||
|
||||
_mount_routers()
|
||||
129
backend/tests/test_main.py
Normal file
129
backend/tests/test_main.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""Tests for backend/main.py — FastAPI application."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_settings(tmp_path):
|
||||
"""Ensure tests use a temp SQLite DB and no Redis."""
|
||||
env = {
|
||||
"DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}",
|
||||
"REDIS_URL": "",
|
||||
"DATA_DIR": str(tmp_path),
|
||||
}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
# Reload settings so it picks up test env
|
||||
import config
|
||||
new_settings = config.Settings(_env_file=None)
|
||||
config.settings = new_settings
|
||||
|
||||
# Patch main's reference too
|
||||
import main
|
||||
main.settings = new_settings
|
||||
main._init_db()
|
||||
main._init_redis()
|
||||
|
||||
# Create tables
|
||||
from models import Base
|
||||
Base.metadata.create_all(bind=main.engine)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from main import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
def test_health_returns_ok(self, client):
|
||||
resp = client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["database"] is True
|
||||
assert data["redis"] is True # in-process mode counts as ok
|
||||
|
||||
def test_health_response_schema(self, client):
|
||||
resp = client.get("/health")
|
||||
data = resp.json()
|
||||
assert set(data.keys()) == {"status", "database", "redis"}
|
||||
|
||||
|
||||
class TestCORSMiddleware:
|
||||
def test_cors_headers_present(self, client):
|
||||
resp = client.options(
|
||||
"/health",
|
||||
headers={
|
||||
"Origin": "http://localhost:3000",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
assert "access-control-allow-origin" in resp.headers
|
||||
|
||||
|
||||
class TestWebSocket:
|
||||
def test_websocket_connect_and_echo(self, client):
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({"type": "ping"})
|
||||
data = ws.receive_json()
|
||||
assert data["type"] == "ack"
|
||||
assert data["data"]["type"] == "ping"
|
||||
|
||||
def test_websocket_disconnect_cleanup(self, client):
|
||||
from main import ws_manager
|
||||
initial_count = len(ws_manager.active_connections)
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
assert len(ws_manager.active_connections) == initial_count + 1
|
||||
# After disconnect, connection should be removed
|
||||
assert len(ws_manager.active_connections) == initial_count
|
||||
|
||||
|
||||
class TestRouterMounting:
|
||||
def test_openapi_schema_loads(self, client):
|
||||
resp = client.get("/openapi.json")
|
||||
assert resp.status_code == 200
|
||||
schema = resp.json()
|
||||
assert schema["info"]["title"] == "PromptLooper"
|
||||
|
||||
def test_unknown_route_returns_404(self, client):
|
||||
resp = client.get("/api/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestConnectionManager:
|
||||
def test_broadcast_removes_dead_connections(self):
|
||||
"""ConnectionManager.broadcast skips and removes broken connections."""
|
||||
from main import ConnectionManager
|
||||
manager = ConnectionManager()
|
||||
# No connections — broadcast should not raise
|
||||
import asyncio
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
manager.broadcast({"test": True})
|
||||
)
|
||||
assert len(manager.active_connections) == 0
|
||||
|
||||
|
||||
class TestGetDb:
|
||||
def test_get_db_yields_session(self):
|
||||
from main import get_db
|
||||
gen = get_db()
|
||||
session = next(gen)
|
||||
assert session is not None
|
||||
# Clean up
|
||||
try:
|
||||
next(gen)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
|
||||
class TestGetRedis:
|
||||
def test_get_redis_returns_none_in_process_mode(self):
|
||||
from main import get_redis
|
||||
# In test setup, Redis is not configured
|
||||
assert get_redis() is None
|
||||
Loading…
Add table
Reference in a new issue