chrysopedia/backend/chat_service.py
jlightner 3cbb614654 test: Rewrote _SYSTEM_PROMPT_TEMPLATE with citation density rules, resp…
- "backend/chat_service.py"

GSD-Task: S09/T02
2026-04-04 14:45:09 +00:00

519 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Chat service: retrieve context via search, stream LLM response as SSE events.
Assembles a numbered context block from search results, then streams
completion tokens from an OpenAI-compatible API. Yields SSE-formatted
events: sources, token, done, and error.
Multi-turn memory: When a conversation_id is provided, prior messages are
loaded from Redis, injected into the LLM messages array, and the new
user+assistant turn is appended after streaming completes. History is
capped at 10 turn pairs (20 messages) and expires after 1 hour of
inactivity.
"""
from __future__ import annotations
import json
import logging
import time
import traceback
import uuid
from typing import Any, AsyncIterator
import openai
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from config import Settings
from models import Creator
from search_service import SearchService
logger = logging.getLogger("chrysopedia.chat")
_SYSTEM_PROMPT_TEMPLATE = """\
You are Chrysopedia, an expert assistant for music production techniques — \
synthesis, sound design, mixing, sampling, and audio processing.
## Rules
- Use ONLY the numbered sources below. Do not invent facts.
- Cite every factual claim inline with [N] immediately after the claim \
(e.g. "Parallel compression adds sustain [2] while preserving transients [1].").
- When sources disagree, present both perspectives with their citations.
- If the sources lack enough information, say so honestly.
## Response format
- Aim for 24 short paragraphs. Expand only when the question warrants detail.
- Use bullet lists for steps, signal chains, or parameter lists.
- **Bold** key terms on first mention.
- Use audio/synthesis/mixing terminology naturally — do not over-explain \
standard concepts (e.g. LFO, sidechain, wet/dry) unless the user asks.
Sources:
{context_block}
"""
_MAX_CONTEXT_SOURCES = 10
_MAX_TURN_PAIRS = 10
_HISTORY_TTL_SECONDS = 3600 # 1 hour
def _redis_key(conversation_id: str) -> str:
return f"chrysopedia:chat:{conversation_id}"
class ChatService:
"""Retrieve context from search, stream an LLM response with citations."""
def __init__(self, settings: Settings, redis=None) -> None:
self.settings = settings
self._search = SearchService(settings)
self._openai = openai.AsyncOpenAI(
base_url=settings.llm_api_url,
api_key=settings.llm_api_key,
)
self._fallback_openai = openai.AsyncOpenAI(
base_url=settings.llm_fallback_url,
api_key=settings.llm_api_key,
)
self._redis = redis
async def _load_history(self, conversation_id: str) -> list[dict[str, str]]:
"""Load conversation history from Redis. Returns empty list on miss."""
if not self._redis:
return []
try:
raw = await self._redis.get(_redis_key(conversation_id))
if raw:
return json.loads(raw)
except Exception:
logger.warning("chat_history_load_error cid=%s", conversation_id, exc_info=True)
return []
async def _save_history(
self,
conversation_id: str,
history: list[dict[str, str]],
user_msg: str,
assistant_msg: str,
) -> None:
"""Append the new turn pair and persist to Redis with TTL refresh."""
if not self._redis:
return
history.append({"role": "user", "content": user_msg})
history.append({"role": "assistant", "content": assistant_msg})
# Cap at _MAX_TURN_PAIRS (keep most recent)
if len(history) > _MAX_TURN_PAIRS * 2:
history = history[-_MAX_TURN_PAIRS * 2:]
try:
await self._redis.set(
_redis_key(conversation_id),
json.dumps(history),
ex=_HISTORY_TTL_SECONDS,
)
except Exception:
logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True)
async def _inject_personality(
self,
system_prompt: str,
db: AsyncSession,
creator_name: str,
weight: float,
) -> str:
"""Query creator personality_profile and append a voice block to the system prompt.
Falls back to the unmodified prompt on DB error, missing creator, or null profile.
"""
try:
result = await db.execute(
select(Creator).where(Creator.name == creator_name)
)
creator_row = result.scalars().first()
except Exception:
logger.warning("chat_personality_db_error creator=%r", creator_name, exc_info=True)
return system_prompt
if creator_row is None or creator_row.personality_profile is None:
logger.debug("chat_personality_skip creator=%r reason=%s",
creator_name,
"not_found" if creator_row is None else "null_profile")
return system_prompt
profile = creator_row.personality_profile
voice_block = _build_personality_block(creator_name, profile, weight)
return system_prompt + "\n\n" + voice_block
async def _log_usage(
self,
db: AsyncSession,
user_id: Any | None,
client_ip: str | None,
creator_slug: str | None,
query: str,
usage: dict[str, int],
cascade_tier: str,
model: str,
latency_ms: float,
) -> None:
"""Insert a ChatUsageLog row. Non-blocking — errors logged, not raised."""
try:
from models import ChatUsageLog
log_entry = ChatUsageLog(
user_id=user_id,
client_ip=client_ip,
creator_slug=creator_slug,
query=query[:2000], # truncate very long queries
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
cascade_tier=cascade_tier,
model=model,
latency_ms=latency_ms,
)
db.add(log_entry)
await db.commit()
except Exception:
logger.error(
"chat_usage_log_insert_error user=%s ip=%s",
user_id, client_ip, exc_info=True,
)
try:
await db.rollback()
except Exception:
pass
async def stream_response(
self,
query: str,
db: AsyncSession,
creator: str | None = None,
conversation_id: str | None = None,
personality_weight: float = 0.0,
user_id: Any | None = None,
client_ip: str | None = None,
) -> AsyncIterator[str]:
"""Yield SSE-formatted events for a chat query.
Protocol:
1. ``event: sources\ndata: <json array of citation metadata>\n\n``
2. ``event: token\ndata: <text chunk>\n\n`` (repeated)
3. ``event: done\ndata: <json with cascade_tier, conversation_id>\n\n``
On error: ``event: error\ndata: <json with message>\n\n``
"""
start = time.monotonic()
# Assign conversation_id if not provided (single-turn becomes trackable)
if conversation_id is None:
conversation_id = str(uuid.uuid4())
# ── 0. Load conversation history ────────────────────────────────
history = await self._load_history(conversation_id)
# ── 1. Retrieve context via search ──────────────────────────────
try:
search_result = await self._search.search(
query=query,
scope="all",
limit=_MAX_CONTEXT_SOURCES,
db=db,
creator=creator,
)
except Exception:
logger.exception("chat_search_error query=%r creator=%r", query, creator)
yield _sse("error", {"message": "Search failed"})
return
items: list[dict[str, Any]] = search_result.get("items", [])
cascade_tier: str = search_result.get("cascade_tier", "")
# ── 2. Build citation metadata and context block ────────────────
sources = _build_sources(items)
context_block = _build_context_block(items)
logger.info(
"chat_search query=%r creator=%r cascade_tier=%s source_count=%d cid=%s",
query, creator, cascade_tier, len(sources), conversation_id,
)
# Emit sources event first
yield _sse("sources", sources)
# ── 3. Stream LLM completion ────────────────────────────────────
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
# Inject creator personality voice when weight > 0
if personality_weight > 0 and creator:
system_prompt = await self._inject_personality(
system_prompt, db, creator, personality_weight,
)
# Scale temperature with personality weight: 0.3 (encyclopedic) → 0.5 (full personality)
temperature = 0.3 + (personality_weight * 0.2)
messages: list[dict[str, str]] = [
{"role": "system", "content": system_prompt},
]
# Inject conversation history between system prompt and current query
messages.extend(history)
messages.append({"role": "user", "content": query})
accumulated_response = ""
usage_data: dict[str, int] | None = None
fallback_used = False
try:
stream = await self._openai.chat.completions.create(
model=self.settings.llm_model,
messages=messages,
stream=True,
stream_options={"include_usage": True},
temperature=temperature,
max_tokens=2048,
)
async for chunk in stream:
# The final chunk with stream_options carries usage in chunk.usage
if hasattr(chunk, "usage") and chunk.usage is not None:
usage_data = {
"prompt_tokens": chunk.usage.prompt_tokens or 0,
"completion_tokens": chunk.usage.completion_tokens or 0,
"total_tokens": chunk.usage.total_tokens or 0,
}
choice = chunk.choices[0] if chunk.choices else None
if choice and choice.delta and choice.delta.content:
text = choice.delta.content
accumulated_response += text
yield _sse("token", text)
except (openai.APIConnectionError, openai.APITimeoutError, openai.InternalServerError) as exc:
logger.warning(
"chat_llm_fallback primary failed (%s: %s), retrying with fallback at %s",
type(exc).__name__, exc, self.settings.llm_fallback_url,
)
fallback_used = True
accumulated_response = ""
usage_data = None
try:
stream = await self._fallback_openai.chat.completions.create(
model=self.settings.llm_fallback_model,
messages=messages,
stream=True,
stream_options={"include_usage": True},
temperature=temperature,
max_tokens=2048,
)
async for chunk in stream:
if hasattr(chunk, "usage") and chunk.usage is not None:
usage_data = {
"prompt_tokens": chunk.usage.prompt_tokens or 0,
"completion_tokens": chunk.usage.completion_tokens or 0,
"total_tokens": chunk.usage.total_tokens or 0,
}
choice = chunk.choices[0] if chunk.choices else None
if choice and choice.delta and choice.delta.content:
text = choice.delta.content
accumulated_response += text
yield _sse("token", text)
except Exception:
tb = traceback.format_exc()
logger.error("chat_llm_error fallback also failed query=%r cid=%s\n%s", query, conversation_id, tb)
yield _sse("error", {"message": "LLM generation failed"})
return
except Exception:
tb = traceback.format_exc()
logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb)
yield _sse("error", {"message": "LLM generation failed"})
return
# ── 4. Save conversation history ────────────────────────────────
await self._save_history(conversation_id, history, query, accumulated_response)
# ── 5. Log token usage ──────────────────────────────────────────
latency_ms = (time.monotonic() - start) * 1000
# Fallback: estimate tokens from character counts if stream_options not available
if usage_data is None:
prompt_chars = sum(len(m.get("content", "")) for m in messages)
est_prompt = prompt_chars // 4
est_completion = len(accumulated_response) // 4
usage_data = {
"prompt_tokens": est_prompt,
"completion_tokens": est_completion,
"total_tokens": est_prompt + est_completion,
}
logger.warning("chat_usage_estimated cid=%s (stream_options usage not available)", conversation_id)
await self._log_usage(
db=db,
user_id=user_id,
client_ip=client_ip,
creator_slug=creator,
query=query,
usage=usage_data,
cascade_tier=cascade_tier,
model=self.settings.llm_fallback_model if fallback_used else self.settings.llm_model,
latency_ms=latency_ms,
)
# ── 6. Done event ───────────────────────────────────────────────
logger.info(
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s tokens=%d",
query, creator, cascade_tier, len(sources), latency_ms, conversation_id,
usage_data.get("total_tokens", 0),
)
yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id, "fallback_used": fallback_used})
# ── Helpers ──────────────────────────────────────────────────────────────────
def _sse(event: str, data: Any) -> str:
"""Format a single SSE event string."""
payload = json.dumps(data) if not isinstance(data, str) else data
return f"event: {event}\ndata: {payload}\n\n"
def _build_sources(items: list[dict[str, Any]]) -> list[dict[str, str]]:
"""Build a numbered citation metadata list from search result items."""
sources: list[dict[str, str]] = []
for idx, item in enumerate(items, start=1):
sources.append({
"number": idx,
"title": item.get("title", ""),
"slug": item.get("technique_page_slug", "") or item.get("slug", ""),
"creator_name": item.get("creator_name", ""),
"topic_category": item.get("topic_category", ""),
"summary": (item.get("summary", "") or "")[:200],
"section_anchor": item.get("section_anchor", ""),
"section_heading": item.get("section_heading", ""),
"source_video_id": item.get("source_video_id", ""),
"start_time": item.get("start_time"),
"end_time": item.get("end_time"),
"video_filename": item.get("video_filename", ""),
})
return sources
def _build_context_block(items: list[dict[str, Any]]) -> str:
"""Build a numbered context block string for the LLM system prompt."""
if not items:
return "(No sources available)"
lines: list[str] = []
for idx, item in enumerate(items, start=1):
title = item.get("title", "Untitled")
creator = item.get("creator_name", "")
summary = item.get("summary", "")
section = item.get("section_heading", "")
parts = [f"[{idx}] {title}"]
if creator:
parts.append(f"by {creator}")
if section:
parts.append(f"{section}")
header = " ".join(parts)
lines.append(header)
if summary:
lines.append(f" {summary}")
lines.append("")
return "\n".join(lines)
def _build_personality_block(creator_name: str, profile: dict[str, Any], weight: float) -> str:
"""Build a personality voice injection block from a creator's personality_profile JSONB.
The ``weight`` (0.01.0) controls progressive inclusion of personality
fields via 5 tiers of continuous interpolation:
- < 0.2: no personality block (empty string)
- 0.20.39: basic tone — teaching_style, formality, energy + subtle hint
- 0.40.59: + descriptors, explanation_approach + adopt-voice instruction
- 0.60.79: + signature_phrases (count scaled by weight) + creator-voice
- 0.80.89: + distinctive_terms, sound_descriptions, sound_words,
self_references, pacing + fully-embody instruction
- >= 0.9: + full summary paragraph
"""
if weight < 0.2:
return ""
vocab = profile.get("vocabulary", {})
tone = profile.get("tone", {})
style = profile.get("style_markers", {})
teaching_style = tone.get("teaching_style", "")
energy = tone.get("energy", "moderate")
formality = tone.get("formality", "conversational")
descriptors = tone.get("descriptors", [])
phrases = vocab.get("signature_phrases", [])
parts: list[str] = []
# --- Tier 1 (0.20.39): basic tone ---
if weight < 0.4:
parts.append(
f"When relevant, subtly reference {creator_name}'s communication style."
)
elif weight < 0.6:
parts.append(f"Adopt {creator_name}'s tone and communication style.")
elif weight < 0.8:
parts.append(
f"Respond as {creator_name} would, using their voice and manner."
)
else:
parts.append(
f"Fully embody {creator_name} — use their exact phrases, energy, and teaching approach."
)
if teaching_style:
parts.append(f"Teaching style: {teaching_style}.")
parts.append(f"Match their {formality} {energy} tone.")
# --- Tier 2 (0.4+): descriptors, explanation_approach, uses_analogies, audience_engagement ---
if weight >= 0.4:
if descriptors:
parts.append(f"Tone: {', '.join(descriptors[:5])}.")
explanation = style.get("explanation_approach", "")
if explanation:
parts.append(f"Explanation approach: {explanation}.")
if style.get("uses_analogies"):
parts.append("Use analogies when helpful.")
if style.get("audience_engagement"):
parts.append(f"Audience engagement: {style['audience_engagement']}.")
# --- Tier 3 (0.6+): signature phrases (count scaled by weight) ---
if weight >= 0.6 and phrases:
count = max(2, round(weight * len(phrases)))
parts.append(f"Use their signature phrases: {', '.join(phrases[:count])}.")
# --- Tier 4 (0.8+): distinctive_terms, sound_descriptions, sound_words, self_references, pacing ---
if weight >= 0.8:
distinctive = vocab.get("distinctive_terms", [])
if distinctive:
parts.append(f"Distinctive terms: {', '.join(distinctive)}.")
sound_desc = vocab.get("sound_descriptions", [])
if sound_desc:
parts.append(f"Sound descriptions: {', '.join(sound_desc)}.")
sound_words = style.get("sound_words", [])
if sound_words:
parts.append(f"Sound words: {', '.join(sound_words)}.")
self_refs = style.get("self_references", "")
if self_refs:
parts.append(f"Self-references: {self_refs}.")
pacing = style.get("pacing", "")
if pacing:
parts.append(f"Pacing: {pacing}.")
# --- Tier 5 (0.9+): full summary paragraph ---
if weight >= 0.9:
summary = profile.get("summary", "")
if summary:
parts.append(summary)
return " ".join(parts)