"""Chat service: retrieve context via search, stream LLM response as SSE events. Assembles a numbered context block from search results, then streams completion tokens from an OpenAI-compatible API. Yields SSE-formatted events: sources, token, done, and error. Multi-turn memory: When a conversation_id is provided, prior messages are loaded from Redis, injected into the LLM messages array, and the new user+assistant turn is appended after streaming completes. History is capped at 10 turn pairs (20 messages) and expires after 1 hour of inactivity. """ from __future__ import annotations import json import logging import time import traceback import uuid from typing import Any, AsyncIterator import openai from sqlalchemy.ext.asyncio import AsyncSession from config import Settings from search_service import SearchService logger = logging.getLogger("chrysopedia.chat") _SYSTEM_PROMPT_TEMPLATE = """\ You are Chrysopedia, an expert encyclopedic assistant for music production techniques. Answer the user's question using ONLY the numbered sources below. Cite sources by writing [N] inline (e.g. [1], [2]) where N is the source number. If the sources do not contain enough information, say so honestly — do not invent facts. Sources: {context_block} """ _MAX_CONTEXT_SOURCES = 10 _MAX_TURN_PAIRS = 10 _HISTORY_TTL_SECONDS = 3600 # 1 hour def _redis_key(conversation_id: str) -> str: return f"chrysopedia:chat:{conversation_id}" class ChatService: """Retrieve context from search, stream an LLM response with citations.""" def __init__(self, settings: Settings, redis=None) -> None: self.settings = settings self._search = SearchService(settings) self._openai = openai.AsyncOpenAI( base_url=settings.llm_api_url, api_key=settings.llm_api_key, ) self._redis = redis async def _load_history(self, conversation_id: str) -> list[dict[str, str]]: """Load conversation history from Redis. Returns empty list on miss.""" if not self._redis: return [] try: raw = await self._redis.get(_redis_key(conversation_id)) if raw: return json.loads(raw) except Exception: logger.warning("chat_history_load_error cid=%s", conversation_id, exc_info=True) return [] async def _save_history( self, conversation_id: str, history: list[dict[str, str]], user_msg: str, assistant_msg: str, ) -> None: """Append the new turn pair and persist to Redis with TTL refresh.""" if not self._redis: return history.append({"role": "user", "content": user_msg}) history.append({"role": "assistant", "content": assistant_msg}) # Cap at _MAX_TURN_PAIRS (keep most recent) if len(history) > _MAX_TURN_PAIRS * 2: history = history[-_MAX_TURN_PAIRS * 2:] try: await self._redis.set( _redis_key(conversation_id), json.dumps(history), ex=_HISTORY_TTL_SECONDS, ) except Exception: logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True) async def stream_response( self, query: str, db: AsyncSession, creator: str | None = None, conversation_id: str | None = None, ) -> AsyncIterator[str]: """Yield SSE-formatted events for a chat query. Protocol: 1. ``event: sources\ndata: \n\n`` 2. ``event: token\ndata: \n\n`` (repeated) 3. ``event: done\ndata: \n\n`` On error: ``event: error\ndata: \n\n`` """ start = time.monotonic() # Assign conversation_id if not provided (single-turn becomes trackable) if conversation_id is None: conversation_id = str(uuid.uuid4()) # ── 0. Load conversation history ──────────────────────────────── history = await self._load_history(conversation_id) # ── 1. Retrieve context via search ────────────────────────────── try: search_result = await self._search.search( query=query, scope="all", limit=_MAX_CONTEXT_SOURCES, db=db, creator=creator, ) except Exception: logger.exception("chat_search_error query=%r creator=%r", query, creator) yield _sse("error", {"message": "Search failed"}) return items: list[dict[str, Any]] = search_result.get("items", []) cascade_tier: str = search_result.get("cascade_tier", "") # ── 2. Build citation metadata and context block ──────────────── sources = _build_sources(items) context_block = _build_context_block(items) logger.info( "chat_search query=%r creator=%r cascade_tier=%s source_count=%d cid=%s", query, creator, cascade_tier, len(sources), conversation_id, ) # Emit sources event first yield _sse("sources", sources) # ── 3. Stream LLM completion ──────────────────────────────────── system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block) messages: list[dict[str, str]] = [ {"role": "system", "content": system_prompt}, ] # Inject conversation history between system prompt and current query messages.extend(history) messages.append({"role": "user", "content": query}) accumulated_response = "" try: stream = await self._openai.chat.completions.create( model=self.settings.llm_model, messages=messages, stream=True, temperature=0.3, max_tokens=2048, ) async for chunk in stream: choice = chunk.choices[0] if chunk.choices else None if choice and choice.delta and choice.delta.content: text = choice.delta.content accumulated_response += text yield _sse("token", text) except Exception: tb = traceback.format_exc() logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb) yield _sse("error", {"message": "LLM generation failed"}) return # ── 4. Save conversation history ──────────────────────────────── await self._save_history(conversation_id, history, query, accumulated_response) # ── 5. Done event ─────────────────────────────────────────────── latency_ms = (time.monotonic() - start) * 1000 logger.info( "chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s", query, creator, cascade_tier, len(sources), latency_ms, conversation_id, ) yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id}) # ── Helpers ────────────────────────────────────────────────────────────────── def _sse(event: str, data: Any) -> str: """Format a single SSE event string.""" payload = json.dumps(data) if not isinstance(data, str) else data return f"event: {event}\ndata: {payload}\n\n" def _build_sources(items: list[dict[str, Any]]) -> list[dict[str, str]]: """Build a numbered citation metadata list from search result items.""" sources: list[dict[str, str]] = [] for idx, item in enumerate(items, start=1): sources.append({ "number": idx, "title": item.get("title", ""), "slug": item.get("technique_page_slug", "") or item.get("slug", ""), "creator_name": item.get("creator_name", ""), "topic_category": item.get("topic_category", ""), "summary": (item.get("summary", "") or "")[:200], "section_anchor": item.get("section_anchor", ""), "section_heading": item.get("section_heading", ""), }) return sources def _build_context_block(items: list[dict[str, Any]]) -> str: """Build a numbered context block string for the LLM system prompt.""" if not items: return "(No sources available)" lines: list[str] = [] for idx, item in enumerate(items, start=1): title = item.get("title", "Untitled") creator = item.get("creator_name", "") summary = item.get("summary", "") section = item.get("section_heading", "") parts = [f"[{idx}] {title}"] if creator: parts.append(f"by {creator}") if section: parts.append(f"— {section}") header = " ".join(parts) lines.append(header) if summary: lines.append(f" {summary}") lines.append("") return "\n".join(lines)