promptlooper/backend/main.py
John Lightner 30fd15ec7a MAESTRO: Implement WebSocket connection manager with per-experiment routing, Redis pub/sub bridge, and message replay
- WebSocketManager in backend/websocket/manager.py with per-experiment and global subscriptions
- Redis pub/sub bridge (sync + async) broadcasting events to relevant WebSocket clients
- Deque-based replay buffers with since_ts/limit filtering for reconnection support
- Runtime subscribe/unsubscribe and stats API
- Enhanced /ws endpoint in main.py with subscribe/unsubscribe/replay actions
- 35 tests in test_ws_manager.py, all passing
2026-04-07 03:34:21 -05:00

218 lines
7.4 KiB
Python

"""PromptLooper FastAPI application."""
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from config import settings
# ---------------------------------------------------------------------------
# Database engine & session factory (lazy, created at startup)
# ---------------------------------------------------------------------------
engine = None
SessionLocal = None
def _init_db() -> None:
"""Create the SQLAlchemy engine and session factory."""
global engine, SessionLocal
connect_args = {}
if settings.is_sqlite:
connect_args["check_same_thread"] = False
engine = create_engine(
settings.effective_database_url,
connect_args=connect_args,
)
SessionLocal = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False)
def get_db():
"""FastAPI dependency that yields a database session."""
db = SessionLocal()
try:
yield db
finally:
db.close()
# ---------------------------------------------------------------------------
# Redis helper
# ---------------------------------------------------------------------------
_redis_client = None
def _init_redis() -> None:
"""Connect to Redis if configured."""
global _redis_client
if not settings.redis_url:
_redis_client = None
return
import redis as redis_lib
_redis_client = redis_lib.Redis.from_url(settings.redis_url, decode_responses=True)
def get_redis():
"""Return the Redis client (or None in single-container mode)."""
return _redis_client
# ---------------------------------------------------------------------------
# WebSocket connection manager
# ---------------------------------------------------------------------------
from websocket.manager import WebSocketManager
ws_manager = WebSocketManager()
# ---------------------------------------------------------------------------
# Lifecycle
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Startup and shutdown lifecycle hooks."""
_init_db()
_init_redis()
# Start WebSocket ↔ Redis bridge
await ws_manager.start_redis_listener(_redis_client)
yield
# Shutdown: clean up connections
await ws_manager.stop_redis_listener()
if _redis_client is not None:
_redis_client.close()
if engine is not None:
engine.dispose()
# ---------------------------------------------------------------------------
# Application
# ---------------------------------------------------------------------------
app = FastAPI(
title="PromptLooper",
description="LLM pipeline tuning workbench",
version="0.1.0",
lifespan=lifespan,
)
# CORS — allow all origins in development; tighten in production via env
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Health endpoint
# ---------------------------------------------------------------------------
@app.get("/health", tags=["system"])
def health_check() -> dict:
"""Check DB and Redis connectivity."""
db_ok = False
redis_ok = False
# Database check
if SessionLocal is not None:
try:
with SessionLocal() as session:
session.execute(text("SELECT 1"))
db_ok = True
except Exception:
pass
# Redis check
if not settings.redis_url:
redis_ok = True # No Redis needed — in-process mode
elif _redis_client is not None:
try:
_redis_client.ping()
redis_ok = True
except Exception:
pass
return {"status": "ok" if (db_ok and redis_ok) else "degraded", "database": db_ok, "redis": redis_ok}
# ---------------------------------------------------------------------------
# WebSocket endpoint
# ---------------------------------------------------------------------------
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> None:
"""WebSocket connection for real-time dashboard updates.
Clients may send JSON messages to control their subscriptions:
- {"action": "subscribe", "experiment_id": "..."}
- {"action": "unsubscribe", "experiment_id": "..."}
- {"action": "replay", "experiment_id": "...", "since_ts": ..., "limit": ...}
"""
# Parse query params for initial experiment subscriptions
exp_ids_param = websocket.query_params.get("experiment_ids", "")
experiment_ids = [e.strip() for e in exp_ids_param.split(",") if e.strip()] or None
await ws_manager.connect(websocket, experiment_ids=experiment_ids)
try:
while True:
data = await websocket.receive_json()
action = data.get("action", "")
if action == "subscribe" and data.get("experiment_id"):
ws_manager.subscribe(websocket, data["experiment_id"])
await websocket.send_json({"type": "ack", "action": "subscribe", "experiment_id": data["experiment_id"]})
elif action == "unsubscribe" and data.get("experiment_id"):
ws_manager.unsubscribe(websocket, data["experiment_id"])
await websocket.send_json({"type": "ack", "action": "unsubscribe", "experiment_id": data["experiment_id"]})
elif action == "replay":
count = await ws_manager.replay(
websocket,
experiment_id=data.get("experiment_id"),
since_ts=data.get("since_ts"),
limit=data.get("limit"),
)
await websocket.send_json({"type": "ack", "action": "replay", "events_replayed": count})
else:
await websocket.send_json({"type": "ack", "data": data})
except WebSocketDisconnect:
ws_manager.disconnect(websocket)
# ---------------------------------------------------------------------------
# Mount routers (stubs — actual implementations come later)
# ---------------------------------------------------------------------------
# Router imports are deferred to avoid circular imports and allow
# stub files to be created independently. Each router will be mounted
# as it is implemented. For now we register empty prefixes.
def _mount_routers() -> None:
"""Import and mount all routers. Silently skip missing ones."""
router_configs = [
("routers.auth", "/api/auth", ["auth"]),
("routers.projects", "/api/projects", ["projects"]),
("routers.experiments", "/api/experiments", ["experiments"]),
("routers.runs", "/api/runs", ["runs"]),
("routers.endpoints", "/api/endpoints", ["endpoints"]),
("routers.export", "/api/export", ["export"]),
("routers.webhooks", "/api/webhooks", ["webhooks"]),
("routers.admin", "/api/admin", ["admin"]),
]
for module_name, prefix, tags in router_configs:
try:
import importlib
mod = importlib.import_module(module_name)
app.include_router(mod.router, prefix=prefix, tags=tags)
except (ImportError, AttributeError):
pass # Router not yet implemented
_mount_routers()