- WebSocketManager in backend/websocket/manager.py with per-experiment and global subscriptions - Redis pub/sub bridge (sync + async) broadcasting events to relevant WebSocket clients - Deque-based replay buffers with since_ts/limit filtering for reconnection support - Runtime subscribe/unsubscribe and stats API - Enhanced /ws endpoint in main.py with subscribe/unsubscribe/replay actions - 35 tests in test_ws_manager.py, all passing
218 lines
7.4 KiB
Python
218 lines
7.4 KiB
Python
"""PromptLooper FastAPI application."""
|
|
|
|
from contextlib import asynccontextmanager
|
|
from typing import AsyncGenerator
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from config import settings
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Database engine & session factory (lazy, created at startup)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
engine = None
|
|
SessionLocal = None
|
|
|
|
|
|
def _init_db() -> None:
|
|
"""Create the SQLAlchemy engine and session factory."""
|
|
global engine, SessionLocal
|
|
connect_args = {}
|
|
if settings.is_sqlite:
|
|
connect_args["check_same_thread"] = False
|
|
engine = create_engine(
|
|
settings.effective_database_url,
|
|
connect_args=connect_args,
|
|
)
|
|
SessionLocal = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False)
|
|
|
|
|
|
def get_db():
|
|
"""FastAPI dependency that yields a database session."""
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Redis helper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_redis_client = None
|
|
|
|
|
|
def _init_redis() -> None:
|
|
"""Connect to Redis if configured."""
|
|
global _redis_client
|
|
if not settings.redis_url:
|
|
_redis_client = None
|
|
return
|
|
import redis as redis_lib
|
|
_redis_client = redis_lib.Redis.from_url(settings.redis_url, decode_responses=True)
|
|
|
|
|
|
def get_redis():
|
|
"""Return the Redis client (or None in single-container mode)."""
|
|
return _redis_client
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# WebSocket connection manager
|
|
# ---------------------------------------------------------------------------
|
|
|
|
from websocket.manager import WebSocketManager
|
|
|
|
ws_manager = WebSocketManager()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Lifecycle
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
"""Startup and shutdown lifecycle hooks."""
|
|
_init_db()
|
|
_init_redis()
|
|
# Start WebSocket ↔ Redis bridge
|
|
await ws_manager.start_redis_listener(_redis_client)
|
|
yield
|
|
# Shutdown: clean up connections
|
|
await ws_manager.stop_redis_listener()
|
|
if _redis_client is not None:
|
|
_redis_client.close()
|
|
if engine is not None:
|
|
engine.dispose()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Application
|
|
# ---------------------------------------------------------------------------
|
|
|
|
app = FastAPI(
|
|
title="PromptLooper",
|
|
description="LLM pipeline tuning workbench",
|
|
version="0.1.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS — allow all origins in development; tighten in production via env
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Health endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.get("/health", tags=["system"])
|
|
def health_check() -> dict:
|
|
"""Check DB and Redis connectivity."""
|
|
db_ok = False
|
|
redis_ok = False
|
|
|
|
# Database check
|
|
if SessionLocal is not None:
|
|
try:
|
|
with SessionLocal() as session:
|
|
session.execute(text("SELECT 1"))
|
|
db_ok = True
|
|
except Exception:
|
|
pass
|
|
|
|
# Redis check
|
|
if not settings.redis_url:
|
|
redis_ok = True # No Redis needed — in-process mode
|
|
elif _redis_client is not None:
|
|
try:
|
|
_redis_client.ping()
|
|
redis_ok = True
|
|
except Exception:
|
|
pass
|
|
|
|
return {"status": "ok" if (db_ok and redis_ok) else "degraded", "database": db_ok, "redis": redis_ok}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# WebSocket endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket) -> None:
|
|
"""WebSocket connection for real-time dashboard updates.
|
|
|
|
Clients may send JSON messages to control their subscriptions:
|
|
- {"action": "subscribe", "experiment_id": "..."}
|
|
- {"action": "unsubscribe", "experiment_id": "..."}
|
|
- {"action": "replay", "experiment_id": "...", "since_ts": ..., "limit": ...}
|
|
"""
|
|
# Parse query params for initial experiment subscriptions
|
|
exp_ids_param = websocket.query_params.get("experiment_ids", "")
|
|
experiment_ids = [e.strip() for e in exp_ids_param.split(",") if e.strip()] or None
|
|
await ws_manager.connect(websocket, experiment_ids=experiment_ids)
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
action = data.get("action", "")
|
|
if action == "subscribe" and data.get("experiment_id"):
|
|
ws_manager.subscribe(websocket, data["experiment_id"])
|
|
await websocket.send_json({"type": "ack", "action": "subscribe", "experiment_id": data["experiment_id"]})
|
|
elif action == "unsubscribe" and data.get("experiment_id"):
|
|
ws_manager.unsubscribe(websocket, data["experiment_id"])
|
|
await websocket.send_json({"type": "ack", "action": "unsubscribe", "experiment_id": data["experiment_id"]})
|
|
elif action == "replay":
|
|
count = await ws_manager.replay(
|
|
websocket,
|
|
experiment_id=data.get("experiment_id"),
|
|
since_ts=data.get("since_ts"),
|
|
limit=data.get("limit"),
|
|
)
|
|
await websocket.send_json({"type": "ack", "action": "replay", "events_replayed": count})
|
|
else:
|
|
await websocket.send_json({"type": "ack", "data": data})
|
|
except WebSocketDisconnect:
|
|
ws_manager.disconnect(websocket)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Mount routers (stubs — actual implementations come later)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Router imports are deferred to avoid circular imports and allow
|
|
# stub files to be created independently. Each router will be mounted
|
|
# as it is implemented. For now we register empty prefixes.
|
|
|
|
def _mount_routers() -> None:
|
|
"""Import and mount all routers. Silently skip missing ones."""
|
|
router_configs = [
|
|
("routers.auth", "/api/auth", ["auth"]),
|
|
("routers.projects", "/api/projects", ["projects"]),
|
|
("routers.experiments", "/api/experiments", ["experiments"]),
|
|
("routers.runs", "/api/runs", ["runs"]),
|
|
("routers.endpoints", "/api/endpoints", ["endpoints"]),
|
|
("routers.export", "/api/export", ["export"]),
|
|
("routers.webhooks", "/api/webhooks", ["webhooks"]),
|
|
("routers.admin", "/api/admin", ["admin"]),
|
|
]
|
|
for module_name, prefix, tags in router_configs:
|
|
try:
|
|
import importlib
|
|
mod = importlib.import_module(module_name)
|
|
app.include_router(mod.router, prefix=prefix, tags=tags)
|
|
except (ImportError, AttributeError):
|
|
pass # Router not yet implemented
|
|
|
|
|
|
_mount_routers()
|