"""PromptLooper FastAPI application.""" from contextlib import asynccontextmanager from typing import AsyncGenerator from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from config import settings # --------------------------------------------------------------------------- # Database engine & session factory (lazy, created at startup) # --------------------------------------------------------------------------- engine = None SessionLocal = None def _init_db() -> None: """Create the SQLAlchemy engine and session factory.""" global engine, SessionLocal connect_args = {} if settings.is_sqlite: connect_args["check_same_thread"] = False engine = create_engine( settings.effective_database_url, connect_args=connect_args, ) SessionLocal = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False) def get_db(): """FastAPI dependency that yields a database session.""" db = SessionLocal() try: yield db finally: db.close() # --------------------------------------------------------------------------- # Redis helper # --------------------------------------------------------------------------- _redis_client = None def _init_redis() -> None: """Connect to Redis if configured.""" global _redis_client if not settings.redis_url: _redis_client = None return import redis as redis_lib _redis_client = redis_lib.Redis.from_url(settings.redis_url, decode_responses=True) def get_redis(): """Return the Redis client (or None in single-container mode).""" return _redis_client # --------------------------------------------------------------------------- # WebSocket connection manager # --------------------------------------------------------------------------- class ConnectionManager: """Manage active WebSocket connections.""" def __init__(self) -> None: self.active_connections: list[WebSocket] = [] async def connect(self, websocket: WebSocket) -> None: await websocket.accept() self.active_connections.append(websocket) def disconnect(self, websocket: WebSocket) -> None: self.active_connections.remove(websocket) async def broadcast(self, message: dict) -> None: for connection in list(self.active_connections): try: await connection.send_json(message) except Exception: self.disconnect(connection) ws_manager = ConnectionManager() # --------------------------------------------------------------------------- # Lifecycle # --------------------------------------------------------------------------- @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Startup and shutdown lifecycle hooks.""" _init_db() _init_redis() yield # Shutdown: clean up connections if _redis_client is not None: _redis_client.close() if engine is not None: engine.dispose() # --------------------------------------------------------------------------- # Application # --------------------------------------------------------------------------- app = FastAPI( title="PromptLooper", description="LLM pipeline tuning workbench", version="0.1.0", lifespan=lifespan, ) # CORS — allow all origins in development; tighten in production via env app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --------------------------------------------------------------------------- # Health endpoint # --------------------------------------------------------------------------- @app.get("/health", tags=["system"]) def health_check() -> dict: """Check DB and Redis connectivity.""" db_ok = False redis_ok = False # Database check if SessionLocal is not None: try: with SessionLocal() as session: session.execute(text("SELECT 1")) db_ok = True except Exception: pass # Redis check if not settings.redis_url: redis_ok = True # No Redis needed — in-process mode elif _redis_client is not None: try: _redis_client.ping() redis_ok = True except Exception: pass return {"status": "ok" if (db_ok and redis_ok) else "degraded", "database": db_ok, "redis": redis_ok} # --------------------------------------------------------------------------- # WebSocket endpoint # --------------------------------------------------------------------------- @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket) -> None: """WebSocket connection for real-time dashboard updates.""" await ws_manager.connect(websocket) try: while True: # Keep connection alive; handle incoming messages if needed data = await websocket.receive_json() # Echo back or handle client messages in future await websocket.send_json({"type": "ack", "data": data}) except WebSocketDisconnect: ws_manager.disconnect(websocket) # --------------------------------------------------------------------------- # Mount routers (stubs — actual implementations come later) # --------------------------------------------------------------------------- # Router imports are deferred to avoid circular imports and allow # stub files to be created independently. Each router will be mounted # as it is implemented. For now we register empty prefixes. def _mount_routers() -> None: """Import and mount all routers. Silently skip missing ones.""" router_configs = [ ("routers.auth", "/api/auth", ["auth"]), ("routers.projects", "/api/projects", ["projects"]), ("routers.experiments", "/api/experiments", ["experiments"]), ("routers.runs", "/api/runs", ["runs"]), ("routers.endpoints", "/api/endpoints", ["endpoints"]), ("routers.export", "/api/export", ["export"]), ("routers.webhooks", "/api/webhooks", ["webhooks"]), ("routers.admin", "/api/admin", ["admin"]), ] for module_name, prefix, tags in router_configs: try: import importlib mod = importlib.import_module(module_name) app.include_router(mod.router, prefix=prefix, tags=tags) except (ImportError, AttributeError): pass # Router not yet implemented _mount_routers()