386 lines
15 KiB
Python
386 lines
15 KiB
Python
"""Chat service: retrieve context via search, stream LLM response as SSE events.
|
||
|
||
Assembles a numbered context block from search results, then streams
|
||
completion tokens from an OpenAI-compatible API. Yields SSE-formatted
|
||
events: sources, token, done, and error.
|
||
|
||
Multi-turn memory: When a conversation_id is provided, prior messages are
|
||
loaded from Redis, injected into the LLM messages array, and the new
|
||
user+assistant turn is appended after streaming completes. History is
|
||
capped at 10 turn pairs (20 messages) and expires after 1 hour of
|
||
inactivity.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import time
|
||
import traceback
|
||
import uuid
|
||
from typing import Any, AsyncIterator
|
||
|
||
import openai
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from config import Settings
|
||
from models import Creator
|
||
from search_service import SearchService
|
||
|
||
logger = logging.getLogger("chrysopedia.chat")
|
||
|
||
_SYSTEM_PROMPT_TEMPLATE = """\
|
||
You are Chrysopedia, an expert encyclopedic assistant for music production techniques.
|
||
Answer the user's question using ONLY the numbered sources below. Cite sources by
|
||
writing [N] inline (e.g. [1], [2]) where N is the source number. If the sources
|
||
do not contain enough information, say so honestly — do not invent facts.
|
||
|
||
Sources:
|
||
{context_block}
|
||
"""
|
||
|
||
_MAX_CONTEXT_SOURCES = 10
|
||
_MAX_TURN_PAIRS = 10
|
||
_HISTORY_TTL_SECONDS = 3600 # 1 hour
|
||
|
||
|
||
def _redis_key(conversation_id: str) -> str:
|
||
return f"chrysopedia:chat:{conversation_id}"
|
||
|
||
|
||
class ChatService:
|
||
"""Retrieve context from search, stream an LLM response with citations."""
|
||
|
||
def __init__(self, settings: Settings, redis=None) -> None:
|
||
self.settings = settings
|
||
self._search = SearchService(settings)
|
||
self._openai = openai.AsyncOpenAI(
|
||
base_url=settings.llm_api_url,
|
||
api_key=settings.llm_api_key,
|
||
)
|
||
self._redis = redis
|
||
|
||
async def _load_history(self, conversation_id: str) -> list[dict[str, str]]:
|
||
"""Load conversation history from Redis. Returns empty list on miss."""
|
||
if not self._redis:
|
||
return []
|
||
try:
|
||
raw = await self._redis.get(_redis_key(conversation_id))
|
||
if raw:
|
||
return json.loads(raw)
|
||
except Exception:
|
||
logger.warning("chat_history_load_error cid=%s", conversation_id, exc_info=True)
|
||
return []
|
||
|
||
async def _save_history(
|
||
self,
|
||
conversation_id: str,
|
||
history: list[dict[str, str]],
|
||
user_msg: str,
|
||
assistant_msg: str,
|
||
) -> None:
|
||
"""Append the new turn pair and persist to Redis with TTL refresh."""
|
||
if not self._redis:
|
||
return
|
||
history.append({"role": "user", "content": user_msg})
|
||
history.append({"role": "assistant", "content": assistant_msg})
|
||
# Cap at _MAX_TURN_PAIRS (keep most recent)
|
||
if len(history) > _MAX_TURN_PAIRS * 2:
|
||
history = history[-_MAX_TURN_PAIRS * 2:]
|
||
try:
|
||
await self._redis.set(
|
||
_redis_key(conversation_id),
|
||
json.dumps(history),
|
||
ex=_HISTORY_TTL_SECONDS,
|
||
)
|
||
except Exception:
|
||
logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True)
|
||
|
||
async def _inject_personality(
|
||
self,
|
||
system_prompt: str,
|
||
db: AsyncSession,
|
||
creator_name: str,
|
||
weight: float,
|
||
) -> str:
|
||
"""Query creator personality_profile and append a voice block to the system prompt.
|
||
|
||
Falls back to the unmodified prompt on DB error, missing creator, or null profile.
|
||
"""
|
||
try:
|
||
result = await db.execute(
|
||
select(Creator).where(Creator.name == creator_name)
|
||
)
|
||
creator_row = result.scalars().first()
|
||
except Exception:
|
||
logger.warning("chat_personality_db_error creator=%r", creator_name, exc_info=True)
|
||
return system_prompt
|
||
|
||
if creator_row is None or creator_row.personality_profile is None:
|
||
logger.debug("chat_personality_skip creator=%r reason=%s",
|
||
creator_name,
|
||
"not_found" if creator_row is None else "null_profile")
|
||
return system_prompt
|
||
|
||
profile = creator_row.personality_profile
|
||
voice_block = _build_personality_block(creator_name, profile, weight)
|
||
return system_prompt + "\n\n" + voice_block
|
||
|
||
async def stream_response(
|
||
self,
|
||
query: str,
|
||
db: AsyncSession,
|
||
creator: str | None = None,
|
||
conversation_id: str | None = None,
|
||
personality_weight: float = 0.0,
|
||
) -> AsyncIterator[str]:
|
||
"""Yield SSE-formatted events for a chat query.
|
||
|
||
Protocol:
|
||
1. ``event: sources\ndata: <json array of citation metadata>\n\n``
|
||
2. ``event: token\ndata: <text chunk>\n\n`` (repeated)
|
||
3. ``event: done\ndata: <json with cascade_tier, conversation_id>\n\n``
|
||
On error: ``event: error\ndata: <json with message>\n\n``
|
||
"""
|
||
start = time.monotonic()
|
||
|
||
# Assign conversation_id if not provided (single-turn becomes trackable)
|
||
if conversation_id is None:
|
||
conversation_id = str(uuid.uuid4())
|
||
|
||
# ── 0. Load conversation history ────────────────────────────────
|
||
history = await self._load_history(conversation_id)
|
||
|
||
# ── 1. Retrieve context via search ──────────────────────────────
|
||
try:
|
||
search_result = await self._search.search(
|
||
query=query,
|
||
scope="all",
|
||
limit=_MAX_CONTEXT_SOURCES,
|
||
db=db,
|
||
creator=creator,
|
||
)
|
||
except Exception:
|
||
logger.exception("chat_search_error query=%r creator=%r", query, creator)
|
||
yield _sse("error", {"message": "Search failed"})
|
||
return
|
||
|
||
items: list[dict[str, Any]] = search_result.get("items", [])
|
||
cascade_tier: str = search_result.get("cascade_tier", "")
|
||
|
||
# ── 2. Build citation metadata and context block ────────────────
|
||
sources = _build_sources(items)
|
||
context_block = _build_context_block(items)
|
||
|
||
logger.info(
|
||
"chat_search query=%r creator=%r cascade_tier=%s source_count=%d cid=%s",
|
||
query, creator, cascade_tier, len(sources), conversation_id,
|
||
)
|
||
|
||
# Emit sources event first
|
||
yield _sse("sources", sources)
|
||
|
||
# ── 3. Stream LLM completion ────────────────────────────────────
|
||
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
|
||
|
||
# Inject creator personality voice when weight > 0
|
||
if personality_weight > 0 and creator:
|
||
system_prompt = await self._inject_personality(
|
||
system_prompt, db, creator, personality_weight,
|
||
)
|
||
|
||
# Scale temperature with personality weight: 0.3 (encyclopedic) → 0.5 (full personality)
|
||
temperature = 0.3 + (personality_weight * 0.2)
|
||
|
||
messages: list[dict[str, str]] = [
|
||
{"role": "system", "content": system_prompt},
|
||
]
|
||
# Inject conversation history between system prompt and current query
|
||
messages.extend(history)
|
||
messages.append({"role": "user", "content": query})
|
||
|
||
accumulated_response = ""
|
||
|
||
try:
|
||
stream = await self._openai.chat.completions.create(
|
||
model=self.settings.llm_model,
|
||
messages=messages,
|
||
stream=True,
|
||
temperature=temperature,
|
||
max_tokens=2048,
|
||
)
|
||
|
||
async for chunk in stream:
|
||
choice = chunk.choices[0] if chunk.choices else None
|
||
if choice and choice.delta and choice.delta.content:
|
||
text = choice.delta.content
|
||
accumulated_response += text
|
||
yield _sse("token", text)
|
||
|
||
except Exception:
|
||
tb = traceback.format_exc()
|
||
logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb)
|
||
yield _sse("error", {"message": "LLM generation failed"})
|
||
return
|
||
|
||
# ── 4. Save conversation history ────────────────────────────────
|
||
await self._save_history(conversation_id, history, query, accumulated_response)
|
||
|
||
# ── 5. Done event ───────────────────────────────────────────────
|
||
latency_ms = (time.monotonic() - start) * 1000
|
||
logger.info(
|
||
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s",
|
||
query, creator, cascade_tier, len(sources), latency_ms, conversation_id,
|
||
)
|
||
yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id})
|
||
|
||
|
||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||
|
||
|
||
def _sse(event: str, data: Any) -> str:
|
||
"""Format a single SSE event string."""
|
||
payload = json.dumps(data) if not isinstance(data, str) else data
|
||
return f"event: {event}\ndata: {payload}\n\n"
|
||
|
||
|
||
def _build_sources(items: list[dict[str, Any]]) -> list[dict[str, str]]:
|
||
"""Build a numbered citation metadata list from search result items."""
|
||
sources: list[dict[str, str]] = []
|
||
for idx, item in enumerate(items, start=1):
|
||
sources.append({
|
||
"number": idx,
|
||
"title": item.get("title", ""),
|
||
"slug": item.get("technique_page_slug", "") or item.get("slug", ""),
|
||
"creator_name": item.get("creator_name", ""),
|
||
"topic_category": item.get("topic_category", ""),
|
||
"summary": (item.get("summary", "") or "")[:200],
|
||
"section_anchor": item.get("section_anchor", ""),
|
||
"section_heading": item.get("section_heading", ""),
|
||
"source_video_id": item.get("source_video_id", ""),
|
||
"start_time": item.get("start_time"),
|
||
"end_time": item.get("end_time"),
|
||
"video_filename": item.get("video_filename", ""),
|
||
})
|
||
return sources
|
||
|
||
|
||
def _build_context_block(items: list[dict[str, Any]]) -> str:
|
||
"""Build a numbered context block string for the LLM system prompt."""
|
||
if not items:
|
||
return "(No sources available)"
|
||
|
||
lines: list[str] = []
|
||
for idx, item in enumerate(items, start=1):
|
||
title = item.get("title", "Untitled")
|
||
creator = item.get("creator_name", "")
|
||
summary = item.get("summary", "")
|
||
section = item.get("section_heading", "")
|
||
|
||
parts = [f"[{idx}] {title}"]
|
||
if creator:
|
||
parts.append(f"by {creator}")
|
||
if section:
|
||
parts.append(f"— {section}")
|
||
header = " ".join(parts)
|
||
|
||
lines.append(header)
|
||
if summary:
|
||
lines.append(f" {summary}")
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _build_personality_block(creator_name: str, profile: dict[str, Any], weight: float) -> str:
|
||
"""Build a personality voice injection block from a creator's personality_profile JSONB.
|
||
|
||
The ``weight`` (0.0–1.0) controls progressive inclusion of personality
|
||
fields via 5 tiers of continuous interpolation:
|
||
|
||
- < 0.2: no personality block (empty string)
|
||
- 0.2–0.39: basic tone — teaching_style, formality, energy + subtle hint
|
||
- 0.4–0.59: + descriptors, explanation_approach + adopt-voice instruction
|
||
- 0.6–0.79: + signature_phrases (count scaled by weight) + creator-voice
|
||
- 0.8–0.89: + distinctive_terms, sound_descriptions, sound_words,
|
||
self_references, pacing + fully-embody instruction
|
||
- >= 0.9: + full summary paragraph
|
||
"""
|
||
if weight < 0.2:
|
||
return ""
|
||
|
||
vocab = profile.get("vocabulary", {})
|
||
tone = profile.get("tone", {})
|
||
style = profile.get("style_markers", {})
|
||
|
||
teaching_style = tone.get("teaching_style", "")
|
||
energy = tone.get("energy", "moderate")
|
||
formality = tone.get("formality", "conversational")
|
||
descriptors = tone.get("descriptors", [])
|
||
phrases = vocab.get("signature_phrases", [])
|
||
|
||
parts: list[str] = []
|
||
|
||
# --- Tier 1 (0.2–0.39): basic tone ---
|
||
if weight < 0.4:
|
||
parts.append(
|
||
f"When relevant, subtly reference {creator_name}'s communication style."
|
||
)
|
||
elif weight < 0.6:
|
||
parts.append(f"Adopt {creator_name}'s tone and communication style.")
|
||
elif weight < 0.8:
|
||
parts.append(
|
||
f"Respond as {creator_name} would, using their voice and manner."
|
||
)
|
||
else:
|
||
parts.append(
|
||
f"Fully embody {creator_name} — use their exact phrases, energy, and teaching approach."
|
||
)
|
||
|
||
if teaching_style:
|
||
parts.append(f"Teaching style: {teaching_style}.")
|
||
parts.append(f"Match their {formality} {energy} tone.")
|
||
|
||
# --- Tier 2 (0.4+): descriptors, explanation_approach, uses_analogies, audience_engagement ---
|
||
if weight >= 0.4:
|
||
if descriptors:
|
||
parts.append(f"Tone: {', '.join(descriptors[:5])}.")
|
||
explanation = style.get("explanation_approach", "")
|
||
if explanation:
|
||
parts.append(f"Explanation approach: {explanation}.")
|
||
if style.get("uses_analogies"):
|
||
parts.append("Use analogies when helpful.")
|
||
if style.get("audience_engagement"):
|
||
parts.append(f"Audience engagement: {style['audience_engagement']}.")
|
||
|
||
# --- Tier 3 (0.6+): signature phrases (count scaled by weight) ---
|
||
if weight >= 0.6 and phrases:
|
||
count = max(2, round(weight * len(phrases)))
|
||
parts.append(f"Use their signature phrases: {', '.join(phrases[:count])}.")
|
||
|
||
# --- Tier 4 (0.8+): distinctive_terms, sound_descriptions, sound_words, self_references, pacing ---
|
||
if weight >= 0.8:
|
||
distinctive = vocab.get("distinctive_terms", [])
|
||
if distinctive:
|
||
parts.append(f"Distinctive terms: {', '.join(distinctive)}.")
|
||
sound_desc = vocab.get("sound_descriptions", [])
|
||
if sound_desc:
|
||
parts.append(f"Sound descriptions: {', '.join(sound_desc)}.")
|
||
sound_words = style.get("sound_words", [])
|
||
if sound_words:
|
||
parts.append(f"Sound words: {', '.join(sound_words)}.")
|
||
self_refs = style.get("self_references", "")
|
||
if self_refs:
|
||
parts.append(f"Self-references: {self_refs}.")
|
||
pacing = style.get("pacing", "")
|
||
if pacing:
|
||
parts.append(f"Pacing: {pacing}.")
|
||
|
||
# --- Tier 5 (0.9+): full summary paragraph ---
|
||
if weight >= 0.9:
|
||
summary = profile.get("summary", "")
|
||
if summary:
|
||
parts.append(summary)
|
||
|
||
return " ".join(parts)
|