chrysopedia/backend/tests/test_chat.py
jlightner a9589bfc93 test: Built ChatService with retrieve-prompt-stream pipeline, POST /api…
- "backend/chat_service.py"
- "backend/routers/chat.py"
- "backend/main.py"
- "backend/tests/test_chat.py"

GSD-Task: S03/T01
2026-04-04 05:19:44 +00:00

300 lines
10 KiB
Python

"""Integration tests for the chat SSE endpoint.
Mocks SearchService.search() and the OpenAI streaming response to verify:
1. Valid SSE format with sources, token, and done events
2. Citation numbering matches the sources array
3. Creator param forwarded to search
4. Empty/invalid query returns 422
5. LLM error produces an SSE error event
These tests use a standalone ASGI client that does NOT require a running
PostgreSQL instance — the DB session dependency is overridden with a mock.
"""
from __future__ import annotations
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
# Ensure backend/ is on sys.path
import pathlib
import sys
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent))
from database import get_session # noqa: E402
from main import app # noqa: E402
# ── Standalone test client (no DB required) ──────────────────────────────────
@pytest_asyncio.fixture()
async def chat_client():
"""Async HTTP test client that mocks out the DB session entirely."""
mock_session = AsyncMock()
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
yield ac
app.dependency_overrides.pop(get_session, None)
# ── Helpers ──────────────────────────────────────────────────────────────────
def _parse_sse(body: str) -> list[dict[str, Any]]:
"""Parse SSE text into a list of {event, data} dicts."""
events: list[dict[str, Any]] = []
current_event: str | None = None
current_data: str | None = None
for line in body.split("\n"):
if line.startswith("event: "):
current_event = line[len("event: "):]
elif line.startswith("data: "):
current_data = line[len("data: "):]
elif line == "" and current_event is not None and current_data is not None:
try:
parsed = json.loads(current_data)
except json.JSONDecodeError:
parsed = current_data
events.append({"event": current_event, "data": parsed})
current_event = None
current_data = None
return events
def _fake_search_result(
items: list[dict[str, Any]] | None = None,
cascade_tier: str = "global",
) -> dict[str, Any]:
"""Build a fake SearchService.search() return value."""
if items is None:
items = [
{
"title": "Snare Compression",
"slug": "snare-compression",
"technique_page_slug": "snare-compression",
"creator_name": "Keota",
"topic_category": "Mixing",
"summary": "How to compress a snare drum for punch and presence.",
"section_anchor": "",
"section_heading": "",
"type": "technique_page",
"score": 0.9,
},
{
"title": "Parallel Processing",
"slug": "parallel-processing",
"technique_page_slug": "parallel-processing",
"creator_name": "Skope",
"topic_category": "Mixing",
"summary": "Using parallel compression for dynamics control.",
"section_anchor": "bus-setup",
"section_heading": "Bus Setup",
"type": "technique_page",
"score": 0.85,
},
]
return {
"items": items,
"partial_matches": [],
"total": len(items),
"query": "snare compression",
"fallback_used": False,
"cascade_tier": cascade_tier,
}
def _mock_openai_stream(chunks: list[str]):
"""Create a mock async iterator that yields OpenAI-style stream chunks."""
class FakeChoice:
def __init__(self, text: str | None):
self.delta = MagicMock()
self.delta.content = text
class FakeChunk:
def __init__(self, text: str | None):
self.choices = [FakeChoice(text)]
class FakeStream:
def __init__(self, chunks: list[str]):
self._chunks = chunks
self._index = 0
def __aiter__(self):
return self
async def __anext__(self):
if self._index >= len(self._chunks):
raise StopAsyncIteration
chunk = FakeChunk(self._chunks[self._index])
self._index += 1
return chunk
return FakeStream(chunks)
def _mock_openai_stream_error():
"""Create a mock async iterator that raises mid-stream."""
class FakeStream:
def __init__(self):
self._yielded = False
def __aiter__(self):
return self
async def __anext__(self):
if not self._yielded:
self._yielded = True
raise RuntimeError("LLM connection lost")
raise StopAsyncIteration
return FakeStream()
# ── Tests ────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_chat_sse_format_and_events(chat_client):
"""SSE stream contains sources, token(s), and done events in order."""
search_result = _fake_search_result()
token_chunks = ["Snare compression ", "uses [1] to add ", "punch. See also [2]."]
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(token_chunks)
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "snare compression"})
assert resp.status_code == 200
assert "text/event-stream" in resp.headers.get("content-type", "")
events = _parse_sse(resp.text)
event_types = [e["event"] for e in events]
# Must have sources first, then tokens, then done
assert event_types[0] == "sources"
assert "token" in event_types
assert event_types[-1] == "done"
# Sources event is a list
sources_data = events[0]["data"]
assert isinstance(sources_data, list)
assert len(sources_data) == 2
# Done event has cascade_tier
done_data = events[-1]["data"]
assert "cascade_tier" in done_data
@pytest.mark.asyncio
async def test_chat_citation_numbering(chat_client):
"""Citation numbers in sources array match 1-based indexing."""
search_result = _fake_search_result()
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["hello"])
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "compression"})
events = _parse_sse(resp.text)
sources = events[0]["data"]
assert sources[0]["number"] == 1
assert sources[0]["title"] == "Snare Compression"
assert sources[1]["number"] == 2
assert sources[1]["title"] == "Parallel Processing"
assert sources[1]["section_anchor"] == "bus-setup"
@pytest.mark.asyncio
async def test_chat_creator_forwarded_to_search(chat_client):
"""Creator parameter is passed through to SearchService.search()."""
search_result = _fake_search_result()
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["ok"])
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result) as mock_search,
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "drum mixing", "creator": "keota"},
)
assert resp.status_code == 200
mock_search.assert_called_once()
call_kwargs = mock_search.call_args.kwargs
assert call_kwargs.get("creator") == "keota"
@pytest.mark.asyncio
async def test_chat_empty_query_returns_422(chat_client):
"""An empty query string should fail Pydantic validation with 422."""
resp = await chat_client.post("/api/v1/chat", json={"query": ""})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_missing_query_returns_422(chat_client):
"""Missing query field should fail with 422."""
resp = await chat_client.post("/api/v1/chat", json={})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_llm_error_produces_error_event(chat_client):
"""When the LLM raises mid-stream, an error SSE event is emitted."""
search_result = _fake_search_result()
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream_error()
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "test error"})
assert resp.status_code == 200 # SSE streams always return 200
events = _parse_sse(resp.text)
event_types = [e["event"] for e in events]
assert "sources" in event_types
assert "error" in event_types
error_event = next(e for e in events if e["event"] == "error")
assert "message" in error_event["data"]