- "backend/chat_service.py" - "backend/routers/chat.py" - "backend/main.py" - "backend/tests/test_chat.py" GSD-Task: S03/T01
300 lines
10 KiB
Python
300 lines
10 KiB
Python
"""Integration tests for the chat SSE endpoint.
|
|
|
|
Mocks SearchService.search() and the OpenAI streaming response to verify:
|
|
1. Valid SSE format with sources, token, and done events
|
|
2. Citation numbering matches the sources array
|
|
3. Creator param forwarded to search
|
|
4. Empty/invalid query returns 422
|
|
5. LLM error produces an SSE error event
|
|
|
|
These tests use a standalone ASGI client that does NOT require a running
|
|
PostgreSQL instance — the DB session dependency is overridden with a mock.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
# Ensure backend/ is on sys.path
|
|
import pathlib
|
|
import sys
|
|
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent))
|
|
|
|
from database import get_session # noqa: E402
|
|
from main import app # noqa: E402
|
|
|
|
|
|
# ── Standalone test client (no DB required) ──────────────────────────────────
|
|
|
|
@pytest_asyncio.fixture()
|
|
async def chat_client():
|
|
"""Async HTTP test client that mocks out the DB session entirely."""
|
|
mock_session = AsyncMock()
|
|
|
|
async def _mock_get_session():
|
|
yield mock_session
|
|
|
|
app.dependency_overrides[get_session] = _mock_get_session
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
|
|
yield ac
|
|
|
|
app.dependency_overrides.pop(get_session, None)
|
|
|
|
|
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def _parse_sse(body: str) -> list[dict[str, Any]]:
|
|
"""Parse SSE text into a list of {event, data} dicts."""
|
|
events: list[dict[str, Any]] = []
|
|
current_event: str | None = None
|
|
current_data: str | None = None
|
|
|
|
for line in body.split("\n"):
|
|
if line.startswith("event: "):
|
|
current_event = line[len("event: "):]
|
|
elif line.startswith("data: "):
|
|
current_data = line[len("data: "):]
|
|
elif line == "" and current_event is not None and current_data is not None:
|
|
try:
|
|
parsed = json.loads(current_data)
|
|
except json.JSONDecodeError:
|
|
parsed = current_data
|
|
events.append({"event": current_event, "data": parsed})
|
|
current_event = None
|
|
current_data = None
|
|
|
|
return events
|
|
|
|
|
|
def _fake_search_result(
|
|
items: list[dict[str, Any]] | None = None,
|
|
cascade_tier: str = "global",
|
|
) -> dict[str, Any]:
|
|
"""Build a fake SearchService.search() return value."""
|
|
if items is None:
|
|
items = [
|
|
{
|
|
"title": "Snare Compression",
|
|
"slug": "snare-compression",
|
|
"technique_page_slug": "snare-compression",
|
|
"creator_name": "Keota",
|
|
"topic_category": "Mixing",
|
|
"summary": "How to compress a snare drum for punch and presence.",
|
|
"section_anchor": "",
|
|
"section_heading": "",
|
|
"type": "technique_page",
|
|
"score": 0.9,
|
|
},
|
|
{
|
|
"title": "Parallel Processing",
|
|
"slug": "parallel-processing",
|
|
"technique_page_slug": "parallel-processing",
|
|
"creator_name": "Skope",
|
|
"topic_category": "Mixing",
|
|
"summary": "Using parallel compression for dynamics control.",
|
|
"section_anchor": "bus-setup",
|
|
"section_heading": "Bus Setup",
|
|
"type": "technique_page",
|
|
"score": 0.85,
|
|
},
|
|
]
|
|
return {
|
|
"items": items,
|
|
"partial_matches": [],
|
|
"total": len(items),
|
|
"query": "snare compression",
|
|
"fallback_used": False,
|
|
"cascade_tier": cascade_tier,
|
|
}
|
|
|
|
|
|
def _mock_openai_stream(chunks: list[str]):
|
|
"""Create a mock async iterator that yields OpenAI-style stream chunks."""
|
|
|
|
class FakeChoice:
|
|
def __init__(self, text: str | None):
|
|
self.delta = MagicMock()
|
|
self.delta.content = text
|
|
|
|
class FakeChunk:
|
|
def __init__(self, text: str | None):
|
|
self.choices = [FakeChoice(text)]
|
|
|
|
class FakeStream:
|
|
def __init__(self, chunks: list[str]):
|
|
self._chunks = chunks
|
|
self._index = 0
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
if self._index >= len(self._chunks):
|
|
raise StopAsyncIteration
|
|
chunk = FakeChunk(self._chunks[self._index])
|
|
self._index += 1
|
|
return chunk
|
|
|
|
return FakeStream(chunks)
|
|
|
|
|
|
def _mock_openai_stream_error():
|
|
"""Create a mock async iterator that raises mid-stream."""
|
|
|
|
class FakeStream:
|
|
def __init__(self):
|
|
self._yielded = False
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
if not self._yielded:
|
|
self._yielded = True
|
|
raise RuntimeError("LLM connection lost")
|
|
raise StopAsyncIteration
|
|
|
|
return FakeStream()
|
|
|
|
|
|
# ── Tests ────────────────────────────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_sse_format_and_events(chat_client):
|
|
"""SSE stream contains sources, token(s), and done events in order."""
|
|
search_result = _fake_search_result()
|
|
token_chunks = ["Snare compression ", "uses [1] to add ", "punch. See also [2]."]
|
|
|
|
mock_openai_client = MagicMock()
|
|
mock_openai_client.chat.completions.create = AsyncMock(
|
|
return_value=_mock_openai_stream(token_chunks)
|
|
)
|
|
|
|
with (
|
|
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
|
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
|
):
|
|
resp = await chat_client.post("/api/v1/chat", json={"query": "snare compression"})
|
|
|
|
assert resp.status_code == 200
|
|
assert "text/event-stream" in resp.headers.get("content-type", "")
|
|
|
|
events = _parse_sse(resp.text)
|
|
event_types = [e["event"] for e in events]
|
|
|
|
# Must have sources first, then tokens, then done
|
|
assert event_types[0] == "sources"
|
|
assert "token" in event_types
|
|
assert event_types[-1] == "done"
|
|
|
|
# Sources event is a list
|
|
sources_data = events[0]["data"]
|
|
assert isinstance(sources_data, list)
|
|
assert len(sources_data) == 2
|
|
|
|
# Done event has cascade_tier
|
|
done_data = events[-1]["data"]
|
|
assert "cascade_tier" in done_data
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_citation_numbering(chat_client):
|
|
"""Citation numbers in sources array match 1-based indexing."""
|
|
search_result = _fake_search_result()
|
|
|
|
mock_openai_client = MagicMock()
|
|
mock_openai_client.chat.completions.create = AsyncMock(
|
|
return_value=_mock_openai_stream(["hello"])
|
|
)
|
|
|
|
with (
|
|
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
|
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
|
):
|
|
resp = await chat_client.post("/api/v1/chat", json={"query": "compression"})
|
|
|
|
events = _parse_sse(resp.text)
|
|
sources = events[0]["data"]
|
|
|
|
assert sources[0]["number"] == 1
|
|
assert sources[0]["title"] == "Snare Compression"
|
|
assert sources[1]["number"] == 2
|
|
assert sources[1]["title"] == "Parallel Processing"
|
|
assert sources[1]["section_anchor"] == "bus-setup"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_creator_forwarded_to_search(chat_client):
|
|
"""Creator parameter is passed through to SearchService.search()."""
|
|
search_result = _fake_search_result()
|
|
|
|
mock_openai_client = MagicMock()
|
|
mock_openai_client.chat.completions.create = AsyncMock(
|
|
return_value=_mock_openai_stream(["ok"])
|
|
)
|
|
|
|
with (
|
|
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result) as mock_search,
|
|
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
|
):
|
|
resp = await chat_client.post(
|
|
"/api/v1/chat",
|
|
json={"query": "drum mixing", "creator": "keota"},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
mock_search.assert_called_once()
|
|
call_kwargs = mock_search.call_args.kwargs
|
|
assert call_kwargs.get("creator") == "keota"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_empty_query_returns_422(chat_client):
|
|
"""An empty query string should fail Pydantic validation with 422."""
|
|
resp = await chat_client.post("/api/v1/chat", json={"query": ""})
|
|
assert resp.status_code == 422
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_missing_query_returns_422(chat_client):
|
|
"""Missing query field should fail with 422."""
|
|
resp = await chat_client.post("/api/v1/chat", json={})
|
|
assert resp.status_code == 422
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_llm_error_produces_error_event(chat_client):
|
|
"""When the LLM raises mid-stream, an error SSE event is emitted."""
|
|
search_result = _fake_search_result()
|
|
|
|
mock_openai_client = MagicMock()
|
|
mock_openai_client.chat.completions.create = AsyncMock(
|
|
return_value=_mock_openai_stream_error()
|
|
)
|
|
|
|
with (
|
|
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
|
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
|
):
|
|
resp = await chat_client.post("/api/v1/chat", json={"query": "test error"})
|
|
|
|
assert resp.status_code == 200 # SSE streams always return 200
|
|
|
|
events = _parse_sse(resp.text)
|
|
event_types = [e["event"] for e in events]
|
|
|
|
assert "sources" in event_types
|
|
assert "error" in event_types
|
|
|
|
error_event = next(e for e in events if e["event"] == "error")
|
|
assert "message" in error_event["data"]
|