"""Integration tests for the chat SSE endpoint. Mocks SearchService.search() and the OpenAI streaming response to verify: 1. Valid SSE format with sources, token, and done events 2. Citation numbering matches the sources array 3. Creator param forwarded to search 4. Empty/invalid query returns 422 5. LLM error produces an SSE error event These tests use a standalone ASGI client that does NOT require a running PostgreSQL instance — the DB session dependency is overridden with a mock. """ from __future__ import annotations import json from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient # Ensure backend/ is on sys.path import pathlib import sys sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) from database import get_session # noqa: E402 from main import app # noqa: E402 # ── Standalone test client (no DB required) ────────────────────────────────── @pytest_asyncio.fixture() async def chat_client(): """Async HTTP test client that mocks out the DB session entirely.""" mock_session = AsyncMock() async def _mock_get_session(): yield mock_session app.dependency_overrides[get_session] = _mock_get_session transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://testserver") as ac: yield ac app.dependency_overrides.pop(get_session, None) # ── Helpers ────────────────────────────────────────────────────────────────── def _parse_sse(body: str) -> list[dict[str, Any]]: """Parse SSE text into a list of {event, data} dicts.""" events: list[dict[str, Any]] = [] current_event: str | None = None current_data: str | None = None for line in body.split("\n"): if line.startswith("event: "): current_event = line[len("event: "):] elif line.startswith("data: "): current_data = line[len("data: "):] elif line == "" and current_event is not None and current_data is not None: try: parsed = json.loads(current_data) except json.JSONDecodeError: parsed = current_data events.append({"event": current_event, "data": parsed}) current_event = None current_data = None return events def _fake_search_result( items: list[dict[str, Any]] | None = None, cascade_tier: str = "global", ) -> dict[str, Any]: """Build a fake SearchService.search() return value.""" if items is None: items = [ { "title": "Snare Compression", "slug": "snare-compression", "technique_page_slug": "snare-compression", "creator_name": "Keota", "topic_category": "Mixing", "summary": "How to compress a snare drum for punch and presence.", "section_anchor": "", "section_heading": "", "type": "technique_page", "score": 0.9, }, { "title": "Parallel Processing", "slug": "parallel-processing", "technique_page_slug": "parallel-processing", "creator_name": "Skope", "topic_category": "Mixing", "summary": "Using parallel compression for dynamics control.", "section_anchor": "bus-setup", "section_heading": "Bus Setup", "type": "technique_page", "score": 0.85, }, ] return { "items": items, "partial_matches": [], "total": len(items), "query": "snare compression", "fallback_used": False, "cascade_tier": cascade_tier, } def _mock_openai_stream(chunks: list[str]): """Create a mock async iterator that yields OpenAI-style stream chunks.""" class FakeChoice: def __init__(self, text: str | None): self.delta = MagicMock() self.delta.content = text class FakeChunk: def __init__(self, text: str | None): self.choices = [FakeChoice(text)] class FakeStream: def __init__(self, chunks: list[str]): self._chunks = chunks self._index = 0 def __aiter__(self): return self async def __anext__(self): if self._index >= len(self._chunks): raise StopAsyncIteration chunk = FakeChunk(self._chunks[self._index]) self._index += 1 return chunk return FakeStream(chunks) def _mock_openai_stream_error(): """Create a mock async iterator that raises mid-stream.""" class FakeStream: def __init__(self): self._yielded = False def __aiter__(self): return self async def __anext__(self): if not self._yielded: self._yielded = True raise RuntimeError("LLM connection lost") raise StopAsyncIteration return FakeStream() # ── Tests ──────────────────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_chat_sse_format_and_events(chat_client): """SSE stream contains sources, token(s), and done events in order.""" search_result = _fake_search_result() token_chunks = ["Snare compression ", "uses [1] to add ", "punch. See also [2]."] mock_openai_client = MagicMock() mock_openai_client.chat.completions.create = AsyncMock( return_value=_mock_openai_stream(token_chunks) ) with ( patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), ): resp = await chat_client.post("/api/v1/chat", json={"query": "snare compression"}) assert resp.status_code == 200 assert "text/event-stream" in resp.headers.get("content-type", "") events = _parse_sse(resp.text) event_types = [e["event"] for e in events] # Must have sources first, then tokens, then done assert event_types[0] == "sources" assert "token" in event_types assert event_types[-1] == "done" # Sources event is a list sources_data = events[0]["data"] assert isinstance(sources_data, list) assert len(sources_data) == 2 # Done event has cascade_tier done_data = events[-1]["data"] assert "cascade_tier" in done_data @pytest.mark.asyncio async def test_chat_citation_numbering(chat_client): """Citation numbers in sources array match 1-based indexing.""" search_result = _fake_search_result() mock_openai_client = MagicMock() mock_openai_client.chat.completions.create = AsyncMock( return_value=_mock_openai_stream(["hello"]) ) with ( patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), ): resp = await chat_client.post("/api/v1/chat", json={"query": "compression"}) events = _parse_sse(resp.text) sources = events[0]["data"] assert sources[0]["number"] == 1 assert sources[0]["title"] == "Snare Compression" assert sources[1]["number"] == 2 assert sources[1]["title"] == "Parallel Processing" assert sources[1]["section_anchor"] == "bus-setup" @pytest.mark.asyncio async def test_chat_creator_forwarded_to_search(chat_client): """Creator parameter is passed through to SearchService.search().""" search_result = _fake_search_result() mock_openai_client = MagicMock() mock_openai_client.chat.completions.create = AsyncMock( return_value=_mock_openai_stream(["ok"]) ) with ( patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result) as mock_search, patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), ): resp = await chat_client.post( "/api/v1/chat", json={"query": "drum mixing", "creator": "keota"}, ) assert resp.status_code == 200 mock_search.assert_called_once() call_kwargs = mock_search.call_args.kwargs assert call_kwargs.get("creator") == "keota" @pytest.mark.asyncio async def test_chat_empty_query_returns_422(chat_client): """An empty query string should fail Pydantic validation with 422.""" resp = await chat_client.post("/api/v1/chat", json={"query": ""}) assert resp.status_code == 422 @pytest.mark.asyncio async def test_chat_missing_query_returns_422(chat_client): """Missing query field should fail with 422.""" resp = await chat_client.post("/api/v1/chat", json={}) assert resp.status_code == 422 @pytest.mark.asyncio async def test_chat_llm_error_produces_error_event(chat_client): """When the LLM raises mid-stream, an error SSE event is emitted.""" search_result = _fake_search_result() mock_openai_client = MagicMock() mock_openai_client.chat.completions.create = AsyncMock( return_value=_mock_openai_stream_error() ) with ( patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), ): resp = await chat_client.post("/api/v1/chat", json={"query": "test error"}) assert resp.status_code == 200 # SSE streams always return 200 events = _parse_sse(resp.text) event_types = [e["event"] for e in events] assert "sources" in event_types assert "error" in event_types error_event = next(e for e in events if e["event"] == "error") assert "message" in error_event["data"]