promptlooper/backend/tests/test_ws_manager.py
John Lightner 30fd15ec7a MAESTRO: Implement WebSocket connection manager with per-experiment routing, Redis pub/sub bridge, and message replay
- WebSocketManager in backend/websocket/manager.py with per-experiment and global subscriptions
- Redis pub/sub bridge (sync + async) broadcasting events to relevant WebSocket clients
- Deque-based replay buffers with since_ts/limit filtering for reconnection support
- Runtime subscribe/unsubscribe and stats API
- Enhanced /ws endpoint in main.py with subscribe/unsubscribe/replay actions
- 35 tests in test_ws_manager.py, all passing
2026-04-07 03:34:21 -05:00

484 lines
15 KiB
Python

"""Tests for the WebSocket connection manager."""
import asyncio
import json
import time
from collections import deque
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from websocket.manager import ConnectionEntry, WebSocketManager
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ws(accept_side_effect=None) -> AsyncMock:
"""Create a mock WebSocket."""
ws = AsyncMock()
ws.accept = AsyncMock(side_effect=accept_side_effect)
ws.send_json = AsyncMock()
ws.receive_json = AsyncMock()
return ws
# ---------------------------------------------------------------------------
# Connection lifecycle
# ---------------------------------------------------------------------------
class TestConnect:
@pytest.mark.asyncio
async def test_accept_and_register_global(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
assert mgr.connection_count == 1
assert id(ws) in mgr._global_subs
ws.accept.assert_awaited_once()
@pytest.mark.asyncio
async def test_accept_with_experiment_ids(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws, experiment_ids=["exp-1", "exp-2"])
assert mgr.connection_count == 1
assert id(ws) not in mgr._global_subs
assert id(ws) in mgr._experiment_subs["exp-1"]
assert id(ws) in mgr._experiment_subs["exp-2"]
@pytest.mark.asyncio
async def test_multiple_connections(self):
mgr = WebSocketManager()
ws1 = _make_ws()
ws2 = _make_ws()
await mgr.connect(ws1)
await mgr.connect(ws2, experiment_ids=["exp-1"])
assert mgr.connection_count == 2
class TestDisconnect:
@pytest.mark.asyncio
async def test_disconnect_global(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
mgr.disconnect(ws)
assert mgr.connection_count == 0
assert id(ws) not in mgr._global_subs
@pytest.mark.asyncio
async def test_disconnect_experiment(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws, experiment_ids=["exp-1"])
mgr.disconnect(ws)
assert mgr.connection_count == 0
assert "exp-1" not in mgr._experiment_subs
@pytest.mark.asyncio
async def test_disconnect_unknown_ws_is_noop(self):
mgr = WebSocketManager()
ws = _make_ws()
mgr.disconnect(ws) # Should not raise
assert mgr.connection_count == 0
@pytest.mark.asyncio
async def test_disconnect_cleans_up_empty_experiment_sets(self):
mgr = WebSocketManager()
ws1 = _make_ws()
ws2 = _make_ws()
await mgr.connect(ws1, experiment_ids=["exp-1"])
await mgr.connect(ws2, experiment_ids=["exp-1"])
mgr.disconnect(ws1)
assert "exp-1" in mgr._experiment_subs
assert len(mgr._experiment_subs["exp-1"]) == 1
mgr.disconnect(ws2)
assert "exp-1" not in mgr._experiment_subs
# ---------------------------------------------------------------------------
# Subscribe / unsubscribe at runtime
# ---------------------------------------------------------------------------
class TestSubscriptions:
@pytest.mark.asyncio
async def test_subscribe_adds_experiment(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
mgr.subscribe(ws, "exp-42")
assert id(ws) in mgr._experiment_subs["exp-42"]
@pytest.mark.asyncio
async def test_unsubscribe_removes_experiment(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws, experiment_ids=["exp-1"])
mgr.unsubscribe(ws, "exp-1")
assert "exp-1" not in mgr._experiment_subs
entry = mgr._connections[id(ws)]
assert "exp-1" not in entry.experiment_ids
@pytest.mark.asyncio
async def test_subscribe_unknown_ws_is_noop(self):
mgr = WebSocketManager()
ws = _make_ws()
mgr.subscribe(ws, "exp-1") # Not connected
assert "exp-1" not in mgr._experiment_subs
@pytest.mark.asyncio
async def test_unsubscribe_unknown_ws_is_noop(self):
mgr = WebSocketManager()
ws = _make_ws()
mgr.unsubscribe(ws, "exp-1") # Should not raise
# ---------------------------------------------------------------------------
# Broadcasting
# ---------------------------------------------------------------------------
class TestBroadcast:
@pytest.mark.asyncio
async def test_broadcast_global(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
await mgr.broadcast_global({"type": "test"})
ws.send_json.assert_awaited_once_with({"type": "test"})
@pytest.mark.asyncio
async def test_broadcast_to_experiment(self):
mgr = WebSocketManager()
ws_exp = _make_ws()
ws_global = _make_ws()
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
await mgr.connect(ws_global)
await mgr.broadcast_to_experiment("exp-1", {"type": "test"})
ws_exp.send_json.assert_awaited_once()
ws_global.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_broadcast_routes_by_experiment_id(self):
mgr = WebSocketManager()
ws_exp = _make_ws()
ws_global = _make_ws()
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
await mgr.connect(ws_global)
msg = {"type": "run.completed", "experiment_id": "exp-1"}
await mgr.broadcast(msg)
# experiment subscriber gets the message (from broadcast_to_experiment)
assert ws_exp.send_json.await_count == 1
# global subscriber gets it (from broadcast_global)
assert ws_global.send_json.await_count == 1
@pytest.mark.asyncio
async def test_broadcast_without_experiment_id_goes_global_only(self):
mgr = WebSocketManager()
ws_exp = _make_ws()
ws_global = _make_ws()
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
await mgr.connect(ws_global)
await mgr.broadcast({"type": "system.status"})
ws_exp.send_json.assert_not_awaited()
ws_global.send_json.assert_awaited_once()
@pytest.mark.asyncio
async def test_broadcast_removes_dead_connections(self):
mgr = WebSocketManager()
ws_good = _make_ws()
ws_dead = _make_ws()
ws_dead.send_json = AsyncMock(side_effect=RuntimeError("closed"))
await mgr.connect(ws_good)
await mgr.connect(ws_dead)
await mgr.broadcast_global({"type": "test"})
assert mgr.connection_count == 1
assert id(ws_dead) not in mgr._connections
@pytest.mark.asyncio
async def test_broadcast_to_multiple_experiment_subscribers(self):
mgr = WebSocketManager()
ws1 = _make_ws()
ws2 = _make_ws()
await mgr.connect(ws1, experiment_ids=["exp-1"])
await mgr.connect(ws2, experiment_ids=["exp-1"])
await mgr.broadcast_to_experiment("exp-1", {"type": "test"})
ws1.send_json.assert_awaited_once()
ws2.send_json.assert_awaited_once()
# ---------------------------------------------------------------------------
# Replay
# ---------------------------------------------------------------------------
class TestReplay:
@pytest.mark.asyncio
async def test_replay_global_events(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
# Store some events
await mgr.broadcast_global({"type": "event1"})
await mgr.broadcast_global({"type": "event2"})
# Reset send tracking
ws.send_json.reset_mock()
count = await mgr.replay(ws)
assert count == 2
assert ws.send_json.await_count == 2
@pytest.mark.asyncio
async def test_replay_experiment_events(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws, experiment_ids=["exp-1"])
await mgr.broadcast_to_experiment("exp-1", {"type": "e1"})
await mgr.broadcast_to_experiment("exp-1", {"type": "e2"})
await mgr.broadcast_to_experiment("exp-1", {"type": "e3"})
ws.send_json.reset_mock()
count = await mgr.replay(ws, experiment_id="exp-1")
assert count == 3
@pytest.mark.asyncio
async def test_replay_with_since_ts(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
# Manually insert events with known timestamps
ts_old = time.time() - 100
ts_new = time.time()
mgr._global_replay.append({"type": "old", "_ts": ts_old})
mgr._global_replay.append({"type": "new", "_ts": ts_new})
count = await mgr.replay(ws, since_ts=ts_old + 1)
assert count == 1
@pytest.mark.asyncio
async def test_replay_with_limit(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
for i in range(10):
mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()})
count = await mgr.replay(ws, limit=3)
assert count == 3
@pytest.mark.asyncio
async def test_replay_empty_buffer(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
count = await mgr.replay(ws)
assert count == 0
@pytest.mark.asyncio
async def test_replay_respects_buffer_size(self):
mgr = WebSocketManager(replay_size=5)
ws = _make_ws()
await mgr.connect(ws)
for i in range(10):
await mgr.broadcast_global({"type": f"event{i}"})
ws.send_json.reset_mock()
count = await mgr.replay(ws)
assert count == 5
@pytest.mark.asyncio
async def test_replay_stops_on_send_failure(self):
mgr = WebSocketManager()
ws = _make_ws()
call_count = 0
async def fail_after_one(msg):
nonlocal call_count
call_count += 1
if call_count > 1:
raise RuntimeError("closed")
ws.send_json = AsyncMock(side_effect=fail_after_one)
await mgr.connect(ws)
for i in range(5):
mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()})
count = await mgr.replay(ws)
assert count == 1
# ---------------------------------------------------------------------------
# Redis pub/sub bridge
# ---------------------------------------------------------------------------
class TestRedisListener:
@pytest.mark.asyncio
async def test_handle_redis_message_json(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
msg = json.dumps({"type": "run.completed", "experiment_id": "exp-1"})
await mgr._handle_redis_message(msg)
# Message should have been broadcast
assert ws.send_json.await_count >= 1
@pytest.mark.asyncio
async def test_handle_redis_message_bytes(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
msg = json.dumps({"type": "test"}).encode("utf-8")
await mgr._handle_redis_message(msg)
assert ws.send_json.await_count >= 1
@pytest.mark.asyncio
async def test_handle_redis_message_invalid_json(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
# Should not raise, just log warning
await mgr._handle_redis_message("not-json{{{")
ws.send_json.assert_not_awaited()
@pytest.mark.asyncio
async def test_start_redis_listener_none_is_noop(self):
mgr = WebSocketManager()
await mgr.start_redis_listener(None)
assert mgr._redis_listener_task is None
@pytest.mark.asyncio
async def test_stop_redis_listener_when_none(self):
mgr = WebSocketManager()
await mgr.stop_redis_listener() # Should not raise
@pytest.mark.asyncio
async def test_stop_cancels_task(self):
mgr = WebSocketManager()
# Create a long-running dummy task
async def dummy():
await asyncio.sleep(999)
mgr._redis_listener_task = asyncio.create_task(dummy())
await mgr.stop_redis_listener()
assert mgr._redis_listener_task is None
@pytest.mark.asyncio
async def test_sync_redis_listener_processes_messages(self):
"""Test the sync Redis pubsub path with a mock."""
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws)
# Mock a sync Redis client
mock_pubsub = MagicMock()
call_count = 0
def get_message_side_effect(ignore_subscribe_messages=True, timeout=1.0):
nonlocal call_count
call_count += 1
if call_count == 1:
return {
"type": "message",
"data": json.dumps({"type": "test_event"}),
}
# After first message, raise to stop the loop
raise asyncio.CancelledError()
mock_pubsub.get_message = MagicMock(side_effect=get_message_side_effect)
mock_pubsub.subscribe = MagicMock()
mock_pubsub.unsubscribe = MagicMock()
mock_redis = MagicMock()
mock_redis.pubsub.return_value = mock_pubsub
# Run the sync listener — it will process one message then get cancelled
with pytest.raises(asyncio.CancelledError):
await mgr._redis_listen_sync(mock_redis, "promptlooper:events")
mock_pubsub.subscribe.assert_called_once_with("promptlooper:events")
assert ws.send_json.await_count >= 1
# ---------------------------------------------------------------------------
# Stats
# ---------------------------------------------------------------------------
class TestStats:
@pytest.mark.asyncio
async def test_stats_empty(self):
mgr = WebSocketManager()
stats = mgr.stats()
assert stats["total_connections"] == 0
assert stats["global_subscribers"] == 0
assert stats["experiment_subscriptions"] == {}
@pytest.mark.asyncio
async def test_stats_with_connections(self):
mgr = WebSocketManager()
ws1 = _make_ws()
ws2 = _make_ws()
await mgr.connect(ws1)
await mgr.connect(ws2, experiment_ids=["exp-1"])
stats = mgr.stats()
assert stats["total_connections"] == 2
assert stats["global_subscribers"] == 1
assert stats["experiment_subscriptions"]["exp-1"] == 1
@pytest.mark.asyncio
async def test_experiment_ids_property(self):
mgr = WebSocketManager()
ws = _make_ws()
await mgr.connect(ws, experiment_ids=["exp-a", "exp-b"])
ids = mgr.experiment_ids
assert set(ids) == {"exp-a", "exp-b"}
@pytest.mark.asyncio
async def test_connection_count_property(self):
mgr = WebSocketManager()
assert mgr.connection_count == 0
ws = _make_ws()
await mgr.connect(ws)
assert mgr.connection_count == 1
mgr.disconnect(ws)
assert mgr.connection_count == 0