- WebSocketManager in backend/websocket/manager.py with per-experiment and global subscriptions - Redis pub/sub bridge (sync + async) broadcasting events to relevant WebSocket clients - Deque-based replay buffers with since_ts/limit filtering for reconnection support - Runtime subscribe/unsubscribe and stats API - Enhanced /ws endpoint in main.py with subscribe/unsubscribe/replay actions - 35 tests in test_ws_manager.py, all passing
128 lines
3.7 KiB
Python
128 lines
3.7 KiB
Python
"""Tests for backend/main.py — FastAPI application."""
|
|
|
|
import os
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _isolate_settings(tmp_path):
|
|
"""Ensure tests use a temp SQLite DB and no Redis."""
|
|
env = {
|
|
"DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}",
|
|
"REDIS_URL": "",
|
|
"DATA_DIR": str(tmp_path),
|
|
}
|
|
with patch.dict(os.environ, env, clear=False):
|
|
# Reload settings so it picks up test env
|
|
import config
|
|
new_settings = config.Settings(_env_file=None)
|
|
config.settings = new_settings
|
|
|
|
# Patch main's reference too
|
|
import main
|
|
main.settings = new_settings
|
|
main._init_db()
|
|
main._init_redis()
|
|
|
|
# Create tables
|
|
from models import Base
|
|
Base.metadata.create_all(bind=main.engine)
|
|
|
|
yield
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
from main import app
|
|
return TestClient(app)
|
|
|
|
|
|
class TestHealthEndpoint:
|
|
def test_health_returns_ok(self, client):
|
|
resp = client.get("/health")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["status"] == "ok"
|
|
assert data["database"] is True
|
|
assert data["redis"] is True # in-process mode counts as ok
|
|
|
|
def test_health_response_schema(self, client):
|
|
resp = client.get("/health")
|
|
data = resp.json()
|
|
assert set(data.keys()) == {"status", "database", "redis"}
|
|
|
|
|
|
class TestCORSMiddleware:
|
|
def test_cors_headers_present(self, client):
|
|
resp = client.options(
|
|
"/health",
|
|
headers={
|
|
"Origin": "http://localhost:3000",
|
|
"Access-Control-Request-Method": "GET",
|
|
},
|
|
)
|
|
assert "access-control-allow-origin" in resp.headers
|
|
|
|
|
|
class TestWebSocket:
|
|
def test_websocket_connect_and_echo(self, client):
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_json({"type": "ping"})
|
|
data = ws.receive_json()
|
|
assert data["type"] == "ack"
|
|
assert data["data"]["type"] == "ping"
|
|
|
|
def test_websocket_disconnect_cleanup(self, client):
|
|
from main import ws_manager
|
|
initial_count = ws_manager.connection_count
|
|
with client.websocket_connect("/ws") as ws:
|
|
assert ws_manager.connection_count == initial_count + 1
|
|
# After disconnect, connection should be removed
|
|
assert ws_manager.connection_count == initial_count
|
|
|
|
|
|
class TestRouterMounting:
|
|
def test_openapi_schema_loads(self, client):
|
|
resp = client.get("/openapi.json")
|
|
assert resp.status_code == 200
|
|
schema = resp.json()
|
|
assert schema["info"]["title"] == "PromptLooper"
|
|
|
|
def test_unknown_route_returns_404(self, client):
|
|
resp = client.get("/api/nonexistent")
|
|
assert resp.status_code == 404
|
|
|
|
|
|
class TestConnectionManager:
|
|
def test_broadcast_on_empty_manager(self):
|
|
"""WebSocketManager.broadcast_global on empty manager should not raise."""
|
|
from websocket.manager import WebSocketManager
|
|
manager = WebSocketManager()
|
|
import asyncio
|
|
asyncio.get_event_loop().run_until_complete(
|
|
manager.broadcast_global({"test": True})
|
|
)
|
|
assert manager.connection_count == 0
|
|
|
|
|
|
class TestGetDb:
|
|
def test_get_db_yields_session(self):
|
|
from main import get_db
|
|
gen = get_db()
|
|
session = next(gen)
|
|
assert session is not None
|
|
# Clean up
|
|
try:
|
|
next(gen)
|
|
except StopIteration:
|
|
pass
|
|
|
|
|
|
class TestGetRedis:
|
|
def test_get_redis_returns_none_in_process_mode(self):
|
|
from main import get_redis
|
|
# In test setup, Redis is not configured
|
|
assert get_redis() is None
|