MAESTRO: Implement WebSocket connection manager with per-experiment routing, Redis pub/sub bridge, and message replay

- WebSocketManager in backend/websocket/manager.py with per-experiment and global subscriptions
- Redis pub/sub bridge (sync + async) broadcasting events to relevant WebSocket clients
- Deque-based replay buffers with since_ts/limit filtering for reconnection support
- Runtime subscribe/unsubscribe and stats API
- Enhanced /ws endpoint in main.py with subscribe/unsubscribe/replay actions
- 35 tests in test_ws_manager.py, all passing
This commit is contained in:
John Lightner 2026-04-07 03:34:21 -05:00
parent e42117c8ee
commit 30fd15ec7a
5 changed files with 886 additions and 37 deletions

View file

@ -44,7 +44,8 @@ Implement the core experiment execution engine: LLM adapters, response caching,
- [x] Implement backend/routers/export.py — export best config in JSON, .env, and YAML formats as defined in the spec. Include metadata (score, experiment name, timestamp). The report endpoint should generate a markdown summary of the experiment: config space explored, top 5 configs, score distributions, token usage, timing stats. - [x] Implement backend/routers/export.py — export best config in JSON, .env, and YAML formats as defined in the spec. Include metadata (score, experiment name, timestamp). The report endpoint should generate a markdown summary of the experiment: config space explored, top 5 configs, score distributions, token usage, timing stats.
<!-- Completed: Full export router with 4 endpoints: /best (JSON with weighted score, metadata), /env (flattened KEY=VALUE with comments), /yaml (simple serializer, no PyYAML dependency), /report (markdown with config space, top N configs, score distributions, token usage, timing stats). Auth required on all endpoints. 34 tests in test_export.py, all passing. --> <!-- Completed: Full export router with 4 endpoints: /best (JSON with weighted score, metadata), /env (flattened KEY=VALUE with comments), /yaml (simple serializer, no PyYAML dependency), /report (markdown with config space, top N configs, score distributions, token usage, timing stats). Auth required on all endpoints. 34 tests in test_export.py, all passing. -->
- [ ] Implement backend/websocket/manager.py — WebSocket connection manager that: maintains active connections per experiment and globally, receives Redis pub/sub messages and broadcasts to relevant connections, handles connection/disconnection cleanly, supports reconnection with message replay (last N events). - [x] Implement backend/websocket/manager.py — WebSocket connection manager that: maintains active connections per experiment and globally, receives Redis pub/sub messages and broadcasts to relevant connections, handles connection/disconnection cleanly, supports reconnection with message replay (last N events).
<!-- Completed: WebSocketManager with per-experiment and global subscriptions, Redis pub/sub bridge (sync + async), deque-based replay buffers with since_ts/limit filtering, clean disconnect cleanup, runtime subscribe/unsubscribe, stats API. Integrated into main.py with enhanced /ws endpoint supporting subscribe/unsubscribe/replay actions and query-param-based initial subscriptions. 35 tests in test_ws_manager.py, all passing. -->
- [ ] Implement backend/routers/webhooks.py — CRUD for webhook configs. When events occur (in runner.py and sweep.py), dispatch webhook calls asynchronously via Celery. Include retry logic (3 attempts) and log delivery status. - [ ] Implement backend/routers/webhooks.py — CRUD for webhook configs. When events occur (in runner.py and sweep.py), dispatch webhook calls asynchronously via Celery. Include retry logic (3 attempts) and log delivery status.

View file

@ -67,28 +67,9 @@ def get_redis():
# WebSocket connection manager # WebSocket connection manager
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class ConnectionManager: from websocket.manager import WebSocketManager
"""Manage active WebSocket connections."""
def __init__(self) -> None: ws_manager = WebSocketManager()
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket) -> None:
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket) -> None:
self.active_connections.remove(websocket)
async def broadcast(self, message: dict) -> None:
for connection in list(self.active_connections):
try:
await connection.send_json(message)
except Exception:
self.disconnect(connection)
ws_manager = ConnectionManager()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -100,8 +81,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Startup and shutdown lifecycle hooks.""" """Startup and shutdown lifecycle hooks."""
_init_db() _init_db()
_init_redis() _init_redis()
# Start WebSocket ↔ Redis bridge
await ws_manager.start_redis_listener(_redis_client)
yield yield
# Shutdown: clean up connections # Shutdown: clean up connections
await ws_manager.stop_redis_listener()
if _redis_client is not None: if _redis_client is not None:
_redis_client.close() _redis_client.close()
if engine is not None: if engine is not None:
@ -167,13 +151,36 @@ def health_check() -> dict:
@app.websocket("/ws") @app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> None: async def websocket_endpoint(websocket: WebSocket) -> None:
"""WebSocket connection for real-time dashboard updates.""" """WebSocket connection for real-time dashboard updates.
await ws_manager.connect(websocket)
Clients may send JSON messages to control their subscriptions:
- {"action": "subscribe", "experiment_id": "..."}
- {"action": "unsubscribe", "experiment_id": "..."}
- {"action": "replay", "experiment_id": "...", "since_ts": ..., "limit": ...}
"""
# Parse query params for initial experiment subscriptions
exp_ids_param = websocket.query_params.get("experiment_ids", "")
experiment_ids = [e.strip() for e in exp_ids_param.split(",") if e.strip()] or None
await ws_manager.connect(websocket, experiment_ids=experiment_ids)
try: try:
while True: while True:
# Keep connection alive; handle incoming messages if needed
data = await websocket.receive_json() data = await websocket.receive_json()
# Echo back or handle client messages in future action = data.get("action", "")
if action == "subscribe" and data.get("experiment_id"):
ws_manager.subscribe(websocket, data["experiment_id"])
await websocket.send_json({"type": "ack", "action": "subscribe", "experiment_id": data["experiment_id"]})
elif action == "unsubscribe" and data.get("experiment_id"):
ws_manager.unsubscribe(websocket, data["experiment_id"])
await websocket.send_json({"type": "ack", "action": "unsubscribe", "experiment_id": data["experiment_id"]})
elif action == "replay":
count = await ws_manager.replay(
websocket,
experiment_id=data.get("experiment_id"),
since_ts=data.get("since_ts"),
limit=data.get("limit"),
)
await websocket.send_json({"type": "ack", "action": "replay", "events_replayed": count})
else:
await websocket.send_json({"type": "ack", "data": data}) await websocket.send_json({"type": "ack", "data": data})
except WebSocketDisconnect: except WebSocketDisconnect:
ws_manager.disconnect(websocket) ws_manager.disconnect(websocket)

View file

@ -77,11 +77,11 @@ class TestWebSocket:
def test_websocket_disconnect_cleanup(self, client): def test_websocket_disconnect_cleanup(self, client):
from main import ws_manager from main import ws_manager
initial_count = len(ws_manager.active_connections) initial_count = ws_manager.connection_count
with client.websocket_connect("/ws") as ws: with client.websocket_connect("/ws") as ws:
assert len(ws_manager.active_connections) == initial_count + 1 assert ws_manager.connection_count == initial_count + 1
# After disconnect, connection should be removed # After disconnect, connection should be removed
assert len(ws_manager.active_connections) == initial_count assert ws_manager.connection_count == initial_count
class TestRouterMounting: class TestRouterMounting:
@ -97,16 +97,15 @@ class TestRouterMounting:
class TestConnectionManager: class TestConnectionManager:
def test_broadcast_removes_dead_connections(self): def test_broadcast_on_empty_manager(self):
"""ConnectionManager.broadcast skips and removes broken connections.""" """WebSocketManager.broadcast_global on empty manager should not raise."""
from main import ConnectionManager from websocket.manager import WebSocketManager
manager = ConnectionManager() manager = WebSocketManager()
# No connections — broadcast should not raise
import asyncio import asyncio
asyncio.get_event_loop().run_until_complete( asyncio.get_event_loop().run_until_complete(
manager.broadcast({"test": True}) manager.broadcast_global({"test": True})
) )
assert len(manager.active_connections) == 0 assert manager.connection_count == 0
class TestGetDb: class TestGetDb:

View file

@ -0,0 +1,484 @@
"""Tests for the WebSocket connection manager."""
import asyncio
import json
import time
from collections import deque
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from websocket.manager import ConnectionEntry, WebSocketManager
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ws(accept_side_effect=None) -> AsyncMock:
"""Create a mock WebSocket."""
ws = AsyncMock()
ws.accept = AsyncMock(side_effect=accept_side_effect)
ws.send_json = AsyncMock()
ws.receive_json = AsyncMock()
return ws
# ---------------------------------------------------------------------------
# Connection lifecycle
# ---------------------------------------------------------------------------
class TestConnect:
@pytest.mark.asyncio
async def test_accept_and_register_global(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
assert mgr.connection_count == 1
assert id(ws) in mgr._global_subs
ws.accept.assert_awaited_once()
@pytest.mark.asyncio
async def test_accept_with_experiment_ids(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws, experiment_ids=["exp-1", "exp-2"])
assert mgr.connection_count == 1
assert id(ws) not in mgr._global_subs
assert id(ws) in mgr._experiment_subs["exp-1"]
assert id(ws) in mgr._experiment_subs["exp-2"]
@pytest.mark.asyncio
async def test_multiple_connections(self):
mgr = WebSocketManager()
ws1 = _make_ws()
ws2 = _make_ws()
await mgr.connect(ws1)
await mgr.connect(ws2, experiment_ids=["exp-1"])
assert mgr.connection_count == 2
class TestDisconnect:
@pytest.mark.asyncio
async def test_disconnect_global(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
mgr.disconnect(ws)
assert mgr.connection_count == 0
assert id(ws) not in mgr._global_subs
@pytest.mark.asyncio
async def test_disconnect_experiment(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws, experiment_ids=["exp-1"])
mgr.disconnect(ws)
assert mgr.connection_count == 0
assert "exp-1" not in mgr._experiment_subs
@pytest.mark.asyncio
async def test_disconnect_unknown_ws_is_noop(self):
mgr = WebSocketManager()
ws = _make_ws()
mgr.disconnect(ws) # Should not raise
assert mgr.connection_count == 0
@pytest.mark.asyncio
async def test_disconnect_cleans_up_empty_experiment_sets(self):
mgr = WebSocketManager()
ws1 = _make_ws()
ws2 = _make_ws()
await mgr.connect(ws1, experiment_ids=["exp-1"])
await mgr.connect(ws2, experiment_ids=["exp-1"])
mgr.disconnect(ws1)
assert "exp-1" in mgr._experiment_subs
assert len(mgr._experiment_subs["exp-1"]) == 1
mgr.disconnect(ws2)
assert "exp-1" not in mgr._experiment_subs
# ---------------------------------------------------------------------------
# Subscribe / unsubscribe at runtime
# ---------------------------------------------------------------------------
class TestSubscriptions:
@pytest.mark.asyncio
async def test_subscribe_adds_experiment(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
mgr.subscribe(ws, "exp-42")
assert id(ws) in mgr._experiment_subs["exp-42"]
@pytest.mark.asyncio
async def test_unsubscribe_removes_experiment(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws, experiment_ids=["exp-1"])
mgr.unsubscribe(ws, "exp-1")
assert "exp-1" not in mgr._experiment_subs
entry = mgr._connections[id(ws)]
assert "exp-1" not in entry.experiment_ids
@pytest.mark.asyncio
async def test_subscribe_unknown_ws_is_noop(self):
mgr = WebSocketManager()
ws = _make_ws()
mgr.subscribe(ws, "exp-1") # Not connected
assert "exp-1" not in mgr._experiment_subs
@pytest.mark.asyncio
async def test_unsubscribe_unknown_ws_is_noop(self):
mgr = WebSocketManager()
ws = _make_ws()
mgr.unsubscribe(ws, "exp-1") # Should not raise
# ---------------------------------------------------------------------------
# Broadcasting
# ---------------------------------------------------------------------------
class TestBroadcast:
@pytest.mark.asyncio
async def test_broadcast_global(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
await mgr.broadcast_global({"type": "test"})
ws.send_json.assert_awaited_once_with({"type": "test"})
@pytest.mark.asyncio
async def test_broadcast_to_experiment(self):
mgr = WebSocketManager()
ws_exp = _make_ws()
ws_global = _make_ws()
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
await mgr.connect(ws_global)
await mgr.broadcast_to_experiment("exp-1", {"type": "test"})
ws_exp.send_json.assert_awaited_once()
ws_global.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_broadcast_routes_by_experiment_id(self):
mgr = WebSocketManager()
ws_exp = _make_ws()
ws_global = _make_ws()
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
await mgr.connect(ws_global)
msg = {"type": "run.completed", "experiment_id": "exp-1"}
await mgr.broadcast(msg)
# experiment subscriber gets the message (from broadcast_to_experiment)
assert ws_exp.send_json.await_count == 1
# global subscriber gets it (from broadcast_global)
assert ws_global.send_json.await_count == 1
@pytest.mark.asyncio
async def test_broadcast_without_experiment_id_goes_global_only(self):
mgr = WebSocketManager()
ws_exp = _make_ws()
ws_global = _make_ws()
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
await mgr.connect(ws_global)
await mgr.broadcast({"type": "system.status"})
ws_exp.send_json.assert_not_awaited()
ws_global.send_json.assert_awaited_once()
@pytest.mark.asyncio
async def test_broadcast_removes_dead_connections(self):
mgr = WebSocketManager()
ws_good = _make_ws()
ws_dead = _make_ws()
ws_dead.send_json = AsyncMock(side_effect=RuntimeError("closed"))
await mgr.connect(ws_good)
await mgr.connect(ws_dead)
await mgr.broadcast_global({"type": "test"})
assert mgr.connection_count == 1
assert id(ws_dead) not in mgr._connections
@pytest.mark.asyncio
async def test_broadcast_to_multiple_experiment_subscribers(self):
mgr = WebSocketManager()
ws1 = _make_ws()
ws2 = _make_ws()
await mgr.connect(ws1, experiment_ids=["exp-1"])
await mgr.connect(ws2, experiment_ids=["exp-1"])
await mgr.broadcast_to_experiment("exp-1", {"type": "test"})
ws1.send_json.assert_awaited_once()
ws2.send_json.assert_awaited_once()
# ---------------------------------------------------------------------------
# Replay
# ---------------------------------------------------------------------------
class TestReplay:
@pytest.mark.asyncio
async def test_replay_global_events(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
# Store some events
await mgr.broadcast_global({"type": "event1"})
await mgr.broadcast_global({"type": "event2"})
# Reset send tracking
ws.send_json.reset_mock()
count = await mgr.replay(ws)
assert count == 2
assert ws.send_json.await_count == 2
@pytest.mark.asyncio
async def test_replay_experiment_events(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws, experiment_ids=["exp-1"])
await mgr.broadcast_to_experiment("exp-1", {"type": "e1"})
await mgr.broadcast_to_experiment("exp-1", {"type": "e2"})
await mgr.broadcast_to_experiment("exp-1", {"type": "e3"})
ws.send_json.reset_mock()
count = await mgr.replay(ws, experiment_id="exp-1")
assert count == 3
@pytest.mark.asyncio
async def test_replay_with_since_ts(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
# Manually insert events with known timestamps
ts_old = time.time() - 100
ts_new = time.time()
mgr._global_replay.append({"type": "old", "_ts": ts_old})
mgr._global_replay.append({"type": "new", "_ts": ts_new})
count = await mgr.replay(ws, since_ts=ts_old + 1)
assert count == 1
@pytest.mark.asyncio
async def test_replay_with_limit(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
for i in range(10):
mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()})
count = await mgr.replay(ws, limit=3)
assert count == 3
@pytest.mark.asyncio
async def test_replay_empty_buffer(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
count = await mgr.replay(ws)
assert count == 0
@pytest.mark.asyncio
async def test_replay_respects_buffer_size(self):
mgr = WebSocketManager(replay_size=5)
ws = _make_ws()
await mgr.connect(ws)
for i in range(10):
await mgr.broadcast_global({"type": f"event{i}"})
ws.send_json.reset_mock()
count = await mgr.replay(ws)
assert count == 5
@pytest.mark.asyncio
async def test_replay_stops_on_send_failure(self):
mgr = WebSocketManager()
ws = _make_ws()
call_count = 0
async def fail_after_one(msg):
nonlocal call_count
call_count += 1
if call_count > 1:
raise RuntimeError("closed")
ws.send_json = AsyncMock(side_effect=fail_after_one)
await mgr.connect(ws)
for i in range(5):
mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()})
count = await mgr.replay(ws)
assert count == 1
# ---------------------------------------------------------------------------
# Redis pub/sub bridge
# ---------------------------------------------------------------------------
class TestRedisListener:
@pytest.mark.asyncio
async def test_handle_redis_message_json(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
msg = json.dumps({"type": "run.completed", "experiment_id": "exp-1"})
await mgr._handle_redis_message(msg)
# Message should have been broadcast
assert ws.send_json.await_count >= 1
@pytest.mark.asyncio
async def test_handle_redis_message_bytes(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
msg = json.dumps({"type": "test"}).encode("utf-8")
await mgr._handle_redis_message(msg)
assert ws.send_json.await_count >= 1
@pytest.mark.asyncio
async def test_handle_redis_message_invalid_json(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
# Should not raise, just log warning
await mgr._handle_redis_message("not-json{{{")
ws.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_start_redis_listener_none_is_noop(self):
mgr = WebSocketManager()
await mgr.start_redis_listener(None)
assert mgr._redis_listener_task is None
@pytest.mark.asyncio
async def test_stop_redis_listener_when_none(self):
mgr = WebSocketManager()
await mgr.stop_redis_listener() # Should not raise
@pytest.mark.asyncio
async def test_stop_cancels_task(self):
mgr = WebSocketManager()
# Create a long-running dummy task
async def dummy():
await asyncio.sleep(999)
mgr._redis_listener_task = asyncio.create_task(dummy())
await mgr.stop_redis_listener()
assert mgr._redis_listener_task is None
@pytest.mark.asyncio
async def test_sync_redis_listener_processes_messages(self):
"""Test the sync Redis pubsub path with a mock."""
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
# Mock a sync Redis client
mock_pubsub = MagicMock()
call_count = 0
def get_message_side_effect(ignore_subscribe_messages=True, timeout=1.0):
nonlocal call_count
call_count += 1
if call_count == 1:
return {
"type": "message",
"data": json.dumps({"type": "test_event"}),
}
# After first message, raise to stop the loop
raise asyncio.CancelledError()
mock_pubsub.get_message = MagicMock(side_effect=get_message_side_effect)
mock_pubsub.subscribe = MagicMock()
mock_pubsub.unsubscribe = MagicMock()
mock_redis = MagicMock()
mock_redis.pubsub.return_value = mock_pubsub
# Run the sync listener — it will process one message then get cancelled
with pytest.raises(asyncio.CancelledError):
await mgr._redis_listen_sync(mock_redis, "promptlooper:events")
mock_pubsub.subscribe.assert_called_once_with("promptlooper:events")
assert ws.send_json.await_count >= 1
# ---------------------------------------------------------------------------
# Stats
# ---------------------------------------------------------------------------
class TestStats:
@pytest.mark.asyncio
async def test_stats_empty(self):
mgr = WebSocketManager()
stats = mgr.stats()
assert stats["total_connections"] == 0
assert stats["global_subscribers"] == 0
assert stats["experiment_subscriptions"] == {}
@pytest.mark.asyncio
async def test_stats_with_connections(self):
mgr = WebSocketManager()
ws1 = _make_ws()
ws2 = _make_ws()
await mgr.connect(ws1)
await mgr.connect(ws2, experiment_ids=["exp-1"])
stats = mgr.stats()
assert stats["total_connections"] == 2
assert stats["global_subscribers"] == 1
assert stats["experiment_subscriptions"]["exp-1"] == 1
@pytest.mark.asyncio
async def test_experiment_ids_property(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws, experiment_ids=["exp-a", "exp-b"])
ids = mgr.experiment_ids
assert set(ids) == {"exp-a", "exp-b"}
@pytest.mark.asyncio
async def test_connection_count_property(self):
mgr = WebSocketManager()
assert mgr.connection_count == 0
ws = _make_ws()
await mgr.connect(ws)
assert mgr.connection_count == 1
mgr.disconnect(ws)
assert mgr.connection_count == 0

View file

@ -0,0 +1,358 @@
"""WebSocket connection manager for PromptLooper.
Manages per-experiment and global WebSocket connections, bridges Redis
pub/sub events to connected clients, and supports message replay on
reconnection.
"""
import asyncio
import json
import logging
import time
from collections import deque
from dataclasses import dataclass, field
from typing import Any
from fastapi import WebSocket
logger = logging.getLogger(__name__)
# Default number of events to retain for replay on reconnect
DEFAULT_REPLAY_SIZE = 50
@dataclass
class ConnectionEntry:
"""Tracks a single WebSocket connection and its subscriptions."""
websocket: WebSocket
experiment_ids: set[str] = field(default_factory=set)
connected_at: float = field(default_factory=time.time)
class WebSocketManager:
"""Manages WebSocket connections with per-experiment routing and replay.
Features:
- Per-experiment and global subscriptions
- Redis pub/sub bridge (when Redis is available)
- Message replay on reconnection (last N events per experiment + global)
- Clean disconnect handling
"""
def __init__(self, replay_size: int = DEFAULT_REPLAY_SIZE) -> None:
self._replay_size = replay_size
# All active connections keyed by id(websocket)
self._connections: dict[int, ConnectionEntry] = {}
# Experiment-specific connection sets: experiment_id -> set of ws ids
self._experiment_subs: dict[str, set[int]] = {}
# Global subscribers (receive all events)
self._global_subs: set[int] = set()
# Event replay buffers
self._global_replay: deque[dict[str, Any]] = deque(maxlen=replay_size)
self._experiment_replay: dict[str, deque[dict[str, Any]]] = {}
# Redis listener task handle
self._redis_listener_task: asyncio.Task | None = None
# ------------------------------------------------------------------
# Connection lifecycle
# ------------------------------------------------------------------
async def connect(
self,
websocket: WebSocket,
experiment_ids: list[str] | None = None,
) -> None:
"""Accept a WebSocket connection and register subscriptions.
Args:
websocket: The incoming WebSocket connection.
experiment_ids: Optional list of experiment IDs to subscribe to.
If empty/None the connection receives global broadcasts only.
"""
await websocket.accept()
ws_id = id(websocket)
entry = ConnectionEntry(websocket=websocket)
self._connections[ws_id] = entry
if experiment_ids:
for exp_id in experiment_ids:
self._subscribe_experiment(ws_id, exp_id)
else:
self._global_subs.add(ws_id)
def disconnect(self, websocket: WebSocket) -> None:
"""Remove a WebSocket from all tracking structures."""
ws_id = id(websocket)
entry = self._connections.pop(ws_id, None)
if entry is None:
return
# Remove from global subs
self._global_subs.discard(ws_id)
# Remove from all experiment subs
for exp_id in list(entry.experiment_ids):
exp_set = self._experiment_subs.get(exp_id)
if exp_set is not None:
exp_set.discard(ws_id)
if not exp_set:
del self._experiment_subs[exp_id]
def _subscribe_experiment(self, ws_id: int, experiment_id: str) -> None:
"""Subscribe a connection to a specific experiment's events."""
entry = self._connections.get(ws_id)
if entry is None:
return
entry.experiment_ids.add(experiment_id)
if experiment_id not in self._experiment_subs:
self._experiment_subs[experiment_id] = set()
self._experiment_subs[experiment_id].add(ws_id)
# ------------------------------------------------------------------
# Subscription management (runtime)
# ------------------------------------------------------------------
def subscribe(self, websocket: WebSocket, experiment_id: str) -> None:
"""Add an experiment subscription for an existing connection."""
ws_id = id(websocket)
if ws_id in self._connections:
self._subscribe_experiment(ws_id, experiment_id)
def unsubscribe(self, websocket: WebSocket, experiment_id: str) -> None:
"""Remove an experiment subscription for an existing connection."""
ws_id = id(websocket)
entry = self._connections.get(ws_id)
if entry is None:
return
entry.experiment_ids.discard(experiment_id)
exp_set = self._experiment_subs.get(experiment_id)
if exp_set is not None:
exp_set.discard(ws_id)
if not exp_set:
del self._experiment_subs[experiment_id]
# ------------------------------------------------------------------
# Broadcasting
# ------------------------------------------------------------------
async def broadcast_to_experiment(
self, experiment_id: str, message: dict[str, Any]
) -> None:
"""Send a message to all connections subscribed to an experiment."""
# Store in replay buffer
if experiment_id not in self._experiment_replay:
self._experiment_replay[experiment_id] = deque(maxlen=self._replay_size)
message_with_ts = {**message, "_ts": time.time()}
self._experiment_replay[experiment_id].append(message_with_ts)
ws_ids = self._experiment_subs.get(experiment_id, set())
await self._send_to_many(ws_ids, message)
async def broadcast_global(self, message: dict[str, Any]) -> None:
"""Send a message to all global subscribers."""
message_with_ts = {**message, "_ts": time.time()}
self._global_replay.append(message_with_ts)
await self._send_to_many(self._global_subs, message)
async def broadcast(self, message: dict[str, Any]) -> None:
"""Route a message to the right subscribers based on experiment_id.
If the message contains an ``experiment_id`` field, it is sent to
that experiment's subscribers *and* to global subscribers. Otherwise
it is sent only to global subscribers.
"""
experiment_id = message.get("experiment_id")
if experiment_id:
await self.broadcast_to_experiment(str(experiment_id), message)
await self.broadcast_global(message)
async def _send_to_many(
self, ws_ids: set[int], message: dict[str, Any]
) -> None:
"""Send a JSON message to a set of connections, disconnecting failures."""
dead: list[int] = []
for ws_id in list(ws_ids):
entry = self._connections.get(ws_id)
if entry is None:
dead.append(ws_id)
continue
try:
await entry.websocket.send_json(message)
except Exception:
logger.debug("WebSocket send failed, removing connection %s", ws_id)
dead.append(ws_id)
# Clean up dead connections
for ws_id in dead:
ws_ids.discard(ws_id)
entry = self._connections.pop(ws_id, None)
if entry:
self._global_subs.discard(ws_id)
for exp_id in entry.experiment_ids:
exp_set = self._experiment_subs.get(exp_id)
if exp_set:
exp_set.discard(ws_id)
# ------------------------------------------------------------------
# Message replay
# ------------------------------------------------------------------
async def replay(
self,
websocket: WebSocket,
experiment_id: str | None = None,
since_ts: float | None = None,
limit: int | None = None,
) -> int:
"""Replay recent events to a connection.
Args:
websocket: Target connection.
experiment_id: If provided, replay experiment-specific events;
otherwise replay global events.
since_ts: Only replay events after this Unix timestamp.
limit: Max number of events to replay (default: all in buffer).
Returns:
Number of events replayed.
"""
if experiment_id:
buffer = self._experiment_replay.get(experiment_id, deque())
else:
buffer = self._global_replay
events = list(buffer)
if since_ts is not None:
events = [e for e in events if e.get("_ts", 0) > since_ts]
if limit is not None:
events = events[-limit:]
count = 0
for event in events:
try:
await websocket.send_json(event)
count += 1
except Exception:
break
return count
# ------------------------------------------------------------------
# Redis pub/sub bridge
# ------------------------------------------------------------------
async def start_redis_listener(self, redis_client: Any) -> None:
"""Start a background task that subscribes to Redis pub/sub and
broadcasts received messages to connected WebSocket clients.
Uses async Redis pubsub if available, otherwise runs the blocking
listener in a thread executor.
"""
if redis_client is None:
return
self._redis_listener_task = asyncio.create_task(
self._redis_listen_loop(redis_client)
)
async def _redis_listen_loop(self, redis_client: Any) -> None:
"""Listen for Redis pub/sub messages and broadcast them."""
channel = "promptlooper:events"
try:
# Check if this is an async redis client (aioredis / redis.asyncio)
if hasattr(redis_client, "pubsub") and asyncio.iscoroutinefunction(
getattr(redis_client, "subscribe", None)
):
await self._redis_listen_async(redis_client, channel)
else:
await self._redis_listen_sync(redis_client, channel)
except asyncio.CancelledError:
logger.info("Redis listener cancelled")
except Exception:
logger.exception("Redis listener error")
async def _redis_listen_async(self, redis_client: Any, channel: str) -> None:
"""Async Redis pub/sub listener."""
pubsub = redis_client.pubsub()
await pubsub.subscribe(channel)
try:
async for raw_message in pubsub.listen():
if raw_message["type"] != "message":
continue
await self._handle_redis_message(raw_message["data"])
finally:
await pubsub.unsubscribe(channel)
async def _redis_listen_sync(self, redis_client: Any, channel: str) -> None:
"""Sync Redis pub/sub listener (run blocking calls in executor)."""
pubsub = redis_client.pubsub()
pubsub.subscribe(channel)
loop = asyncio.get_running_loop()
try:
while True:
raw_message = await loop.run_in_executor(
None, lambda: pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
)
if raw_message is None:
continue
if raw_message["type"] != "message":
continue
await self._handle_redis_message(raw_message["data"])
finally:
pubsub.unsubscribe(channel)
async def _handle_redis_message(self, data: str | bytes) -> None:
"""Parse a Redis message and broadcast it."""
try:
if isinstance(data, bytes):
data = data.decode("utf-8")
message = json.loads(data)
except (json.JSONDecodeError, UnicodeDecodeError):
logger.warning("Invalid Redis message: %s", data)
return
await self.broadcast(message)
async def stop_redis_listener(self) -> None:
"""Cancel the Redis listener background task."""
if self._redis_listener_task is not None:
self._redis_listener_task.cancel()
try:
await self._redis_listener_task
except asyncio.CancelledError:
pass
self._redis_listener_task = None
# ------------------------------------------------------------------
# Status
# ------------------------------------------------------------------
@property
def connection_count(self) -> int:
"""Total number of active connections."""
return len(self._connections)
@property
def experiment_ids(self) -> list[str]:
"""List of experiment IDs with active subscriptions."""
return list(self._experiment_subs.keys())
def stats(self) -> dict[str, Any]:
"""Return connection statistics."""
return {
"total_connections": self.connection_count,
"global_subscribers": len(self._global_subs),
"experiment_subscriptions": {
exp_id: len(ws_ids)
for exp_id, ws_ids in self._experiment_subs.items()
},
"replay_buffer_sizes": {
"global": len(self._global_replay),
**{
exp_id: len(buf)
for exp_id, buf in self._experiment_replay.items()
},
},
}
# Module-level singleton for use across the application
manager = WebSocketManager()