promptlooper/backend/main.py
John Lightner 15ca2c922a MAESTRO: Create backend/main.py with FastAPI app, CORS, health check, WebSocket, and router mounting
FastAPI application with:
- CORS middleware (permissive for dev)
- /health endpoint checking DB and Redis connectivity
- /ws WebSocket endpoint with ConnectionManager for real-time updates
- Async lifespan hooks for DB engine and Redis init/teardown
- get_db dependency for session management
- Dynamic router mounting that silently skips missing router modules
- 10 tests covering all endpoints and utilities
2026-04-07 01:56:40 -05:00

211 lines
6.5 KiB
Python

"""PromptLooper FastAPI application."""
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from config import settings
# ---------------------------------------------------------------------------
# Database engine & session factory (lazy, created at startup)
# ---------------------------------------------------------------------------
engine = None
SessionLocal = None
def _init_db() -> None:
"""Create the SQLAlchemy engine and session factory."""
global engine, SessionLocal
connect_args = {}
if settings.is_sqlite:
connect_args["check_same_thread"] = False
engine = create_engine(
settings.effective_database_url,
connect_args=connect_args,
)
SessionLocal = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False)
def get_db():
"""FastAPI dependency that yields a database session."""
db = SessionLocal()
try:
yield db
finally:
db.close()
# ---------------------------------------------------------------------------
# Redis helper
# ---------------------------------------------------------------------------
_redis_client = None
def _init_redis() -> None:
"""Connect to Redis if configured."""
global _redis_client
if not settings.redis_url:
_redis_client = None
return
import redis as redis_lib
_redis_client = redis_lib.Redis.from_url(settings.redis_url, decode_responses=True)
def get_redis():
"""Return the Redis client (or None in single-container mode)."""
return _redis_client
# ---------------------------------------------------------------------------
# WebSocket connection manager
# ---------------------------------------------------------------------------
class ConnectionManager:
"""Manage active WebSocket connections."""
def __init__(self) -> None:
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket) -> None:
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket) -> None:
self.active_connections.remove(websocket)
async def broadcast(self, message: dict) -> None:
for connection in list(self.active_connections):
try:
await connection.send_json(message)
except Exception:
self.disconnect(connection)
ws_manager = ConnectionManager()
# ---------------------------------------------------------------------------
# Lifecycle
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Startup and shutdown lifecycle hooks."""
_init_db()
_init_redis()
yield
# Shutdown: clean up connections
if _redis_client is not None:
_redis_client.close()
if engine is not None:
engine.dispose()
# ---------------------------------------------------------------------------
# Application
# ---------------------------------------------------------------------------
app = FastAPI(
title="PromptLooper",
description="LLM pipeline tuning workbench",
version="0.1.0",
lifespan=lifespan,
)
# CORS — allow all origins in development; tighten in production via env
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Health endpoint
# ---------------------------------------------------------------------------
@app.get("/health", tags=["system"])
def health_check() -> dict:
"""Check DB and Redis connectivity."""
db_ok = False
redis_ok = False
# Database check
if SessionLocal is not None:
try:
with SessionLocal() as session:
session.execute(text("SELECT 1"))
db_ok = True
except Exception:
pass
# Redis check
if not settings.redis_url:
redis_ok = True # No Redis needed — in-process mode
elif _redis_client is not None:
try:
_redis_client.ping()
redis_ok = True
except Exception:
pass
return {"status": "ok" if (db_ok and redis_ok) else "degraded", "database": db_ok, "redis": redis_ok}
# ---------------------------------------------------------------------------
# WebSocket endpoint
# ---------------------------------------------------------------------------
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> None:
"""WebSocket connection for real-time dashboard updates."""
await ws_manager.connect(websocket)
try:
while True:
# Keep connection alive; handle incoming messages if needed
data = await websocket.receive_json()
# Echo back or handle client messages in future
await websocket.send_json({"type": "ack", "data": data})
except WebSocketDisconnect:
ws_manager.disconnect(websocket)
# ---------------------------------------------------------------------------
# Mount routers (stubs — actual implementations come later)
# ---------------------------------------------------------------------------
# Router imports are deferred to avoid circular imports and allow
# stub files to be created independently. Each router will be mounted
# as it is implemented. For now we register empty prefixes.
def _mount_routers() -> None:
"""Import and mount all routers. Silently skip missing ones."""
router_configs = [
("routers.auth", "/api/auth", ["auth"]),
("routers.projects", "/api/projects", ["projects"]),
("routers.experiments", "/api/experiments", ["experiments"]),
("routers.runs", "/api/runs", ["runs"]),
("routers.endpoints", "/api/endpoints", ["endpoints"]),
("routers.export", "/api/export", ["export"]),
("routers.webhooks", "/api/webhooks", ["webhooks"]),
("routers.admin", "/api/admin", ["admin"]),
]
for module_name, prefix, tags in router_configs:
try:
import importlib
mod = importlib.import_module(module_name)
app.include_router(mod.router, prefix=prefix, tags=tags)
except (ImportError, AttributeError):
pass # Router not yet implemented
_mount_routers()