"""PromptLooper FastAPI application.""" from contextlib import asynccontextmanager from typing import AsyncGenerator from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from config import settings # --------------------------------------------------------------------------- # Database engine & session factory (lazy, created at startup) # --------------------------------------------------------------------------- engine = None SessionLocal = None def _init_db() -> None: """Create the SQLAlchemy engine and session factory.""" global engine, SessionLocal connect_args = {} if settings.is_sqlite: connect_args["check_same_thread"] = False engine = create_engine( settings.effective_database_url, connect_args=connect_args, ) SessionLocal = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False) def get_db(): """FastAPI dependency that yields a database session.""" db = SessionLocal() try: yield db finally: db.close() # --------------------------------------------------------------------------- # Redis helper # --------------------------------------------------------------------------- _redis_client = None def _init_redis() -> None: """Connect to Redis if configured.""" global _redis_client if not settings.redis_url: _redis_client = None return import redis as redis_lib _redis_client = redis_lib.Redis.from_url(settings.redis_url, decode_responses=True) def get_redis(): """Return the Redis client (or None in single-container mode).""" return _redis_client # --------------------------------------------------------------------------- # WebSocket connection manager # --------------------------------------------------------------------------- from websocket.manager import WebSocketManager ws_manager = WebSocketManager() # --------------------------------------------------------------------------- # Lifecycle # --------------------------------------------------------------------------- @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Startup and shutdown lifecycle hooks.""" _init_db() _init_redis() # Start WebSocket ↔ Redis bridge await ws_manager.start_redis_listener(_redis_client) yield # Shutdown: clean up connections await ws_manager.stop_redis_listener() if _redis_client is not None: _redis_client.close() if engine is not None: engine.dispose() # --------------------------------------------------------------------------- # Application # --------------------------------------------------------------------------- app = FastAPI( title="PromptLooper", description="LLM pipeline tuning workbench", version="0.1.0", lifespan=lifespan, ) # CORS — allow all origins in development; tighten in production via env app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --------------------------------------------------------------------------- # Health endpoint # --------------------------------------------------------------------------- @app.get("/health", tags=["system"]) def health_check() -> dict: """Check DB and Redis connectivity.""" db_ok = False redis_ok = False # Database check if SessionLocal is not None: try: with SessionLocal() as session: session.execute(text("SELECT 1")) db_ok = True except Exception: pass # Redis check if not settings.redis_url: redis_ok = True # No Redis needed — in-process mode elif _redis_client is not None: try: _redis_client.ping() redis_ok = True except Exception: pass return {"status": "ok" if (db_ok and redis_ok) else "degraded", "database": db_ok, "redis": redis_ok} # --------------------------------------------------------------------------- # WebSocket endpoint # --------------------------------------------------------------------------- @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket) -> None: """WebSocket connection for real-time dashboard updates. Clients may send JSON messages to control their subscriptions: - {"action": "subscribe", "experiment_id": "..."} - {"action": "unsubscribe", "experiment_id": "..."} - {"action": "replay", "experiment_id": "...", "since_ts": ..., "limit": ...} """ # Parse query params for initial experiment subscriptions exp_ids_param = websocket.query_params.get("experiment_ids", "") experiment_ids = [e.strip() for e in exp_ids_param.split(",") if e.strip()] or None await ws_manager.connect(websocket, experiment_ids=experiment_ids) try: while True: data = await websocket.receive_json() action = data.get("action", "") if action == "subscribe" and data.get("experiment_id"): ws_manager.subscribe(websocket, data["experiment_id"]) await websocket.send_json({"type": "ack", "action": "subscribe", "experiment_id": data["experiment_id"]}) elif action == "unsubscribe" and data.get("experiment_id"): ws_manager.unsubscribe(websocket, data["experiment_id"]) await websocket.send_json({"type": "ack", "action": "unsubscribe", "experiment_id": data["experiment_id"]}) elif action == "replay": count = await ws_manager.replay( websocket, experiment_id=data.get("experiment_id"), since_ts=data.get("since_ts"), limit=data.get("limit"), ) await websocket.send_json({"type": "ack", "action": "replay", "events_replayed": count}) else: await websocket.send_json({"type": "ack", "data": data}) except WebSocketDisconnect: ws_manager.disconnect(websocket) # --------------------------------------------------------------------------- # Mount routers (stubs — actual implementations come later) # --------------------------------------------------------------------------- # Router imports are deferred to avoid circular imports and allow # stub files to be created independently. Each router will be mounted # as it is implemented. For now we register empty prefixes. def _mount_routers() -> None: """Import and mount all routers. Silently skip missing ones.""" router_configs = [ ("routers.auth", "/api/auth", ["auth"]), ("routers.projects", "/api/projects", ["projects"]), ("routers.experiments", "/api/experiments", ["experiments"]), ("routers.runs", "/api/runs", ["runs"]), ("routers.endpoints", "/api/endpoints", ["endpoints"]), ("routers.export", "/api/export", ["export"]), ("routers.webhooks", "/api/webhooks", ["webhooks"]), ("routers.admin", "/api/admin", ["admin"]), ] for module_name, prefix, tags in router_configs: try: import importlib mod = importlib.import_module(module_name) app.include_router(mod.router, prefix=prefix, tags=tags) except (ImportError, AttributeError): pass # Router not yet implemented _mount_routers()