From 30fd15ec7aa3d5bf4a49a4a84f58ff00c3dc6bae Mon Sep 17 00:00:00 2001 From: John Lightner Date: Tue, 7 Apr 2026 03:34:21 -0500 Subject: [PATCH] MAESTRO: Implement WebSocket connection manager with per-experiment routing, Redis pub/sub bridge, and message replay - WebSocketManager in backend/websocket/manager.py with per-experiment and global subscriptions - Redis pub/sub bridge (sync + async) broadcasting events to relevant WebSocket clients - Deque-based replay buffers with since_ts/limit filtering for reconnection support - Runtime subscribe/unsubscribe and stats API - Enhanced /ws endpoint in main.py with subscribe/unsubscribe/replay actions - 35 tests in test_ws_manager.py, all passing --- Auto Run Docs/02a-backend-engine.md | 3 +- backend/main.py | 59 ++-- backend/tests/test_main.py | 19 +- backend/tests/test_ws_manager.py | 484 ++++++++++++++++++++++++++++ backend/websocket/manager.py | 358 ++++++++++++++++++++ 5 files changed, 886 insertions(+), 37 deletions(-) create mode 100644 backend/tests/test_ws_manager.py create mode 100644 backend/websocket/manager.py diff --git a/Auto Run Docs/02a-backend-engine.md b/Auto Run Docs/02a-backend-engine.md index bcf1207..f8cc8f9 100644 --- a/Auto Run Docs/02a-backend-engine.md +++ b/Auto Run Docs/02a-backend-engine.md @@ -44,7 +44,8 @@ Implement the core experiment execution engine: LLM adapters, response caching, - [x] Implement backend/routers/export.py — export best config in JSON, .env, and YAML formats as defined in the spec. Include metadata (score, experiment name, timestamp). The report endpoint should generate a markdown summary of the experiment: config space explored, top 5 configs, score distributions, token usage, timing stats. -- [ ] Implement backend/websocket/manager.py — WebSocket connection manager that: maintains active connections per experiment and globally, receives Redis pub/sub messages and broadcasts to relevant connections, handles connection/disconnection cleanly, supports reconnection with message replay (last N events). +- [x] Implement backend/websocket/manager.py — WebSocket connection manager that: maintains active connections per experiment and globally, receives Redis pub/sub messages and broadcasts to relevant connections, handles connection/disconnection cleanly, supports reconnection with message replay (last N events). + - [ ] Implement backend/routers/webhooks.py — CRUD for webhook configs. When events occur (in runner.py and sweep.py), dispatch webhook calls asynchronously via Celery. Include retry logic (3 attempts) and log delivery status. diff --git a/backend/main.py b/backend/main.py index 1afe997..4ba35e6 100644 --- a/backend/main.py +++ b/backend/main.py @@ -67,28 +67,9 @@ def get_redis(): # WebSocket connection manager # --------------------------------------------------------------------------- -class ConnectionManager: - """Manage active WebSocket connections.""" +from websocket.manager import WebSocketManager - def __init__(self) -> None: - self.active_connections: list[WebSocket] = [] - - async def connect(self, websocket: WebSocket) -> None: - await websocket.accept() - self.active_connections.append(websocket) - - def disconnect(self, websocket: WebSocket) -> None: - self.active_connections.remove(websocket) - - async def broadcast(self, message: dict) -> None: - for connection in list(self.active_connections): - try: - await connection.send_json(message) - except Exception: - self.disconnect(connection) - - -ws_manager = ConnectionManager() +ws_manager = WebSocketManager() # --------------------------------------------------------------------------- @@ -100,8 +81,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Startup and shutdown lifecycle hooks.""" _init_db() _init_redis() + # Start WebSocket ↔ Redis bridge + await ws_manager.start_redis_listener(_redis_client) yield # Shutdown: clean up connections + await ws_manager.stop_redis_listener() if _redis_client is not None: _redis_client.close() if engine is not None: @@ -167,14 +151,37 @@ def health_check() -> dict: @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket) -> None: - """WebSocket connection for real-time dashboard updates.""" - await ws_manager.connect(websocket) + """WebSocket connection for real-time dashboard updates. + + Clients may send JSON messages to control their subscriptions: + - {"action": "subscribe", "experiment_id": "..."} + - {"action": "unsubscribe", "experiment_id": "..."} + - {"action": "replay", "experiment_id": "...", "since_ts": ..., "limit": ...} + """ + # Parse query params for initial experiment subscriptions + exp_ids_param = websocket.query_params.get("experiment_ids", "") + experiment_ids = [e.strip() for e in exp_ids_param.split(",") if e.strip()] or None + await ws_manager.connect(websocket, experiment_ids=experiment_ids) try: while True: - # Keep connection alive; handle incoming messages if needed data = await websocket.receive_json() - # Echo back or handle client messages in future - await websocket.send_json({"type": "ack", "data": data}) + action = data.get("action", "") + if action == "subscribe" and data.get("experiment_id"): + ws_manager.subscribe(websocket, data["experiment_id"]) + await websocket.send_json({"type": "ack", "action": "subscribe", "experiment_id": data["experiment_id"]}) + elif action == "unsubscribe" and data.get("experiment_id"): + ws_manager.unsubscribe(websocket, data["experiment_id"]) + await websocket.send_json({"type": "ack", "action": "unsubscribe", "experiment_id": data["experiment_id"]}) + elif action == "replay": + count = await ws_manager.replay( + websocket, + experiment_id=data.get("experiment_id"), + since_ts=data.get("since_ts"), + limit=data.get("limit"), + ) + await websocket.send_json({"type": "ack", "action": "replay", "events_replayed": count}) + else: + await websocket.send_json({"type": "ack", "data": data}) except WebSocketDisconnect: ws_manager.disconnect(websocket) diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py index 76cf5dd..300514f 100644 --- a/backend/tests/test_main.py +++ b/backend/tests/test_main.py @@ -77,11 +77,11 @@ class TestWebSocket: def test_websocket_disconnect_cleanup(self, client): from main import ws_manager - initial_count = len(ws_manager.active_connections) + initial_count = ws_manager.connection_count with client.websocket_connect("/ws") as ws: - assert len(ws_manager.active_connections) == initial_count + 1 + assert ws_manager.connection_count == initial_count + 1 # After disconnect, connection should be removed - assert len(ws_manager.active_connections) == initial_count + assert ws_manager.connection_count == initial_count class TestRouterMounting: @@ -97,16 +97,15 @@ class TestRouterMounting: class TestConnectionManager: - def test_broadcast_removes_dead_connections(self): - """ConnectionManager.broadcast skips and removes broken connections.""" - from main import ConnectionManager - manager = ConnectionManager() - # No connections — broadcast should not raise + def test_broadcast_on_empty_manager(self): + """WebSocketManager.broadcast_global on empty manager should not raise.""" + from websocket.manager import WebSocketManager + manager = WebSocketManager() import asyncio asyncio.get_event_loop().run_until_complete( - manager.broadcast({"test": True}) + manager.broadcast_global({"test": True}) ) - assert len(manager.active_connections) == 0 + assert manager.connection_count == 0 class TestGetDb: diff --git a/backend/tests/test_ws_manager.py b/backend/tests/test_ws_manager.py new file mode 100644 index 0000000..0a07f3a --- /dev/null +++ b/backend/tests/test_ws_manager.py @@ -0,0 +1,484 @@ +"""Tests for the WebSocket connection manager.""" + +import asyncio +import json +import time +from collections import deque +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from websocket.manager import ConnectionEntry, WebSocketManager + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_ws(accept_side_effect=None) -> AsyncMock: + """Create a mock WebSocket.""" + ws = AsyncMock() + ws.accept = AsyncMock(side_effect=accept_side_effect) + ws.send_json = AsyncMock() + ws.receive_json = AsyncMock() + return ws + + +# --------------------------------------------------------------------------- +# Connection lifecycle +# --------------------------------------------------------------------------- + + +class TestConnect: + @pytest.mark.asyncio + async def test_accept_and_register_global(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws) + + assert mgr.connection_count == 1 + assert id(ws) in mgr._global_subs + ws.accept.assert_awaited_once() + + @pytest.mark.asyncio + async def test_accept_with_experiment_ids(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws, experiment_ids=["exp-1", "exp-2"]) + + assert mgr.connection_count == 1 + assert id(ws) not in mgr._global_subs + assert id(ws) in mgr._experiment_subs["exp-1"] + assert id(ws) in mgr._experiment_subs["exp-2"] + + @pytest.mark.asyncio + async def test_multiple_connections(self): + mgr = WebSocketManager() + ws1 = _make_ws() + ws2 = _make_ws() + await mgr.connect(ws1) + await mgr.connect(ws2, experiment_ids=["exp-1"]) + + assert mgr.connection_count == 2 + + +class TestDisconnect: + @pytest.mark.asyncio + async def test_disconnect_global(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws) + mgr.disconnect(ws) + + assert mgr.connection_count == 0 + assert id(ws) not in mgr._global_subs + + @pytest.mark.asyncio + async def test_disconnect_experiment(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws, experiment_ids=["exp-1"]) + mgr.disconnect(ws) + + assert mgr.connection_count == 0 + assert "exp-1" not in mgr._experiment_subs + + @pytest.mark.asyncio + async def test_disconnect_unknown_ws_is_noop(self): + mgr = WebSocketManager() + ws = _make_ws() + mgr.disconnect(ws) # Should not raise + assert mgr.connection_count == 0 + + @pytest.mark.asyncio + async def test_disconnect_cleans_up_empty_experiment_sets(self): + mgr = WebSocketManager() + ws1 = _make_ws() + ws2 = _make_ws() + await mgr.connect(ws1, experiment_ids=["exp-1"]) + await mgr.connect(ws2, experiment_ids=["exp-1"]) + + mgr.disconnect(ws1) + assert "exp-1" in mgr._experiment_subs + assert len(mgr._experiment_subs["exp-1"]) == 1 + + mgr.disconnect(ws2) + assert "exp-1" not in mgr._experiment_subs + + +# --------------------------------------------------------------------------- +# Subscribe / unsubscribe at runtime +# --------------------------------------------------------------------------- + + +class TestSubscriptions: + @pytest.mark.asyncio + async def test_subscribe_adds_experiment(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws) + mgr.subscribe(ws, "exp-42") + + assert id(ws) in mgr._experiment_subs["exp-42"] + + @pytest.mark.asyncio + async def test_unsubscribe_removes_experiment(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws, experiment_ids=["exp-1"]) + mgr.unsubscribe(ws, "exp-1") + + assert "exp-1" not in mgr._experiment_subs + entry = mgr._connections[id(ws)] + assert "exp-1" not in entry.experiment_ids + + @pytest.mark.asyncio + async def test_subscribe_unknown_ws_is_noop(self): + mgr = WebSocketManager() + ws = _make_ws() + mgr.subscribe(ws, "exp-1") # Not connected + assert "exp-1" not in mgr._experiment_subs + + @pytest.mark.asyncio + async def test_unsubscribe_unknown_ws_is_noop(self): + mgr = WebSocketManager() + ws = _make_ws() + mgr.unsubscribe(ws, "exp-1") # Should not raise + + +# --------------------------------------------------------------------------- +# Broadcasting +# --------------------------------------------------------------------------- + + +class TestBroadcast: + @pytest.mark.asyncio + async def test_broadcast_global(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws) + + await mgr.broadcast_global({"type": "test"}) + ws.send_json.assert_awaited_once_with({"type": "test"}) + + @pytest.mark.asyncio + async def test_broadcast_to_experiment(self): + mgr = WebSocketManager() + ws_exp = _make_ws() + ws_global = _make_ws() + await mgr.connect(ws_exp, experiment_ids=["exp-1"]) + await mgr.connect(ws_global) + + await mgr.broadcast_to_experiment("exp-1", {"type": "test"}) + ws_exp.send_json.assert_awaited_once() + ws_global.send_json.assert_not_awaited() + + @pytest.mark.asyncio + async def test_broadcast_routes_by_experiment_id(self): + mgr = WebSocketManager() + ws_exp = _make_ws() + ws_global = _make_ws() + await mgr.connect(ws_exp, experiment_ids=["exp-1"]) + await mgr.connect(ws_global) + + msg = {"type": "run.completed", "experiment_id": "exp-1"} + await mgr.broadcast(msg) + + # experiment subscriber gets the message (from broadcast_to_experiment) + assert ws_exp.send_json.await_count == 1 + # global subscriber gets it (from broadcast_global) + assert ws_global.send_json.await_count == 1 + + @pytest.mark.asyncio + async def test_broadcast_without_experiment_id_goes_global_only(self): + mgr = WebSocketManager() + ws_exp = _make_ws() + ws_global = _make_ws() + await mgr.connect(ws_exp, experiment_ids=["exp-1"]) + await mgr.connect(ws_global) + + await mgr.broadcast({"type": "system.status"}) + ws_exp.send_json.assert_not_awaited() + ws_global.send_json.assert_awaited_once() + + @pytest.mark.asyncio + async def test_broadcast_removes_dead_connections(self): + mgr = WebSocketManager() + ws_good = _make_ws() + ws_dead = _make_ws() + ws_dead.send_json = AsyncMock(side_effect=RuntimeError("closed")) + await mgr.connect(ws_good) + await mgr.connect(ws_dead) + + await mgr.broadcast_global({"type": "test"}) + + assert mgr.connection_count == 1 + assert id(ws_dead) not in mgr._connections + + @pytest.mark.asyncio + async def test_broadcast_to_multiple_experiment_subscribers(self): + mgr = WebSocketManager() + ws1 = _make_ws() + ws2 = _make_ws() + await mgr.connect(ws1, experiment_ids=["exp-1"]) + await mgr.connect(ws2, experiment_ids=["exp-1"]) + + await mgr.broadcast_to_experiment("exp-1", {"type": "test"}) + ws1.send_json.assert_awaited_once() + ws2.send_json.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# Replay +# --------------------------------------------------------------------------- + + +class TestReplay: + @pytest.mark.asyncio + async def test_replay_global_events(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws) + + # Store some events + await mgr.broadcast_global({"type": "event1"}) + await mgr.broadcast_global({"type": "event2"}) + + # Reset send tracking + ws.send_json.reset_mock() + + count = await mgr.replay(ws) + assert count == 2 + assert ws.send_json.await_count == 2 + + @pytest.mark.asyncio + async def test_replay_experiment_events(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws, experiment_ids=["exp-1"]) + + await mgr.broadcast_to_experiment("exp-1", {"type": "e1"}) + await mgr.broadcast_to_experiment("exp-1", {"type": "e2"}) + await mgr.broadcast_to_experiment("exp-1", {"type": "e3"}) + + ws.send_json.reset_mock() + + count = await mgr.replay(ws, experiment_id="exp-1") + assert count == 3 + + @pytest.mark.asyncio + async def test_replay_with_since_ts(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws) + + # Manually insert events with known timestamps + ts_old = time.time() - 100 + ts_new = time.time() + mgr._global_replay.append({"type": "old", "_ts": ts_old}) + mgr._global_replay.append({"type": "new", "_ts": ts_new}) + + count = await mgr.replay(ws, since_ts=ts_old + 1) + assert count == 1 + + @pytest.mark.asyncio + async def test_replay_with_limit(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws) + + for i in range(10): + mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()}) + + count = await mgr.replay(ws, limit=3) + assert count == 3 + + @pytest.mark.asyncio + async def test_replay_empty_buffer(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws) + + count = await mgr.replay(ws) + assert count == 0 + + @pytest.mark.asyncio + async def test_replay_respects_buffer_size(self): + mgr = WebSocketManager(replay_size=5) + ws = _make_ws() + await mgr.connect(ws) + + for i in range(10): + await mgr.broadcast_global({"type": f"event{i}"}) + + ws.send_json.reset_mock() + count = await mgr.replay(ws) + assert count == 5 + + @pytest.mark.asyncio + async def test_replay_stops_on_send_failure(self): + mgr = WebSocketManager() + ws = _make_ws() + call_count = 0 + + async def fail_after_one(msg): + nonlocal call_count + call_count += 1 + if call_count > 1: + raise RuntimeError("closed") + + ws.send_json = AsyncMock(side_effect=fail_after_one) + await mgr.connect(ws) + + for i in range(5): + mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()}) + + count = await mgr.replay(ws) + assert count == 1 + + +# --------------------------------------------------------------------------- +# Redis pub/sub bridge +# --------------------------------------------------------------------------- + + +class TestRedisListener: + @pytest.mark.asyncio + async def test_handle_redis_message_json(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws) + + msg = json.dumps({"type": "run.completed", "experiment_id": "exp-1"}) + await mgr._handle_redis_message(msg) + + # Message should have been broadcast + assert ws.send_json.await_count >= 1 + + @pytest.mark.asyncio + async def test_handle_redis_message_bytes(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws) + + msg = json.dumps({"type": "test"}).encode("utf-8") + await mgr._handle_redis_message(msg) + assert ws.send_json.await_count >= 1 + + @pytest.mark.asyncio + async def test_handle_redis_message_invalid_json(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws) + + # Should not raise, just log warning + await mgr._handle_redis_message("not-json{{{") + ws.send_json.assert_not_awaited() + + @pytest.mark.asyncio + async def test_start_redis_listener_none_is_noop(self): + mgr = WebSocketManager() + await mgr.start_redis_listener(None) + assert mgr._redis_listener_task is None + + @pytest.mark.asyncio + async def test_stop_redis_listener_when_none(self): + mgr = WebSocketManager() + await mgr.stop_redis_listener() # Should not raise + + @pytest.mark.asyncio + async def test_stop_cancels_task(self): + mgr = WebSocketManager() + + # Create a long-running dummy task + async def dummy(): + await asyncio.sleep(999) + + mgr._redis_listener_task = asyncio.create_task(dummy()) + await mgr.stop_redis_listener() + assert mgr._redis_listener_task is None + + @pytest.mark.asyncio + async def test_sync_redis_listener_processes_messages(self): + """Test the sync Redis pubsub path with a mock.""" + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws) + + # Mock a sync Redis client + mock_pubsub = MagicMock() + call_count = 0 + + def get_message_side_effect(ignore_subscribe_messages=True, timeout=1.0): + nonlocal call_count + call_count += 1 + if call_count == 1: + return { + "type": "message", + "data": json.dumps({"type": "test_event"}), + } + # After first message, raise to stop the loop + raise asyncio.CancelledError() + + mock_pubsub.get_message = MagicMock(side_effect=get_message_side_effect) + mock_pubsub.subscribe = MagicMock() + mock_pubsub.unsubscribe = MagicMock() + + mock_redis = MagicMock() + mock_redis.pubsub.return_value = mock_pubsub + + # Run the sync listener — it will process one message then get cancelled + with pytest.raises(asyncio.CancelledError): + await mgr._redis_listen_sync(mock_redis, "promptlooper:events") + + mock_pubsub.subscribe.assert_called_once_with("promptlooper:events") + assert ws.send_json.await_count >= 1 + + +# --------------------------------------------------------------------------- +# Stats +# --------------------------------------------------------------------------- + + +class TestStats: + @pytest.mark.asyncio + async def test_stats_empty(self): + mgr = WebSocketManager() + stats = mgr.stats() + assert stats["total_connections"] == 0 + assert stats["global_subscribers"] == 0 + assert stats["experiment_subscriptions"] == {} + + @pytest.mark.asyncio + async def test_stats_with_connections(self): + mgr = WebSocketManager() + ws1 = _make_ws() + ws2 = _make_ws() + await mgr.connect(ws1) + await mgr.connect(ws2, experiment_ids=["exp-1"]) + + stats = mgr.stats() + assert stats["total_connections"] == 2 + assert stats["global_subscribers"] == 1 + assert stats["experiment_subscriptions"]["exp-1"] == 1 + + @pytest.mark.asyncio + async def test_experiment_ids_property(self): + mgr = WebSocketManager() + ws = _make_ws() + await mgr.connect(ws, experiment_ids=["exp-a", "exp-b"]) + + ids = mgr.experiment_ids + assert set(ids) == {"exp-a", "exp-b"} + + @pytest.mark.asyncio + async def test_connection_count_property(self): + mgr = WebSocketManager() + assert mgr.connection_count == 0 + ws = _make_ws() + await mgr.connect(ws) + assert mgr.connection_count == 1 + mgr.disconnect(ws) + assert mgr.connection_count == 0 diff --git a/backend/websocket/manager.py b/backend/websocket/manager.py new file mode 100644 index 0000000..2c6c904 --- /dev/null +++ b/backend/websocket/manager.py @@ -0,0 +1,358 @@ +"""WebSocket connection manager for PromptLooper. + +Manages per-experiment and global WebSocket connections, bridges Redis +pub/sub events to connected clients, and supports message replay on +reconnection. +""" + +import asyncio +import json +import logging +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Any + +from fastapi import WebSocket + +logger = logging.getLogger(__name__) + +# Default number of events to retain for replay on reconnect +DEFAULT_REPLAY_SIZE = 50 + + +@dataclass +class ConnectionEntry: + """Tracks a single WebSocket connection and its subscriptions.""" + + websocket: WebSocket + experiment_ids: set[str] = field(default_factory=set) + connected_at: float = field(default_factory=time.time) + + +class WebSocketManager: + """Manages WebSocket connections with per-experiment routing and replay. + + Features: + - Per-experiment and global subscriptions + - Redis pub/sub bridge (when Redis is available) + - Message replay on reconnection (last N events per experiment + global) + - Clean disconnect handling + """ + + def __init__(self, replay_size: int = DEFAULT_REPLAY_SIZE) -> None: + self._replay_size = replay_size + # All active connections keyed by id(websocket) + self._connections: dict[int, ConnectionEntry] = {} + # Experiment-specific connection sets: experiment_id -> set of ws ids + self._experiment_subs: dict[str, set[int]] = {} + # Global subscribers (receive all events) + self._global_subs: set[int] = set() + # Event replay buffers + self._global_replay: deque[dict[str, Any]] = deque(maxlen=replay_size) + self._experiment_replay: dict[str, deque[dict[str, Any]]] = {} + # Redis listener task handle + self._redis_listener_task: asyncio.Task | None = None + + # ------------------------------------------------------------------ + # Connection lifecycle + # ------------------------------------------------------------------ + + async def connect( + self, + websocket: WebSocket, + experiment_ids: list[str] | None = None, + ) -> None: + """Accept a WebSocket connection and register subscriptions. + + Args: + websocket: The incoming WebSocket connection. + experiment_ids: Optional list of experiment IDs to subscribe to. + If empty/None the connection receives global broadcasts only. + """ + await websocket.accept() + ws_id = id(websocket) + entry = ConnectionEntry(websocket=websocket) + self._connections[ws_id] = entry + + if experiment_ids: + for exp_id in experiment_ids: + self._subscribe_experiment(ws_id, exp_id) + else: + self._global_subs.add(ws_id) + + def disconnect(self, websocket: WebSocket) -> None: + """Remove a WebSocket from all tracking structures.""" + ws_id = id(websocket) + entry = self._connections.pop(ws_id, None) + if entry is None: + return + + # Remove from global subs + self._global_subs.discard(ws_id) + + # Remove from all experiment subs + for exp_id in list(entry.experiment_ids): + exp_set = self._experiment_subs.get(exp_id) + if exp_set is not None: + exp_set.discard(ws_id) + if not exp_set: + del self._experiment_subs[exp_id] + + def _subscribe_experiment(self, ws_id: int, experiment_id: str) -> None: + """Subscribe a connection to a specific experiment's events.""" + entry = self._connections.get(ws_id) + if entry is None: + return + entry.experiment_ids.add(experiment_id) + if experiment_id not in self._experiment_subs: + self._experiment_subs[experiment_id] = set() + self._experiment_subs[experiment_id].add(ws_id) + + # ------------------------------------------------------------------ + # Subscription management (runtime) + # ------------------------------------------------------------------ + + def subscribe(self, websocket: WebSocket, experiment_id: str) -> None: + """Add an experiment subscription for an existing connection.""" + ws_id = id(websocket) + if ws_id in self._connections: + self._subscribe_experiment(ws_id, experiment_id) + + def unsubscribe(self, websocket: WebSocket, experiment_id: str) -> None: + """Remove an experiment subscription for an existing connection.""" + ws_id = id(websocket) + entry = self._connections.get(ws_id) + if entry is None: + return + entry.experiment_ids.discard(experiment_id) + exp_set = self._experiment_subs.get(experiment_id) + if exp_set is not None: + exp_set.discard(ws_id) + if not exp_set: + del self._experiment_subs[experiment_id] + + # ------------------------------------------------------------------ + # Broadcasting + # ------------------------------------------------------------------ + + async def broadcast_to_experiment( + self, experiment_id: str, message: dict[str, Any] + ) -> None: + """Send a message to all connections subscribed to an experiment.""" + # Store in replay buffer + if experiment_id not in self._experiment_replay: + self._experiment_replay[experiment_id] = deque(maxlen=self._replay_size) + message_with_ts = {**message, "_ts": time.time()} + self._experiment_replay[experiment_id].append(message_with_ts) + + ws_ids = self._experiment_subs.get(experiment_id, set()) + await self._send_to_many(ws_ids, message) + + async def broadcast_global(self, message: dict[str, Any]) -> None: + """Send a message to all global subscribers.""" + message_with_ts = {**message, "_ts": time.time()} + self._global_replay.append(message_with_ts) + await self._send_to_many(self._global_subs, message) + + async def broadcast(self, message: dict[str, Any]) -> None: + """Route a message to the right subscribers based on experiment_id. + + If the message contains an ``experiment_id`` field, it is sent to + that experiment's subscribers *and* to global subscribers. Otherwise + it is sent only to global subscribers. + """ + experiment_id = message.get("experiment_id") + if experiment_id: + await self.broadcast_to_experiment(str(experiment_id), message) + await self.broadcast_global(message) + + async def _send_to_many( + self, ws_ids: set[int], message: dict[str, Any] + ) -> None: + """Send a JSON message to a set of connections, disconnecting failures.""" + dead: list[int] = [] + for ws_id in list(ws_ids): + entry = self._connections.get(ws_id) + if entry is None: + dead.append(ws_id) + continue + try: + await entry.websocket.send_json(message) + except Exception: + logger.debug("WebSocket send failed, removing connection %s", ws_id) + dead.append(ws_id) + # Clean up dead connections + for ws_id in dead: + ws_ids.discard(ws_id) + entry = self._connections.pop(ws_id, None) + if entry: + self._global_subs.discard(ws_id) + for exp_id in entry.experiment_ids: + exp_set = self._experiment_subs.get(exp_id) + if exp_set: + exp_set.discard(ws_id) + + # ------------------------------------------------------------------ + # Message replay + # ------------------------------------------------------------------ + + async def replay( + self, + websocket: WebSocket, + experiment_id: str | None = None, + since_ts: float | None = None, + limit: int | None = None, + ) -> int: + """Replay recent events to a connection. + + Args: + websocket: Target connection. + experiment_id: If provided, replay experiment-specific events; + otherwise replay global events. + since_ts: Only replay events after this Unix timestamp. + limit: Max number of events to replay (default: all in buffer). + + Returns: + Number of events replayed. + """ + if experiment_id: + buffer = self._experiment_replay.get(experiment_id, deque()) + else: + buffer = self._global_replay + + events = list(buffer) + if since_ts is not None: + events = [e for e in events if e.get("_ts", 0) > since_ts] + if limit is not None: + events = events[-limit:] + + count = 0 + for event in events: + try: + await websocket.send_json(event) + count += 1 + except Exception: + break + return count + + # ------------------------------------------------------------------ + # Redis pub/sub bridge + # ------------------------------------------------------------------ + + async def start_redis_listener(self, redis_client: Any) -> None: + """Start a background task that subscribes to Redis pub/sub and + broadcasts received messages to connected WebSocket clients. + + Uses async Redis pubsub if available, otherwise runs the blocking + listener in a thread executor. + """ + if redis_client is None: + return + + self._redis_listener_task = asyncio.create_task( + self._redis_listen_loop(redis_client) + ) + + async def _redis_listen_loop(self, redis_client: Any) -> None: + """Listen for Redis pub/sub messages and broadcast them.""" + channel = "promptlooper:events" + try: + # Check if this is an async redis client (aioredis / redis.asyncio) + if hasattr(redis_client, "pubsub") and asyncio.iscoroutinefunction( + getattr(redis_client, "subscribe", None) + ): + await self._redis_listen_async(redis_client, channel) + else: + await self._redis_listen_sync(redis_client, channel) + except asyncio.CancelledError: + logger.info("Redis listener cancelled") + except Exception: + logger.exception("Redis listener error") + + async def _redis_listen_async(self, redis_client: Any, channel: str) -> None: + """Async Redis pub/sub listener.""" + pubsub = redis_client.pubsub() + await pubsub.subscribe(channel) + try: + async for raw_message in pubsub.listen(): + if raw_message["type"] != "message": + continue + await self._handle_redis_message(raw_message["data"]) + finally: + await pubsub.unsubscribe(channel) + + async def _redis_listen_sync(self, redis_client: Any, channel: str) -> None: + """Sync Redis pub/sub listener (run blocking calls in executor).""" + pubsub = redis_client.pubsub() + pubsub.subscribe(channel) + loop = asyncio.get_running_loop() + try: + while True: + raw_message = await loop.run_in_executor( + None, lambda: pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0) + ) + if raw_message is None: + continue + if raw_message["type"] != "message": + continue + await self._handle_redis_message(raw_message["data"]) + finally: + pubsub.unsubscribe(channel) + + async def _handle_redis_message(self, data: str | bytes) -> None: + """Parse a Redis message and broadcast it.""" + try: + if isinstance(data, bytes): + data = data.decode("utf-8") + message = json.loads(data) + except (json.JSONDecodeError, UnicodeDecodeError): + logger.warning("Invalid Redis message: %s", data) + return + await self.broadcast(message) + + async def stop_redis_listener(self) -> None: + """Cancel the Redis listener background task.""" + if self._redis_listener_task is not None: + self._redis_listener_task.cancel() + try: + await self._redis_listener_task + except asyncio.CancelledError: + pass + self._redis_listener_task = None + + # ------------------------------------------------------------------ + # Status + # ------------------------------------------------------------------ + + @property + def connection_count(self) -> int: + """Total number of active connections.""" + return len(self._connections) + + @property + def experiment_ids(self) -> list[str]: + """List of experiment IDs with active subscriptions.""" + return list(self._experiment_subs.keys()) + + def stats(self) -> dict[str, Any]: + """Return connection statistics.""" + return { + "total_connections": self.connection_count, + "global_subscribers": len(self._global_subs), + "experiment_subscriptions": { + exp_id: len(ws_ids) + for exp_id, ws_ids in self._experiment_subs.items() + }, + "replay_buffer_sizes": { + "global": len(self._global_replay), + **{ + exp_id: len(buf) + for exp_id, buf in self._experiment_replay.items() + }, + }, + } + + +# Module-level singleton for use across the application +manager = WebSocketManager()