"""Chat service: retrieve context via search, stream LLM response as SSE events. Assembles a numbered context block from search results, then streams completion tokens from an OpenAI-compatible API. Yields SSE-formatted events: sources, token, done, and error. Multi-turn memory: When a conversation_id is provided, prior messages are loaded from Redis, injected into the LLM messages array, and the new user+assistant turn is appended after streaming completes. History is capped at 10 turn pairs (20 messages) and expires after 1 hour of inactivity. """ from __future__ import annotations import json import logging import time import traceback import uuid from typing import Any, AsyncIterator import openai from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from config import Settings from models import Creator from search_service import SearchService logger = logging.getLogger("chrysopedia.chat") _SYSTEM_PROMPT_TEMPLATE = """\ You are Chrysopedia, an expert encyclopedic assistant for music production techniques. Answer the user's question using ONLY the numbered sources below. Cite sources by writing [N] inline (e.g. [1], [2]) where N is the source number. If the sources do not contain enough information, say so honestly — do not invent facts. Sources: {context_block} """ _MAX_CONTEXT_SOURCES = 10 _MAX_TURN_PAIRS = 10 _HISTORY_TTL_SECONDS = 3600 # 1 hour def _redis_key(conversation_id: str) -> str: return f"chrysopedia:chat:{conversation_id}" class ChatService: """Retrieve context from search, stream an LLM response with citations.""" def __init__(self, settings: Settings, redis=None) -> None: self.settings = settings self._search = SearchService(settings) self._openai = openai.AsyncOpenAI( base_url=settings.llm_api_url, api_key=settings.llm_api_key, ) self._redis = redis async def _load_history(self, conversation_id: str) -> list[dict[str, str]]: """Load conversation history from Redis. Returns empty list on miss.""" if not self._redis: return [] try: raw = await self._redis.get(_redis_key(conversation_id)) if raw: return json.loads(raw) except Exception: logger.warning("chat_history_load_error cid=%s", conversation_id, exc_info=True) return [] async def _save_history( self, conversation_id: str, history: list[dict[str, str]], user_msg: str, assistant_msg: str, ) -> None: """Append the new turn pair and persist to Redis with TTL refresh.""" if not self._redis: return history.append({"role": "user", "content": user_msg}) history.append({"role": "assistant", "content": assistant_msg}) # Cap at _MAX_TURN_PAIRS (keep most recent) if len(history) > _MAX_TURN_PAIRS * 2: history = history[-_MAX_TURN_PAIRS * 2:] try: await self._redis.set( _redis_key(conversation_id), json.dumps(history), ex=_HISTORY_TTL_SECONDS, ) except Exception: logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True) async def _inject_personality( self, system_prompt: str, db: AsyncSession, creator_name: str, weight: float, ) -> str: """Query creator personality_profile and append a voice block to the system prompt. Falls back to the unmodified prompt on DB error, missing creator, or null profile. """ try: result = await db.execute( select(Creator).where(Creator.name == creator_name) ) creator_row = result.scalars().first() except Exception: logger.warning("chat_personality_db_error creator=%r", creator_name, exc_info=True) return system_prompt if creator_row is None or creator_row.personality_profile is None: logger.debug("chat_personality_skip creator=%r reason=%s", creator_name, "not_found" if creator_row is None else "null_profile") return system_prompt profile = creator_row.personality_profile voice_block = _build_personality_block(creator_name, profile, weight) return system_prompt + "\n\n" + voice_block async def stream_response( self, query: str, db: AsyncSession, creator: str | None = None, conversation_id: str | None = None, personality_weight: float = 0.0, ) -> AsyncIterator[str]: """Yield SSE-formatted events for a chat query. Protocol: 1. ``event: sources\ndata: \n\n`` 2. ``event: token\ndata: \n\n`` (repeated) 3. ``event: done\ndata: \n\n`` On error: ``event: error\ndata: \n\n`` """ start = time.monotonic() # Assign conversation_id if not provided (single-turn becomes trackable) if conversation_id is None: conversation_id = str(uuid.uuid4()) # ── 0. Load conversation history ──────────────────────────────── history = await self._load_history(conversation_id) # ── 1. Retrieve context via search ────────────────────────────── try: search_result = await self._search.search( query=query, scope="all", limit=_MAX_CONTEXT_SOURCES, db=db, creator=creator, ) except Exception: logger.exception("chat_search_error query=%r creator=%r", query, creator) yield _sse("error", {"message": "Search failed"}) return items: list[dict[str, Any]] = search_result.get("items", []) cascade_tier: str = search_result.get("cascade_tier", "") # ── 2. Build citation metadata and context block ──────────────── sources = _build_sources(items) context_block = _build_context_block(items) logger.info( "chat_search query=%r creator=%r cascade_tier=%s source_count=%d cid=%s", query, creator, cascade_tier, len(sources), conversation_id, ) # Emit sources event first yield _sse("sources", sources) # ── 3. Stream LLM completion ──────────────────────────────────── system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block) # Inject creator personality voice when weight > 0 if personality_weight > 0 and creator: system_prompt = await self._inject_personality( system_prompt, db, creator, personality_weight, ) # Scale temperature with personality weight: 0.3 (encyclopedic) → 0.5 (full personality) temperature = 0.3 + (personality_weight * 0.2) messages: list[dict[str, str]] = [ {"role": "system", "content": system_prompt}, ] # Inject conversation history between system prompt and current query messages.extend(history) messages.append({"role": "user", "content": query}) accumulated_response = "" try: stream = await self._openai.chat.completions.create( model=self.settings.llm_model, messages=messages, stream=True, temperature=temperature, max_tokens=2048, ) async for chunk in stream: choice = chunk.choices[0] if chunk.choices else None if choice and choice.delta and choice.delta.content: text = choice.delta.content accumulated_response += text yield _sse("token", text) except Exception: tb = traceback.format_exc() logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb) yield _sse("error", {"message": "LLM generation failed"}) return # ── 4. Save conversation history ──────────────────────────────── await self._save_history(conversation_id, history, query, accumulated_response) # ── 5. Done event ─────────────────────────────────────────────── latency_ms = (time.monotonic() - start) * 1000 logger.info( "chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s", query, creator, cascade_tier, len(sources), latency_ms, conversation_id, ) yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id}) # ── Helpers ────────────────────────────────────────────────────────────────── def _sse(event: str, data: Any) -> str: """Format a single SSE event string.""" payload = json.dumps(data) if not isinstance(data, str) else data return f"event: {event}\ndata: {payload}\n\n" def _build_sources(items: list[dict[str, Any]]) -> list[dict[str, str]]: """Build a numbered citation metadata list from search result items.""" sources: list[dict[str, str]] = [] for idx, item in enumerate(items, start=1): sources.append({ "number": idx, "title": item.get("title", ""), "slug": item.get("technique_page_slug", "") or item.get("slug", ""), "creator_name": item.get("creator_name", ""), "topic_category": item.get("topic_category", ""), "summary": (item.get("summary", "") or "")[:200], "section_anchor": item.get("section_anchor", ""), "section_heading": item.get("section_heading", ""), }) return sources def _build_context_block(items: list[dict[str, Any]]) -> str: """Build a numbered context block string for the LLM system prompt.""" if not items: return "(No sources available)" lines: list[str] = [] for idx, item in enumerate(items, start=1): title = item.get("title", "Untitled") creator = item.get("creator_name", "") summary = item.get("summary", "") section = item.get("section_heading", "") parts = [f"[{idx}] {title}"] if creator: parts.append(f"by {creator}") if section: parts.append(f"— {section}") header = " ".join(parts) lines.append(header) if summary: lines.append(f" {summary}") lines.append("") return "\n".join(lines) def _build_personality_block(creator_name: str, profile: dict[str, Any], weight: float) -> str: """Build a personality voice injection block from a creator's personality_profile JSONB. The ``weight`` (0.0–1.0) controls progressive inclusion of personality fields via 5 tiers of continuous interpolation: - < 0.2: no personality block (empty string) - 0.2–0.39: basic tone — teaching_style, formality, energy + subtle hint - 0.4–0.59: + descriptors, explanation_approach + adopt-voice instruction - 0.6–0.79: + signature_phrases (count scaled by weight) + creator-voice - 0.8–0.89: + distinctive_terms, sound_descriptions, sound_words, self_references, pacing + fully-embody instruction - >= 0.9: + full summary paragraph """ if weight < 0.2: return "" vocab = profile.get("vocabulary", {}) tone = profile.get("tone", {}) style = profile.get("style_markers", {}) teaching_style = tone.get("teaching_style", "") energy = tone.get("energy", "moderate") formality = tone.get("formality", "conversational") descriptors = tone.get("descriptors", []) phrases = vocab.get("signature_phrases", []) parts: list[str] = [] # --- Tier 1 (0.2–0.39): basic tone --- if weight < 0.4: parts.append( f"When relevant, subtly reference {creator_name}'s communication style." ) elif weight < 0.6: parts.append(f"Adopt {creator_name}'s tone and communication style.") elif weight < 0.8: parts.append( f"Respond as {creator_name} would, using their voice and manner." ) else: parts.append( f"Fully embody {creator_name} — use their exact phrases, energy, and teaching approach." ) if teaching_style: parts.append(f"Teaching style: {teaching_style}.") parts.append(f"Match their {formality} {energy} tone.") # --- Tier 2 (0.4+): descriptors, explanation_approach, uses_analogies, audience_engagement --- if weight >= 0.4: if descriptors: parts.append(f"Tone: {', '.join(descriptors[:5])}.") explanation = style.get("explanation_approach", "") if explanation: parts.append(f"Explanation approach: {explanation}.") if style.get("uses_analogies"): parts.append("Use analogies when helpful.") if style.get("audience_engagement"): parts.append(f"Audience engagement: {style['audience_engagement']}.") # --- Tier 3 (0.6+): signature phrases (count scaled by weight) --- if weight >= 0.6 and phrases: count = max(2, round(weight * len(phrases))) parts.append(f"Use their signature phrases: {', '.join(phrases[:count])}.") # --- Tier 4 (0.8+): distinctive_terms, sound_descriptions, sound_words, self_references, pacing --- if weight >= 0.8: distinctive = vocab.get("distinctive_terms", []) if distinctive: parts.append(f"Distinctive terms: {', '.join(distinctive)}.") sound_desc = vocab.get("sound_descriptions", []) if sound_desc: parts.append(f"Sound descriptions: {', '.join(sound_desc)}.") sound_words = style.get("sound_words", []) if sound_words: parts.append(f"Sound words: {', '.join(sound_words)}.") self_refs = style.get("self_references", "") if self_refs: parts.append(f"Self-references: {self_refs}.") pacing = style.get("pacing", "") if pacing: parts.append(f"Pacing: {pacing}.") # --- Tier 5 (0.9+): full summary paragraph --- if weight >= 0.9: summary = profile.get("summary", "") if summary: parts.append(summary) return " ".join(parts)