- "backend/chat_service.py" - "backend/routers/chat.py" - "backend/tests/test_chat.py" GSD-Task: S04/T01
247 lines
9.3 KiB
Python
247 lines
9.3 KiB
Python
"""Chat service: retrieve context via search, stream LLM response as SSE events.
|
|
|
|
Assembles a numbered context block from search results, then streams
|
|
completion tokens from an OpenAI-compatible API. Yields SSE-formatted
|
|
events: sources, token, done, and error.
|
|
|
|
Multi-turn memory: When a conversation_id is provided, prior messages are
|
|
loaded from Redis, injected into the LLM messages array, and the new
|
|
user+assistant turn is appended after streaming completes. History is
|
|
capped at 10 turn pairs (20 messages) and expires after 1 hour of
|
|
inactivity.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from typing import Any, AsyncIterator
|
|
|
|
import openai
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from config import Settings
|
|
from search_service import SearchService
|
|
|
|
logger = logging.getLogger("chrysopedia.chat")
|
|
|
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
|
You are Chrysopedia, an expert encyclopedic assistant for music production techniques.
|
|
Answer the user's question using ONLY the numbered sources below. Cite sources by
|
|
writing [N] inline (e.g. [1], [2]) where N is the source number. If the sources
|
|
do not contain enough information, say so honestly — do not invent facts.
|
|
|
|
Sources:
|
|
{context_block}
|
|
"""
|
|
|
|
_MAX_CONTEXT_SOURCES = 10
|
|
_MAX_TURN_PAIRS = 10
|
|
_HISTORY_TTL_SECONDS = 3600 # 1 hour
|
|
|
|
|
|
def _redis_key(conversation_id: str) -> str:
|
|
return f"chrysopedia:chat:{conversation_id}"
|
|
|
|
|
|
class ChatService:
|
|
"""Retrieve context from search, stream an LLM response with citations."""
|
|
|
|
def __init__(self, settings: Settings, redis=None) -> None:
|
|
self.settings = settings
|
|
self._search = SearchService(settings)
|
|
self._openai = openai.AsyncOpenAI(
|
|
base_url=settings.llm_api_url,
|
|
api_key=settings.llm_api_key,
|
|
)
|
|
self._redis = redis
|
|
|
|
async def _load_history(self, conversation_id: str) -> list[dict[str, str]]:
|
|
"""Load conversation history from Redis. Returns empty list on miss."""
|
|
if not self._redis:
|
|
return []
|
|
try:
|
|
raw = await self._redis.get(_redis_key(conversation_id))
|
|
if raw:
|
|
return json.loads(raw)
|
|
except Exception:
|
|
logger.warning("chat_history_load_error cid=%s", conversation_id, exc_info=True)
|
|
return []
|
|
|
|
async def _save_history(
|
|
self,
|
|
conversation_id: str,
|
|
history: list[dict[str, str]],
|
|
user_msg: str,
|
|
assistant_msg: str,
|
|
) -> None:
|
|
"""Append the new turn pair and persist to Redis with TTL refresh."""
|
|
if not self._redis:
|
|
return
|
|
history.append({"role": "user", "content": user_msg})
|
|
history.append({"role": "assistant", "content": assistant_msg})
|
|
# Cap at _MAX_TURN_PAIRS (keep most recent)
|
|
if len(history) > _MAX_TURN_PAIRS * 2:
|
|
history = history[-_MAX_TURN_PAIRS * 2:]
|
|
try:
|
|
await self._redis.set(
|
|
_redis_key(conversation_id),
|
|
json.dumps(history),
|
|
ex=_HISTORY_TTL_SECONDS,
|
|
)
|
|
except Exception:
|
|
logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True)
|
|
|
|
async def stream_response(
|
|
self,
|
|
query: str,
|
|
db: AsyncSession,
|
|
creator: str | None = None,
|
|
conversation_id: str | None = None,
|
|
) -> AsyncIterator[str]:
|
|
"""Yield SSE-formatted events for a chat query.
|
|
|
|
Protocol:
|
|
1. ``event: sources\ndata: <json array of citation metadata>\n\n``
|
|
2. ``event: token\ndata: <text chunk>\n\n`` (repeated)
|
|
3. ``event: done\ndata: <json with cascade_tier, conversation_id>\n\n``
|
|
On error: ``event: error\ndata: <json with message>\n\n``
|
|
"""
|
|
start = time.monotonic()
|
|
|
|
# Assign conversation_id if not provided (single-turn becomes trackable)
|
|
if conversation_id is None:
|
|
conversation_id = str(uuid.uuid4())
|
|
|
|
# ── 0. Load conversation history ────────────────────────────────
|
|
history = await self._load_history(conversation_id)
|
|
|
|
# ── 1. Retrieve context via search ──────────────────────────────
|
|
try:
|
|
search_result = await self._search.search(
|
|
query=query,
|
|
scope="all",
|
|
limit=_MAX_CONTEXT_SOURCES,
|
|
db=db,
|
|
creator=creator,
|
|
)
|
|
except Exception:
|
|
logger.exception("chat_search_error query=%r creator=%r", query, creator)
|
|
yield _sse("error", {"message": "Search failed"})
|
|
return
|
|
|
|
items: list[dict[str, Any]] = search_result.get("items", [])
|
|
cascade_tier: str = search_result.get("cascade_tier", "")
|
|
|
|
# ── 2. Build citation metadata and context block ────────────────
|
|
sources = _build_sources(items)
|
|
context_block = _build_context_block(items)
|
|
|
|
logger.info(
|
|
"chat_search query=%r creator=%r cascade_tier=%s source_count=%d cid=%s",
|
|
query, creator, cascade_tier, len(sources), conversation_id,
|
|
)
|
|
|
|
# Emit sources event first
|
|
yield _sse("sources", sources)
|
|
|
|
# ── 3. Stream LLM completion ────────────────────────────────────
|
|
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
|
|
|
|
messages: list[dict[str, str]] = [
|
|
{"role": "system", "content": system_prompt},
|
|
]
|
|
# Inject conversation history between system prompt and current query
|
|
messages.extend(history)
|
|
messages.append({"role": "user", "content": query})
|
|
|
|
accumulated_response = ""
|
|
|
|
try:
|
|
stream = await self._openai.chat.completions.create(
|
|
model=self.settings.llm_model,
|
|
messages=messages,
|
|
stream=True,
|
|
temperature=0.3,
|
|
max_tokens=2048,
|
|
)
|
|
|
|
async for chunk in stream:
|
|
choice = chunk.choices[0] if chunk.choices else None
|
|
if choice and choice.delta and choice.delta.content:
|
|
text = choice.delta.content
|
|
accumulated_response += text
|
|
yield _sse("token", text)
|
|
|
|
except Exception:
|
|
tb = traceback.format_exc()
|
|
logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb)
|
|
yield _sse("error", {"message": "LLM generation failed"})
|
|
return
|
|
|
|
# ── 4. Save conversation history ────────────────────────────────
|
|
await self._save_history(conversation_id, history, query, accumulated_response)
|
|
|
|
# ── 5. Done event ───────────────────────────────────────────────
|
|
latency_ms = (time.monotonic() - start) * 1000
|
|
logger.info(
|
|
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s",
|
|
query, creator, cascade_tier, len(sources), latency_ms, conversation_id,
|
|
)
|
|
yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id})
|
|
|
|
|
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def _sse(event: str, data: Any) -> str:
|
|
"""Format a single SSE event string."""
|
|
payload = json.dumps(data) if not isinstance(data, str) else data
|
|
return f"event: {event}\ndata: {payload}\n\n"
|
|
|
|
|
|
def _build_sources(items: list[dict[str, Any]]) -> list[dict[str, str]]:
|
|
"""Build a numbered citation metadata list from search result items."""
|
|
sources: list[dict[str, str]] = []
|
|
for idx, item in enumerate(items, start=1):
|
|
sources.append({
|
|
"number": idx,
|
|
"title": item.get("title", ""),
|
|
"slug": item.get("technique_page_slug", "") or item.get("slug", ""),
|
|
"creator_name": item.get("creator_name", ""),
|
|
"topic_category": item.get("topic_category", ""),
|
|
"summary": (item.get("summary", "") or "")[:200],
|
|
"section_anchor": item.get("section_anchor", ""),
|
|
"section_heading": item.get("section_heading", ""),
|
|
})
|
|
return sources
|
|
|
|
|
|
def _build_context_block(items: list[dict[str, Any]]) -> str:
|
|
"""Build a numbered context block string for the LLM system prompt."""
|
|
if not items:
|
|
return "(No sources available)"
|
|
|
|
lines: list[str] = []
|
|
for idx, item in enumerate(items, start=1):
|
|
title = item.get("title", "Untitled")
|
|
creator = item.get("creator_name", "")
|
|
summary = item.get("summary", "")
|
|
section = item.get("section_heading", "")
|
|
|
|
parts = [f"[{idx}] {title}"]
|
|
if creator:
|
|
parts.append(f"by {creator}")
|
|
if section:
|
|
parts.append(f"— {section}")
|
|
header = " ".join(parts)
|
|
|
|
lines.append(header)
|
|
if summary:
|
|
lines.append(f" {summary}")
|
|
lines.append("")
|
|
|
|
return "\n".join(lines)
|