chrysopedia/backend/chat_service.py
jlightner a9589bfc93 test: Built ChatService with retrieve-prompt-stream pipeline, POST /api…
- "backend/chat_service.py"
- "backend/routers/chat.py"
- "backend/main.py"
- "backend/tests/test_chat.py"

GSD-Task: S03/T01
2026-04-04 05:19:44 +00:00

178 lines
6.5 KiB
Python

"""Chat service: retrieve context via search, stream LLM response as SSE events.
Assembles a numbered context block from search results, then streams
completion tokens from an OpenAI-compatible API. Yields SSE-formatted
events: sources, token, done, and error.
"""
from __future__ import annotations
import json
import logging
import time
import traceback
from typing import Any, AsyncIterator
import openai
from sqlalchemy.ext.asyncio import AsyncSession
from config import Settings
from search_service import SearchService
logger = logging.getLogger("chrysopedia.chat")
_SYSTEM_PROMPT_TEMPLATE = """\
You are Chrysopedia, an expert encyclopedic assistant for music production techniques.
Answer the user's question using ONLY the numbered sources below. Cite sources by
writing [N] inline (e.g. [1], [2]) where N is the source number. If the sources
do not contain enough information, say so honestly — do not invent facts.
Sources:
{context_block}
"""
_MAX_CONTEXT_SOURCES = 10
class ChatService:
"""Retrieve context from search, stream an LLM response with citations."""
def __init__(self, settings: Settings) -> None:
self.settings = settings
self._search = SearchService(settings)
self._openai = openai.AsyncOpenAI(
base_url=settings.llm_api_url,
api_key=settings.llm_api_key,
)
async def stream_response(
self,
query: str,
db: AsyncSession,
creator: str | None = None,
) -> AsyncIterator[str]:
"""Yield SSE-formatted events for a chat query.
Protocol:
1. ``event: sources\ndata: <json array of citation metadata>\n\n``
2. ``event: token\ndata: <text chunk>\n\n`` (repeated)
3. ``event: done\ndata: <json with cascade_tier>\n\n``
On error: ``event: error\ndata: <json with message>\n\n``
"""
start = time.monotonic()
# ── 1. Retrieve context via search ──────────────────────────────
try:
search_result = await self._search.search(
query=query,
scope="all",
limit=_MAX_CONTEXT_SOURCES,
db=db,
creator=creator,
)
except Exception:
logger.exception("chat_search_error query=%r creator=%r", query, creator)
yield _sse("error", {"message": "Search failed"})
return
items: list[dict[str, Any]] = search_result.get("items", [])
cascade_tier: str = search_result.get("cascade_tier", "")
# ── 2. Build citation metadata and context block ────────────────
sources = _build_sources(items)
context_block = _build_context_block(items)
logger.info(
"chat_search query=%r creator=%r cascade_tier=%s source_count=%d",
query, creator, cascade_tier, len(sources),
)
# Emit sources event first
yield _sse("sources", sources)
# ── 3. Stream LLM completion ────────────────────────────────────
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
try:
stream = await self._openai.chat.completions.create(
model=self.settings.llm_model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": query},
],
stream=True,
temperature=0.3,
max_tokens=2048,
)
async for chunk in stream:
choice = chunk.choices[0] if chunk.choices else None
if choice and choice.delta and choice.delta.content:
yield _sse("token", choice.delta.content)
except Exception:
tb = traceback.format_exc()
logger.error("chat_llm_error query=%r\n%s", query, tb)
yield _sse("error", {"message": "LLM generation failed"})
return
# ── 4. Done event ───────────────────────────────────────────────
latency_ms = (time.monotonic() - start) * 1000
logger.info(
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f",
query, creator, cascade_tier, len(sources), latency_ms,
)
yield _sse("done", {"cascade_tier": cascade_tier})
# ── Helpers ──────────────────────────────────────────────────────────────────
def _sse(event: str, data: Any) -> str:
"""Format a single SSE event string."""
payload = json.dumps(data) if not isinstance(data, str) else data
return f"event: {event}\ndata: {payload}\n\n"
def _build_sources(items: list[dict[str, Any]]) -> list[dict[str, str]]:
"""Build a numbered citation metadata list from search result items."""
sources: list[dict[str, str]] = []
for idx, item in enumerate(items, start=1):
sources.append({
"number": idx,
"title": item.get("title", ""),
"slug": item.get("technique_page_slug", "") or item.get("slug", ""),
"creator_name": item.get("creator_name", ""),
"topic_category": item.get("topic_category", ""),
"summary": (item.get("summary", "") or "")[:200],
"section_anchor": item.get("section_anchor", ""),
"section_heading": item.get("section_heading", ""),
})
return sources
def _build_context_block(items: list[dict[str, Any]]) -> str:
"""Build a numbered context block string for the LLM system prompt."""
if not items:
return "(No sources available)"
lines: list[str] = []
for idx, item in enumerate(items, start=1):
title = item.get("title", "Untitled")
creator = item.get("creator_name", "")
summary = item.get("summary", "")
section = item.get("section_heading", "")
parts = [f"[{idx}] {title}"]
if creator:
parts.append(f"by {creator}")
if section:
parts.append(f"{section}")
header = " ".join(parts)
lines.append(header)
if summary:
lines.append(f" {summary}")
lines.append("")
return "\n".join(lines)