"""Tests for the WebSocket connection manager.""" import asyncio import json import time from collections import deque from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest from websocket.manager import ConnectionEntry, WebSocketManager # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_ws(accept_side_effect=None) -> AsyncMock: """Create a mock WebSocket.""" ws = AsyncMock() ws.accept = AsyncMock(side_effect=accept_side_effect) ws.send_json = AsyncMock() ws.receive_json = AsyncMock() return ws # --------------------------------------------------------------------------- # Connection lifecycle # --------------------------------------------------------------------------- class TestConnect: @pytest.mark.asyncio async def test_accept_and_register_global(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws) assert mgr.connection_count == 1 assert id(ws) in mgr._global_subs ws.accept.assert_awaited_once() @pytest.mark.asyncio async def test_accept_with_experiment_ids(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws, experiment_ids=["exp-1", "exp-2"]) assert mgr.connection_count == 1 assert id(ws) not in mgr._global_subs assert id(ws) in mgr._experiment_subs["exp-1"] assert id(ws) in mgr._experiment_subs["exp-2"] @pytest.mark.asyncio async def test_multiple_connections(self): mgr = WebSocketManager() ws1 = _make_ws() ws2 = _make_ws() await mgr.connect(ws1) await mgr.connect(ws2, experiment_ids=["exp-1"]) assert mgr.connection_count == 2 class TestDisconnect: @pytest.mark.asyncio async def test_disconnect_global(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws) mgr.disconnect(ws) assert mgr.connection_count == 0 assert id(ws) not in mgr._global_subs @pytest.mark.asyncio async def test_disconnect_experiment(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws, experiment_ids=["exp-1"]) mgr.disconnect(ws) assert mgr.connection_count == 0 assert "exp-1" not in mgr._experiment_subs @pytest.mark.asyncio async def test_disconnect_unknown_ws_is_noop(self): mgr = WebSocketManager() ws = _make_ws() mgr.disconnect(ws) # Should not raise assert mgr.connection_count == 0 @pytest.mark.asyncio async def test_disconnect_cleans_up_empty_experiment_sets(self): mgr = WebSocketManager() ws1 = _make_ws() ws2 = _make_ws() await mgr.connect(ws1, experiment_ids=["exp-1"]) await mgr.connect(ws2, experiment_ids=["exp-1"]) mgr.disconnect(ws1) assert "exp-1" in mgr._experiment_subs assert len(mgr._experiment_subs["exp-1"]) == 1 mgr.disconnect(ws2) assert "exp-1" not in mgr._experiment_subs # --------------------------------------------------------------------------- # Subscribe / unsubscribe at runtime # --------------------------------------------------------------------------- class TestSubscriptions: @pytest.mark.asyncio async def test_subscribe_adds_experiment(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws) mgr.subscribe(ws, "exp-42") assert id(ws) in mgr._experiment_subs["exp-42"] @pytest.mark.asyncio async def test_unsubscribe_removes_experiment(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws, experiment_ids=["exp-1"]) mgr.unsubscribe(ws, "exp-1") assert "exp-1" not in mgr._experiment_subs entry = mgr._connections[id(ws)] assert "exp-1" not in entry.experiment_ids @pytest.mark.asyncio async def test_subscribe_unknown_ws_is_noop(self): mgr = WebSocketManager() ws = _make_ws() mgr.subscribe(ws, "exp-1") # Not connected assert "exp-1" not in mgr._experiment_subs @pytest.mark.asyncio async def test_unsubscribe_unknown_ws_is_noop(self): mgr = WebSocketManager() ws = _make_ws() mgr.unsubscribe(ws, "exp-1") # Should not raise # --------------------------------------------------------------------------- # Broadcasting # --------------------------------------------------------------------------- class TestBroadcast: @pytest.mark.asyncio async def test_broadcast_global(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws) await mgr.broadcast_global({"type": "test"}) ws.send_json.assert_awaited_once_with({"type": "test"}) @pytest.mark.asyncio async def test_broadcast_to_experiment(self): mgr = WebSocketManager() ws_exp = _make_ws() ws_global = _make_ws() await mgr.connect(ws_exp, experiment_ids=["exp-1"]) await mgr.connect(ws_global) await mgr.broadcast_to_experiment("exp-1", {"type": "test"}) ws_exp.send_json.assert_awaited_once() ws_global.send_json.assert_not_awaited() @pytest.mark.asyncio async def test_broadcast_routes_by_experiment_id(self): mgr = WebSocketManager() ws_exp = _make_ws() ws_global = _make_ws() await mgr.connect(ws_exp, experiment_ids=["exp-1"]) await mgr.connect(ws_global) msg = {"type": "run.completed", "experiment_id": "exp-1"} await mgr.broadcast(msg) # experiment subscriber gets the message (from broadcast_to_experiment) assert ws_exp.send_json.await_count == 1 # global subscriber gets it (from broadcast_global) assert ws_global.send_json.await_count == 1 @pytest.mark.asyncio async def test_broadcast_without_experiment_id_goes_global_only(self): mgr = WebSocketManager() ws_exp = _make_ws() ws_global = _make_ws() await mgr.connect(ws_exp, experiment_ids=["exp-1"]) await mgr.connect(ws_global) await mgr.broadcast({"type": "system.status"}) ws_exp.send_json.assert_not_awaited() ws_global.send_json.assert_awaited_once() @pytest.mark.asyncio async def test_broadcast_removes_dead_connections(self): mgr = WebSocketManager() ws_good = _make_ws() ws_dead = _make_ws() ws_dead.send_json = AsyncMock(side_effect=RuntimeError("closed")) await mgr.connect(ws_good) await mgr.connect(ws_dead) await mgr.broadcast_global({"type": "test"}) assert mgr.connection_count == 1 assert id(ws_dead) not in mgr._connections @pytest.mark.asyncio async def test_broadcast_to_multiple_experiment_subscribers(self): mgr = WebSocketManager() ws1 = _make_ws() ws2 = _make_ws() await mgr.connect(ws1, experiment_ids=["exp-1"]) await mgr.connect(ws2, experiment_ids=["exp-1"]) await mgr.broadcast_to_experiment("exp-1", {"type": "test"}) ws1.send_json.assert_awaited_once() ws2.send_json.assert_awaited_once() # --------------------------------------------------------------------------- # Replay # --------------------------------------------------------------------------- class TestReplay: @pytest.mark.asyncio async def test_replay_global_events(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws) # Store some events await mgr.broadcast_global({"type": "event1"}) await mgr.broadcast_global({"type": "event2"}) # Reset send tracking ws.send_json.reset_mock() count = await mgr.replay(ws) assert count == 2 assert ws.send_json.await_count == 2 @pytest.mark.asyncio async def test_replay_experiment_events(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws, experiment_ids=["exp-1"]) await mgr.broadcast_to_experiment("exp-1", {"type": "e1"}) await mgr.broadcast_to_experiment("exp-1", {"type": "e2"}) await mgr.broadcast_to_experiment("exp-1", {"type": "e3"}) ws.send_json.reset_mock() count = await mgr.replay(ws, experiment_id="exp-1") assert count == 3 @pytest.mark.asyncio async def test_replay_with_since_ts(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws) # Manually insert events with known timestamps ts_old = time.time() - 100 ts_new = time.time() mgr._global_replay.append({"type": "old", "_ts": ts_old}) mgr._global_replay.append({"type": "new", "_ts": ts_new}) count = await mgr.replay(ws, since_ts=ts_old + 1) assert count == 1 @pytest.mark.asyncio async def test_replay_with_limit(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws) for i in range(10): mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()}) count = await mgr.replay(ws, limit=3) assert count == 3 @pytest.mark.asyncio async def test_replay_empty_buffer(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws) count = await mgr.replay(ws) assert count == 0 @pytest.mark.asyncio async def test_replay_respects_buffer_size(self): mgr = WebSocketManager(replay_size=5) ws = _make_ws() await mgr.connect(ws) for i in range(10): await mgr.broadcast_global({"type": f"event{i}"}) ws.send_json.reset_mock() count = await mgr.replay(ws) assert count == 5 @pytest.mark.asyncio async def test_replay_stops_on_send_failure(self): mgr = WebSocketManager() ws = _make_ws() call_count = 0 async def fail_after_one(msg): nonlocal call_count call_count += 1 if call_count > 1: raise RuntimeError("closed") ws.send_json = AsyncMock(side_effect=fail_after_one) await mgr.connect(ws) for i in range(5): mgr._global_replay.append({"type": f"event{i}", "_ts": time.time()}) count = await mgr.replay(ws) assert count == 1 # --------------------------------------------------------------------------- # Redis pub/sub bridge # --------------------------------------------------------------------------- class TestRedisListener: @pytest.mark.asyncio async def test_handle_redis_message_json(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws) msg = json.dumps({"type": "run.completed", "experiment_id": "exp-1"}) await mgr._handle_redis_message(msg) # Message should have been broadcast assert ws.send_json.await_count >= 1 @pytest.mark.asyncio async def test_handle_redis_message_bytes(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws) msg = json.dumps({"type": "test"}).encode("utf-8") await mgr._handle_redis_message(msg) assert ws.send_json.await_count >= 1 @pytest.mark.asyncio async def test_handle_redis_message_invalid_json(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws) # Should not raise, just log warning await mgr._handle_redis_message("not-json{{{") ws.send_json.assert_not_awaited() @pytest.mark.asyncio async def test_start_redis_listener_none_is_noop(self): mgr = WebSocketManager() await mgr.start_redis_listener(None) assert mgr._redis_listener_task is None @pytest.mark.asyncio async def test_stop_redis_listener_when_none(self): mgr = WebSocketManager() await mgr.stop_redis_listener() # Should not raise @pytest.mark.asyncio async def test_stop_cancels_task(self): mgr = WebSocketManager() # Create a long-running dummy task async def dummy(): await asyncio.sleep(999) mgr._redis_listener_task = asyncio.create_task(dummy()) await mgr.stop_redis_listener() assert mgr._redis_listener_task is None @pytest.mark.asyncio async def test_sync_redis_listener_processes_messages(self): """Test the sync Redis pubsub path with a mock.""" mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws) # Mock a sync Redis client mock_pubsub = MagicMock() call_count = 0 def get_message_side_effect(ignore_subscribe_messages=True, timeout=1.0): nonlocal call_count call_count += 1 if call_count == 1: return { "type": "message", "data": json.dumps({"type": "test_event"}), } # After first message, raise to stop the loop raise asyncio.CancelledError() mock_pubsub.get_message = MagicMock(side_effect=get_message_side_effect) mock_pubsub.subscribe = MagicMock() mock_pubsub.unsubscribe = MagicMock() mock_redis = MagicMock() mock_redis.pubsub.return_value = mock_pubsub # Run the sync listener — it will process one message then get cancelled with pytest.raises(asyncio.CancelledError): await mgr._redis_listen_sync(mock_redis, "promptlooper:events") mock_pubsub.subscribe.assert_called_once_with("promptlooper:events") assert ws.send_json.await_count >= 1 # --------------------------------------------------------------------------- # Stats # --------------------------------------------------------------------------- class TestStats: @pytest.mark.asyncio async def test_stats_empty(self): mgr = WebSocketManager() stats = mgr.stats() assert stats["total_connections"] == 0 assert stats["global_subscribers"] == 0 assert stats["experiment_subscriptions"] == {} @pytest.mark.asyncio async def test_stats_with_connections(self): mgr = WebSocketManager() ws1 = _make_ws() ws2 = _make_ws() await mgr.connect(ws1) await mgr.connect(ws2, experiment_ids=["exp-1"]) stats = mgr.stats() assert stats["total_connections"] == 2 assert stats["global_subscribers"] == 1 assert stats["experiment_subscriptions"]["exp-1"] == 1 @pytest.mark.asyncio async def test_experiment_ids_property(self): mgr = WebSocketManager() ws = _make_ws() await mgr.connect(ws, experiment_ids=["exp-a", "exp-b"]) ids = mgr.experiment_ids assert set(ids) == {"exp-a", "exp-b"} @pytest.mark.asyncio async def test_connection_count_property(self): mgr = WebSocketManager() assert mgr.connection_count == 0 ws = _make_ws() await mgr.connect(ws) assert mgr.connection_count == 1 mgr.disconnect(ws) assert mgr.connection_count == 0