chrysopedia/backend/chat_service.py
jlightner d1efdbb3fa feat: Added personality_weight (0.0–1.0) to chat API; modulates system…
- "backend/routers/chat.py"
- "backend/chat_service.py"
- "backend/tests/test_chat.py"

GSD-Task: S02/T01
2026-04-04 09:28:35 +00:00

333 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Chat service: retrieve context via search, stream LLM response as SSE events.
Assembles a numbered context block from search results, then streams
completion tokens from an OpenAI-compatible API. Yields SSE-formatted
events: sources, token, done, and error.
Multi-turn memory: When a conversation_id is provided, prior messages are
loaded from Redis, injected into the LLM messages array, and the new
user+assistant turn is appended after streaming completes. History is
capped at 10 turn pairs (20 messages) and expires after 1 hour of
inactivity.
"""
from __future__ import annotations
import json
import logging
import time
import traceback
import uuid
from typing import Any, AsyncIterator
import openai
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from config import Settings
from models import Creator
from search_service import SearchService
logger = logging.getLogger("chrysopedia.chat")
_SYSTEM_PROMPT_TEMPLATE = """\
You are Chrysopedia, an expert encyclopedic assistant for music production techniques.
Answer the user's question using ONLY the numbered sources below. Cite sources by
writing [N] inline (e.g. [1], [2]) where N is the source number. If the sources
do not contain enough information, say so honestly — do not invent facts.
Sources:
{context_block}
"""
_MAX_CONTEXT_SOURCES = 10
_MAX_TURN_PAIRS = 10
_HISTORY_TTL_SECONDS = 3600 # 1 hour
def _redis_key(conversation_id: str) -> str:
return f"chrysopedia:chat:{conversation_id}"
class ChatService:
"""Retrieve context from search, stream an LLM response with citations."""
def __init__(self, settings: Settings, redis=None) -> None:
self.settings = settings
self._search = SearchService(settings)
self._openai = openai.AsyncOpenAI(
base_url=settings.llm_api_url,
api_key=settings.llm_api_key,
)
self._redis = redis
async def _load_history(self, conversation_id: str) -> list[dict[str, str]]:
"""Load conversation history from Redis. Returns empty list on miss."""
if not self._redis:
return []
try:
raw = await self._redis.get(_redis_key(conversation_id))
if raw:
return json.loads(raw)
except Exception:
logger.warning("chat_history_load_error cid=%s", conversation_id, exc_info=True)
return []
async def _save_history(
self,
conversation_id: str,
history: list[dict[str, str]],
user_msg: str,
assistant_msg: str,
) -> None:
"""Append the new turn pair and persist to Redis with TTL refresh."""
if not self._redis:
return
history.append({"role": "user", "content": user_msg})
history.append({"role": "assistant", "content": assistant_msg})
# Cap at _MAX_TURN_PAIRS (keep most recent)
if len(history) > _MAX_TURN_PAIRS * 2:
history = history[-_MAX_TURN_PAIRS * 2:]
try:
await self._redis.set(
_redis_key(conversation_id),
json.dumps(history),
ex=_HISTORY_TTL_SECONDS,
)
except Exception:
logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True)
async def _inject_personality(
self,
system_prompt: str,
db: AsyncSession,
creator_name: str,
weight: float,
) -> str:
"""Query creator personality_profile and append a voice block to the system prompt.
Falls back to the unmodified prompt on DB error, missing creator, or null profile.
"""
try:
result = await db.execute(
select(Creator).where(Creator.name == creator_name)
)
creator_row = result.scalars().first()
except Exception:
logger.warning("chat_personality_db_error creator=%r", creator_name, exc_info=True)
return system_prompt
if creator_row is None or creator_row.personality_profile is None:
logger.debug("chat_personality_skip creator=%r reason=%s",
creator_name,
"not_found" if creator_row is None else "null_profile")
return system_prompt
profile = creator_row.personality_profile
voice_block = _build_personality_block(creator_name, profile, weight)
return system_prompt + "\n\n" + voice_block
async def stream_response(
self,
query: str,
db: AsyncSession,
creator: str | None = None,
conversation_id: str | None = None,
personality_weight: float = 0.0,
) -> AsyncIterator[str]:
"""Yield SSE-formatted events for a chat query.
Protocol:
1. ``event: sources\ndata: <json array of citation metadata>\n\n``
2. ``event: token\ndata: <text chunk>\n\n`` (repeated)
3. ``event: done\ndata: <json with cascade_tier, conversation_id>\n\n``
On error: ``event: error\ndata: <json with message>\n\n``
"""
start = time.monotonic()
# Assign conversation_id if not provided (single-turn becomes trackable)
if conversation_id is None:
conversation_id = str(uuid.uuid4())
# ── 0. Load conversation history ────────────────────────────────
history = await self._load_history(conversation_id)
# ── 1. Retrieve context via search ──────────────────────────────
try:
search_result = await self._search.search(
query=query,
scope="all",
limit=_MAX_CONTEXT_SOURCES,
db=db,
creator=creator,
)
except Exception:
logger.exception("chat_search_error query=%r creator=%r", query, creator)
yield _sse("error", {"message": "Search failed"})
return
items: list[dict[str, Any]] = search_result.get("items", [])
cascade_tier: str = search_result.get("cascade_tier", "")
# ── 2. Build citation metadata and context block ────────────────
sources = _build_sources(items)
context_block = _build_context_block(items)
logger.info(
"chat_search query=%r creator=%r cascade_tier=%s source_count=%d cid=%s",
query, creator, cascade_tier, len(sources), conversation_id,
)
# Emit sources event first
yield _sse("sources", sources)
# ── 3. Stream LLM completion ────────────────────────────────────
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
# Inject creator personality voice when weight > 0
if personality_weight > 0 and creator:
system_prompt = await self._inject_personality(
system_prompt, db, creator, personality_weight,
)
# Scale temperature with personality weight: 0.3 (encyclopedic) → 0.5 (full personality)
temperature = 0.3 + (personality_weight * 0.2)
messages: list[dict[str, str]] = [
{"role": "system", "content": system_prompt},
]
# Inject conversation history between system prompt and current query
messages.extend(history)
messages.append({"role": "user", "content": query})
accumulated_response = ""
try:
stream = await self._openai.chat.completions.create(
model=self.settings.llm_model,
messages=messages,
stream=True,
temperature=temperature,
max_tokens=2048,
)
async for chunk in stream:
choice = chunk.choices[0] if chunk.choices else None
if choice and choice.delta and choice.delta.content:
text = choice.delta.content
accumulated_response += text
yield _sse("token", text)
except Exception:
tb = traceback.format_exc()
logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb)
yield _sse("error", {"message": "LLM generation failed"})
return
# ── 4. Save conversation history ────────────────────────────────
await self._save_history(conversation_id, history, query, accumulated_response)
# ── 5. Done event ───────────────────────────────────────────────
latency_ms = (time.monotonic() - start) * 1000
logger.info(
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s",
query, creator, cascade_tier, len(sources), latency_ms, conversation_id,
)
yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id})
# ── Helpers ──────────────────────────────────────────────────────────────────
def _sse(event: str, data: Any) -> str:
"""Format a single SSE event string."""
payload = json.dumps(data) if not isinstance(data, str) else data
return f"event: {event}\ndata: {payload}\n\n"
def _build_sources(items: list[dict[str, Any]]) -> list[dict[str, str]]:
"""Build a numbered citation metadata list from search result items."""
sources: list[dict[str, str]] = []
for idx, item in enumerate(items, start=1):
sources.append({
"number": idx,
"title": item.get("title", ""),
"slug": item.get("technique_page_slug", "") or item.get("slug", ""),
"creator_name": item.get("creator_name", ""),
"topic_category": item.get("topic_category", ""),
"summary": (item.get("summary", "") or "")[:200],
"section_anchor": item.get("section_anchor", ""),
"section_heading": item.get("section_heading", ""),
})
return sources
def _build_context_block(items: list[dict[str, Any]]) -> str:
"""Build a numbered context block string for the LLM system prompt."""
if not items:
return "(No sources available)"
lines: list[str] = []
for idx, item in enumerate(items, start=1):
title = item.get("title", "Untitled")
creator = item.get("creator_name", "")
summary = item.get("summary", "")
section = item.get("section_heading", "")
parts = [f"[{idx}] {title}"]
if creator:
parts.append(f"by {creator}")
if section:
parts.append(f"{section}")
header = " ".join(parts)
lines.append(header)
if summary:
lines.append(f" {summary}")
lines.append("")
return "\n".join(lines)
def _build_personality_block(creator_name: str, profile: dict[str, Any], weight: float) -> str:
"""Build a personality voice injection block from a creator's personality_profile JSONB.
The ``weight`` (0.01.0) determines how strongly the personality should
come through. At low weights the instruction is softer ("subtly adopt");
at high weights it is emphatic ("fully embody").
"""
vocab = profile.get("vocabulary", {})
tone = profile.get("tone", {})
style = profile.get("style_markers", {})
phrases = vocab.get("signature_phrases", [])
descriptors = tone.get("descriptors", [])
teaching_style = tone.get("teaching_style", "")
energy = tone.get("energy", "moderate")
formality = tone.get("formality", "conversational")
parts: list[str] = []
# Intensity qualifier
if weight >= 0.8:
parts.append(f"Fully embody {creator_name}'s voice and style.")
elif weight >= 0.4:
parts.append(f"Respond in {creator_name}'s voice.")
else:
parts.append(f"Subtly adopt {creator_name}'s communication style.")
if teaching_style:
parts.append(f"Teaching style: {teaching_style}.")
if descriptors:
parts.append(f"Tone: {', '.join(descriptors[:5])}.")
if phrases:
parts.append(f"Use their signature phrases: {', '.join(phrases[:6])}.")
parts.append(f"Match their {formality} {energy} tone.")
# Style markers
if style.get("uses_analogies"):
parts.append("Use analogies when helpful.")
if style.get("audience_engagement"):
parts.append(f"Audience engagement: {style['audience_engagement']}.")
return " ".join(parts)