- WebSocketManager in backend/websocket/manager.py with per-experiment and global subscriptions - Redis pub/sub bridge (sync + async) broadcasting events to relevant WebSocket clients - Deque-based replay buffers with since_ts/limit filtering for reconnection support - Runtime subscribe/unsubscribe and stats API - Enhanced /ws endpoint in main.py with subscribe/unsubscribe/replay actions - 35 tests in test_ws_manager.py, all passing
358 lines
13 KiB
Python
358 lines
13 KiB
Python
"""WebSocket connection manager for PromptLooper.
|
|
|
|
Manages per-experiment and global WebSocket connections, bridges Redis
|
|
pub/sub events to connected clients, and supports message replay on
|
|
reconnection.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from collections import deque
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from fastapi import WebSocket
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default number of events to retain for replay on reconnect
|
|
DEFAULT_REPLAY_SIZE = 50
|
|
|
|
|
|
@dataclass
|
|
class ConnectionEntry:
|
|
"""Tracks a single WebSocket connection and its subscriptions."""
|
|
|
|
websocket: WebSocket
|
|
experiment_ids: set[str] = field(default_factory=set)
|
|
connected_at: float = field(default_factory=time.time)
|
|
|
|
|
|
class WebSocketManager:
|
|
"""Manages WebSocket connections with per-experiment routing and replay.
|
|
|
|
Features:
|
|
- Per-experiment and global subscriptions
|
|
- Redis pub/sub bridge (when Redis is available)
|
|
- Message replay on reconnection (last N events per experiment + global)
|
|
- Clean disconnect handling
|
|
"""
|
|
|
|
def __init__(self, replay_size: int = DEFAULT_REPLAY_SIZE) -> None:
|
|
self._replay_size = replay_size
|
|
# All active connections keyed by id(websocket)
|
|
self._connections: dict[int, ConnectionEntry] = {}
|
|
# Experiment-specific connection sets: experiment_id -> set of ws ids
|
|
self._experiment_subs: dict[str, set[int]] = {}
|
|
# Global subscribers (receive all events)
|
|
self._global_subs: set[int] = set()
|
|
# Event replay buffers
|
|
self._global_replay: deque[dict[str, Any]] = deque(maxlen=replay_size)
|
|
self._experiment_replay: dict[str, deque[dict[str, Any]]] = {}
|
|
# Redis listener task handle
|
|
self._redis_listener_task: asyncio.Task | None = None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Connection lifecycle
|
|
# ------------------------------------------------------------------
|
|
|
|
async def connect(
|
|
self,
|
|
websocket: WebSocket,
|
|
experiment_ids: list[str] | None = None,
|
|
) -> None:
|
|
"""Accept a WebSocket connection and register subscriptions.
|
|
|
|
Args:
|
|
websocket: The incoming WebSocket connection.
|
|
experiment_ids: Optional list of experiment IDs to subscribe to.
|
|
If empty/None the connection receives global broadcasts only.
|
|
"""
|
|
await websocket.accept()
|
|
ws_id = id(websocket)
|
|
entry = ConnectionEntry(websocket=websocket)
|
|
self._connections[ws_id] = entry
|
|
|
|
if experiment_ids:
|
|
for exp_id in experiment_ids:
|
|
self._subscribe_experiment(ws_id, exp_id)
|
|
else:
|
|
self._global_subs.add(ws_id)
|
|
|
|
def disconnect(self, websocket: WebSocket) -> None:
|
|
"""Remove a WebSocket from all tracking structures."""
|
|
ws_id = id(websocket)
|
|
entry = self._connections.pop(ws_id, None)
|
|
if entry is None:
|
|
return
|
|
|
|
# Remove from global subs
|
|
self._global_subs.discard(ws_id)
|
|
|
|
# Remove from all experiment subs
|
|
for exp_id in list(entry.experiment_ids):
|
|
exp_set = self._experiment_subs.get(exp_id)
|
|
if exp_set is not None:
|
|
exp_set.discard(ws_id)
|
|
if not exp_set:
|
|
del self._experiment_subs[exp_id]
|
|
|
|
def _subscribe_experiment(self, ws_id: int, experiment_id: str) -> None:
|
|
"""Subscribe a connection to a specific experiment's events."""
|
|
entry = self._connections.get(ws_id)
|
|
if entry is None:
|
|
return
|
|
entry.experiment_ids.add(experiment_id)
|
|
if experiment_id not in self._experiment_subs:
|
|
self._experiment_subs[experiment_id] = set()
|
|
self._experiment_subs[experiment_id].add(ws_id)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Subscription management (runtime)
|
|
# ------------------------------------------------------------------
|
|
|
|
def subscribe(self, websocket: WebSocket, experiment_id: str) -> None:
|
|
"""Add an experiment subscription for an existing connection."""
|
|
ws_id = id(websocket)
|
|
if ws_id in self._connections:
|
|
self._subscribe_experiment(ws_id, experiment_id)
|
|
|
|
def unsubscribe(self, websocket: WebSocket, experiment_id: str) -> None:
|
|
"""Remove an experiment subscription for an existing connection."""
|
|
ws_id = id(websocket)
|
|
entry = self._connections.get(ws_id)
|
|
if entry is None:
|
|
return
|
|
entry.experiment_ids.discard(experiment_id)
|
|
exp_set = self._experiment_subs.get(experiment_id)
|
|
if exp_set is not None:
|
|
exp_set.discard(ws_id)
|
|
if not exp_set:
|
|
del self._experiment_subs[experiment_id]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Broadcasting
|
|
# ------------------------------------------------------------------
|
|
|
|
async def broadcast_to_experiment(
|
|
self, experiment_id: str, message: dict[str, Any]
|
|
) -> None:
|
|
"""Send a message to all connections subscribed to an experiment."""
|
|
# Store in replay buffer
|
|
if experiment_id not in self._experiment_replay:
|
|
self._experiment_replay[experiment_id] = deque(maxlen=self._replay_size)
|
|
message_with_ts = {**message, "_ts": time.time()}
|
|
self._experiment_replay[experiment_id].append(message_with_ts)
|
|
|
|
ws_ids = self._experiment_subs.get(experiment_id, set())
|
|
await self._send_to_many(ws_ids, message)
|
|
|
|
async def broadcast_global(self, message: dict[str, Any]) -> None:
|
|
"""Send a message to all global subscribers."""
|
|
message_with_ts = {**message, "_ts": time.time()}
|
|
self._global_replay.append(message_with_ts)
|
|
await self._send_to_many(self._global_subs, message)
|
|
|
|
async def broadcast(self, message: dict[str, Any]) -> None:
|
|
"""Route a message to the right subscribers based on experiment_id.
|
|
|
|
If the message contains an ``experiment_id`` field, it is sent to
|
|
that experiment's subscribers *and* to global subscribers. Otherwise
|
|
it is sent only to global subscribers.
|
|
"""
|
|
experiment_id = message.get("experiment_id")
|
|
if experiment_id:
|
|
await self.broadcast_to_experiment(str(experiment_id), message)
|
|
await self.broadcast_global(message)
|
|
|
|
async def _send_to_many(
|
|
self, ws_ids: set[int], message: dict[str, Any]
|
|
) -> None:
|
|
"""Send a JSON message to a set of connections, disconnecting failures."""
|
|
dead: list[int] = []
|
|
for ws_id in list(ws_ids):
|
|
entry = self._connections.get(ws_id)
|
|
if entry is None:
|
|
dead.append(ws_id)
|
|
continue
|
|
try:
|
|
await entry.websocket.send_json(message)
|
|
except Exception:
|
|
logger.debug("WebSocket send failed, removing connection %s", ws_id)
|
|
dead.append(ws_id)
|
|
# Clean up dead connections
|
|
for ws_id in dead:
|
|
ws_ids.discard(ws_id)
|
|
entry = self._connections.pop(ws_id, None)
|
|
if entry:
|
|
self._global_subs.discard(ws_id)
|
|
for exp_id in entry.experiment_ids:
|
|
exp_set = self._experiment_subs.get(exp_id)
|
|
if exp_set:
|
|
exp_set.discard(ws_id)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Message replay
|
|
# ------------------------------------------------------------------
|
|
|
|
async def replay(
|
|
self,
|
|
websocket: WebSocket,
|
|
experiment_id: str | None = None,
|
|
since_ts: float | None = None,
|
|
limit: int | None = None,
|
|
) -> int:
|
|
"""Replay recent events to a connection.
|
|
|
|
Args:
|
|
websocket: Target connection.
|
|
experiment_id: If provided, replay experiment-specific events;
|
|
otherwise replay global events.
|
|
since_ts: Only replay events after this Unix timestamp.
|
|
limit: Max number of events to replay (default: all in buffer).
|
|
|
|
Returns:
|
|
Number of events replayed.
|
|
"""
|
|
if experiment_id:
|
|
buffer = self._experiment_replay.get(experiment_id, deque())
|
|
else:
|
|
buffer = self._global_replay
|
|
|
|
events = list(buffer)
|
|
if since_ts is not None:
|
|
events = [e for e in events if e.get("_ts", 0) > since_ts]
|
|
if limit is not None:
|
|
events = events[-limit:]
|
|
|
|
count = 0
|
|
for event in events:
|
|
try:
|
|
await websocket.send_json(event)
|
|
count += 1
|
|
except Exception:
|
|
break
|
|
return count
|
|
|
|
# ------------------------------------------------------------------
|
|
# Redis pub/sub bridge
|
|
# ------------------------------------------------------------------
|
|
|
|
async def start_redis_listener(self, redis_client: Any) -> None:
|
|
"""Start a background task that subscribes to Redis pub/sub and
|
|
broadcasts received messages to connected WebSocket clients.
|
|
|
|
Uses async Redis pubsub if available, otherwise runs the blocking
|
|
listener in a thread executor.
|
|
"""
|
|
if redis_client is None:
|
|
return
|
|
|
|
self._redis_listener_task = asyncio.create_task(
|
|
self._redis_listen_loop(redis_client)
|
|
)
|
|
|
|
async def _redis_listen_loop(self, redis_client: Any) -> None:
|
|
"""Listen for Redis pub/sub messages and broadcast them."""
|
|
channel = "promptlooper:events"
|
|
try:
|
|
# Check if this is an async redis client (aioredis / redis.asyncio)
|
|
if hasattr(redis_client, "pubsub") and asyncio.iscoroutinefunction(
|
|
getattr(redis_client, "subscribe", None)
|
|
):
|
|
await self._redis_listen_async(redis_client, channel)
|
|
else:
|
|
await self._redis_listen_sync(redis_client, channel)
|
|
except asyncio.CancelledError:
|
|
logger.info("Redis listener cancelled")
|
|
except Exception:
|
|
logger.exception("Redis listener error")
|
|
|
|
async def _redis_listen_async(self, redis_client: Any, channel: str) -> None:
|
|
"""Async Redis pub/sub listener."""
|
|
pubsub = redis_client.pubsub()
|
|
await pubsub.subscribe(channel)
|
|
try:
|
|
async for raw_message in pubsub.listen():
|
|
if raw_message["type"] != "message":
|
|
continue
|
|
await self._handle_redis_message(raw_message["data"])
|
|
finally:
|
|
await pubsub.unsubscribe(channel)
|
|
|
|
async def _redis_listen_sync(self, redis_client: Any, channel: str) -> None:
|
|
"""Sync Redis pub/sub listener (run blocking calls in executor)."""
|
|
pubsub = redis_client.pubsub()
|
|
pubsub.subscribe(channel)
|
|
loop = asyncio.get_running_loop()
|
|
try:
|
|
while True:
|
|
raw_message = await loop.run_in_executor(
|
|
None, lambda: pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
|
)
|
|
if raw_message is None:
|
|
continue
|
|
if raw_message["type"] != "message":
|
|
continue
|
|
await self._handle_redis_message(raw_message["data"])
|
|
finally:
|
|
pubsub.unsubscribe(channel)
|
|
|
|
async def _handle_redis_message(self, data: str | bytes) -> None:
|
|
"""Parse a Redis message and broadcast it."""
|
|
try:
|
|
if isinstance(data, bytes):
|
|
data = data.decode("utf-8")
|
|
message = json.loads(data)
|
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
logger.warning("Invalid Redis message: %s", data)
|
|
return
|
|
await self.broadcast(message)
|
|
|
|
async def stop_redis_listener(self) -> None:
|
|
"""Cancel the Redis listener background task."""
|
|
if self._redis_listener_task is not None:
|
|
self._redis_listener_task.cancel()
|
|
try:
|
|
await self._redis_listener_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._redis_listener_task = None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Status
|
|
# ------------------------------------------------------------------
|
|
|
|
@property
|
|
def connection_count(self) -> int:
|
|
"""Total number of active connections."""
|
|
return len(self._connections)
|
|
|
|
@property
|
|
def experiment_ids(self) -> list[str]:
|
|
"""List of experiment IDs with active subscriptions."""
|
|
return list(self._experiment_subs.keys())
|
|
|
|
def stats(self) -> dict[str, Any]:
|
|
"""Return connection statistics."""
|
|
return {
|
|
"total_connections": self.connection_count,
|
|
"global_subscribers": len(self._global_subs),
|
|
"experiment_subscriptions": {
|
|
exp_id: len(ws_ids)
|
|
for exp_id, ws_ids in self._experiment_subs.items()
|
|
},
|
|
"replay_buffer_sizes": {
|
|
"global": len(self._global_replay),
|
|
**{
|
|
exp_id: len(buf)
|
|
for exp_id, buf in self._experiment_replay.items()
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
# Module-level singleton for use across the application
|
|
manager = WebSocketManager()
|