- WebSocketManager in backend/websocket/manager.py with per-experiment and global subscriptions - Redis pub/sub bridge (sync + async) broadcasting events to relevant WebSocket clients - Deque-based replay buffers with since_ts/limit filtering for reconnection support - Runtime subscribe/unsubscribe and stats API - Enhanced /ws endpoint in main.py with subscribe/unsubscribe/replay actions - 35 tests in test_ws_manager.py, all passing
484 lines
15 KiB
Python
484 lines
15 KiB
Python
"""Tests for the WebSocket connection manager."""
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
from collections import deque
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from websocket.manager import ConnectionEntry, WebSocketManager
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_ws(accept_side_effect=None) -> AsyncMock:
|
|
"""Create a mock WebSocket."""
|
|
ws = AsyncMock()
|
|
ws.accept = AsyncMock(side_effect=accept_side_effect)
|
|
ws.send_json = AsyncMock()
|
|
ws.receive_json = AsyncMock()
|
|
return ws
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Connection lifecycle
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConnect:
|
|
@pytest.mark.asyncio
|
|
async def test_accept_and_register_global(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
|
|
assert mgr.connection_count == 1
|
|
assert id(ws) in mgr._global_subs
|
|
ws.accept.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_accept_with_experiment_ids(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws, experiment_ids=["exp-1", "exp-2"])
|
|
|
|
assert mgr.connection_count == 1
|
|
assert id(ws) not in mgr._global_subs
|
|
assert id(ws) in mgr._experiment_subs["exp-1"]
|
|
assert id(ws) in mgr._experiment_subs["exp-2"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_connections(self):
|
|
mgr = WebSocketManager()
|
|
ws1 = _make_ws()
|
|
ws2 = _make_ws()
|
|
await mgr.connect(ws1)
|
|
await mgr.connect(ws2, experiment_ids=["exp-1"])
|
|
|
|
assert mgr.connection_count == 2
|
|
|
|
|
|
class TestDisconnect:
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_global(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
mgr.disconnect(ws)
|
|
|
|
assert mgr.connection_count == 0
|
|
assert id(ws) not in mgr._global_subs
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_experiment(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws, experiment_ids=["exp-1"])
|
|
mgr.disconnect(ws)
|
|
|
|
assert mgr.connection_count == 0
|
|
assert "exp-1" not in mgr._experiment_subs
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_unknown_ws_is_noop(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
mgr.disconnect(ws) # Should not raise
|
|
assert mgr.connection_count == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_cleans_up_empty_experiment_sets(self):
|
|
mgr = WebSocketManager()
|
|
ws1 = _make_ws()
|
|
ws2 = _make_ws()
|
|
await mgr.connect(ws1, experiment_ids=["exp-1"])
|
|
await mgr.connect(ws2, experiment_ids=["exp-1"])
|
|
|
|
mgr.disconnect(ws1)
|
|
assert "exp-1" in mgr._experiment_subs
|
|
assert len(mgr._experiment_subs["exp-1"]) == 1
|
|
|
|
mgr.disconnect(ws2)
|
|
assert "exp-1" not in mgr._experiment_subs
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Subscribe / unsubscribe at runtime
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSubscriptions:
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_adds_experiment(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
mgr.subscribe(ws, "exp-42")
|
|
|
|
assert id(ws) in mgr._experiment_subs["exp-42"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unsubscribe_removes_experiment(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws, experiment_ids=["exp-1"])
|
|
mgr.unsubscribe(ws, "exp-1")
|
|
|
|
assert "exp-1" not in mgr._experiment_subs
|
|
entry = mgr._connections[id(ws)]
|
|
assert "exp-1" not in entry.experiment_ids
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_unknown_ws_is_noop(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
mgr.subscribe(ws, "exp-1") # Not connected
|
|
assert "exp-1" not in mgr._experiment_subs
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unsubscribe_unknown_ws_is_noop(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
mgr.unsubscribe(ws, "exp-1") # Should not raise
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Broadcasting
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBroadcast:
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_global(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
|
|
await mgr.broadcast_global({"type": "test"})
|
|
ws.send_json.assert_awaited_once_with({"type": "test"})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_to_experiment(self):
|
|
mgr = WebSocketManager()
|
|
ws_exp = _make_ws()
|
|
ws_global = _make_ws()
|
|
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
|
|
await mgr.connect(ws_global)
|
|
|
|
await mgr.broadcast_to_experiment("exp-1", {"type": "test"})
|
|
ws_exp.send_json.assert_awaited_once()
|
|
ws_global.send_json.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_routes_by_experiment_id(self):
|
|
mgr = WebSocketManager()
|
|
ws_exp = _make_ws()
|
|
ws_global = _make_ws()
|
|
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
|
|
await mgr.connect(ws_global)
|
|
|
|
msg = {"type": "run.completed", "experiment_id": "exp-1"}
|
|
await mgr.broadcast(msg)
|
|
|
|
# experiment subscriber gets the message (from broadcast_to_experiment)
|
|
assert ws_exp.send_json.await_count == 1
|
|
# global subscriber gets it (from broadcast_global)
|
|
assert ws_global.send_json.await_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_without_experiment_id_goes_global_only(self):
|
|
mgr = WebSocketManager()
|
|
ws_exp = _make_ws()
|
|
ws_global = _make_ws()
|
|
await mgr.connect(ws_exp, experiment_ids=["exp-1"])
|
|
await mgr.connect(ws_global)
|
|
|
|
await mgr.broadcast({"type": "system.status"})
|
|
ws_exp.send_json.assert_not_awaited()
|
|
ws_global.send_json.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_removes_dead_connections(self):
|
|
mgr = WebSocketManager()
|
|
ws_good = _make_ws()
|
|
ws_dead = _make_ws()
|
|
ws_dead.send_json = AsyncMock(side_effect=RuntimeError("closed"))
|
|
await mgr.connect(ws_good)
|
|
await mgr.connect(ws_dead)
|
|
|
|
await mgr.broadcast_global({"type": "test"})
|
|
|
|
assert mgr.connection_count == 1
|
|
assert id(ws_dead) not in mgr._connections
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_to_multiple_experiment_subscribers(self):
|
|
mgr = WebSocketManager()
|
|
ws1 = _make_ws()
|
|
ws2 = _make_ws()
|
|
await mgr.connect(ws1, experiment_ids=["exp-1"])
|
|
await mgr.connect(ws2, experiment_ids=["exp-1"])
|
|
|
|
await mgr.broadcast_to_experiment("exp-1", {"type": "test"})
|
|
ws1.send_json.assert_awaited_once()
|
|
ws2.send_json.assert_awaited_once()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Replay
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestReplay:
|
|
@pytest.mark.asyncio
|
|
async def test_replay_global_events(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
|
|
# Store some events
|
|
await mgr.broadcast_global({"type": "event1"})
|
|
await mgr.broadcast_global({"type": "event2"})
|
|
|
|
# Reset send tracking
|
|
ws.send_json.reset_mock()
|
|
|
|
count = await mgr.replay(ws)
|
|
assert count == 2
|
|
assert ws.send_json.await_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_replay_experiment_events(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws, experiment_ids=["exp-1"])
|
|
|
|
await mgr.broadcast_to_experiment("exp-1", {"type": "e1"})
|
|
await mgr.broadcast_to_experiment("exp-1", {"type": "e2"})
|
|
await mgr.broadcast_to_experiment("exp-1", {"type": "e3"})
|
|
|
|
ws.send_json.reset_mock()
|
|
|
|
count = await mgr.replay(ws, experiment_id="exp-1")
|
|
assert count == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_replay_with_since_ts(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
|
|
# Manually insert events with known timestamps
|
|
ts_old = time.time() - 100
|
|
ts_new = time.time()
|
|
mgr._global_replay.append({"type": "old", "_ts": ts_old})
|
|
mgr._global_replay.append({"type": "new", "_ts": ts_new})
|
|
|
|
count = await mgr.replay(ws, since_ts=ts_old + 1)
|
|
assert count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_replay_with_limit(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
|
|
for i in range(10):
|
|
mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()})
|
|
|
|
count = await mgr.replay(ws, limit=3)
|
|
assert count == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_replay_empty_buffer(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
|
|
count = await mgr.replay(ws)
|
|
assert count == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_replay_respects_buffer_size(self):
|
|
mgr = WebSocketManager(replay_size=5)
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
|
|
for i in range(10):
|
|
await mgr.broadcast_global({"type": f"event{i}"})
|
|
|
|
ws.send_json.reset_mock()
|
|
count = await mgr.replay(ws)
|
|
assert count == 5
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_replay_stops_on_send_failure(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
call_count = 0
|
|
|
|
async def fail_after_one(msg):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count > 1:
|
|
raise RuntimeError("closed")
|
|
|
|
ws.send_json = AsyncMock(side_effect=fail_after_one)
|
|
await mgr.connect(ws)
|
|
|
|
for i in range(5):
|
|
mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()})
|
|
|
|
count = await mgr.replay(ws)
|
|
assert count == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Redis pub/sub bridge
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRedisListener:
|
|
@pytest.mark.asyncio
|
|
async def test_handle_redis_message_json(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
|
|
msg = json.dumps({"type": "run.completed", "experiment_id": "exp-1"})
|
|
await mgr._handle_redis_message(msg)
|
|
|
|
# Message should have been broadcast
|
|
assert ws.send_json.await_count >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_redis_message_bytes(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
|
|
msg = json.dumps({"type": "test"}).encode("utf-8")
|
|
await mgr._handle_redis_message(msg)
|
|
assert ws.send_json.await_count >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_redis_message_invalid_json(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
|
|
# Should not raise, just log warning
|
|
await mgr._handle_redis_message("not-json{{{")
|
|
ws.send_json.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_redis_listener_none_is_noop(self):
|
|
mgr = WebSocketManager()
|
|
await mgr.start_redis_listener(None)
|
|
assert mgr._redis_listener_task is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_redis_listener_when_none(self):
|
|
mgr = WebSocketManager()
|
|
await mgr.stop_redis_listener() # Should not raise
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_cancels_task(self):
|
|
mgr = WebSocketManager()
|
|
|
|
# Create a long-running dummy task
|
|
async def dummy():
|
|
await asyncio.sleep(999)
|
|
|
|
mgr._redis_listener_task = asyncio.create_task(dummy())
|
|
await mgr.stop_redis_listener()
|
|
assert mgr._redis_listener_task is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sync_redis_listener_processes_messages(self):
|
|
"""Test the sync Redis pubsub path with a mock."""
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
|
|
# Mock a sync Redis client
|
|
mock_pubsub = MagicMock()
|
|
call_count = 0
|
|
|
|
def get_message_side_effect(ignore_subscribe_messages=True, timeout=1.0):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
return {
|
|
"type": "message",
|
|
"data": json.dumps({"type": "test_event"}),
|
|
}
|
|
# After first message, raise to stop the loop
|
|
raise asyncio.CancelledError()
|
|
|
|
mock_pubsub.get_message = MagicMock(side_effect=get_message_side_effect)
|
|
mock_pubsub.subscribe = MagicMock()
|
|
mock_pubsub.unsubscribe = MagicMock()
|
|
|
|
mock_redis = MagicMock()
|
|
mock_redis.pubsub.return_value = mock_pubsub
|
|
|
|
# Run the sync listener — it will process one message then get cancelled
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await mgr._redis_listen_sync(mock_redis, "promptlooper:events")
|
|
|
|
mock_pubsub.subscribe.assert_called_once_with("promptlooper:events")
|
|
assert ws.send_json.await_count >= 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Stats
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestStats:
|
|
@pytest.mark.asyncio
|
|
async def test_stats_empty(self):
|
|
mgr = WebSocketManager()
|
|
stats = mgr.stats()
|
|
assert stats["total_connections"] == 0
|
|
assert stats["global_subscribers"] == 0
|
|
assert stats["experiment_subscriptions"] == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stats_with_connections(self):
|
|
mgr = WebSocketManager()
|
|
ws1 = _make_ws()
|
|
ws2 = _make_ws()
|
|
await mgr.connect(ws1)
|
|
await mgr.connect(ws2, experiment_ids=["exp-1"])
|
|
|
|
stats = mgr.stats()
|
|
assert stats["total_connections"] == 2
|
|
assert stats["global_subscribers"] == 1
|
|
assert stats["experiment_subscriptions"]["exp-1"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_experiment_ids_property(self):
|
|
mgr = WebSocketManager()
|
|
ws = _make_ws()
|
|
await mgr.connect(ws, experiment_ids=["exp-a", "exp-b"])
|
|
|
|
ids = mgr.experiment_ids
|
|
assert set(ids) == {"exp-a", "exp-b"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connection_count_property(self):
|
|
mgr = WebSocketManager()
|
|
assert mgr.connection_count == 0
|
|
ws = _make_ws()
|
|
await mgr.connect(ws)
|
|
assert mgr.connection_count == 1
|
|
mgr.disconnect(ws)
|
|
assert mgr.connection_count == 0
|