media-rip/backend/tests/test_sse.py

326 lines
11 KiB
Python

"""Tests for the SSE event streaming endpoint and generator.
Covers: init replay, live job_update events, disconnect cleanup,
keepalive ping, job_removed broadcasting, and session isolation.
"""
from __future__ import annotations
import asyncio
import contextlib
import json
import uuid
from datetime import datetime, timezone
from unittest.mock import patch
from app.core.database import create_job, create_session, get_active_jobs_by_session
from app.models.job import Job, ProgressEvent
from app.routers.sse import event_generator
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_job(session_id: str, *, status: str = "queued", **overrides) -> Job:
"""Build a Job with sane defaults."""
return Job(
id=overrides.get("id", str(uuid.uuid4())),
session_id=session_id,
url=overrides.get("url", "https://example.com/video"),
status=status,
created_at=overrides.get("created_at", datetime.now(timezone.utc).isoformat()),
)
async def _collect_events(gen, *, count: int = 1, timeout: float = 5.0):
"""Consume *count* events from an async generator with a safety timeout."""
events = []
async for event in gen:
events.append(event)
if len(events) >= count:
break
return events
# ---------------------------------------------------------------------------
# Database query tests
# ---------------------------------------------------------------------------
class TestGetActiveJobsBySession:
"""Verify that get_active_jobs_by_session filters terminal statuses."""
async def test_returns_only_non_terminal(self, db):
sid = str(uuid.uuid4())
await create_session(db, sid)
queued_job = _make_job(sid, status="queued")
downloading_job = _make_job(sid, status="downloading")
completed_job = _make_job(sid, status="completed")
failed_job = _make_job(sid, status="failed")
for j in (queued_job, downloading_job, completed_job, failed_job):
await create_job(db, j)
active = await get_active_jobs_by_session(db, sid)
active_ids = {j.id for j in active}
assert queued_job.id in active_ids
assert downloading_job.id in active_ids
assert completed_job.id not in active_ids
assert failed_job.id not in active_ids
async def test_empty_when_all_terminal(self, db):
sid = str(uuid.uuid4())
await create_session(db, sid)
for status in ("completed", "failed", "expired"):
await create_job(db, _make_job(sid, status=status))
active = await get_active_jobs_by_session(db, sid)
assert active == []
# ---------------------------------------------------------------------------
# Generator-level tests (direct, no HTTP)
# ---------------------------------------------------------------------------
class TestEventGeneratorInit:
"""Init event replays current non-terminal jobs."""
async def test_init_event_with_jobs(self, db, broker):
sid = str(uuid.uuid4())
await create_session(db, sid)
job = _make_job(sid, status="queued")
await create_job(db, job)
gen = event_generator(sid, broker, db)
events = await _collect_events(gen, count=1)
assert len(events) == 1
assert events[0]["event"] == "init"
payload = json.loads(events[0]["data"])
assert len(payload["jobs"]) == 1
assert payload["jobs"][0]["id"] == job.id
# Cleanup — close generator to trigger finally block
await gen.aclose()
async def test_init_event_empty_session(self, db, broker):
sid = str(uuid.uuid4())
await create_session(db, sid)
gen = event_generator(sid, broker, db)
events = await _collect_events(gen, count=1)
payload = json.loads(events[0]["data"])
assert payload["jobs"] == []
await gen.aclose()
class TestEventGeneratorLiveStream:
"""Live job_update and dict events arrive correctly."""
async def test_progress_event_delivery(self, db, broker):
sid = str(uuid.uuid4())
await create_session(db, sid)
gen = event_generator(sid, broker, db)
# Consume init
await _collect_events(gen, count=1)
# Publish a ProgressEvent to the broker
progress = ProgressEvent(
job_id="job-1", status="downloading", percent=42.0,
)
# Use _publish_sync since we're on the event loop already
broker._publish_sync(sid, progress)
events = await _collect_events(gen, count=1)
assert events[0]["event"] == "job_update"
data = json.loads(events[0]["data"])
assert data["job_id"] == "job-1"
assert data["percent"] == 42.0
await gen.aclose()
async def test_dict_event_delivery(self, db, broker):
sid = str(uuid.uuid4())
await create_session(db, sid)
gen = event_generator(sid, broker, db)
await _collect_events(gen, count=1) # init
broker._publish_sync(sid, {"event": "job_removed", "data": {"job_id": "abc"}})
events = await _collect_events(gen, count=1)
assert events[0]["event"] == "job_removed"
data = json.loads(events[0]["data"])
assert data["job_id"] == "abc"
await gen.aclose()
class TestEventGeneratorDisconnect:
"""Verify that unsubscribe fires on generator close."""
async def test_unsubscribe_on_close(self, db, broker):
sid = str(uuid.uuid4())
await create_session(db, sid)
gen = event_generator(sid, broker, db)
await _collect_events(gen, count=1) # init
# Broker should have a subscriber now
assert sid in broker._subscribers
assert len(broker._subscribers[sid]) == 1
# Close the generator — triggers finally block
await gen.aclose()
# Subscriber should be cleaned up
assert sid not in broker._subscribers
class TestEventGeneratorKeepalive:
"""Verify that a ping event is sent after the keepalive timeout."""
async def test_ping_after_timeout(self, db, broker):
sid = str(uuid.uuid4())
await create_session(db, sid)
# Patch the timeout to a very short value for test speed
with patch("app.routers.sse.KEEPALIVE_TIMEOUT", 0.1):
gen = event_generator(sid, broker, db)
await _collect_events(gen, count=1) # init
# Next event should be a ping (no messages published)
events = await _collect_events(gen, count=1)
assert events[0]["event"] == "ping"
assert events[0]["data"] == ""
await gen.aclose()
class TestSessionIsolation:
"""Jobs for one session don't leak into another session's init."""
async def test_init_only_contains_own_session(self, db, broker):
sid_a = str(uuid.uuid4())
sid_b = str(uuid.uuid4())
await create_session(db, sid_a)
await create_session(db, sid_b)
job_a = _make_job(sid_a, status="queued")
job_b = _make_job(sid_b, status="downloading")
await create_job(db, job_a)
await create_job(db, job_b)
# Connect as session A
gen = event_generator(sid_a, broker, db)
events = await _collect_events(gen, count=1)
payload = json.loads(events[0]["data"])
job_ids = [j["id"] for j in payload["jobs"]]
assert job_a.id in job_ids
assert job_b.id not in job_ids
await gen.aclose()
# ---------------------------------------------------------------------------
# HTTP-level integration test
# ---------------------------------------------------------------------------
class TestSSEEndpointHTTP:
"""Integration test hitting the real HTTP endpoint via httpx."""
async def test_sse_endpoint_returns_init(self, client):
"""GET /api/events returns 200 with text/event-stream and an init event.
httpx's ``ASGITransport`` calls ``await app(scope, receive, send)`` and
waits for the *entire* response body — so an infinite SSE stream hangs
it forever. We bypass the transport and invoke the ASGI app directly
with custom ``receive``/``send`` callables. Once the body contains
``"jobs"`` (i.e. the init event has been sent) we set a disconnect
event; ``EventSourceResponse``'s ``_listen_for_disconnect`` task picks
that up, cancels the task group, and returns normally.
"""
# Access the underlying ASGI app wired by the client fixture.
test_app = client._transport.app
received_status: int | None = None
received_content_type: str | None = None
received_body = b""
disconnected = asyncio.Event()
async def receive() -> dict:
await disconnected.wait()
return {"type": "http.disconnect"}
async def send(message: dict) -> None:
nonlocal received_status, received_content_type, received_body
if message["type"] == "http.response.start":
received_status = message["status"]
for k, v in message.get("headers", []):
if k == b"content-type":
received_content_type = v.decode()
elif message["type"] == "http.response.body":
received_body += message.get("body", b"")
# Signal disconnect as soon as the init event payload arrives.
if b'"jobs"' in received_body:
disconnected.set()
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": "GET",
"headers": [],
"scheme": "http",
"path": "/api/events",
"raw_path": b"/api/events",
"query_string": b"",
"server": ("testserver", 80),
"client": ("127.0.0.1", 1234),
"root_path": "",
}
# Safety timeout in case disconnect signalling doesn't terminate the app.
with contextlib.suppress(TimeoutError):
async with asyncio.timeout(5.0):
await test_app(scope, receive, send)
assert received_status == 200
assert received_content_type is not None
assert "text/event-stream" in received_content_type
assert b'"jobs"' in received_body
class TestJobRemovedViaDELETE:
"""DELETE /api/downloads/{id} publishes job_removed event."""
async def test_delete_publishes_job_removed(self, db, broker):
"""Create a job, subscribe, delete it, verify job_removed arrives."""
sid = str(uuid.uuid4())
await create_session(db, sid)
job = _make_job(sid, status="queued")
await create_job(db, job)
# Subscribe to the broker for this session
queue = broker.subscribe(sid)
# Simulate what the DELETE handler does: publish job_removed
broker._publish_sync(
sid,
{"event": "job_removed", "data": {"job_id": job.id}},
)
event = queue.get_nowait()
assert event["event"] == "job_removed"
assert event["data"]["job_id"] == job.id
broker.unsubscribe(sid, queue)