"""WebSocket connection manager for PromptLooper. Manages per-experiment and global WebSocket connections, bridges Redis pub/sub events to connected clients, and supports message replay on reconnection. """ import asyncio import json import logging import time from collections import deque from dataclasses import dataclass, field from typing import Any from fastapi import WebSocket logger = logging.getLogger(__name__) # Default number of events to retain for replay on reconnect DEFAULT_REPLAY_SIZE = 50 @dataclass class ConnectionEntry: """Tracks a single WebSocket connection and its subscriptions.""" websocket: WebSocket experiment_ids: set[str] = field(default_factory=set) connected_at: float = field(default_factory=time.time) class WebSocketManager: """Manages WebSocket connections with per-experiment routing and replay. Features: - Per-experiment and global subscriptions - Redis pub/sub bridge (when Redis is available) - Message replay on reconnection (last N events per experiment + global) - Clean disconnect handling """ def __init__(self, replay_size: int = DEFAULT_REPLAY_SIZE) -> None: self._replay_size = replay_size # All active connections keyed by id(websocket) self._connections: dict[int, ConnectionEntry] = {} # Experiment-specific connection sets: experiment_id -> set of ws ids self._experiment_subs: dict[str, set[int]] = {} # Global subscribers (receive all events) self._global_subs: set[int] = set() # Event replay buffers self._global_replay: deque[dict[str, Any]] = deque(maxlen=replay_size) self._experiment_replay: dict[str, deque[dict[str, Any]]] = {} # Redis listener task handle self._redis_listener_task: asyncio.Task | None = None # ------------------------------------------------------------------ # Connection lifecycle # ------------------------------------------------------------------ async def connect( self, websocket: WebSocket, experiment_ids: list[str] | None = None, ) -> None: """Accept a WebSocket connection and register subscriptions. Args: websocket: The incoming WebSocket connection. experiment_ids: Optional list of experiment IDs to subscribe to. If empty/None the connection receives global broadcasts only. """ await websocket.accept() ws_id = id(websocket) entry = ConnectionEntry(websocket=websocket) self._connections[ws_id] = entry if experiment_ids: for exp_id in experiment_ids: self._subscribe_experiment(ws_id, exp_id) else: self._global_subs.add(ws_id) def disconnect(self, websocket: WebSocket) -> None: """Remove a WebSocket from all tracking structures.""" ws_id = id(websocket) entry = self._connections.pop(ws_id, None) if entry is None: return # Remove from global subs self._global_subs.discard(ws_id) # Remove from all experiment subs for exp_id in list(entry.experiment_ids): exp_set = self._experiment_subs.get(exp_id) if exp_set is not None: exp_set.discard(ws_id) if not exp_set: del self._experiment_subs[exp_id] def _subscribe_experiment(self, ws_id: int, experiment_id: str) -> None: """Subscribe a connection to a specific experiment's events.""" entry = self._connections.get(ws_id) if entry is None: return entry.experiment_ids.add(experiment_id) if experiment_id not in self._experiment_subs: self._experiment_subs[experiment_id] = set() self._experiment_subs[experiment_id].add(ws_id) # ------------------------------------------------------------------ # Subscription management (runtime) # ------------------------------------------------------------------ def subscribe(self, websocket: WebSocket, experiment_id: str) -> None: """Add an experiment subscription for an existing connection.""" ws_id = id(websocket) if ws_id in self._connections: self._subscribe_experiment(ws_id, experiment_id) def unsubscribe(self, websocket: WebSocket, experiment_id: str) -> None: """Remove an experiment subscription for an existing connection.""" ws_id = id(websocket) entry = self._connections.get(ws_id) if entry is None: return entry.experiment_ids.discard(experiment_id) exp_set = self._experiment_subs.get(experiment_id) if exp_set is not None: exp_set.discard(ws_id) if not exp_set: del self._experiment_subs[experiment_id] # ------------------------------------------------------------------ # Broadcasting # ------------------------------------------------------------------ async def broadcast_to_experiment( self, experiment_id: str, message: dict[str, Any] ) -> None: """Send a message to all connections subscribed to an experiment.""" # Store in replay buffer if experiment_id not in self._experiment_replay: self._experiment_replay[experiment_id] = deque(maxlen=self._replay_size) message_with_ts = {**message, "_ts": time.time()} self._experiment_replay[experiment_id].append(message_with_ts) ws_ids = self._experiment_subs.get(experiment_id, set()) await self._send_to_many(ws_ids, message) async def broadcast_global(self, message: dict[str, Any]) -> None: """Send a message to all global subscribers.""" message_with_ts = {**message, "_ts": time.time()} self._global_replay.append(message_with_ts) await self._send_to_many(self._global_subs, message) async def broadcast(self, message: dict[str, Any]) -> None: """Route a message to the right subscribers based on experiment_id. If the message contains an ``experiment_id`` field, it is sent to that experiment's subscribers *and* to global subscribers. Otherwise it is sent only to global subscribers. """ experiment_id = message.get("experiment_id") if experiment_id: await self.broadcast_to_experiment(str(experiment_id), message) await self.broadcast_global(message) async def _send_to_many( self, ws_ids: set[int], message: dict[str, Any] ) -> None: """Send a JSON message to a set of connections, disconnecting failures.""" dead: list[int] = [] for ws_id in list(ws_ids): entry = self._connections.get(ws_id) if entry is None: dead.append(ws_id) continue try: await entry.websocket.send_json(message) except Exception: logger.debug("WebSocket send failed, removing connection %s", ws_id) dead.append(ws_id) # Clean up dead connections for ws_id in dead: ws_ids.discard(ws_id) entry = self._connections.pop(ws_id, None) if entry: self._global_subs.discard(ws_id) for exp_id in entry.experiment_ids: exp_set = self._experiment_subs.get(exp_id) if exp_set: exp_set.discard(ws_id) # ------------------------------------------------------------------ # Message replay # ------------------------------------------------------------------ async def replay( self, websocket: WebSocket, experiment_id: str | None = None, since_ts: float | None = None, limit: int | None = None, ) -> int: """Replay recent events to a connection. Args: websocket: Target connection. experiment_id: If provided, replay experiment-specific events; otherwise replay global events. since_ts: Only replay events after this Unix timestamp. limit: Max number of events to replay (default: all in buffer). Returns: Number of events replayed. """ if experiment_id: buffer = self._experiment_replay.get(experiment_id, deque()) else: buffer = self._global_replay events = list(buffer) if since_ts is not None: events = [e for e in events if e.get("_ts", 0) > since_ts] if limit is not None: events = events[-limit:] count = 0 for event in events: try: await websocket.send_json(event) count += 1 except Exception: break return count # ------------------------------------------------------------------ # Redis pub/sub bridge # ------------------------------------------------------------------ async def start_redis_listener(self, redis_client: Any) -> None: """Start a background task that subscribes to Redis pub/sub and broadcasts received messages to connected WebSocket clients. Uses async Redis pubsub if available, otherwise runs the blocking listener in a thread executor. """ if redis_client is None: return self._redis_listener_task = asyncio.create_task( self._redis_listen_loop(redis_client) ) async def _redis_listen_loop(self, redis_client: Any) -> None: """Listen for Redis pub/sub messages and broadcast them.""" channel = "promptlooper:events" try: # Check if this is an async redis client (aioredis / redis.asyncio) if hasattr(redis_client, "pubsub") and asyncio.iscoroutinefunction( getattr(redis_client, "subscribe", None) ): await self._redis_listen_async(redis_client, channel) else: await self._redis_listen_sync(redis_client, channel) except asyncio.CancelledError: logger.info("Redis listener cancelled") except Exception: logger.exception("Redis listener error") async def _redis_listen_async(self, redis_client: Any, channel: str) -> None: """Async Redis pub/sub listener.""" pubsub = redis_client.pubsub() await pubsub.subscribe(channel) try: async for raw_message in pubsub.listen(): if raw_message["type"] != "message": continue await self._handle_redis_message(raw_message["data"]) finally: await pubsub.unsubscribe(channel) async def _redis_listen_sync(self, redis_client: Any, channel: str) -> None: """Sync Redis pub/sub listener (run blocking calls in executor).""" pubsub = redis_client.pubsub() pubsub.subscribe(channel) loop = asyncio.get_running_loop() try: while True: raw_message = await loop.run_in_executor( None, lambda: pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0) ) if raw_message is None: continue if raw_message["type"] != "message": continue await self._handle_redis_message(raw_message["data"]) finally: pubsub.unsubscribe(channel) async def _handle_redis_message(self, data: str | bytes) -> None: """Parse a Redis message and broadcast it.""" try: if isinstance(data, bytes): data = data.decode("utf-8") message = json.loads(data) except (json.JSONDecodeError, UnicodeDecodeError): logger.warning("Invalid Redis message: %s", data) return await self.broadcast(message) async def stop_redis_listener(self) -> None: """Cancel the Redis listener background task.""" if self._redis_listener_task is not None: self._redis_listener_task.cancel() try: await self._redis_listener_task except asyncio.CancelledError: pass self._redis_listener_task = None # ------------------------------------------------------------------ # Status # ------------------------------------------------------------------ @property def connection_count(self) -> int: """Total number of active connections.""" return len(self._connections) @property def experiment_ids(self) -> list[str]: """List of experiment IDs with active subscriptions.""" return list(self._experiment_subs.keys()) def stats(self) -> dict[str, Any]: """Return connection statistics.""" return { "total_connections": self.connection_count, "global_subscribers": len(self._global_subs), "experiment_subscriptions": { exp_id: len(ws_ids) for exp_id, ws_ids in self._experiment_subs.items() }, "replay_buffer_sizes": { "global": len(self._global_replay), **{ exp_id: len(buf) for exp_id, buf in self._experiment_replay.items() }, }, } # Module-level singleton for use across the application manager = WebSocketManager()