MAESTRO: Implement WebSocket connection manager with per-experiment routing, Redis pub/sub bridge, and message replay
- WebSocketManager in backend/websocket/manager.py with per-experiment and global subscriptions - Redis pub/sub bridge (sync + async) broadcasting events to relevant WebSocket clients - Deque-based replay buffers with since_ts/limit filtering for reconnection support - Runtime subscribe/unsubscribe and stats API - Enhanced /ws endpoint in main.py with subscribe/unsubscribe/replay actions - 35 tests in test_ws_manager.py, all passing
This commit is contained in:
parent
e42117c8ee
commit
30fd15ec7a
5 changed files with 886 additions and 37 deletions
|
|
@ -44,7 +44,8 @@ Implement the core experiment execution engine: LLM adapters, response caching,
|
|||
- [x] Implement backend/routers/export.py — export best config in JSON, .env, and YAML formats as defined in the spec. Include metadata (score, experiment name, timestamp). The report endpoint should generate a markdown summary of the experiment: config space explored, top 5 configs, score distributions, token usage, timing stats.
|
||||
<!-- Completed: Full export router with 4 endpoints: /best (JSON with weighted score, metadata), /env (flattened KEY=VALUE with comments), /yaml (simple serializer, no PyYAML dependency), /report (markdown with config space, top N configs, score distributions, token usage, timing stats). Auth required on all endpoints. 34 tests in test_export.py, all passing. -->
|
||||
|
||||
- [ ] Implement backend/websocket/manager.py — WebSocket connection manager that: maintains active connections per experiment and globally, receives Redis pub/sub messages and broadcasts to relevant connections, handles connection/disconnection cleanly, supports reconnection with message replay (last N events).
|
||||
- [x] Implement backend/websocket/manager.py — WebSocket connection manager that: maintains active connections per experiment and globally, receives Redis pub/sub messages and broadcasts to relevant connections, handles connection/disconnection cleanly, supports reconnection with message replay (last N events).
|
||||
<!-- Completed: WebSocketManager with per-experiment and global subscriptions, Redis pub/sub bridge (sync + async), deque-based replay buffers with since_ts/limit filtering, clean disconnect cleanup, runtime subscribe/unsubscribe, stats API. Integrated into main.py with enhanced /ws endpoint supporting subscribe/unsubscribe/replay actions and query-param-based initial subscriptions. 35 tests in test_ws_manager.py, all passing. -->
|
||||
|
||||
- [ ] Implement backend/routers/webhooks.py — CRUD for webhook configs. When events occur (in runner.py and sweep.py), dispatch webhook calls asynchronously via Celery. Include retry logic (3 attempts) and log delivery status.
|
||||
|
||||
|
|
|
|||
|
|
@ -67,28 +67,9 @@ def get_redis():
|
|||
# WebSocket connection manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manage active WebSocket connections."""
|
||||
from websocket.manager import WebSocketManager
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: list[WebSocket] = []
|
||||
|
||||
async def connect(self, websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket) -> None:
|
||||
self.active_connections.remove(websocket)
|
||||
|
||||
async def broadcast(self, message: dict) -> None:
|
||||
for connection in list(self.active_connections):
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception:
|
||||
self.disconnect(connection)
|
||||
|
||||
|
||||
ws_manager = ConnectionManager()
|
||||
ws_manager = WebSocketManager()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -100,8 +81,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||
"""Startup and shutdown lifecycle hooks."""
|
||||
_init_db()
|
||||
_init_redis()
|
||||
# Start WebSocket ↔ Redis bridge
|
||||
await ws_manager.start_redis_listener(_redis_client)
|
||||
yield
|
||||
# Shutdown: clean up connections
|
||||
await ws_manager.stop_redis_listener()
|
||||
if _redis_client is not None:
|
||||
_redis_client.close()
|
||||
if engine is not None:
|
||||
|
|
@ -167,14 +151,37 @@ def health_check() -> dict:
|
|||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket) -> None:
|
||||
"""WebSocket connection for real-time dashboard updates."""
|
||||
await ws_manager.connect(websocket)
|
||||
"""WebSocket connection for real-time dashboard updates.
|
||||
|
||||
Clients may send JSON messages to control their subscriptions:
|
||||
- {"action": "subscribe", "experiment_id": "..."}
|
||||
- {"action": "unsubscribe", "experiment_id": "..."}
|
||||
- {"action": "replay", "experiment_id": "...", "since_ts": ..., "limit": ...}
|
||||
"""
|
||||
# Parse query params for initial experiment subscriptions
|
||||
exp_ids_param = websocket.query_params.get("experiment_ids", "")
|
||||
experiment_ids = [e.strip() for e in exp_ids_param.split(",") if e.strip()] or None
|
||||
await ws_manager.connect(websocket, experiment_ids=experiment_ids)
|
||||
try:
|
||||
while True:
|
||||
# Keep connection alive; handle incoming messages if needed
|
||||
data = await websocket.receive_json()
|
||||
# Echo back or handle client messages in future
|
||||
await websocket.send_json({"type": "ack", "data": data})
|
||||
action = data.get("action", "")
|
||||
if action == "subscribe" and data.get("experiment_id"):
|
||||
ws_manager.subscribe(websocket, data["experiment_id"])
|
||||
await websocket.send_json({"type": "ack", "action": "subscribe", "experiment_id": data["experiment_id"]})
|
||||
elif action == "unsubscribe" and data.get("experiment_id"):
|
||||
ws_manager.unsubscribe(websocket, data["experiment_id"])
|
||||
await websocket.send_json({"type": "ack", "action": "unsubscribe", "experiment_id": data["experiment_id"]})
|
||||
elif action == "replay":
|
||||
count = await ws_manager.replay(
|
||||
websocket,
|
||||
experiment_id=data.get("experiment_id"),
|
||||
since_ts=data.get("since_ts"),
|
||||
limit=data.get("limit"),
|
||||
)
|
||||
await websocket.send_json({"type": "ack", "action": "replay", "events_replayed": count})
|
||||
else:
|
||||
await websocket.send_json({"type": "ack", "data": data})
|
||||
except WebSocketDisconnect:
|
||||
ws_manager.disconnect(websocket)
|
||||
|
||||
|
|
|
|||
|
|
@ -77,11 +77,11 @@ class TestWebSocket:
|
|||
|
||||
def test_websocket_disconnect_cleanup(self, client):
|
||||
from main import ws_manager
|
||||
initial_count = len(ws_manager.active_connections)
|
||||
initial_count = ws_manager.connection_count
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
assert len(ws_manager.active_connections) == initial_count + 1
|
||||
assert ws_manager.connection_count == initial_count + 1
|
||||
# After disconnect, connection should be removed
|
||||
assert len(ws_manager.active_connections) == initial_count
|
||||
assert ws_manager.connection_count == initial_count
|
||||
|
||||
|
||||
class TestRouterMounting:
|
||||
|
|
@ -97,16 +97,15 @@ class TestRouterMounting:
|
|||
|
||||
|
||||
class TestConnectionManager:
|
||||
def test_broadcast_removes_dead_connections(self):
|
||||
"""ConnectionManager.broadcast skips and removes broken connections."""
|
||||
from main import ConnectionManager
|
||||
manager = ConnectionManager()
|
||||
# No connections — broadcast should not raise
|
||||
def test_broadcast_on_empty_manager(self):
|
||||
"""WebSocketManager.broadcast_global on empty manager should not raise."""
|
||||
from websocket.manager import WebSocketManager
|
||||
manager = WebSocketManager()
|
||||
import asyncio
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
manager.broadcast({"test": True})
|
||||
manager.broadcast_global({"test": True})
|
||||
)
|
||||
assert len(manager.active_connections) == 0
|
||||
assert manager.connection_count == 0
|
||||
|
||||
|
||||
class TestGetDb:
|
||||
|
|
|
|||
484
backend/tests/test_ws_manager.py
Normal file
484
backend/tests/test_ws_manager.py
Normal file
|
|
@ -0,0 +1,484 @@
|
|||
"""Tests for the WebSocket connection manager."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from websocket.manager import ConnectionEntry, WebSocketManager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ws(accept_side_effect=None) -> AsyncMock:
|
||||
"""Create a mock WebSocket."""
|
||||
ws = AsyncMock()
|
||||
ws.accept = AsyncMock(side_effect=accept_side_effect)
|
||||
ws.send_json = AsyncMock()
|
||||
ws.receive_json = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connection lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConnect:
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_and_register_global(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
|
||||
assert mgr.connection_count == 1
|
||||
assert id(ws) in mgr._global_subs
|
||||
ws.accept.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_with_experiment_ids(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws, experiment_ids=["exp-1", "exp-2"])
|
||||
|
||||
assert mgr.connection_count == 1
|
||||
assert id(ws) not in mgr._global_subs
|
||||
assert id(ws) in mgr._experiment_subs["exp-1"]
|
||||
assert id(ws) in mgr._experiment_subs["exp-2"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_connections(self):
|
||||
mgr = WebSocketManager()
|
||||
ws1 = _make_ws()
|
||||
ws2 = _make_ws()
|
||||
await mgr.connect(ws1)
|
||||
await mgr.connect(ws2, experiment_ids=["exp-1"])
|
||||
|
||||
assert mgr.connection_count == 2
|
||||
|
||||
|
||||
class TestDisconnect:
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_global(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
mgr.disconnect(ws)
|
||||
|
||||
assert mgr.connection_count == 0
|
||||
assert id(ws) not in mgr._global_subs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_experiment(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws, experiment_ids=["exp-1"])
|
||||
mgr.disconnect(ws)
|
||||
|
||||
assert mgr.connection_count == 0
|
||||
assert "exp-1" not in mgr._experiment_subs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_unknown_ws_is_noop(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
mgr.disconnect(ws) # Should not raise
|
||||
assert mgr.connection_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_cleans_up_empty_experiment_sets(self):
|
||||
mgr = WebSocketManager()
|
||||
ws1 = _make_ws()
|
||||
ws2 = _make_ws()
|
||||
await mgr.connect(ws1, experiment_ids=["exp-1"])
|
||||
await mgr.connect(ws2, experiment_ids=["exp-1"])
|
||||
|
||||
mgr.disconnect(ws1)
|
||||
assert "exp-1" in mgr._experiment_subs
|
||||
assert len(mgr._experiment_subs["exp-1"]) == 1
|
||||
|
||||
mgr.disconnect(ws2)
|
||||
assert "exp-1" not in mgr._experiment_subs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subscribe / unsubscribe at runtime
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubscriptions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_adds_experiment(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
mgr.subscribe(ws, "exp-42")
|
||||
|
||||
assert id(ws) in mgr._experiment_subs["exp-42"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_removes_experiment(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws, experiment_ids=["exp-1"])
|
||||
mgr.unsubscribe(ws, "exp-1")
|
||||
|
||||
assert "exp-1" not in mgr._experiment_subs
|
||||
entry = mgr._connections[id(ws)]
|
||||
assert "exp-1" not in entry.experiment_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_unknown_ws_is_noop(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
mgr.subscribe(ws, "exp-1") # Not connected
|
||||
assert "exp-1" not in mgr._experiment_subs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_unknown_ws_is_noop(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
mgr.unsubscribe(ws, "exp-1") # Should not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Broadcasting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBroadcast:
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_global(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
|
||||
await mgr.broadcast_global({"type": "test"})
|
||||
ws.send_json.assert_awaited_once_with({"type": "test"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_to_experiment(self):
|
||||
mgr = WebSocketManager()
|
||||
ws_exp = _make_ws()
|
||||
ws_global = _make_ws()
|
||||
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
|
||||
await mgr.connect(ws_global)
|
||||
|
||||
await mgr.broadcast_to_experiment("exp-1", {"type": "test"})
|
||||
ws_exp.send_json.assert_awaited_once()
|
||||
ws_global.send_json.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_routes_by_experiment_id(self):
|
||||
mgr = WebSocketManager()
|
||||
ws_exp = _make_ws()
|
||||
ws_global = _make_ws()
|
||||
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
|
||||
await mgr.connect(ws_global)
|
||||
|
||||
msg = {"type": "run.completed", "experiment_id": "exp-1"}
|
||||
await mgr.broadcast(msg)
|
||||
|
||||
# experiment subscriber gets the message (from broadcast_to_experiment)
|
||||
assert ws_exp.send_json.await_count == 1
|
||||
# global subscriber gets it (from broadcast_global)
|
||||
assert ws_global.send_json.await_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_without_experiment_id_goes_global_only(self):
|
||||
mgr = WebSocketManager()
|
||||
ws_exp = _make_ws()
|
||||
ws_global = _make_ws()
|
||||
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
|
||||
await mgr.connect(ws_global)
|
||||
|
||||
await mgr.broadcast({"type": "system.status"})
|
||||
ws_exp.send_json.assert_not_awaited()
|
||||
ws_global.send_json.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_removes_dead_connections(self):
|
||||
mgr = WebSocketManager()
|
||||
ws_good = _make_ws()
|
||||
ws_dead = _make_ws()
|
||||
ws_dead.send_json = AsyncMock(side_effect=RuntimeError("closed"))
|
||||
await mgr.connect(ws_good)
|
||||
await mgr.connect(ws_dead)
|
||||
|
||||
await mgr.broadcast_global({"type": "test"})
|
||||
|
||||
assert mgr.connection_count == 1
|
||||
assert id(ws_dead) not in mgr._connections
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_to_multiple_experiment_subscribers(self):
|
||||
mgr = WebSocketManager()
|
||||
ws1 = _make_ws()
|
||||
ws2 = _make_ws()
|
||||
await mgr.connect(ws1, experiment_ids=["exp-1"])
|
||||
await mgr.connect(ws2, experiment_ids=["exp-1"])
|
||||
|
||||
await mgr.broadcast_to_experiment("exp-1", {"type": "test"})
|
||||
ws1.send_json.assert_awaited_once()
|
||||
ws2.send_json.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Replay
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReplay:
|
||||
@pytest.mark.asyncio
|
||||
async def test_replay_global_events(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
|
||||
# Store some events
|
||||
await mgr.broadcast_global({"type": "event1"})
|
||||
await mgr.broadcast_global({"type": "event2"})
|
||||
|
||||
# Reset send tracking
|
||||
ws.send_json.reset_mock()
|
||||
|
||||
count = await mgr.replay(ws)
|
||||
assert count == 2
|
||||
assert ws.send_json.await_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replay_experiment_events(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws, experiment_ids=["exp-1"])
|
||||
|
||||
await mgr.broadcast_to_experiment("exp-1", {"type": "e1"})
|
||||
await mgr.broadcast_to_experiment("exp-1", {"type": "e2"})
|
||||
await mgr.broadcast_to_experiment("exp-1", {"type": "e3"})
|
||||
|
||||
ws.send_json.reset_mock()
|
||||
|
||||
count = await mgr.replay(ws, experiment_id="exp-1")
|
||||
assert count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replay_with_since_ts(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
|
||||
# Manually insert events with known timestamps
|
||||
ts_old = time.time() - 100
|
||||
ts_new = time.time()
|
||||
mgr._global_replay.append({"type": "old", "_ts": ts_old})
|
||||
mgr._global_replay.append({"type": "new", "_ts": ts_new})
|
||||
|
||||
count = await mgr.replay(ws, since_ts=ts_old + 1)
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replay_with_limit(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
|
||||
for i in range(10):
|
||||
mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()})
|
||||
|
||||
count = await mgr.replay(ws, limit=3)
|
||||
assert count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replay_empty_buffer(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
|
||||
count = await mgr.replay(ws)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replay_respects_buffer_size(self):
|
||||
mgr = WebSocketManager(replay_size=5)
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
|
||||
for i in range(10):
|
||||
await mgr.broadcast_global({"type": f"event{i}"})
|
||||
|
||||
ws.send_json.reset_mock()
|
||||
count = await mgr.replay(ws)
|
||||
assert count == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replay_stops_on_send_failure(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
call_count = 0
|
||||
|
||||
async def fail_after_one(msg):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
raise RuntimeError("closed")
|
||||
|
||||
ws.send_json = AsyncMock(side_effect=fail_after_one)
|
||||
await mgr.connect(ws)
|
||||
|
||||
for i in range(5):
|
||||
mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()})
|
||||
|
||||
count = await mgr.replay(ws)
|
||||
assert count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Redis pub/sub bridge
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRedisListener:
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_redis_message_json(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
|
||||
msg = json.dumps({"type": "run.completed", "experiment_id": "exp-1"})
|
||||
await mgr._handle_redis_message(msg)
|
||||
|
||||
# Message should have been broadcast
|
||||
assert ws.send_json.await_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_redis_message_bytes(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
|
||||
msg = json.dumps({"type": "test"}).encode("utf-8")
|
||||
await mgr._handle_redis_message(msg)
|
||||
assert ws.send_json.await_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_redis_message_invalid_json(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
|
||||
# Should not raise, just log warning
|
||||
await mgr._handle_redis_message("not-json{{{")
|
||||
ws.send_json.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_redis_listener_none_is_noop(self):
|
||||
mgr = WebSocketManager()
|
||||
await mgr.start_redis_listener(None)
|
||||
assert mgr._redis_listener_task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_redis_listener_when_none(self):
|
||||
mgr = WebSocketManager()
|
||||
await mgr.stop_redis_listener() # Should not raise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_task(self):
|
||||
mgr = WebSocketManager()
|
||||
|
||||
# Create a long-running dummy task
|
||||
async def dummy():
|
||||
await asyncio.sleep(999)
|
||||
|
||||
mgr._redis_listener_task = asyncio.create_task(dummy())
|
||||
await mgr.stop_redis_listener()
|
||||
assert mgr._redis_listener_task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_redis_listener_processes_messages(self):
|
||||
"""Test the sync Redis pubsub path with a mock."""
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
|
||||
# Mock a sync Redis client
|
||||
mock_pubsub = MagicMock()
|
||||
call_count = 0
|
||||
|
||||
def get_message_side_effect(ignore_subscribe_messages=True, timeout=1.0):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return {
|
||||
"type": "message",
|
||||
"data": json.dumps({"type": "test_event"}),
|
||||
}
|
||||
# After first message, raise to stop the loop
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
mock_pubsub.get_message = MagicMock(side_effect=get_message_side_effect)
|
||||
mock_pubsub.subscribe = MagicMock()
|
||||
mock_pubsub.unsubscribe = MagicMock()
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pubsub.return_value = mock_pubsub
|
||||
|
||||
# Run the sync listener — it will process one message then get cancelled
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await mgr._redis_listen_sync(mock_redis, "promptlooper:events")
|
||||
|
||||
mock_pubsub.subscribe.assert_called_once_with("promptlooper:events")
|
||||
assert ws.send_json.await_count >= 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stats
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStats:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stats_empty(self):
|
||||
mgr = WebSocketManager()
|
||||
stats = mgr.stats()
|
||||
assert stats["total_connections"] == 0
|
||||
assert stats["global_subscribers"] == 0
|
||||
assert stats["experiment_subscriptions"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stats_with_connections(self):
|
||||
mgr = WebSocketManager()
|
||||
ws1 = _make_ws()
|
||||
ws2 = _make_ws()
|
||||
await mgr.connect(ws1)
|
||||
await mgr.connect(ws2, experiment_ids=["exp-1"])
|
||||
|
||||
stats = mgr.stats()
|
||||
assert stats["total_connections"] == 2
|
||||
assert stats["global_subscribers"] == 1
|
||||
assert stats["experiment_subscriptions"]["exp-1"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_experiment_ids_property(self):
|
||||
mgr = WebSocketManager()
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws, experiment_ids=["exp-a", "exp-b"])
|
||||
|
||||
ids = mgr.experiment_ids
|
||||
assert set(ids) == {"exp-a", "exp-b"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_count_property(self):
|
||||
mgr = WebSocketManager()
|
||||
assert mgr.connection_count == 0
|
||||
ws = _make_ws()
|
||||
await mgr.connect(ws)
|
||||
assert mgr.connection_count == 1
|
||||
mgr.disconnect(ws)
|
||||
assert mgr.connection_count == 0
|
||||
358
backend/websocket/manager.py
Normal file
358
backend/websocket/manager.py
Normal file
|
|
@ -0,0 +1,358 @@
|
|||
"""WebSocket connection manager for PromptLooper.
|
||||
|
||||
Manages per-experiment and global WebSocket connections, bridges Redis
|
||||
pub/sub events to connected clients, and supports message replay on
|
||||
reconnection.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default number of events to retain for replay on reconnect
|
||||
DEFAULT_REPLAY_SIZE = 50
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionEntry:
|
||||
"""Tracks a single WebSocket connection and its subscriptions."""
|
||||
|
||||
websocket: WebSocket
|
||||
experiment_ids: set[str] = field(default_factory=set)
|
||||
connected_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""Manages WebSocket connections with per-experiment routing and replay.
|
||||
|
||||
Features:
|
||||
- Per-experiment and global subscriptions
|
||||
- Redis pub/sub bridge (when Redis is available)
|
||||
- Message replay on reconnection (last N events per experiment + global)
|
||||
- Clean disconnect handling
|
||||
"""
|
||||
|
||||
def __init__(self, replay_size: int = DEFAULT_REPLAY_SIZE) -> None:
|
||||
self._replay_size = replay_size
|
||||
# All active connections keyed by id(websocket)
|
||||
self._connections: dict[int, ConnectionEntry] = {}
|
||||
# Experiment-specific connection sets: experiment_id -> set of ws ids
|
||||
self._experiment_subs: dict[str, set[int]] = {}
|
||||
# Global subscribers (receive all events)
|
||||
self._global_subs: set[int] = set()
|
||||
# Event replay buffers
|
||||
self._global_replay: deque[dict[str, Any]] = deque(maxlen=replay_size)
|
||||
self._experiment_replay: dict[str, deque[dict[str, Any]]] = {}
|
||||
# Redis listener task handle
|
||||
self._redis_listener_task: asyncio.Task | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
experiment_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Accept a WebSocket connection and register subscriptions.
|
||||
|
||||
Args:
|
||||
websocket: The incoming WebSocket connection.
|
||||
experiment_ids: Optional list of experiment IDs to subscribe to.
|
||||
If empty/None the connection receives global broadcasts only.
|
||||
"""
|
||||
await websocket.accept()
|
||||
ws_id = id(websocket)
|
||||
entry = ConnectionEntry(websocket=websocket)
|
||||
self._connections[ws_id] = entry
|
||||
|
||||
if experiment_ids:
|
||||
for exp_id in experiment_ids:
|
||||
self._subscribe_experiment(ws_id, exp_id)
|
||||
else:
|
||||
self._global_subs.add(ws_id)
|
||||
|
||||
def disconnect(self, websocket: WebSocket) -> None:
|
||||
"""Remove a WebSocket from all tracking structures."""
|
||||
ws_id = id(websocket)
|
||||
entry = self._connections.pop(ws_id, None)
|
||||
if entry is None:
|
||||
return
|
||||
|
||||
# Remove from global subs
|
||||
self._global_subs.discard(ws_id)
|
||||
|
||||
# Remove from all experiment subs
|
||||
for exp_id in list(entry.experiment_ids):
|
||||
exp_set = self._experiment_subs.get(exp_id)
|
||||
if exp_set is not None:
|
||||
exp_set.discard(ws_id)
|
||||
if not exp_set:
|
||||
del self._experiment_subs[exp_id]
|
||||
|
||||
def _subscribe_experiment(self, ws_id: int, experiment_id: str) -> None:
|
||||
"""Subscribe a connection to a specific experiment's events."""
|
||||
entry = self._connections.get(ws_id)
|
||||
if entry is None:
|
||||
return
|
||||
entry.experiment_ids.add(experiment_id)
|
||||
if experiment_id not in self._experiment_subs:
|
||||
self._experiment_subs[experiment_id] = set()
|
||||
self._experiment_subs[experiment_id].add(ws_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Subscription management (runtime)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def subscribe(self, websocket: WebSocket, experiment_id: str) -> None:
|
||||
"""Add an experiment subscription for an existing connection."""
|
||||
ws_id = id(websocket)
|
||||
if ws_id in self._connections:
|
||||
self._subscribe_experiment(ws_id, experiment_id)
|
||||
|
||||
def unsubscribe(self, websocket: WebSocket, experiment_id: str) -> None:
|
||||
"""Remove an experiment subscription for an existing connection."""
|
||||
ws_id = id(websocket)
|
||||
entry = self._connections.get(ws_id)
|
||||
if entry is None:
|
||||
return
|
||||
entry.experiment_ids.discard(experiment_id)
|
||||
exp_set = self._experiment_subs.get(experiment_id)
|
||||
if exp_set is not None:
|
||||
exp_set.discard(ws_id)
|
||||
if not exp_set:
|
||||
del self._experiment_subs[experiment_id]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Broadcasting
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def broadcast_to_experiment(
|
||||
self, experiment_id: str, message: dict[str, Any]
|
||||
) -> None:
|
||||
"""Send a message to all connections subscribed to an experiment."""
|
||||
# Store in replay buffer
|
||||
if experiment_id not in self._experiment_replay:
|
||||
self._experiment_replay[experiment_id] = deque(maxlen=self._replay_size)
|
||||
message_with_ts = {**message, "_ts": time.time()}
|
||||
self._experiment_replay[experiment_id].append(message_with_ts)
|
||||
|
||||
ws_ids = self._experiment_subs.get(experiment_id, set())
|
||||
await self._send_to_many(ws_ids, message)
|
||||
|
||||
async def broadcast_global(self, message: dict[str, Any]) -> None:
|
||||
"""Send a message to all global subscribers."""
|
||||
message_with_ts = {**message, "_ts": time.time()}
|
||||
self._global_replay.append(message_with_ts)
|
||||
await self._send_to_many(self._global_subs, message)
|
||||
|
||||
async def broadcast(self, message: dict[str, Any]) -> None:
|
||||
"""Route a message to the right subscribers based on experiment_id.
|
||||
|
||||
If the message contains an ``experiment_id`` field, it is sent to
|
||||
that experiment's subscribers *and* to global subscribers. Otherwise
|
||||
it is sent only to global subscribers.
|
||||
"""
|
||||
experiment_id = message.get("experiment_id")
|
||||
if experiment_id:
|
||||
await self.broadcast_to_experiment(str(experiment_id), message)
|
||||
await self.broadcast_global(message)
|
||||
|
||||
async def _send_to_many(
|
||||
self, ws_ids: set[int], message: dict[str, Any]
|
||||
) -> None:
|
||||
"""Send a JSON message to a set of connections, disconnecting failures."""
|
||||
dead: list[int] = []
|
||||
for ws_id in list(ws_ids):
|
||||
entry = self._connections.get(ws_id)
|
||||
if entry is None:
|
||||
dead.append(ws_id)
|
||||
continue
|
||||
try:
|
||||
await entry.websocket.send_json(message)
|
||||
except Exception:
|
||||
logger.debug("WebSocket send failed, removing connection %s", ws_id)
|
||||
dead.append(ws_id)
|
||||
# Clean up dead connections
|
||||
for ws_id in dead:
|
||||
ws_ids.discard(ws_id)
|
||||
entry = self._connections.pop(ws_id, None)
|
||||
if entry:
|
||||
self._global_subs.discard(ws_id)
|
||||
for exp_id in entry.experiment_ids:
|
||||
exp_set = self._experiment_subs.get(exp_id)
|
||||
if exp_set:
|
||||
exp_set.discard(ws_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Message replay
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def replay(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
experiment_id: str | None = None,
|
||||
since_ts: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> int:
|
||||
"""Replay recent events to a connection.
|
||||
|
||||
Args:
|
||||
websocket: Target connection.
|
||||
experiment_id: If provided, replay experiment-specific events;
|
||||
otherwise replay global events.
|
||||
since_ts: Only replay events after this Unix timestamp.
|
||||
limit: Max number of events to replay (default: all in buffer).
|
||||
|
||||
Returns:
|
||||
Number of events replayed.
|
||||
"""
|
||||
if experiment_id:
|
||||
buffer = self._experiment_replay.get(experiment_id, deque())
|
||||
else:
|
||||
buffer = self._global_replay
|
||||
|
||||
events = list(buffer)
|
||||
if since_ts is not None:
|
||||
events = [e for e in events if e.get("_ts", 0) > since_ts]
|
||||
if limit is not None:
|
||||
events = events[-limit:]
|
||||
|
||||
count = 0
|
||||
for event in events:
|
||||
try:
|
||||
await websocket.send_json(event)
|
||||
count += 1
|
||||
except Exception:
|
||||
break
|
||||
return count
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Redis pub/sub bridge
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def start_redis_listener(self, redis_client: Any) -> None:
|
||||
"""Start a background task that subscribes to Redis pub/sub and
|
||||
broadcasts received messages to connected WebSocket clients.
|
||||
|
||||
Uses async Redis pubsub if available, otherwise runs the blocking
|
||||
listener in a thread executor.
|
||||
"""
|
||||
if redis_client is None:
|
||||
return
|
||||
|
||||
self._redis_listener_task = asyncio.create_task(
|
||||
self._redis_listen_loop(redis_client)
|
||||
)
|
||||
|
||||
async def _redis_listen_loop(self, redis_client: Any) -> None:
|
||||
"""Listen for Redis pub/sub messages and broadcast them."""
|
||||
channel = "promptlooper:events"
|
||||
try:
|
||||
# Check if this is an async redis client (aioredis / redis.asyncio)
|
||||
if hasattr(redis_client, "pubsub") and asyncio.iscoroutinefunction(
|
||||
getattr(redis_client, "subscribe", None)
|
||||
):
|
||||
await self._redis_listen_async(redis_client, channel)
|
||||
else:
|
||||
await self._redis_listen_sync(redis_client, channel)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Redis listener cancelled")
|
||||
except Exception:
|
||||
logger.exception("Redis listener error")
|
||||
|
||||
async def _redis_listen_async(self, redis_client: Any, channel: str) -> None:
|
||||
"""Async Redis pub/sub listener."""
|
||||
pubsub = redis_client.pubsub()
|
||||
await pubsub.subscribe(channel)
|
||||
try:
|
||||
async for raw_message in pubsub.listen():
|
||||
if raw_message["type"] != "message":
|
||||
continue
|
||||
await self._handle_redis_message(raw_message["data"])
|
||||
finally:
|
||||
await pubsub.unsubscribe(channel)
|
||||
|
||||
async def _redis_listen_sync(self, redis_client: Any, channel: str) -> None:
|
||||
"""Sync Redis pub/sub listener (run blocking calls in executor)."""
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.subscribe(channel)
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
while True:
|
||||
raw_message = await loop.run_in_executor(
|
||||
None, lambda: pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
||||
)
|
||||
if raw_message is None:
|
||||
continue
|
||||
if raw_message["type"] != "message":
|
||||
continue
|
||||
await self._handle_redis_message(raw_message["data"])
|
||||
finally:
|
||||
pubsub.unsubscribe(channel)
|
||||
|
||||
async def _handle_redis_message(self, data: str | bytes) -> None:
|
||||
"""Parse a Redis message and broadcast it."""
|
||||
try:
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
message = json.loads(data)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
logger.warning("Invalid Redis message: %s", data)
|
||||
return
|
||||
await self.broadcast(message)
|
||||
|
||||
async def stop_redis_listener(self) -> None:
|
||||
"""Cancel the Redis listener background task."""
|
||||
if self._redis_listener_task is not None:
|
||||
self._redis_listener_task.cancel()
|
||||
try:
|
||||
await self._redis_listener_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._redis_listener_task = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Status
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def connection_count(self) -> int:
|
||||
"""Total number of active connections."""
|
||||
return len(self._connections)
|
||||
|
||||
@property
|
||||
def experiment_ids(self) -> list[str]:
|
||||
"""List of experiment IDs with active subscriptions."""
|
||||
return list(self._experiment_subs.keys())
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""Return connection statistics."""
|
||||
return {
|
||||
"total_connections": self.connection_count,
|
||||
"global_subscribers": len(self._global_subs),
|
||||
"experiment_subscriptions": {
|
||||
exp_id: len(ws_ids)
|
||||
for exp_id, ws_ids in self._experiment_subs.items()
|
||||
},
|
||||
"replay_buffer_sizes": {
|
||||
"global": len(self._global_replay),
|
||||
**{
|
||||
exp_id: len(buf)
|
||||
for exp_id, buf in self._experiment_replay.items()
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Module-level singleton for use across the application
|
||||
manager = WebSocketManager()
|
||||
Loading…
Add table
Reference in a new issue