diff --git a/backend/chat_service.py b/backend/chat_service.py new file mode 100644 index 0000000..05fd926 --- /dev/null +++ b/backend/chat_service.py @@ -0,0 +1,178 @@ +"""Chat service: retrieve context via search, stream LLM response as SSE events. + +Assembles a numbered context block from search results, then streams +completion tokens from an OpenAI-compatible API. Yields SSE-formatted +events: sources, token, done, and error. +""" + +from __future__ import annotations + +import json +import logging +import time +import traceback +from typing import Any, AsyncIterator + +import openai +from sqlalchemy.ext.asyncio import AsyncSession + +from config import Settings +from search_service import SearchService + +logger = logging.getLogger("chrysopedia.chat") + +_SYSTEM_PROMPT_TEMPLATE = """\ +You are Chrysopedia, an expert encyclopedic assistant for music production techniques. +Answer the user's question using ONLY the numbered sources below. Cite sources by +writing [N] inline (e.g. [1], [2]) where N is the source number. If the sources +do not contain enough information, say so honestly — do not invent facts. + +Sources: +{context_block} +""" + +_MAX_CONTEXT_SOURCES = 10 + + +class ChatService: + """Retrieve context from search, stream an LLM response with citations.""" + + def __init__(self, settings: Settings) -> None: + self.settings = settings + self._search = SearchService(settings) + self._openai = openai.AsyncOpenAI( + base_url=settings.llm_api_url, + api_key=settings.llm_api_key, + ) + + async def stream_response( + self, + query: str, + db: AsyncSession, + creator: str | None = None, + ) -> AsyncIterator[str]: + """Yield SSE-formatted events for a chat query. + + Protocol: + 1. ``event: sources\ndata: \n\n`` + 2. ``event: token\ndata: \n\n`` (repeated) + 3. ``event: done\ndata: \n\n`` + On error: ``event: error\ndata: \n\n`` + """ + start = time.monotonic() + + # ── 1. Retrieve context via search ────────────────────────────── + try: + search_result = await self._search.search( + query=query, + scope="all", + limit=_MAX_CONTEXT_SOURCES, + db=db, + creator=creator, + ) + except Exception: + logger.exception("chat_search_error query=%r creator=%r", query, creator) + yield _sse("error", {"message": "Search failed"}) + return + + items: list[dict[str, Any]] = search_result.get("items", []) + cascade_tier: str = search_result.get("cascade_tier", "") + + # ── 2. Build citation metadata and context block ──────────────── + sources = _build_sources(items) + context_block = _build_context_block(items) + + logger.info( + "chat_search query=%r creator=%r cascade_tier=%s source_count=%d", + query, creator, cascade_tier, len(sources), + ) + + # Emit sources event first + yield _sse("sources", sources) + + # ── 3. Stream LLM completion ──────────────────────────────────── + system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block) + + try: + stream = await self._openai.chat.completions.create( + model=self.settings.llm_model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": query}, + ], + stream=True, + temperature=0.3, + max_tokens=2048, + ) + + async for chunk in stream: + choice = chunk.choices[0] if chunk.choices else None + if choice and choice.delta and choice.delta.content: + yield _sse("token", choice.delta.content) + + except Exception: + tb = traceback.format_exc() + logger.error("chat_llm_error query=%r\n%s", query, tb) + yield _sse("error", {"message": "LLM generation failed"}) + return + + # ── 4. Done event ─────────────────────────────────────────────── + latency_ms = (time.monotonic() - start) * 1000 + logger.info( + "chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f", + query, creator, cascade_tier, len(sources), latency_ms, + ) + yield _sse("done", {"cascade_tier": cascade_tier}) + + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _sse(event: str, data: Any) -> str: + """Format a single SSE event string.""" + payload = json.dumps(data) if not isinstance(data, str) else data + return f"event: {event}\ndata: {payload}\n\n" + + +def _build_sources(items: list[dict[str, Any]]) -> list[dict[str, str]]: + """Build a numbered citation metadata list from search result items.""" + sources: list[dict[str, str]] = [] + for idx, item in enumerate(items, start=1): + sources.append({ + "number": idx, + "title": item.get("title", ""), + "slug": item.get("technique_page_slug", "") or item.get("slug", ""), + "creator_name": item.get("creator_name", ""), + "topic_category": item.get("topic_category", ""), + "summary": (item.get("summary", "") or "")[:200], + "section_anchor": item.get("section_anchor", ""), + "section_heading": item.get("section_heading", ""), + }) + return sources + + +def _build_context_block(items: list[dict[str, Any]]) -> str: + """Build a numbered context block string for the LLM system prompt.""" + if not items: + return "(No sources available)" + + lines: list[str] = [] + for idx, item in enumerate(items, start=1): + title = item.get("title", "Untitled") + creator = item.get("creator_name", "") + summary = item.get("summary", "") + section = item.get("section_heading", "") + + parts = [f"[{idx}] {title}"] + if creator: + parts.append(f"by {creator}") + if section: + parts.append(f"— {section}") + header = " ".join(parts) + + lines.append(header) + if summary: + lines.append(f" {summary}") + lines.append("") + + return "\n".join(lines) diff --git a/backend/main.py b/backend/main.py index 5e604a0..fe6d30a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,7 +12,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from config import get_settings -from routers import admin, auth, consent, creator_dashboard, creators, health, ingest, pipeline, reports, search, stats, techniques, topics, videos +from routers import admin, auth, chat, consent, creator_dashboard, creators, health, ingest, pipeline, reports, search, stats, techniques, topics, videos def _setup_logging() -> None: @@ -80,6 +80,7 @@ app.include_router(health.router) # Versioned API app.include_router(admin.router, prefix="/api/v1") app.include_router(auth.router, prefix="/api/v1") +app.include_router(chat.router, prefix="/api/v1") app.include_router(consent.router, prefix="/api/v1") app.include_router(creator_dashboard.router, prefix="/api/v1") app.include_router(creators.router, prefix="/api/v1") diff --git a/backend/routers/chat.py b/backend/routers/chat.py new file mode 100644 index 0000000..b38365f --- /dev/null +++ b/backend/routers/chat.py @@ -0,0 +1,60 @@ +"""Chat endpoint — POST /api/v1/chat with SSE streaming response. + +Accepts a query and optional creator filter, returns a Server-Sent Events +stream with sources, token, done, and error events. +""" + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, Depends +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession + +from chat_service import ChatService +from config import Settings, get_settings +from database import get_session + +logger = logging.getLogger("chrysopedia.chat.router") + +router = APIRouter(prefix="/chat", tags=["chat"]) + + +class ChatRequest(BaseModel): + """Request body for the chat endpoint.""" + + query: str = Field(..., min_length=1, max_length=1000) + creator: str | None = None + + +def _get_chat_service(settings: Settings = Depends(get_settings)) -> ChatService: + """Build a ChatService from current settings.""" + return ChatService(settings) + + +@router.post("") +async def chat( + body: ChatRequest, + db: AsyncSession = Depends(get_session), + service: ChatService = Depends(_get_chat_service), +) -> StreamingResponse: + """Stream a chat response as Server-Sent Events. + + SSE protocol: + - ``event: sources`` — citation metadata array (sent first) + - ``event: token`` — streamed text chunk (repeated) + - ``event: done`` — completion metadata with cascade_tier + - ``event: error`` — error message (on failure) + """ + logger.info("chat_request query=%r creator=%r", body.query, body.creator) + + return StreamingResponse( + service.stream_response(query=body.query, db=db, creator=body.creator), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) diff --git a/backend/tests/test_chat.py b/backend/tests/test_chat.py new file mode 100644 index 0000000..985e16c --- /dev/null +++ b/backend/tests/test_chat.py @@ -0,0 +1,300 @@ +"""Integration tests for the chat SSE endpoint. + +Mocks SearchService.search() and the OpenAI streaming response to verify: +1. Valid SSE format with sources, token, and done events +2. Citation numbering matches the sources array +3. Creator param forwarded to search +4. Empty/invalid query returns 422 +5. LLM error produces an SSE error event + +These tests use a standalone ASGI client that does NOT require a running +PostgreSQL instance — the DB session dependency is overridden with a mock. +""" + +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient + +# Ensure backend/ is on sys.path +import pathlib +import sys +sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent)) + +from database import get_session # noqa: E402 +from main import app # noqa: E402 + + +# ── Standalone test client (no DB required) ────────────────────────────────── + +@pytest_asyncio.fixture() +async def chat_client(): + """Async HTTP test client that mocks out the DB session entirely.""" + mock_session = AsyncMock() + + async def _mock_get_session(): + yield mock_session + + app.dependency_overrides[get_session] = _mock_get_session + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://testserver") as ac: + yield ac + + app.dependency_overrides.pop(get_session, None) + + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _parse_sse(body: str) -> list[dict[str, Any]]: + """Parse SSE text into a list of {event, data} dicts.""" + events: list[dict[str, Any]] = [] + current_event: str | None = None + current_data: str | None = None + + for line in body.split("\n"): + if line.startswith("event: "): + current_event = line[len("event: "):] + elif line.startswith("data: "): + current_data = line[len("data: "):] + elif line == "" and current_event is not None and current_data is not None: + try: + parsed = json.loads(current_data) + except json.JSONDecodeError: + parsed = current_data + events.append({"event": current_event, "data": parsed}) + current_event = None + current_data = None + + return events + + +def _fake_search_result( + items: list[dict[str, Any]] | None = None, + cascade_tier: str = "global", +) -> dict[str, Any]: + """Build a fake SearchService.search() return value.""" + if items is None: + items = [ + { + "title": "Snare Compression", + "slug": "snare-compression", + "technique_page_slug": "snare-compression", + "creator_name": "Keota", + "topic_category": "Mixing", + "summary": "How to compress a snare drum for punch and presence.", + "section_anchor": "", + "section_heading": "", + "type": "technique_page", + "score": 0.9, + }, + { + "title": "Parallel Processing", + "slug": "parallel-processing", + "technique_page_slug": "parallel-processing", + "creator_name": "Skope", + "topic_category": "Mixing", + "summary": "Using parallel compression for dynamics control.", + "section_anchor": "bus-setup", + "section_heading": "Bus Setup", + "type": "technique_page", + "score": 0.85, + }, + ] + return { + "items": items, + "partial_matches": [], + "total": len(items), + "query": "snare compression", + "fallback_used": False, + "cascade_tier": cascade_tier, + } + + +def _mock_openai_stream(chunks: list[str]): + """Create a mock async iterator that yields OpenAI-style stream chunks.""" + + class FakeChoice: + def __init__(self, text: str | None): + self.delta = MagicMock() + self.delta.content = text + + class FakeChunk: + def __init__(self, text: str | None): + self.choices = [FakeChoice(text)] + + class FakeStream: + def __init__(self, chunks: list[str]): + self._chunks = chunks + self._index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self._index >= len(self._chunks): + raise StopAsyncIteration + chunk = FakeChunk(self._chunks[self._index]) + self._index += 1 + return chunk + + return FakeStream(chunks) + + +def _mock_openai_stream_error(): + """Create a mock async iterator that raises mid-stream.""" + + class FakeStream: + def __init__(self): + self._yielded = False + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._yielded: + self._yielded = True + raise RuntimeError("LLM connection lost") + raise StopAsyncIteration + + return FakeStream() + + +# ── Tests ──────────────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_chat_sse_format_and_events(chat_client): + """SSE stream contains sources, token(s), and done events in order.""" + search_result = _fake_search_result() + token_chunks = ["Snare compression ", "uses [1] to add ", "punch. See also [2]."] + + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create = AsyncMock( + return_value=_mock_openai_stream(token_chunks) + ) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + resp = await chat_client.post("/api/v1/chat", json={"query": "snare compression"}) + + assert resp.status_code == 200 + assert "text/event-stream" in resp.headers.get("content-type", "") + + events = _parse_sse(resp.text) + event_types = [e["event"] for e in events] + + # Must have sources first, then tokens, then done + assert event_types[0] == "sources" + assert "token" in event_types + assert event_types[-1] == "done" + + # Sources event is a list + sources_data = events[0]["data"] + assert isinstance(sources_data, list) + assert len(sources_data) == 2 + + # Done event has cascade_tier + done_data = events[-1]["data"] + assert "cascade_tier" in done_data + + +@pytest.mark.asyncio +async def test_chat_citation_numbering(chat_client): + """Citation numbers in sources array match 1-based indexing.""" + search_result = _fake_search_result() + + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create = AsyncMock( + return_value=_mock_openai_stream(["hello"]) + ) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + resp = await chat_client.post("/api/v1/chat", json={"query": "compression"}) + + events = _parse_sse(resp.text) + sources = events[0]["data"] + + assert sources[0]["number"] == 1 + assert sources[0]["title"] == "Snare Compression" + assert sources[1]["number"] == 2 + assert sources[1]["title"] == "Parallel Processing" + assert sources[1]["section_anchor"] == "bus-setup" + + +@pytest.mark.asyncio +async def test_chat_creator_forwarded_to_search(chat_client): + """Creator parameter is passed through to SearchService.search().""" + search_result = _fake_search_result() + + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create = AsyncMock( + return_value=_mock_openai_stream(["ok"]) + ) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result) as mock_search, + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "drum mixing", "creator": "keota"}, + ) + + assert resp.status_code == 200 + mock_search.assert_called_once() + call_kwargs = mock_search.call_args.kwargs + assert call_kwargs.get("creator") == "keota" + + +@pytest.mark.asyncio +async def test_chat_empty_query_returns_422(chat_client): + """An empty query string should fail Pydantic validation with 422.""" + resp = await chat_client.post("/api/v1/chat", json={"query": ""}) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_chat_missing_query_returns_422(chat_client): + """Missing query field should fail with 422.""" + resp = await chat_client.post("/api/v1/chat", json={}) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_chat_llm_error_produces_error_event(chat_client): + """When the LLM raises mid-stream, an error SSE event is emitted.""" + search_result = _fake_search_result() + + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create = AsyncMock( + return_value=_mock_openai_stream_error() + ) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + resp = await chat_client.post("/api/v1/chat", json={"query": "test error"}) + + assert resp.status_code == 200 # SSE streams always return 200 + + events = _parse_sse(resp.text) + event_types = [e["event"] for e in events] + + assert "sources" in event_types + assert "error" in event_types + + error_event = next(e for e in events if e["event"] == "error") + assert "message" in error_event["data"]