- "backend/chat_service.py" - "backend/routers/chat.py" - "backend/main.py" - "backend/tests/test_chat.py" GSD-Task: S03/T01
178 lines
6.5 KiB
Python
178 lines
6.5 KiB
Python
"""Chat service: retrieve context via search, stream LLM response as SSE events.
|
|
|
|
Assembles a numbered context block from search results, then streams
|
|
completion tokens from an OpenAI-compatible API. Yields SSE-formatted
|
|
events: sources, token, done, and error.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
import traceback
|
|
from typing import Any, AsyncIterator
|
|
|
|
import openai
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from config import Settings
|
|
from search_service import SearchService
|
|
|
|
logger = logging.getLogger("chrysopedia.chat")
|
|
|
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
|
You are Chrysopedia, an expert encyclopedic assistant for music production techniques.
|
|
Answer the user's question using ONLY the numbered sources below. Cite sources by
|
|
writing [N] inline (e.g. [1], [2]) where N is the source number. If the sources
|
|
do not contain enough information, say so honestly — do not invent facts.
|
|
|
|
Sources:
|
|
{context_block}
|
|
"""
|
|
|
|
_MAX_CONTEXT_SOURCES = 10
|
|
|
|
|
|
class ChatService:
|
|
"""Retrieve context from search, stream an LLM response with citations."""
|
|
|
|
def __init__(self, settings: Settings) -> None:
|
|
self.settings = settings
|
|
self._search = SearchService(settings)
|
|
self._openai = openai.AsyncOpenAI(
|
|
base_url=settings.llm_api_url,
|
|
api_key=settings.llm_api_key,
|
|
)
|
|
|
|
async def stream_response(
|
|
self,
|
|
query: str,
|
|
db: AsyncSession,
|
|
creator: str | None = None,
|
|
) -> AsyncIterator[str]:
|
|
"""Yield SSE-formatted events for a chat query.
|
|
|
|
Protocol:
|
|
1. ``event: sources\ndata: <json array of citation metadata>\n\n``
|
|
2. ``event: token\ndata: <text chunk>\n\n`` (repeated)
|
|
3. ``event: done\ndata: <json with cascade_tier>\n\n``
|
|
On error: ``event: error\ndata: <json with message>\n\n``
|
|
"""
|
|
start = time.monotonic()
|
|
|
|
# ── 1. Retrieve context via search ──────────────────────────────
|
|
try:
|
|
search_result = await self._search.search(
|
|
query=query,
|
|
scope="all",
|
|
limit=_MAX_CONTEXT_SOURCES,
|
|
db=db,
|
|
creator=creator,
|
|
)
|
|
except Exception:
|
|
logger.exception("chat_search_error query=%r creator=%r", query, creator)
|
|
yield _sse("error", {"message": "Search failed"})
|
|
return
|
|
|
|
items: list[dict[str, Any]] = search_result.get("items", [])
|
|
cascade_tier: str = search_result.get("cascade_tier", "")
|
|
|
|
# ── 2. Build citation metadata and context block ────────────────
|
|
sources = _build_sources(items)
|
|
context_block = _build_context_block(items)
|
|
|
|
logger.info(
|
|
"chat_search query=%r creator=%r cascade_tier=%s source_count=%d",
|
|
query, creator, cascade_tier, len(sources),
|
|
)
|
|
|
|
# Emit sources event first
|
|
yield _sse("sources", sources)
|
|
|
|
# ── 3. Stream LLM completion ────────────────────────────────────
|
|
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
|
|
|
|
try:
|
|
stream = await self._openai.chat.completions.create(
|
|
model=self.settings.llm_model,
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": query},
|
|
],
|
|
stream=True,
|
|
temperature=0.3,
|
|
max_tokens=2048,
|
|
)
|
|
|
|
async for chunk in stream:
|
|
choice = chunk.choices[0] if chunk.choices else None
|
|
if choice and choice.delta and choice.delta.content:
|
|
yield _sse("token", choice.delta.content)
|
|
|
|
except Exception:
|
|
tb = traceback.format_exc()
|
|
logger.error("chat_llm_error query=%r\n%s", query, tb)
|
|
yield _sse("error", {"message": "LLM generation failed"})
|
|
return
|
|
|
|
# ── 4. Done event ───────────────────────────────────────────────
|
|
latency_ms = (time.monotonic() - start) * 1000
|
|
logger.info(
|
|
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f",
|
|
query, creator, cascade_tier, len(sources), latency_ms,
|
|
)
|
|
yield _sse("done", {"cascade_tier": cascade_tier})
|
|
|
|
|
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def _sse(event: str, data: Any) -> str:
|
|
"""Format a single SSE event string."""
|
|
payload = json.dumps(data) if not isinstance(data, str) else data
|
|
return f"event: {event}\ndata: {payload}\n\n"
|
|
|
|
|
|
def _build_sources(items: list[dict[str, Any]]) -> list[dict[str, str]]:
|
|
"""Build a numbered citation metadata list from search result items."""
|
|
sources: list[dict[str, str]] = []
|
|
for idx, item in enumerate(items, start=1):
|
|
sources.append({
|
|
"number": idx,
|
|
"title": item.get("title", ""),
|
|
"slug": item.get("technique_page_slug", "") or item.get("slug", ""),
|
|
"creator_name": item.get("creator_name", ""),
|
|
"topic_category": item.get("topic_category", ""),
|
|
"summary": (item.get("summary", "") or "")[:200],
|
|
"section_anchor": item.get("section_anchor", ""),
|
|
"section_heading": item.get("section_heading", ""),
|
|
})
|
|
return sources
|
|
|
|
|
|
def _build_context_block(items: list[dict[str, Any]]) -> str:
|
|
"""Build a numbered context block string for the LLM system prompt."""
|
|
if not items:
|
|
return "(No sources available)"
|
|
|
|
lines: list[str] = []
|
|
for idx, item in enumerate(items, start=1):
|
|
title = item.get("title", "Untitled")
|
|
creator = item.get("creator_name", "")
|
|
summary = item.get("summary", "")
|
|
section = item.get("section_heading", "")
|
|
|
|
parts = [f"[{idx}] {title}"]
|
|
if creator:
|
|
parts.append(f"by {creator}")
|
|
if section:
|
|
parts.append(f"— {section}")
|
|
header = " ".join(parts)
|
|
|
|
lines.append(header)
|
|
if summary:
|
|
lines.append(f" {summary}")
|
|
lines.append("")
|
|
|
|
return "\n".join(lines)
|