"""Chat service: retrieve context via search, stream LLM response as SSE events. Assembles a numbered context block from search results, then streams completion tokens from an OpenAI-compatible API. Yields SSE-formatted events: sources, token, done, and error. Multi-turn memory: When a conversation_id is provided, prior messages are loaded from Redis, injected into the LLM messages array, and the new user+assistant turn is appended after streaming completes. History is capped at 10 turn pairs (20 messages) and expires after 1 hour of inactivity. """ from __future__ import annotations import json import logging import time import traceback import uuid from typing import Any, AsyncIterator import openai from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from config import Settings from models import Creator from search_service import SearchService logger = logging.getLogger("chrysopedia.chat") _SYSTEM_PROMPT_TEMPLATE = """\ You are Chrysopedia, an expert assistant for music production techniques — \ synthesis, sound design, mixing, sampling, and audio processing. ## Rules - Use ONLY the numbered sources below. Do not invent facts. - Cite every factual claim inline with [N] immediately after the claim \ (e.g. "Parallel compression adds sustain [2] while preserving transients [1]."). - When sources disagree, present both perspectives with their citations. - If the sources lack enough information, say so honestly. ## Response format - Aim for 2–4 short paragraphs. Expand only when the question warrants detail. - Use bullet lists for steps, signal chains, or parameter lists. - **Bold** key terms on first mention. - Use audio/synthesis/mixing terminology naturally — do not over-explain \ standard concepts (e.g. LFO, sidechain, wet/dry) unless the user asks. Sources: {context_block} """ _MAX_CONTEXT_SOURCES = 10 _MAX_TURN_PAIRS = 10 _HISTORY_TTL_SECONDS = 3600 # 1 hour def _redis_key(conversation_id: str) -> str: return f"chrysopedia:chat:{conversation_id}" class ChatService: """Retrieve context from search, stream an LLM response with citations.""" def __init__(self, settings: Settings, redis=None) -> None: self.settings = settings self._search = SearchService(settings) self._openai = openai.AsyncOpenAI( base_url=settings.llm_api_url, api_key=settings.llm_api_key, ) self._fallback_openai = openai.AsyncOpenAI( base_url=settings.llm_fallback_url, api_key=settings.llm_api_key, ) self._redis = redis async def _load_history(self, conversation_id: str) -> list[dict[str, str]]: """Load conversation history from Redis. Returns empty list on miss.""" if not self._redis: return [] try: raw = await self._redis.get(_redis_key(conversation_id)) if raw: return json.loads(raw) except Exception: logger.warning("chat_history_load_error cid=%s", conversation_id, exc_info=True) return [] async def _save_history( self, conversation_id: str, history: list[dict[str, str]], user_msg: str, assistant_msg: str, ) -> None: """Append the new turn pair and persist to Redis with TTL refresh.""" if not self._redis: return history.append({"role": "user", "content": user_msg}) history.append({"role": "assistant", "content": assistant_msg}) # Cap at _MAX_TURN_PAIRS (keep most recent) if len(history) > _MAX_TURN_PAIRS * 2: history = history[-_MAX_TURN_PAIRS * 2:] try: await self._redis.set( _redis_key(conversation_id), json.dumps(history), ex=_HISTORY_TTL_SECONDS, ) except Exception: logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True) async def _inject_personality( self, system_prompt: str, db: AsyncSession, creator_name: str, weight: float, ) -> str: """Query creator personality_profile and append a voice block to the system prompt. Falls back to the unmodified prompt on DB error, missing creator, or null profile. """ try: result = await db.execute( select(Creator).where(Creator.name == creator_name) ) creator_row = result.scalars().first() except Exception: logger.warning("chat_personality_db_error creator=%r", creator_name, exc_info=True) return system_prompt if creator_row is None or creator_row.personality_profile is None: logger.debug("chat_personality_skip creator=%r reason=%s", creator_name, "not_found" if creator_row is None else "null_profile") return system_prompt profile = creator_row.personality_profile voice_block = _build_personality_block(creator_name, profile, weight) return system_prompt + "\n\n" + voice_block async def _log_usage( self, db: AsyncSession, user_id: Any | None, client_ip: str | None, creator_slug: str | None, query: str, usage: dict[str, int], cascade_tier: str, model: str, latency_ms: float, ) -> None: """Insert a ChatUsageLog row. Non-blocking — errors logged, not raised.""" try: from models import ChatUsageLog log_entry = ChatUsageLog( user_id=user_id, client_ip=client_ip, creator_slug=creator_slug, query=query[:2000], # truncate very long queries prompt_tokens=usage.get("prompt_tokens", 0), completion_tokens=usage.get("completion_tokens", 0), total_tokens=usage.get("total_tokens", 0), cascade_tier=cascade_tier, model=model, latency_ms=latency_ms, ) db.add(log_entry) await db.commit() except Exception: logger.error( "chat_usage_log_insert_error user=%s ip=%s", user_id, client_ip, exc_info=True, ) try: await db.rollback() except Exception: pass async def stream_response( self, query: str, db: AsyncSession, creator: str | None = None, conversation_id: str | None = None, personality_weight: float = 0.0, user_id: Any | None = None, client_ip: str | None = None, ) -> AsyncIterator[str]: """Yield SSE-formatted events for a chat query. Protocol: 1. ``event: sources\ndata: \n\n`` 2. ``event: token\ndata: \n\n`` (repeated) 3. ``event: done\ndata: \n\n`` On error: ``event: error\ndata: \n\n`` """ start = time.monotonic() # Assign conversation_id if not provided (single-turn becomes trackable) if conversation_id is None: conversation_id = str(uuid.uuid4()) # ── 0. Load conversation history ──────────────────────────────── history = await self._load_history(conversation_id) # ── 1. Retrieve context via search ────────────────────────────── try: search_result = await self._search.search( query=query, scope="all", limit=_MAX_CONTEXT_SOURCES, db=db, creator=creator, ) except Exception: logger.exception("chat_search_error query=%r creator=%r", query, creator) yield _sse("error", {"message": "Search failed"}) return items: list[dict[str, Any]] = search_result.get("items", []) cascade_tier: str = search_result.get("cascade_tier", "") # ── 2. Build citation metadata and context block ──────────────── sources = _build_sources(items) context_block = _build_context_block(items) logger.info( "chat_search query=%r creator=%r cascade_tier=%s source_count=%d cid=%s", query, creator, cascade_tier, len(sources), conversation_id, ) # Emit sources event first yield _sse("sources", sources) # ── 3. Stream LLM completion ──────────────────────────────────── system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block) # Inject creator personality voice when weight > 0 if personality_weight > 0 and creator: system_prompt = await self._inject_personality( system_prompt, db, creator, personality_weight, ) # Scale temperature with personality weight: 0.3 (encyclopedic) → 0.5 (full personality) temperature = 0.3 + (personality_weight * 0.2) messages: list[dict[str, str]] = [ {"role": "system", "content": system_prompt}, ] # Inject conversation history between system prompt and current query messages.extend(history) messages.append({"role": "user", "content": query}) accumulated_response = "" usage_data: dict[str, int] | None = None fallback_used = False try: stream = await self._openai.chat.completions.create( model=self.settings.llm_model, messages=messages, stream=True, stream_options={"include_usage": True}, temperature=temperature, max_tokens=2048, ) async for chunk in stream: # The final chunk with stream_options carries usage in chunk.usage if hasattr(chunk, "usage") and chunk.usage is not None: usage_data = { "prompt_tokens": chunk.usage.prompt_tokens or 0, "completion_tokens": chunk.usage.completion_tokens or 0, "total_tokens": chunk.usage.total_tokens or 0, } choice = chunk.choices[0] if chunk.choices else None if choice and choice.delta and choice.delta.content: text = choice.delta.content accumulated_response += text yield _sse("token", text) except (openai.APIConnectionError, openai.APITimeoutError, openai.InternalServerError) as exc: logger.warning( "chat_llm_fallback primary failed (%s: %s), retrying with fallback at %s", type(exc).__name__, exc, self.settings.llm_fallback_url, ) fallback_used = True accumulated_response = "" usage_data = None try: stream = await self._fallback_openai.chat.completions.create( model=self.settings.llm_fallback_model, messages=messages, stream=True, stream_options={"include_usage": True}, temperature=temperature, max_tokens=2048, ) async for chunk in stream: if hasattr(chunk, "usage") and chunk.usage is not None: usage_data = { "prompt_tokens": chunk.usage.prompt_tokens or 0, "completion_tokens": chunk.usage.completion_tokens or 0, "total_tokens": chunk.usage.total_tokens or 0, } choice = chunk.choices[0] if chunk.choices else None if choice and choice.delta and choice.delta.content: text = choice.delta.content accumulated_response += text yield _sse("token", text) except Exception: tb = traceback.format_exc() logger.error("chat_llm_error fallback also failed query=%r cid=%s\n%s", query, conversation_id, tb) yield _sse("error", {"message": "LLM generation failed"}) return except Exception: tb = traceback.format_exc() logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb) yield _sse("error", {"message": "LLM generation failed"}) return # ── 4. Save conversation history ──────────────────────────────── await self._save_history(conversation_id, history, query, accumulated_response) # ── 5. Log token usage ────────────────────────────────────────── latency_ms = (time.monotonic() - start) * 1000 # Fallback: estimate tokens from character counts if stream_options not available if usage_data is None: prompt_chars = sum(len(m.get("content", "")) for m in messages) est_prompt = prompt_chars // 4 est_completion = len(accumulated_response) // 4 usage_data = { "prompt_tokens": est_prompt, "completion_tokens": est_completion, "total_tokens": est_prompt + est_completion, } logger.warning("chat_usage_estimated cid=%s (stream_options usage not available)", conversation_id) await self._log_usage( db=db, user_id=user_id, client_ip=client_ip, creator_slug=creator, query=query, usage=usage_data, cascade_tier=cascade_tier, model=self.settings.llm_fallback_model if fallback_used else self.settings.llm_model, latency_ms=latency_ms, ) # ── 6. Done event ─────────────────────────────────────────────── logger.info( "chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s tokens=%d", query, creator, cascade_tier, len(sources), latency_ms, conversation_id, usage_data.get("total_tokens", 0), ) yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id, "fallback_used": fallback_used}) # ── Helpers ────────────────────────────────────────────────────────────────── def _sse(event: str, data: Any) -> str: """Format a single SSE event string.""" payload = json.dumps(data) if not isinstance(data, str) else data return f"event: {event}\ndata: {payload}\n\n" def _build_sources(items: list[dict[str, Any]]) -> list[dict[str, str]]: """Build a numbered citation metadata list from search result items.""" sources: list[dict[str, str]] = [] for idx, item in enumerate(items, start=1): sources.append({ "number": idx, "title": item.get("title", ""), "slug": item.get("technique_page_slug", "") or item.get("slug", ""), "creator_name": item.get("creator_name", ""), "topic_category": item.get("topic_category", ""), "summary": (item.get("summary", "") or "")[:200], "section_anchor": item.get("section_anchor", ""), "section_heading": item.get("section_heading", ""), "source_video_id": item.get("source_video_id", ""), "start_time": item.get("start_time"), "end_time": item.get("end_time"), "video_filename": item.get("video_filename", ""), }) return sources def _build_context_block(items: list[dict[str, Any]]) -> str: """Build a numbered context block string for the LLM system prompt.""" if not items: return "(No sources available)" lines: list[str] = [] for idx, item in enumerate(items, start=1): title = item.get("title", "Untitled") creator = item.get("creator_name", "") summary = item.get("summary", "") section = item.get("section_heading", "") parts = [f"[{idx}] {title}"] if creator: parts.append(f"by {creator}") if section: parts.append(f"— {section}") header = " ".join(parts) lines.append(header) if summary: lines.append(f" {summary}") lines.append("") return "\n".join(lines) def _build_personality_block(creator_name: str, profile: dict[str, Any], weight: float) -> str: """Build a personality voice injection block from a creator's personality_profile JSONB. The ``weight`` (0.0–1.0) controls progressive inclusion of personality fields via 5 tiers of continuous interpolation: - < 0.2: no personality block (empty string) - 0.2–0.39: basic tone — teaching_style, formality, energy + subtle hint - 0.4–0.59: + descriptors, explanation_approach + adopt-voice instruction - 0.6–0.79: + signature_phrases (count scaled by weight) + creator-voice - 0.8–0.89: + distinctive_terms, sound_descriptions, sound_words, self_references, pacing + fully-embody instruction - >= 0.9: + full summary paragraph """ if weight < 0.2: return "" vocab = profile.get("vocabulary", {}) tone = profile.get("tone", {}) style = profile.get("style_markers", {}) teaching_style = tone.get("teaching_style", "") energy = tone.get("energy", "moderate") formality = tone.get("formality", "conversational") descriptors = tone.get("descriptors", []) phrases = vocab.get("signature_phrases", []) parts: list[str] = [] # --- Tier 1 (0.2–0.39): basic tone --- if weight < 0.4: parts.append( f"When relevant, subtly reference {creator_name}'s communication style." ) elif weight < 0.6: parts.append(f"Adopt {creator_name}'s tone and communication style.") elif weight < 0.8: parts.append( f"Respond as {creator_name} would, using their voice and manner." ) else: parts.append( f"Fully embody {creator_name} — use their exact phrases, energy, and teaching approach." ) if teaching_style: parts.append(f"Teaching style: {teaching_style}.") parts.append(f"Match their {formality} {energy} tone.") # --- Tier 2 (0.4+): descriptors, explanation_approach, uses_analogies, audience_engagement --- if weight >= 0.4: if descriptors: parts.append(f"Tone: {', '.join(descriptors[:5])}.") explanation = style.get("explanation_approach", "") if explanation: parts.append(f"Explanation approach: {explanation}.") if style.get("uses_analogies"): parts.append("Use analogies when helpful.") if style.get("audience_engagement"): parts.append(f"Audience engagement: {style['audience_engagement']}.") # --- Tier 3 (0.6+): signature phrases (count scaled by weight) --- if weight >= 0.6 and phrases: count = max(2, round(weight * len(phrases))) parts.append(f"Use their signature phrases: {', '.join(phrases[:count])}.") # --- Tier 4 (0.8+): distinctive_terms, sound_descriptions, sound_words, self_references, pacing --- if weight >= 0.8: distinctive = vocab.get("distinctive_terms", []) if distinctive: parts.append(f"Distinctive terms: {', '.join(distinctive)}.") sound_desc = vocab.get("sound_descriptions", []) if sound_desc: parts.append(f"Sound descriptions: {', '.join(sound_desc)}.") sound_words = style.get("sound_words", []) if sound_words: parts.append(f"Sound words: {', '.join(sound_words)}.") self_refs = style.get("self_references", "") if self_refs: parts.append(f"Self-references: {self_refs}.") pacing = style.get("pacing", "") if pacing: parts.append(f"Pacing: {pacing}.") # --- Tier 5 (0.9+): full summary paragraph --- if weight >= 0.9: summary = profile.get("summary", "") if summary: parts.append(summary) return " ".join(parts)