"""Chat service: retrieve context via search, stream LLM response as SSE events. Assembles a numbered context block from search results, then streams completion tokens from an OpenAI-compatible API. Yields SSE-formatted events: sources, token, done, and error. """ from __future__ import annotations import json import logging import time import traceback from typing import Any, AsyncIterator import openai from sqlalchemy.ext.asyncio import AsyncSession from config import Settings from search_service import SearchService logger = logging.getLogger("chrysopedia.chat") _SYSTEM_PROMPT_TEMPLATE = """\ You are Chrysopedia, an expert encyclopedic assistant for music production techniques. Answer the user's question using ONLY the numbered sources below. Cite sources by writing [N] inline (e.g. [1], [2]) where N is the source number. If the sources do not contain enough information, say so honestly — do not invent facts. Sources: {context_block} """ _MAX_CONTEXT_SOURCES = 10 class ChatService: """Retrieve context from search, stream an LLM response with citations.""" def __init__(self, settings: Settings) -> None: self.settings = settings self._search = SearchService(settings) self._openai = openai.AsyncOpenAI( base_url=settings.llm_api_url, api_key=settings.llm_api_key, ) async def stream_response( self, query: str, db: AsyncSession, creator: str | None = None, ) -> AsyncIterator[str]: """Yield SSE-formatted events for a chat query. Protocol: 1. ``event: sources\ndata: \n\n`` 2. ``event: token\ndata: \n\n`` (repeated) 3. ``event: done\ndata: \n\n`` On error: ``event: error\ndata: \n\n`` """ start = time.monotonic() # ── 1. Retrieve context via search ────────────────────────────── try: search_result = await self._search.search( query=query, scope="all", limit=_MAX_CONTEXT_SOURCES, db=db, creator=creator, ) except Exception: logger.exception("chat_search_error query=%r creator=%r", query, creator) yield _sse("error", {"message": "Search failed"}) return items: list[dict[str, Any]] = search_result.get("items", []) cascade_tier: str = search_result.get("cascade_tier", "") # ── 2. Build citation metadata and context block ──────────────── sources = _build_sources(items) context_block = _build_context_block(items) logger.info( "chat_search query=%r creator=%r cascade_tier=%s source_count=%d", query, creator, cascade_tier, len(sources), ) # Emit sources event first yield _sse("sources", sources) # ── 3. Stream LLM completion ──────────────────────────────────── system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block) try: stream = await self._openai.chat.completions.create( model=self.settings.llm_model, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": query}, ], stream=True, temperature=0.3, max_tokens=2048, ) async for chunk in stream: choice = chunk.choices[0] if chunk.choices else None if choice and choice.delta and choice.delta.content: yield _sse("token", choice.delta.content) except Exception: tb = traceback.format_exc() logger.error("chat_llm_error query=%r\n%s", query, tb) yield _sse("error", {"message": "LLM generation failed"}) return # ── 4. Done event ─────────────────────────────────────────────── latency_ms = (time.monotonic() - start) * 1000 logger.info( "chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f", query, creator, cascade_tier, len(sources), latency_ms, ) yield _sse("done", {"cascade_tier": cascade_tier}) # ── Helpers ────────────────────────────────────────────────────────────────── def _sse(event: str, data: Any) -> str: """Format a single SSE event string.""" payload = json.dumps(data) if not isinstance(data, str) else data return f"event: {event}\ndata: {payload}\n\n" def _build_sources(items: list[dict[str, Any]]) -> list[dict[str, str]]: """Build a numbered citation metadata list from search result items.""" sources: list[dict[str, str]] = [] for idx, item in enumerate(items, start=1): sources.append({ "number": idx, "title": item.get("title", ""), "slug": item.get("technique_page_slug", "") or item.get("slug", ""), "creator_name": item.get("creator_name", ""), "topic_category": item.get("topic_category", ""), "summary": (item.get("summary", "") or "")[:200], "section_anchor": item.get("section_anchor", ""), "section_heading": item.get("section_heading", ""), }) return sources def _build_context_block(items: list[dict[str, Any]]) -> str: """Build a numbered context block string for the LLM system prompt.""" if not items: return "(No sources available)" lines: list[str] = [] for idx, item in enumerate(items, start=1): title = item.get("title", "Untitled") creator = item.get("creator_name", "") summary = item.get("summary", "") section = item.get("section_heading", "") parts = [f"[{idx}] {title}"] if creator: parts.append(f"by {creator}") if section: parts.append(f"— {section}") header = " ".join(parts) lines.append(header) if summary: lines.append(f" {summary}") lines.append("") return "\n".join(lines)