test: Built ChatService with retrieve-prompt-stream pipeline, POST /api…
- "backend/chat_service.py" - "backend/routers/chat.py" - "backend/main.py" - "backend/tests/test_chat.py" GSD-Task: S03/T01
This commit is contained in:
parent
9530c85b9c
commit
5e0ce753a5
4 changed files with 540 additions and 1 deletions
178
backend/chat_service.py
Normal file
178
backend/chat_service.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
"""Chat service: retrieve context via search, stream LLM response as SSE events.
|
||||
|
||||
Assembles a numbered context block from search results, then streams
|
||||
completion tokens from an OpenAI-compatible API. Yields SSE-formatted
|
||||
events: sources, token, done, and error.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import openai
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from config import Settings
|
||||
from search_service import SearchService
|
||||
|
||||
logger = logging.getLogger("chrysopedia.chat")
|
||||
|
||||
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||
You are Chrysopedia, an expert encyclopedic assistant for music production techniques.
|
||||
Answer the user's question using ONLY the numbered sources below. Cite sources by
|
||||
writing [N] inline (e.g. [1], [2]) where N is the source number. If the sources
|
||||
do not contain enough information, say so honestly — do not invent facts.
|
||||
|
||||
Sources:
|
||||
{context_block}
|
||||
"""
|
||||
|
||||
_MAX_CONTEXT_SOURCES = 10
|
||||
|
||||
|
||||
class ChatService:
|
||||
"""Retrieve context from search, stream an LLM response with citations."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
self._search = SearchService(settings)
|
||||
self._openai = openai.AsyncOpenAI(
|
||||
base_url=settings.llm_api_url,
|
||||
api_key=settings.llm_api_key,
|
||||
)
|
||||
|
||||
async def stream_response(
|
||||
self,
|
||||
query: str,
|
||||
db: AsyncSession,
|
||||
creator: str | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Yield SSE-formatted events for a chat query.
|
||||
|
||||
Protocol:
|
||||
1. ``event: sources\ndata: <json array of citation metadata>\n\n``
|
||||
2. ``event: token\ndata: <text chunk>\n\n`` (repeated)
|
||||
3. ``event: done\ndata: <json with cascade_tier>\n\n``
|
||||
On error: ``event: error\ndata: <json with message>\n\n``
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
# ── 1. Retrieve context via search ──────────────────────────────
|
||||
try:
|
||||
search_result = await self._search.search(
|
||||
query=query,
|
||||
scope="all",
|
||||
limit=_MAX_CONTEXT_SOURCES,
|
||||
db=db,
|
||||
creator=creator,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("chat_search_error query=%r creator=%r", query, creator)
|
||||
yield _sse("error", {"message": "Search failed"})
|
||||
return
|
||||
|
||||
items: list[dict[str, Any]] = search_result.get("items", [])
|
||||
cascade_tier: str = search_result.get("cascade_tier", "")
|
||||
|
||||
# ── 2. Build citation metadata and context block ────────────────
|
||||
sources = _build_sources(items)
|
||||
context_block = _build_context_block(items)
|
||||
|
||||
logger.info(
|
||||
"chat_search query=%r creator=%r cascade_tier=%s source_count=%d",
|
||||
query, creator, cascade_tier, len(sources),
|
||||
)
|
||||
|
||||
# Emit sources event first
|
||||
yield _sse("sources", sources)
|
||||
|
||||
# ── 3. Stream LLM completion ────────────────────────────────────
|
||||
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
|
||||
|
||||
try:
|
||||
stream = await self._openai.chat.completions.create(
|
||||
model=self.settings.llm_model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
stream=True,
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if choice and choice.delta and choice.delta.content:
|
||||
yield _sse("token", choice.delta.content)
|
||||
|
||||
except Exception:
|
||||
tb = traceback.format_exc()
|
||||
logger.error("chat_llm_error query=%r\n%s", query, tb)
|
||||
yield _sse("error", {"message": "LLM generation failed"})
|
||||
return
|
||||
|
||||
# ── 4. Done event ───────────────────────────────────────────────
|
||||
latency_ms = (time.monotonic() - start) * 1000
|
||||
logger.info(
|
||||
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f",
|
||||
query, creator, cascade_tier, len(sources), latency_ms,
|
||||
)
|
||||
yield _sse("done", {"cascade_tier": cascade_tier})
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _sse(event: str, data: Any) -> str:
|
||||
"""Format a single SSE event string."""
|
||||
payload = json.dumps(data) if not isinstance(data, str) else data
|
||||
return f"event: {event}\ndata: {payload}\n\n"
|
||||
|
||||
|
||||
def _build_sources(items: list[dict[str, Any]]) -> list[dict[str, str]]:
|
||||
"""Build a numbered citation metadata list from search result items."""
|
||||
sources: list[dict[str, str]] = []
|
||||
for idx, item in enumerate(items, start=1):
|
||||
sources.append({
|
||||
"number": idx,
|
||||
"title": item.get("title", ""),
|
||||
"slug": item.get("technique_page_slug", "") or item.get("slug", ""),
|
||||
"creator_name": item.get("creator_name", ""),
|
||||
"topic_category": item.get("topic_category", ""),
|
||||
"summary": (item.get("summary", "") or "")[:200],
|
||||
"section_anchor": item.get("section_anchor", ""),
|
||||
"section_heading": item.get("section_heading", ""),
|
||||
})
|
||||
return sources
|
||||
|
||||
|
||||
def _build_context_block(items: list[dict[str, Any]]) -> str:
|
||||
"""Build a numbered context block string for the LLM system prompt."""
|
||||
if not items:
|
||||
return "(No sources available)"
|
||||
|
||||
lines: list[str] = []
|
||||
for idx, item in enumerate(items, start=1):
|
||||
title = item.get("title", "Untitled")
|
||||
creator = item.get("creator_name", "")
|
||||
summary = item.get("summary", "")
|
||||
section = item.get("section_heading", "")
|
||||
|
||||
parts = [f"[{idx}] {title}"]
|
||||
if creator:
|
||||
parts.append(f"by {creator}")
|
||||
if section:
|
||||
parts.append(f"— {section}")
|
||||
header = " ".join(parts)
|
||||
|
||||
lines.append(header)
|
||||
if summary:
|
||||
lines.append(f" {summary}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
|
@ -12,7 +12,7 @@ from fastapi import FastAPI
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from config import get_settings
|
||||
from routers import admin, auth, consent, creator_dashboard, creators, health, ingest, pipeline, reports, search, stats, techniques, topics, videos
|
||||
from routers import admin, auth, chat, consent, creator_dashboard, creators, health, ingest, pipeline, reports, search, stats, techniques, topics, videos
|
||||
|
||||
|
||||
def _setup_logging() -> None:
|
||||
|
|
@ -80,6 +80,7 @@ app.include_router(health.router)
|
|||
# Versioned API
|
||||
app.include_router(admin.router, prefix="/api/v1")
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(chat.router, prefix="/api/v1")
|
||||
app.include_router(consent.router, prefix="/api/v1")
|
||||
app.include_router(creator_dashboard.router, prefix="/api/v1")
|
||||
app.include_router(creators.router, prefix="/api/v1")
|
||||
|
|
|
|||
60
backend/routers/chat.py
Normal file
60
backend/routers/chat.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
"""Chat endpoint — POST /api/v1/chat with SSE streaming response.
|
||||
|
||||
Accepts a query and optional creator filter, returns a Server-Sent Events
|
||||
stream with sources, token, done, and error events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from chat_service import ChatService
|
||||
from config import Settings, get_settings
|
||||
from database import get_session
|
||||
|
||||
logger = logging.getLogger("chrysopedia.chat.router")
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""Request body for the chat endpoint."""
|
||||
|
||||
query: str = Field(..., min_length=1, max_length=1000)
|
||||
creator: str | None = None
|
||||
|
||||
|
||||
def _get_chat_service(settings: Settings = Depends(get_settings)) -> ChatService:
|
||||
"""Build a ChatService from current settings."""
|
||||
return ChatService(settings)
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def chat(
|
||||
body: ChatRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
service: ChatService = Depends(_get_chat_service),
|
||||
) -> StreamingResponse:
|
||||
"""Stream a chat response as Server-Sent Events.
|
||||
|
||||
SSE protocol:
|
||||
- ``event: sources`` — citation metadata array (sent first)
|
||||
- ``event: token`` — streamed text chunk (repeated)
|
||||
- ``event: done`` — completion metadata with cascade_tier
|
||||
- ``event: error`` — error message (on failure)
|
||||
"""
|
||||
logger.info("chat_request query=%r creator=%r", body.query, body.creator)
|
||||
|
||||
return StreamingResponse(
|
||||
service.stream_response(query=body.query, db=db, creator=body.creator),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
300
backend/tests/test_chat.py
Normal file
300
backend/tests/test_chat.py
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
"""Integration tests for the chat SSE endpoint.
|
||||
|
||||
Mocks SearchService.search() and the OpenAI streaming response to verify:
|
||||
1. Valid SSE format with sources, token, and done events
|
||||
2. Citation numbering matches the sources array
|
||||
3. Creator param forwarded to search
|
||||
4. Empty/invalid query returns 422
|
||||
5. LLM error produces an SSE error event
|
||||
|
||||
These tests use a standalone ASGI client that does NOT require a running
|
||||
PostgreSQL instance — the DB session dependency is overridden with a mock.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
# Ensure backend/ is on sys.path
|
||||
import pathlib
|
||||
import sys
|
||||
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent))
|
||||
|
||||
from database import get_session # noqa: E402
|
||||
from main import app # noqa: E402
|
||||
|
||||
|
||||
# ── Standalone test client (no DB required) ──────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def chat_client():
|
||||
"""Async HTTP test client that mocks out the DB session entirely."""
|
||||
mock_session = AsyncMock()
|
||||
|
||||
async def _mock_get_session():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_session] = _mock_get_session
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
|
||||
yield ac
|
||||
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _parse_sse(body: str) -> list[dict[str, Any]]:
|
||||
"""Parse SSE text into a list of {event, data} dicts."""
|
||||
events: list[dict[str, Any]] = []
|
||||
current_event: str | None = None
|
||||
current_data: str | None = None
|
||||
|
||||
for line in body.split("\n"):
|
||||
if line.startswith("event: "):
|
||||
current_event = line[len("event: "):]
|
||||
elif line.startswith("data: "):
|
||||
current_data = line[len("data: "):]
|
||||
elif line == "" and current_event is not None and current_data is not None:
|
||||
try:
|
||||
parsed = json.loads(current_data)
|
||||
except json.JSONDecodeError:
|
||||
parsed = current_data
|
||||
events.append({"event": current_event, "data": parsed})
|
||||
current_event = None
|
||||
current_data = None
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def _fake_search_result(
|
||||
items: list[dict[str, Any]] | None = None,
|
||||
cascade_tier: str = "global",
|
||||
) -> dict[str, Any]:
|
||||
"""Build a fake SearchService.search() return value."""
|
||||
if items is None:
|
||||
items = [
|
||||
{
|
||||
"title": "Snare Compression",
|
||||
"slug": "snare-compression",
|
||||
"technique_page_slug": "snare-compression",
|
||||
"creator_name": "Keota",
|
||||
"topic_category": "Mixing",
|
||||
"summary": "How to compress a snare drum for punch and presence.",
|
||||
"section_anchor": "",
|
||||
"section_heading": "",
|
||||
"type": "technique_page",
|
||||
"score": 0.9,
|
||||
},
|
||||
{
|
||||
"title": "Parallel Processing",
|
||||
"slug": "parallel-processing",
|
||||
"technique_page_slug": "parallel-processing",
|
||||
"creator_name": "Skope",
|
||||
"topic_category": "Mixing",
|
||||
"summary": "Using parallel compression for dynamics control.",
|
||||
"section_anchor": "bus-setup",
|
||||
"section_heading": "Bus Setup",
|
||||
"type": "technique_page",
|
||||
"score": 0.85,
|
||||
},
|
||||
]
|
||||
return {
|
||||
"items": items,
|
||||
"partial_matches": [],
|
||||
"total": len(items),
|
||||
"query": "snare compression",
|
||||
"fallback_used": False,
|
||||
"cascade_tier": cascade_tier,
|
||||
}
|
||||
|
||||
|
||||
def _mock_openai_stream(chunks: list[str]):
|
||||
"""Create a mock async iterator that yields OpenAI-style stream chunks."""
|
||||
|
||||
class FakeChoice:
|
||||
def __init__(self, text: str | None):
|
||||
self.delta = MagicMock()
|
||||
self.delta.content = text
|
||||
|
||||
class FakeChunk:
|
||||
def __init__(self, text: str | None):
|
||||
self.choices = [FakeChoice(text)]
|
||||
|
||||
class FakeStream:
|
||||
def __init__(self, chunks: list[str]):
|
||||
self._chunks = chunks
|
||||
self._index = 0
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self._index >= len(self._chunks):
|
||||
raise StopAsyncIteration
|
||||
chunk = FakeChunk(self._chunks[self._index])
|
||||
self._index += 1
|
||||
return chunk
|
||||
|
||||
return FakeStream(chunks)
|
||||
|
||||
|
||||
def _mock_openai_stream_error():
|
||||
"""Create a mock async iterator that raises mid-stream."""
|
||||
|
||||
class FakeStream:
|
||||
def __init__(self):
|
||||
self._yielded = False
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._yielded:
|
||||
self._yielded = True
|
||||
raise RuntimeError("LLM connection lost")
|
||||
raise StopAsyncIteration
|
||||
|
||||
return FakeStream()
|
||||
|
||||
|
||||
# ── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_sse_format_and_events(chat_client):
|
||||
"""SSE stream contains sources, token(s), and done events in order."""
|
||||
search_result = _fake_search_result()
|
||||
token_chunks = ["Snare compression ", "uses [1] to add ", "punch. See also [2]."]
|
||||
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create = AsyncMock(
|
||||
return_value=_mock_openai_stream(token_chunks)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||
):
|
||||
resp = await chat_client.post("/api/v1/chat", json={"query": "snare compression"})
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert "text/event-stream" in resp.headers.get("content-type", "")
|
||||
|
||||
events = _parse_sse(resp.text)
|
||||
event_types = [e["event"] for e in events]
|
||||
|
||||
# Must have sources first, then tokens, then done
|
||||
assert event_types[0] == "sources"
|
||||
assert "token" in event_types
|
||||
assert event_types[-1] == "done"
|
||||
|
||||
# Sources event is a list
|
||||
sources_data = events[0]["data"]
|
||||
assert isinstance(sources_data, list)
|
||||
assert len(sources_data) == 2
|
||||
|
||||
# Done event has cascade_tier
|
||||
done_data = events[-1]["data"]
|
||||
assert "cascade_tier" in done_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_citation_numbering(chat_client):
|
||||
"""Citation numbers in sources array match 1-based indexing."""
|
||||
search_result = _fake_search_result()
|
||||
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create = AsyncMock(
|
||||
return_value=_mock_openai_stream(["hello"])
|
||||
)
|
||||
|
||||
with (
|
||||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||
):
|
||||
resp = await chat_client.post("/api/v1/chat", json={"query": "compression"})
|
||||
|
||||
events = _parse_sse(resp.text)
|
||||
sources = events[0]["data"]
|
||||
|
||||
assert sources[0]["number"] == 1
|
||||
assert sources[0]["title"] == "Snare Compression"
|
||||
assert sources[1]["number"] == 2
|
||||
assert sources[1]["title"] == "Parallel Processing"
|
||||
assert sources[1]["section_anchor"] == "bus-setup"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_creator_forwarded_to_search(chat_client):
|
||||
"""Creator parameter is passed through to SearchService.search()."""
|
||||
search_result = _fake_search_result()
|
||||
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create = AsyncMock(
|
||||
return_value=_mock_openai_stream(["ok"])
|
||||
)
|
||||
|
||||
with (
|
||||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result) as mock_search,
|
||||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||
):
|
||||
resp = await chat_client.post(
|
||||
"/api/v1/chat",
|
||||
json={"query": "drum mixing", "creator": "keota"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
mock_search.assert_called_once()
|
||||
call_kwargs = mock_search.call_args.kwargs
|
||||
assert call_kwargs.get("creator") == "keota"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_empty_query_returns_422(chat_client):
|
||||
"""An empty query string should fail Pydantic validation with 422."""
|
||||
resp = await chat_client.post("/api/v1/chat", json={"query": ""})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_missing_query_returns_422(chat_client):
|
||||
"""Missing query field should fail with 422."""
|
||||
resp = await chat_client.post("/api/v1/chat", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_llm_error_produces_error_event(chat_client):
|
||||
"""When the LLM raises mid-stream, an error SSE event is emitted."""
|
||||
search_result = _fake_search_result()
|
||||
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create = AsyncMock(
|
||||
return_value=_mock_openai_stream_error()
|
||||
)
|
||||
|
||||
with (
|
||||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||
):
|
||||
resp = await chat_client.post("/api/v1/chat", json={"query": "test error"})
|
||||
|
||||
assert resp.status_code == 200 # SSE streams always return 200
|
||||
|
||||
events = _parse_sse(resp.text)
|
||||
event_types = [e["event"] for e in events]
|
||||
|
||||
assert "sources" in event_types
|
||||
assert "error" in event_types
|
||||
|
||||
error_event = next(e for e in events if e["event"] == "error")
|
||||
assert "message" in error_event["data"]
|
||||
Loading…
Add table
Reference in a new issue