MAESTRO: Implement WebSocket connection manager with per-experiment routing, Redis pub/sub bridge, and message replay
- WebSocketManager in backend/websocket/manager.py with per-experiment and global subscriptions - Redis pub/sub bridge (sync + async) broadcasting events to relevant WebSocket clients - Deque-based replay buffers with since_ts/limit filtering for reconnection support - Runtime subscribe/unsubscribe and stats API - Enhanced /ws endpoint in main.py with subscribe/unsubscribe/replay actions - 35 tests in test_ws_manager.py, all passing
This commit is contained in:
parent
e42117c8ee
commit
30fd15ec7a
5 changed files with 886 additions and 37 deletions
|
|
@ -44,7 +44,8 @@ Implement the core experiment execution engine: LLM adapters, response caching,
|
||||||
- [x] Implement backend/routers/export.py — export best config in JSON, .env, and YAML formats as defined in the spec. Include metadata (score, experiment name, timestamp). The report endpoint should generate a markdown summary of the experiment: config space explored, top 5 configs, score distributions, token usage, timing stats.
|
- [x] Implement backend/routers/export.py — export best config in JSON, .env, and YAML formats as defined in the spec. Include metadata (score, experiment name, timestamp). The report endpoint should generate a markdown summary of the experiment: config space explored, top 5 configs, score distributions, token usage, timing stats.
|
||||||
<!-- Completed: Full export router with 4 endpoints: /best (JSON with weighted score, metadata), /env (flattened KEY=VALUE with comments), /yaml (simple serializer, no PyYAML dependency), /report (markdown with config space, top N configs, score distributions, token usage, timing stats). Auth required on all endpoints. 34 tests in test_export.py, all passing. -->
|
<!-- Completed: Full export router with 4 endpoints: /best (JSON with weighted score, metadata), /env (flattened KEY=VALUE with comments), /yaml (simple serializer, no PyYAML dependency), /report (markdown with config space, top N configs, score distributions, token usage, timing stats). Auth required on all endpoints. 34 tests in test_export.py, all passing. -->
|
||||||
|
|
||||||
- [ ] Implement backend/websocket/manager.py — WebSocket connection manager that: maintains active connections per experiment and globally, receives Redis pub/sub messages and broadcasts to relevant connections, handles connection/disconnection cleanly, supports reconnection with message replay (last N events).
|
- [x] Implement backend/websocket/manager.py — WebSocket connection manager that: maintains active connections per experiment and globally, receives Redis pub/sub messages and broadcasts to relevant connections, handles connection/disconnection cleanly, supports reconnection with message replay (last N events).
|
||||||
|
<!-- Completed: WebSocketManager with per-experiment and global subscriptions, Redis pub/sub bridge (sync + async), deque-based replay buffers with since_ts/limit filtering, clean disconnect cleanup, runtime subscribe/unsubscribe, stats API. Integrated into main.py with enhanced /ws endpoint supporting subscribe/unsubscribe/replay actions and query-param-based initial subscriptions. 35 tests in test_ws_manager.py, all passing. -->
|
||||||
|
|
||||||
- [ ] Implement backend/routers/webhooks.py — CRUD for webhook configs. When events occur (in runner.py and sweep.py), dispatch webhook calls asynchronously via Celery. Include retry logic (3 attempts) and log delivery status.
|
- [ ] Implement backend/routers/webhooks.py — CRUD for webhook configs. When events occur (in runner.py and sweep.py), dispatch webhook calls asynchronously via Celery. Include retry logic (3 attempts) and log delivery status.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -67,28 +67,9 @@ def get_redis():
|
||||||
# WebSocket connection manager
|
# WebSocket connection manager
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
class ConnectionManager:
|
from websocket.manager import WebSocketManager
|
||||||
"""Manage active WebSocket connections."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
ws_manager = WebSocketManager()
|
||||||
self.active_connections: list[WebSocket] = []
|
|
||||||
|
|
||||||
async def connect(self, websocket: WebSocket) -> None:
|
|
||||||
await websocket.accept()
|
|
||||||
self.active_connections.append(websocket)
|
|
||||||
|
|
||||||
def disconnect(self, websocket: WebSocket) -> None:
|
|
||||||
self.active_connections.remove(websocket)
|
|
||||||
|
|
||||||
async def broadcast(self, message: dict) -> None:
|
|
||||||
for connection in list(self.active_connections):
|
|
||||||
try:
|
|
||||||
await connection.send_json(message)
|
|
||||||
except Exception:
|
|
||||||
self.disconnect(connection)
|
|
||||||
|
|
||||||
|
|
||||||
ws_manager = ConnectionManager()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -100,8 +81,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
"""Startup and shutdown lifecycle hooks."""
|
"""Startup and shutdown lifecycle hooks."""
|
||||||
_init_db()
|
_init_db()
|
||||||
_init_redis()
|
_init_redis()
|
||||||
|
# Start WebSocket ↔ Redis bridge
|
||||||
|
await ws_manager.start_redis_listener(_redis_client)
|
||||||
yield
|
yield
|
||||||
# Shutdown: clean up connections
|
# Shutdown: clean up connections
|
||||||
|
await ws_manager.stop_redis_listener()
|
||||||
if _redis_client is not None:
|
if _redis_client is not None:
|
||||||
_redis_client.close()
|
_redis_client.close()
|
||||||
if engine is not None:
|
if engine is not None:
|
||||||
|
|
@ -167,13 +151,36 @@ def health_check() -> dict:
|
||||||
|
|
||||||
@app.websocket("/ws")
|
@app.websocket("/ws")
|
||||||
async def websocket_endpoint(websocket: WebSocket) -> None:
|
async def websocket_endpoint(websocket: WebSocket) -> None:
|
||||||
"""WebSocket connection for real-time dashboard updates."""
|
"""WebSocket connection for real-time dashboard updates.
|
||||||
await ws_manager.connect(websocket)
|
|
||||||
|
Clients may send JSON messages to control their subscriptions:
|
||||||
|
- {"action": "subscribe", "experiment_id": "..."}
|
||||||
|
- {"action": "unsubscribe", "experiment_id": "..."}
|
||||||
|
- {"action": "replay", "experiment_id": "...", "since_ts": ..., "limit": ...}
|
||||||
|
"""
|
||||||
|
# Parse query params for initial experiment subscriptions
|
||||||
|
exp_ids_param = websocket.query_params.get("experiment_ids", "")
|
||||||
|
experiment_ids = [e.strip() for e in exp_ids_param.split(",") if e.strip()] or None
|
||||||
|
await ws_manager.connect(websocket, experiment_ids=experiment_ids)
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
# Keep connection alive; handle incoming messages if needed
|
|
||||||
data = await websocket.receive_json()
|
data = await websocket.receive_json()
|
||||||
# Echo back or handle client messages in future
|
action = data.get("action", "")
|
||||||
|
if action == "subscribe" and data.get("experiment_id"):
|
||||||
|
ws_manager.subscribe(websocket, data["experiment_id"])
|
||||||
|
await websocket.send_json({"type": "ack", "action": "subscribe", "experiment_id": data["experiment_id"]})
|
||||||
|
elif action == "unsubscribe" and data.get("experiment_id"):
|
||||||
|
ws_manager.unsubscribe(websocket, data["experiment_id"])
|
||||||
|
await websocket.send_json({"type": "ack", "action": "unsubscribe", "experiment_id": data["experiment_id"]})
|
||||||
|
elif action == "replay":
|
||||||
|
count = await ws_manager.replay(
|
||||||
|
websocket,
|
||||||
|
experiment_id=data.get("experiment_id"),
|
||||||
|
since_ts=data.get("since_ts"),
|
||||||
|
limit=data.get("limit"),
|
||||||
|
)
|
||||||
|
await websocket.send_json({"type": "ack", "action": "replay", "events_replayed": count})
|
||||||
|
else:
|
||||||
await websocket.send_json({"type": "ack", "data": data})
|
await websocket.send_json({"type": "ack", "data": data})
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
ws_manager.disconnect(websocket)
|
ws_manager.disconnect(websocket)
|
||||||
|
|
|
||||||
|
|
@ -77,11 +77,11 @@ class TestWebSocket:
|
||||||
|
|
||||||
def test_websocket_disconnect_cleanup(self, client):
|
def test_websocket_disconnect_cleanup(self, client):
|
||||||
from main import ws_manager
|
from main import ws_manager
|
||||||
initial_count = len(ws_manager.active_connections)
|
initial_count = ws_manager.connection_count
|
||||||
with client.websocket_connect("/ws") as ws:
|
with client.websocket_connect("/ws") as ws:
|
||||||
assert len(ws_manager.active_connections) == initial_count + 1
|
assert ws_manager.connection_count == initial_count + 1
|
||||||
# After disconnect, connection should be removed
|
# After disconnect, connection should be removed
|
||||||
assert len(ws_manager.active_connections) == initial_count
|
assert ws_manager.connection_count == initial_count
|
||||||
|
|
||||||
|
|
||||||
class TestRouterMounting:
|
class TestRouterMounting:
|
||||||
|
|
@ -97,16 +97,15 @@ class TestRouterMounting:
|
||||||
|
|
||||||
|
|
||||||
class TestConnectionManager:
|
class TestConnectionManager:
|
||||||
def test_broadcast_removes_dead_connections(self):
|
def test_broadcast_on_empty_manager(self):
|
||||||
"""ConnectionManager.broadcast skips and removes broken connections."""
|
"""WebSocketManager.broadcast_global on empty manager should not raise."""
|
||||||
from main import ConnectionManager
|
from websocket.manager import WebSocketManager
|
||||||
manager = ConnectionManager()
|
manager = WebSocketManager()
|
||||||
# No connections — broadcast should not raise
|
|
||||||
import asyncio
|
import asyncio
|
||||||
asyncio.get_event_loop().run_until_complete(
|
asyncio.get_event_loop().run_until_complete(
|
||||||
manager.broadcast({"test": True})
|
manager.broadcast_global({"test": True})
|
||||||
)
|
)
|
||||||
assert len(manager.active_connections) == 0
|
assert manager.connection_count == 0
|
||||||
|
|
||||||
|
|
||||||
class TestGetDb:
|
class TestGetDb:
|
||||||
|
|
|
||||||
484
backend/tests/test_ws_manager.py
Normal file
484
backend/tests/test_ws_manager.py
Normal file
|
|
@ -0,0 +1,484 @@
|
||||||
|
"""Tests for the WebSocket connection manager."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from websocket.manager import ConnectionEntry, WebSocketManager
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ws(accept_side_effect=None) -> AsyncMock:
|
||||||
|
"""Create a mock WebSocket."""
|
||||||
|
ws = AsyncMock()
|
||||||
|
ws.accept = AsyncMock(side_effect=accept_side_effect)
|
||||||
|
ws.send_json = AsyncMock()
|
||||||
|
ws.receive_json = AsyncMock()
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Connection lifecycle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnect:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_accept_and_register_global(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
|
||||||
|
assert mgr.connection_count == 1
|
||||||
|
assert id(ws) in mgr._global_subs
|
||||||
|
ws.accept.assert_awaited_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_accept_with_experiment_ids(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws, experiment_ids=["exp-1", "exp-2"])
|
||||||
|
|
||||||
|
assert mgr.connection_count == 1
|
||||||
|
assert id(ws) not in mgr._global_subs
|
||||||
|
assert id(ws) in mgr._experiment_subs["exp-1"]
|
||||||
|
assert id(ws) in mgr._experiment_subs["exp-2"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_connections(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws1 = _make_ws()
|
||||||
|
ws2 = _make_ws()
|
||||||
|
await mgr.connect(ws1)
|
||||||
|
await mgr.connect(ws2, experiment_ids=["exp-1"])
|
||||||
|
|
||||||
|
assert mgr.connection_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestDisconnect:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disconnect_global(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
mgr.disconnect(ws)
|
||||||
|
|
||||||
|
assert mgr.connection_count == 0
|
||||||
|
assert id(ws) not in mgr._global_subs
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disconnect_experiment(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws, experiment_ids=["exp-1"])
|
||||||
|
mgr.disconnect(ws)
|
||||||
|
|
||||||
|
assert mgr.connection_count == 0
|
||||||
|
assert "exp-1" not in mgr._experiment_subs
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disconnect_unknown_ws_is_noop(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
mgr.disconnect(ws) # Should not raise
|
||||||
|
assert mgr.connection_count == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disconnect_cleans_up_empty_experiment_sets(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws1 = _make_ws()
|
||||||
|
ws2 = _make_ws()
|
||||||
|
await mgr.connect(ws1, experiment_ids=["exp-1"])
|
||||||
|
await mgr.connect(ws2, experiment_ids=["exp-1"])
|
||||||
|
|
||||||
|
mgr.disconnect(ws1)
|
||||||
|
assert "exp-1" in mgr._experiment_subs
|
||||||
|
assert len(mgr._experiment_subs["exp-1"]) == 1
|
||||||
|
|
||||||
|
mgr.disconnect(ws2)
|
||||||
|
assert "exp-1" not in mgr._experiment_subs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Subscribe / unsubscribe at runtime
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubscriptions:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscribe_adds_experiment(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
mgr.subscribe(ws, "exp-42")
|
||||||
|
|
||||||
|
assert id(ws) in mgr._experiment_subs["exp-42"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unsubscribe_removes_experiment(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws, experiment_ids=["exp-1"])
|
||||||
|
mgr.unsubscribe(ws, "exp-1")
|
||||||
|
|
||||||
|
assert "exp-1" not in mgr._experiment_subs
|
||||||
|
entry = mgr._connections[id(ws)]
|
||||||
|
assert "exp-1" not in entry.experiment_ids
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscribe_unknown_ws_is_noop(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
mgr.subscribe(ws, "exp-1") # Not connected
|
||||||
|
assert "exp-1" not in mgr._experiment_subs
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unsubscribe_unknown_ws_is_noop(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
mgr.unsubscribe(ws, "exp-1") # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Broadcasting
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBroadcast:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broadcast_global(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
|
||||||
|
await mgr.broadcast_global({"type": "test"})
|
||||||
|
ws.send_json.assert_awaited_once_with({"type": "test"})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broadcast_to_experiment(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws_exp = _make_ws()
|
||||||
|
ws_global = _make_ws()
|
||||||
|
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
|
||||||
|
await mgr.connect(ws_global)
|
||||||
|
|
||||||
|
await mgr.broadcast_to_experiment("exp-1", {"type": "test"})
|
||||||
|
ws_exp.send_json.assert_awaited_once()
|
||||||
|
ws_global.send_json.assert_not_awaited()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broadcast_routes_by_experiment_id(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws_exp = _make_ws()
|
||||||
|
ws_global = _make_ws()
|
||||||
|
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
|
||||||
|
await mgr.connect(ws_global)
|
||||||
|
|
||||||
|
msg = {"type": "run.completed", "experiment_id": "exp-1"}
|
||||||
|
await mgr.broadcast(msg)
|
||||||
|
|
||||||
|
# experiment subscriber gets the message (from broadcast_to_experiment)
|
||||||
|
assert ws_exp.send_json.await_count == 1
|
||||||
|
# global subscriber gets it (from broadcast_global)
|
||||||
|
assert ws_global.send_json.await_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broadcast_without_experiment_id_goes_global_only(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws_exp = _make_ws()
|
||||||
|
ws_global = _make_ws()
|
||||||
|
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
|
||||||
|
await mgr.connect(ws_global)
|
||||||
|
|
||||||
|
await mgr.broadcast({"type": "system.status"})
|
||||||
|
ws_exp.send_json.assert_not_awaited()
|
||||||
|
ws_global.send_json.assert_awaited_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broadcast_removes_dead_connections(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws_good = _make_ws()
|
||||||
|
ws_dead = _make_ws()
|
||||||
|
ws_dead.send_json = AsyncMock(side_effect=RuntimeError("closed"))
|
||||||
|
await mgr.connect(ws_good)
|
||||||
|
await mgr.connect(ws_dead)
|
||||||
|
|
||||||
|
await mgr.broadcast_global({"type": "test"})
|
||||||
|
|
||||||
|
assert mgr.connection_count == 1
|
||||||
|
assert id(ws_dead) not in mgr._connections
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broadcast_to_multiple_experiment_subscribers(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws1 = _make_ws()
|
||||||
|
ws2 = _make_ws()
|
||||||
|
await mgr.connect(ws1, experiment_ids=["exp-1"])
|
||||||
|
await mgr.connect(ws2, experiment_ids=["exp-1"])
|
||||||
|
|
||||||
|
await mgr.broadcast_to_experiment("exp-1", {"type": "test"})
|
||||||
|
ws1.send_json.assert_awaited_once()
|
||||||
|
ws2.send_json.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Replay
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplay:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replay_global_events(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
|
||||||
|
# Store some events
|
||||||
|
await mgr.broadcast_global({"type": "event1"})
|
||||||
|
await mgr.broadcast_global({"type": "event2"})
|
||||||
|
|
||||||
|
# Reset send tracking
|
||||||
|
ws.send_json.reset_mock()
|
||||||
|
|
||||||
|
count = await mgr.replay(ws)
|
||||||
|
assert count == 2
|
||||||
|
assert ws.send_json.await_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replay_experiment_events(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws, experiment_ids=["exp-1"])
|
||||||
|
|
||||||
|
await mgr.broadcast_to_experiment("exp-1", {"type": "e1"})
|
||||||
|
await mgr.broadcast_to_experiment("exp-1", {"type": "e2"})
|
||||||
|
await mgr.broadcast_to_experiment("exp-1", {"type": "e3"})
|
||||||
|
|
||||||
|
ws.send_json.reset_mock()
|
||||||
|
|
||||||
|
count = await mgr.replay(ws, experiment_id="exp-1")
|
||||||
|
assert count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replay_with_since_ts(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
|
||||||
|
# Manually insert events with known timestamps
|
||||||
|
ts_old = time.time() - 100
|
||||||
|
ts_new = time.time()
|
||||||
|
mgr._global_replay.append({"type": "old", "_ts": ts_old})
|
||||||
|
mgr._global_replay.append({"type": "new", "_ts": ts_new})
|
||||||
|
|
||||||
|
count = await mgr.replay(ws, since_ts=ts_old + 1)
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replay_with_limit(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()})
|
||||||
|
|
||||||
|
count = await mgr.replay(ws, limit=3)
|
||||||
|
assert count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replay_empty_buffer(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
|
||||||
|
count = await mgr.replay(ws)
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replay_respects_buffer_size(self):
|
||||||
|
mgr = WebSocketManager(replay_size=5)
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
await mgr.broadcast_global({"type": f"event{i}"})
|
||||||
|
|
||||||
|
ws.send_json.reset_mock()
|
||||||
|
count = await mgr.replay(ws)
|
||||||
|
assert count == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replay_stops_on_send_failure(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def fail_after_one(msg):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count > 1:
|
||||||
|
raise RuntimeError("closed")
|
||||||
|
|
||||||
|
ws.send_json = AsyncMock(side_effect=fail_after_one)
|
||||||
|
await mgr.connect(ws)
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()})
|
||||||
|
|
||||||
|
count = await mgr.replay(ws)
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Redis pub/sub bridge
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedisListener:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_redis_message_json(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
|
||||||
|
msg = json.dumps({"type": "run.completed", "experiment_id": "exp-1"})
|
||||||
|
await mgr._handle_redis_message(msg)
|
||||||
|
|
||||||
|
# Message should have been broadcast
|
||||||
|
assert ws.send_json.await_count >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_redis_message_bytes(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
|
||||||
|
msg = json.dumps({"type": "test"}).encode("utf-8")
|
||||||
|
await mgr._handle_redis_message(msg)
|
||||||
|
assert ws.send_json.await_count >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_redis_message_invalid_json(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
|
||||||
|
# Should not raise, just log warning
|
||||||
|
await mgr._handle_redis_message("not-json{{{")
|
||||||
|
ws.send_json.assert_not_awaited()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_redis_listener_none_is_noop(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
await mgr.start_redis_listener(None)
|
||||||
|
assert mgr._redis_listener_task is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_redis_listener_when_none(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
await mgr.stop_redis_listener() # Should not raise
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_cancels_task(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
|
||||||
|
# Create a long-running dummy task
|
||||||
|
async def dummy():
|
||||||
|
await asyncio.sleep(999)
|
||||||
|
|
||||||
|
mgr._redis_listener_task = asyncio.create_task(dummy())
|
||||||
|
await mgr.stop_redis_listener()
|
||||||
|
assert mgr._redis_listener_task is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sync_redis_listener_processes_messages(self):
|
||||||
|
"""Test the sync Redis pubsub path with a mock."""
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
|
||||||
|
# Mock a sync Redis client
|
||||||
|
mock_pubsub = MagicMock()
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def get_message_side_effect(ignore_subscribe_messages=True, timeout=1.0):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return {
|
||||||
|
"type": "message",
|
||||||
|
"data": json.dumps({"type": "test_event"}),
|
||||||
|
}
|
||||||
|
# After first message, raise to stop the loop
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
|
mock_pubsub.get_message = MagicMock(side_effect=get_message_side_effect)
|
||||||
|
mock_pubsub.subscribe = MagicMock()
|
||||||
|
mock_pubsub.unsubscribe = MagicMock()
|
||||||
|
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
mock_redis.pubsub.return_value = mock_pubsub
|
||||||
|
|
||||||
|
# Run the sync listener — it will process one message then get cancelled
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await mgr._redis_listen_sync(mock_redis, "promptlooper:events")
|
||||||
|
|
||||||
|
mock_pubsub.subscribe.assert_called_once_with("promptlooper:events")
|
||||||
|
assert ws.send_json.await_count >= 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Stats
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestStats:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stats_empty(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
stats = mgr.stats()
|
||||||
|
assert stats["total_connections"] == 0
|
||||||
|
assert stats["global_subscribers"] == 0
|
||||||
|
assert stats["experiment_subscriptions"] == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stats_with_connections(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws1 = _make_ws()
|
||||||
|
ws2 = _make_ws()
|
||||||
|
await mgr.connect(ws1)
|
||||||
|
await mgr.connect(ws2, experiment_ids=["exp-1"])
|
||||||
|
|
||||||
|
stats = mgr.stats()
|
||||||
|
assert stats["total_connections"] == 2
|
||||||
|
assert stats["global_subscribers"] == 1
|
||||||
|
assert stats["experiment_subscriptions"]["exp-1"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_experiment_ids_property(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws, experiment_ids=["exp-a", "exp-b"])
|
||||||
|
|
||||||
|
ids = mgr.experiment_ids
|
||||||
|
assert set(ids) == {"exp-a", "exp-b"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connection_count_property(self):
|
||||||
|
mgr = WebSocketManager()
|
||||||
|
assert mgr.connection_count == 0
|
||||||
|
ws = _make_ws()
|
||||||
|
await mgr.connect(ws)
|
||||||
|
assert mgr.connection_count == 1
|
||||||
|
mgr.disconnect(ws)
|
||||||
|
assert mgr.connection_count == 0
|
||||||
358
backend/websocket/manager.py
Normal file
358
backend/websocket/manager.py
Normal file
|
|
@ -0,0 +1,358 @@
|
||||||
|
"""WebSocket connection manager for PromptLooper.
|
||||||
|
|
||||||
|
Manages per-experiment and global WebSocket connections, bridges Redis
|
||||||
|
pub/sub events to connected clients, and supports message replay on
|
||||||
|
reconnection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Default number of events to retain for replay on reconnect
|
||||||
|
DEFAULT_REPLAY_SIZE = 50
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConnectionEntry:
|
||||||
|
"""Tracks a single WebSocket connection and its subscriptions."""
|
||||||
|
|
||||||
|
websocket: WebSocket
|
||||||
|
experiment_ids: set[str] = field(default_factory=set)
|
||||||
|
connected_at: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketManager:
|
||||||
|
"""Manages WebSocket connections with per-experiment routing and replay.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Per-experiment and global subscriptions
|
||||||
|
- Redis pub/sub bridge (when Redis is available)
|
||||||
|
- Message replay on reconnection (last N events per experiment + global)
|
||||||
|
- Clean disconnect handling
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, replay_size: int = DEFAULT_REPLAY_SIZE) -> None:
|
||||||
|
self._replay_size = replay_size
|
||||||
|
# All active connections keyed by id(websocket)
|
||||||
|
self._connections: dict[int, ConnectionEntry] = {}
|
||||||
|
# Experiment-specific connection sets: experiment_id -> set of ws ids
|
||||||
|
self._experiment_subs: dict[str, set[int]] = {}
|
||||||
|
# Global subscribers (receive all events)
|
||||||
|
self._global_subs: set[int] = set()
|
||||||
|
# Event replay buffers
|
||||||
|
self._global_replay: deque[dict[str, Any]] = deque(maxlen=replay_size)
|
||||||
|
self._experiment_replay: dict[str, deque[dict[str, Any]]] = {}
|
||||||
|
# Redis listener task handle
|
||||||
|
self._redis_listener_task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Connection lifecycle
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def connect(
|
||||||
|
self,
|
||||||
|
websocket: WebSocket,
|
||||||
|
experiment_ids: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Accept a WebSocket connection and register subscriptions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket: The incoming WebSocket connection.
|
||||||
|
experiment_ids: Optional list of experiment IDs to subscribe to.
|
||||||
|
If empty/None the connection receives global broadcasts only.
|
||||||
|
"""
|
||||||
|
await websocket.accept()
|
||||||
|
ws_id = id(websocket)
|
||||||
|
entry = ConnectionEntry(websocket=websocket)
|
||||||
|
self._connections[ws_id] = entry
|
||||||
|
|
||||||
|
if experiment_ids:
|
||||||
|
for exp_id in experiment_ids:
|
||||||
|
self._subscribe_experiment(ws_id, exp_id)
|
||||||
|
else:
|
||||||
|
self._global_subs.add(ws_id)
|
||||||
|
|
||||||
|
def disconnect(self, websocket: WebSocket) -> None:
|
||||||
|
"""Remove a WebSocket from all tracking structures."""
|
||||||
|
ws_id = id(websocket)
|
||||||
|
entry = self._connections.pop(ws_id, None)
|
||||||
|
if entry is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove from global subs
|
||||||
|
self._global_subs.discard(ws_id)
|
||||||
|
|
||||||
|
# Remove from all experiment subs
|
||||||
|
for exp_id in list(entry.experiment_ids):
|
||||||
|
exp_set = self._experiment_subs.get(exp_id)
|
||||||
|
if exp_set is not None:
|
||||||
|
exp_set.discard(ws_id)
|
||||||
|
if not exp_set:
|
||||||
|
del self._experiment_subs[exp_id]
|
||||||
|
|
||||||
|
def _subscribe_experiment(self, ws_id: int, experiment_id: str) -> None:
|
||||||
|
"""Subscribe a connection to a specific experiment's events."""
|
||||||
|
entry = self._connections.get(ws_id)
|
||||||
|
if entry is None:
|
||||||
|
return
|
||||||
|
entry.experiment_ids.add(experiment_id)
|
||||||
|
if experiment_id not in self._experiment_subs:
|
||||||
|
self._experiment_subs[experiment_id] = set()
|
||||||
|
self._experiment_subs[experiment_id].add(ws_id)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Subscription management (runtime)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def subscribe(self, websocket: WebSocket, experiment_id: str) -> None:
|
||||||
|
"""Add an experiment subscription for an existing connection."""
|
||||||
|
ws_id = id(websocket)
|
||||||
|
if ws_id in self._connections:
|
||||||
|
self._subscribe_experiment(ws_id, experiment_id)
|
||||||
|
|
||||||
|
def unsubscribe(self, websocket: WebSocket, experiment_id: str) -> None:
|
||||||
|
"""Remove an experiment subscription for an existing connection."""
|
||||||
|
ws_id = id(websocket)
|
||||||
|
entry = self._connections.get(ws_id)
|
||||||
|
if entry is None:
|
||||||
|
return
|
||||||
|
entry.experiment_ids.discard(experiment_id)
|
||||||
|
exp_set = self._experiment_subs.get(experiment_id)
|
||||||
|
if exp_set is not None:
|
||||||
|
exp_set.discard(ws_id)
|
||||||
|
if not exp_set:
|
||||||
|
del self._experiment_subs[experiment_id]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Broadcasting
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def broadcast_to_experiment(
|
||||||
|
self, experiment_id: str, message: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Send a message to all connections subscribed to an experiment."""
|
||||||
|
# Store in replay buffer
|
||||||
|
if experiment_id not in self._experiment_replay:
|
||||||
|
self._experiment_replay[experiment_id] = deque(maxlen=self._replay_size)
|
||||||
|
message_with_ts = {**message, "_ts": time.time()}
|
||||||
|
self._experiment_replay[experiment_id].append(message_with_ts)
|
||||||
|
|
||||||
|
ws_ids = self._experiment_subs.get(experiment_id, set())
|
||||||
|
await self._send_to_many(ws_ids, message)
|
||||||
|
|
||||||
|
async def broadcast_global(self, message: dict[str, Any]) -> None:
|
||||||
|
"""Send a message to all global subscribers."""
|
||||||
|
message_with_ts = {**message, "_ts": time.time()}
|
||||||
|
self._global_replay.append(message_with_ts)
|
||||||
|
await self._send_to_many(self._global_subs, message)
|
||||||
|
|
||||||
|
async def broadcast(self, message: dict[str, Any]) -> None:
|
||||||
|
"""Route a message to the right subscribers based on experiment_id.
|
||||||
|
|
||||||
|
If the message contains an ``experiment_id`` field, it is sent to
|
||||||
|
that experiment's subscribers *and* to global subscribers. Otherwise
|
||||||
|
it is sent only to global subscribers.
|
||||||
|
"""
|
||||||
|
experiment_id = message.get("experiment_id")
|
||||||
|
if experiment_id:
|
||||||
|
await self.broadcast_to_experiment(str(experiment_id), message)
|
||||||
|
await self.broadcast_global(message)
|
||||||
|
|
||||||
|
async def _send_to_many(
|
||||||
|
self, ws_ids: set[int], message: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Send a JSON message to a set of connections, disconnecting failures."""
|
||||||
|
dead: list[int] = []
|
||||||
|
for ws_id in list(ws_ids):
|
||||||
|
entry = self._connections.get(ws_id)
|
||||||
|
if entry is None:
|
||||||
|
dead.append(ws_id)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await entry.websocket.send_json(message)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("WebSocket send failed, removing connection %s", ws_id)
|
||||||
|
dead.append(ws_id)
|
||||||
|
# Clean up dead connections
|
||||||
|
for ws_id in dead:
|
||||||
|
ws_ids.discard(ws_id)
|
||||||
|
entry = self._connections.pop(ws_id, None)
|
||||||
|
if entry:
|
||||||
|
self._global_subs.discard(ws_id)
|
||||||
|
for exp_id in entry.experiment_ids:
|
||||||
|
exp_set = self._experiment_subs.get(exp_id)
|
||||||
|
if exp_set:
|
||||||
|
exp_set.discard(ws_id)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Message replay
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def replay(
|
||||||
|
self,
|
||||||
|
websocket: WebSocket,
|
||||||
|
experiment_id: str | None = None,
|
||||||
|
since_ts: float | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Replay recent events to a connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
websocket: Target connection.
|
||||||
|
experiment_id: If provided, replay experiment-specific events;
|
||||||
|
otherwise replay global events.
|
||||||
|
since_ts: Only replay events after this Unix timestamp.
|
||||||
|
limit: Max number of events to replay (default: all in buffer).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of events replayed.
|
||||||
|
"""
|
||||||
|
if experiment_id:
|
||||||
|
buffer = self._experiment_replay.get(experiment_id, deque())
|
||||||
|
else:
|
||||||
|
buffer = self._global_replay
|
||||||
|
|
||||||
|
events = list(buffer)
|
||||||
|
if since_ts is not None:
|
||||||
|
events = [e for e in events if e.get("_ts", 0) > since_ts]
|
||||||
|
if limit is not None:
|
||||||
|
events = events[-limit:]
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for event in events:
|
||||||
|
try:
|
||||||
|
await websocket.send_json(event)
|
||||||
|
count += 1
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
return count
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Redis pub/sub bridge
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def start_redis_listener(self, redis_client: Any) -> None:
|
||||||
|
"""Start a background task that subscribes to Redis pub/sub and
|
||||||
|
broadcasts received messages to connected WebSocket clients.
|
||||||
|
|
||||||
|
Uses async Redis pubsub if available, otherwise runs the blocking
|
||||||
|
listener in a thread executor.
|
||||||
|
"""
|
||||||
|
if redis_client is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._redis_listener_task = asyncio.create_task(
|
||||||
|
self._redis_listen_loop(redis_client)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _redis_listen_loop(self, redis_client: Any) -> None:
|
||||||
|
"""Listen for Redis pub/sub messages and broadcast them."""
|
||||||
|
channel = "promptlooper:events"
|
||||||
|
try:
|
||||||
|
# Check if this is an async redis client (aioredis / redis.asyncio)
|
||||||
|
if hasattr(redis_client, "pubsub") and asyncio.iscoroutinefunction(
|
||||||
|
getattr(redis_client, "subscribe", None)
|
||||||
|
):
|
||||||
|
await self._redis_listen_async(redis_client, channel)
|
||||||
|
else:
|
||||||
|
await self._redis_listen_sync(redis_client, channel)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Redis listener cancelled")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Redis listener error")
|
||||||
|
|
||||||
|
async def _redis_listen_async(self, redis_client: Any, channel: str) -> None:
|
||||||
|
"""Async Redis pub/sub listener."""
|
||||||
|
pubsub = redis_client.pubsub()
|
||||||
|
await pubsub.subscribe(channel)
|
||||||
|
try:
|
||||||
|
async for raw_message in pubsub.listen():
|
||||||
|
if raw_message["type"] != "message":
|
||||||
|
continue
|
||||||
|
await self._handle_redis_message(raw_message["data"])
|
||||||
|
finally:
|
||||||
|
await pubsub.unsubscribe(channel)
|
||||||
|
|
||||||
|
async def _redis_listen_sync(self, redis_client: Any, channel: str) -> None:
|
||||||
|
"""Sync Redis pub/sub listener (run blocking calls in executor)."""
|
||||||
|
pubsub = redis_client.pubsub()
|
||||||
|
pubsub.subscribe(channel)
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
raw_message = await loop.run_in_executor(
|
||||||
|
None, lambda: pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||||
|
)
|
||||||
|
if raw_message is None:
|
||||||
|
continue
|
||||||
|
if raw_message["type"] != "message":
|
||||||
|
continue
|
||||||
|
await self._handle_redis_message(raw_message["data"])
|
||||||
|
finally:
|
||||||
|
pubsub.unsubscribe(channel)
|
||||||
|
|
||||||
|
async def _handle_redis_message(self, data: str | bytes) -> None:
|
||||||
|
"""Parse a Redis message and broadcast it."""
|
||||||
|
try:
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
data = data.decode("utf-8")
|
||||||
|
message = json.loads(data)
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
|
logger.warning("Invalid Redis message: %s", data)
|
||||||
|
return
|
||||||
|
await self.broadcast(message)
|
||||||
|
|
||||||
|
async def stop_redis_listener(self) -> None:
|
||||||
|
"""Cancel the Redis listener background task."""
|
||||||
|
if self._redis_listener_task is not None:
|
||||||
|
self._redis_listener_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._redis_listener_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._redis_listener_task = None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Status
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connection_count(self) -> int:
|
||||||
|
"""Total number of active connections."""
|
||||||
|
return len(self._connections)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def experiment_ids(self) -> list[str]:
|
||||||
|
"""List of experiment IDs with active subscriptions."""
|
||||||
|
return list(self._experiment_subs.keys())
|
||||||
|
|
||||||
|
def stats(self) -> dict[str, Any]:
|
||||||
|
"""Return connection statistics."""
|
||||||
|
return {
|
||||||
|
"total_connections": self.connection_count,
|
||||||
|
"global_subscribers": len(self._global_subs),
|
||||||
|
"experiment_subscriptions": {
|
||||||
|
exp_id: len(ws_ids)
|
||||||
|
for exp_id, ws_ids in self._experiment_subs.items()
|
||||||
|
},
|
||||||
|
"replay_buffer_sizes": {
|
||||||
|
"global": len(self._global_replay),
|
||||||
|
**{
|
||||||
|
exp_id: len(buf)
|
||||||
|
for exp_id, buf in self._experiment_replay.items()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton for use across the application
|
||||||
|
manager = WebSocketManager()
|
||||||
Loading…
Add table
Reference in a new issue