promptlooper/backend/websocket/manager.py
John Lightner 30fd15ec7a MAESTRO: Implement WebSocket connection manager with per-experiment routing, Redis pub/sub bridge, and message replay
- WebSocketManager in backend/websocket/manager.py with per-experiment and global subscriptions
- Redis pub/sub bridge (sync + async) broadcasting events to relevant WebSocket clients
- Deque-based replay buffers with since_ts/limit filtering for reconnection support
- Runtime subscribe/unsubscribe and stats API
- Enhanced /ws endpoint in main.py with subscribe/unsubscribe/replay actions
- 35 tests in test_ws_manager.py, all passing
2026-04-07 03:34:21 -05:00

358 lines
13 KiB
Python

"""WebSocket connection manager for PromptLooper.
Manages per-experiment and global WebSocket connections, bridges Redis
pub/sub events to connected clients, and supports message replay on
reconnection.
"""
import asyncio
import json
import logging
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Any
from fastapi import WebSocket
logger = logging.getLogger(__name__)
# Default number of events to retain for replay on reconnect
DEFAULT_REPLAY_SIZE = 50
@dataclass
class ConnectionEntry:
"""Tracks a single WebSocket connection and its subscriptions."""
websocket: WebSocket
experiment_ids: set[str] = field(default_factory=set)
connected_at: float = field(default_factory=time.time)
class WebSocketManager:
"""Manages WebSocket connections with per-experiment routing and replay.
Features:
- Per-experiment and global subscriptions
- Redis pub/sub bridge (when Redis is available)
- Message replay on reconnection (last N events per experiment + global)
- Clean disconnect handling
"""
def __init__(self, replay_size: int = DEFAULT_REPLAY_SIZE) -> None:
self._replay_size = replay_size
# All active connections keyed by id(websocket)
self._connections: dict[int, ConnectionEntry] = {}
# Experiment-specific connection sets: experiment_id -> set of ws ids
self._experiment_subs: dict[str, set[int]] = {}
# Global subscribers (receive all events)
self._global_subs: set[int] = set()
# Event replay buffers
self._global_replay: deque[dict[str, Any]] = deque(maxlen=replay_size)
self._experiment_replay: dict[str, deque[dict[str, Any]]] = {}
# Redis listener task handle
self._redis_listener_task: asyncio.Task | None = None
# ------------------------------------------------------------------
# Connection lifecycle
# ------------------------------------------------------------------
async def connect(
self,
websocket: WebSocket,
experiment_ids: list[str] | None = None,
) -> None:
"""Accept a WebSocket connection and register subscriptions.
Args:
websocket: The incoming WebSocket connection.
experiment_ids: Optional list of experiment IDs to subscribe to.
If empty/None the connection receives global broadcasts only.
"""
await websocket.accept()
ws_id = id(websocket)
entry = ConnectionEntry(websocket=websocket)
self._connections[ws_id] = entry
if experiment_ids:
for exp_id in experiment_ids:
self._subscribe_experiment(ws_id, exp_id)
else:
self._global_subs.add(ws_id)
def disconnect(self, websocket: WebSocket) -> None:
"""Remove a WebSocket from all tracking structures."""
ws_id = id(websocket)
entry = self._connections.pop(ws_id, None)
if entry is None:
return
# Remove from global subs
self._global_subs.discard(ws_id)
# Remove from all experiment subs
for exp_id in list(entry.experiment_ids):
exp_set = self._experiment_subs.get(exp_id)
if exp_set is not None:
exp_set.discard(ws_id)
if not exp_set:
del self._experiment_subs[exp_id]
def _subscribe_experiment(self, ws_id: int, experiment_id: str) -> None:
"""Subscribe a connection to a specific experiment's events."""
entry = self._connections.get(ws_id)
if entry is None:
return
entry.experiment_ids.add(experiment_id)
if experiment_id not in self._experiment_subs:
self._experiment_subs[experiment_id] = set()
self._experiment_subs[experiment_id].add(ws_id)
# ------------------------------------------------------------------
# Subscription management (runtime)
# ------------------------------------------------------------------
def subscribe(self, websocket: WebSocket, experiment_id: str) -> None:
"""Add an experiment subscription for an existing connection."""
ws_id = id(websocket)
if ws_id in self._connections:
self._subscribe_experiment(ws_id, experiment_id)
def unsubscribe(self, websocket: WebSocket, experiment_id: str) -> None:
"""Remove an experiment subscription for an existing connection."""
ws_id = id(websocket)
entry = self._connections.get(ws_id)
if entry is None:
return
entry.experiment_ids.discard(experiment_id)
exp_set = self._experiment_subs.get(experiment_id)
if exp_set is not None:
exp_set.discard(ws_id)
if not exp_set:
del self._experiment_subs[experiment_id]
# ------------------------------------------------------------------
# Broadcasting
# ------------------------------------------------------------------
async def broadcast_to_experiment(
self, experiment_id: str, message: dict[str, Any]
) -> None:
"""Send a message to all connections subscribed to an experiment."""
# Store in replay buffer
if experiment_id not in self._experiment_replay:
self._experiment_replay[experiment_id] = deque(maxlen=self._replay_size)
message_with_ts = {**message, "_ts": time.time()}
self._experiment_replay[experiment_id].append(message_with_ts)
ws_ids = self._experiment_subs.get(experiment_id, set())
await self._send_to_many(ws_ids, message)
async def broadcast_global(self, message: dict[str, Any]) -> None:
"""Send a message to all global subscribers."""
message_with_ts = {**message, "_ts": time.time()}
self._global_replay.append(message_with_ts)
await self._send_to_many(self._global_subs, message)
async def broadcast(self, message: dict[str, Any]) -> None:
"""Route a message to the right subscribers based on experiment_id.
If the message contains an ``experiment_id`` field, it is sent to
that experiment's subscribers *and* to global subscribers. Otherwise
it is sent only to global subscribers.
"""
experiment_id = message.get("experiment_id")
if experiment_id:
await self.broadcast_to_experiment(str(experiment_id), message)
await self.broadcast_global(message)
async def _send_to_many(
self, ws_ids: set[int], message: dict[str, Any]
) -> None:
"""Send a JSON message to a set of connections, disconnecting failures."""
dead: list[int] = []
for ws_id in list(ws_ids):
entry = self._connections.get(ws_id)
if entry is None:
dead.append(ws_id)
continue
try:
await entry.websocket.send_json(message)
except Exception:
logger.debug("WebSocket send failed, removing connection %s", ws_id)
dead.append(ws_id)
# Clean up dead connections
for ws_id in dead:
ws_ids.discard(ws_id)
entry = self._connections.pop(ws_id, None)
if entry:
self._global_subs.discard(ws_id)
for exp_id in entry.experiment_ids:
exp_set = self._experiment_subs.get(exp_id)
if exp_set:
exp_set.discard(ws_id)
# ------------------------------------------------------------------
# Message replay
# ------------------------------------------------------------------
async def replay(
self,
websocket: WebSocket,
experiment_id: str | None = None,
since_ts: float | None = None,
limit: int | None = None,
) -> int:
"""Replay recent events to a connection.
Args:
websocket: Target connection.
experiment_id: If provided, replay experiment-specific events;
otherwise replay global events.
since_ts: Only replay events after this Unix timestamp.
limit: Max number of events to replay (default: all in buffer).
Returns:
Number of events replayed.
"""
if experiment_id:
buffer = self._experiment_replay.get(experiment_id, deque())
else:
buffer = self._global_replay
events = list(buffer)
if since_ts is not None:
events = [e for e in events if e.get("_ts", 0) > since_ts]
if limit is not None:
events = events[-limit:]
count = 0
for event in events:
try:
await websocket.send_json(event)
count += 1
except Exception:
break
return count
# ------------------------------------------------------------------
# Redis pub/sub bridge
# ------------------------------------------------------------------
async def start_redis_listener(self, redis_client: Any) -> None:
"""Start a background task that subscribes to Redis pub/sub and
broadcasts received messages to connected WebSocket clients.
Uses async Redis pubsub if available, otherwise runs the blocking
listener in a thread executor.
"""
if redis_client is None:
return
self._redis_listener_task = asyncio.create_task(
self._redis_listen_loop(redis_client)
)
async def _redis_listen_loop(self, redis_client: Any) -> None:
"""Listen for Redis pub/sub messages and broadcast them."""
channel = "promptlooper:events"
try:
# Check if this is an async redis client (aioredis / redis.asyncio)
if hasattr(redis_client, "pubsub") and asyncio.iscoroutinefunction(
getattr(redis_client, "subscribe", None)
):
await self._redis_listen_async(redis_client, channel)
else:
await self._redis_listen_sync(redis_client, channel)
except asyncio.CancelledError:
logger.info("Redis listener cancelled")
except Exception:
logger.exception("Redis listener error")
async def _redis_listen_async(self, redis_client: Any, channel: str) -> None:
"""Async Redis pub/sub listener."""
pubsub = redis_client.pubsub()
await pubsub.subscribe(channel)
try:
async for raw_message in pubsub.listen():
if raw_message["type"] != "message":
continue
await self._handle_redis_message(raw_message["data"])
finally:
await pubsub.unsubscribe(channel)
async def _redis_listen_sync(self, redis_client: Any, channel: str) -> None:
"""Sync Redis pub/sub listener (run blocking calls in executor)."""
pubsub = redis_client.pubsub()
pubsub.subscribe(channel)
loop = asyncio.get_running_loop()
try:
while True:
raw_message = await loop.run_in_executor(
None, lambda: pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
)
if raw_message is None:
continue
if raw_message["type"] != "message":
continue
await self._handle_redis_message(raw_message["data"])
finally:
pubsub.unsubscribe(channel)
async def _handle_redis_message(self, data: str | bytes) -> None:
"""Parse a Redis message and broadcast it."""
try:
if isinstance(data, bytes):
data = data.decode("utf-8")
message = json.loads(data)
except (json.JSONDecodeError, UnicodeDecodeError):
logger.warning("Invalid Redis message: %s", data)
return
await self.broadcast(message)
async def stop_redis_listener(self) -> None:
"""Cancel the Redis listener background task."""
if self._redis_listener_task is not None:
self._redis_listener_task.cancel()
try:
await self._redis_listener_task
except asyncio.CancelledError:
pass
self._redis_listener_task = None
# ------------------------------------------------------------------
# Status
# ------------------------------------------------------------------
@property
def connection_count(self) -> int:
"""Total number of active connections."""
return len(self._connections)
@property
def experiment_ids(self) -> list[str]:
"""List of experiment IDs with active subscriptions."""
return list(self._experiment_subs.keys())
def stats(self) -> dict[str, Any]:
"""Return connection statistics."""
return {
"total_connections": self.connection_count,
"global_subscribers": len(self._global_subs),
"experiment_subscriptions": {
exp_id: len(ws_ids)
for exp_id, ws_ids in self._experiment_subs.items()
},
"replay_buffer_sizes": {
"global": len(self._global_replay),
**{
exp_id: len(buf)
for exp_id, buf in self._experiment_replay.items()
},
},
}
# Module-level singleton for use across the application
manager = WebSocketManager()