test: Added multi-turn conversation memory with Redis-backed history (1…

- "backend/chat_service.py"
- "backend/routers/chat.py"
- "backend/tests/test_chat.py"

GSD-Task: S04/T01
This commit is contained in:
jlightner 2026-04-04 07:50:30 +00:00
parent 0098254fdd
commit 29e60bbc99
3 changed files with 365 additions and 26 deletions

View file

@ -3,6 +3,12 @@
Assembles a numbered context block from search results, then streams Assembles a numbered context block from search results, then streams
completion tokens from an OpenAI-compatible API. Yields SSE-formatted completion tokens from an OpenAI-compatible API. Yields SSE-formatted
events: sources, token, done, and error. events: sources, token, done, and error.
Multi-turn memory: When a conversation_id is provided, prior messages are
loaded from Redis, injected into the LLM messages array, and the new
user+assistant turn is appended after streaming completes. History is
capped at 10 turn pairs (20 messages) and expires after 1 hour of
inactivity.
""" """
from __future__ import annotations from __future__ import annotations
@ -11,6 +17,7 @@ import json
import logging import logging
import time import time
import traceback import traceback
import uuid
from typing import Any, AsyncIterator from typing import Any, AsyncIterator
import openai import openai
@ -32,35 +39,86 @@ Sources:
""" """
_MAX_CONTEXT_SOURCES = 10 _MAX_CONTEXT_SOURCES = 10
_MAX_TURN_PAIRS = 10
_HISTORY_TTL_SECONDS = 3600 # 1 hour
def _redis_key(conversation_id: str) -> str:
return f"chrysopedia:chat:{conversation_id}"
class ChatService: class ChatService:
"""Retrieve context from search, stream an LLM response with citations.""" """Retrieve context from search, stream an LLM response with citations."""
def __init__(self, settings: Settings) -> None: def __init__(self, settings: Settings, redis=None) -> None:
self.settings = settings self.settings = settings
self._search = SearchService(settings) self._search = SearchService(settings)
self._openai = openai.AsyncOpenAI( self._openai = openai.AsyncOpenAI(
base_url=settings.llm_api_url, base_url=settings.llm_api_url,
api_key=settings.llm_api_key, api_key=settings.llm_api_key,
) )
self._redis = redis
async def _load_history(self, conversation_id: str) -> list[dict[str, str]]:
"""Load conversation history from Redis. Returns empty list on miss."""
if not self._redis:
return []
try:
raw = await self._redis.get(_redis_key(conversation_id))
if raw:
return json.loads(raw)
except Exception:
logger.warning("chat_history_load_error cid=%s", conversation_id, exc_info=True)
return []
async def _save_history(
self,
conversation_id: str,
history: list[dict[str, str]],
user_msg: str,
assistant_msg: str,
) -> None:
"""Append the new turn pair and persist to Redis with TTL refresh."""
if not self._redis:
return
history.append({"role": "user", "content": user_msg})
history.append({"role": "assistant", "content": assistant_msg})
# Cap at _MAX_TURN_PAIRS (keep most recent)
if len(history) > _MAX_TURN_PAIRS * 2:
history = history[-_MAX_TURN_PAIRS * 2:]
try:
await self._redis.set(
_redis_key(conversation_id),
json.dumps(history),
ex=_HISTORY_TTL_SECONDS,
)
except Exception:
logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True)
async def stream_response( async def stream_response(
self, self,
query: str, query: str,
db: AsyncSession, db: AsyncSession,
creator: str | None = None, creator: str | None = None,
conversation_id: str | None = None,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Yield SSE-formatted events for a chat query. """Yield SSE-formatted events for a chat query.
Protocol: Protocol:
1. ``event: sources\ndata: <json array of citation metadata>\n\n`` 1. ``event: sources\ndata: <json array of citation metadata>\n\n``
2. ``event: token\ndata: <text chunk>\n\n`` (repeated) 2. ``event: token\ndata: <text chunk>\n\n`` (repeated)
3. ``event: done\ndata: <json with cascade_tier>\n\n`` 3. ``event: done\ndata: <json with cascade_tier, conversation_id>\n\n``
On error: ``event: error\ndata: <json with message>\n\n`` On error: ``event: error\ndata: <json with message>\n\n``
""" """
start = time.monotonic() start = time.monotonic()
# Assign conversation_id if not provided (single-turn becomes trackable)
if conversation_id is None:
conversation_id = str(uuid.uuid4())
# ── 0. Load conversation history ────────────────────────────────
history = await self._load_history(conversation_id)
# ── 1. Retrieve context via search ────────────────────────────── # ── 1. Retrieve context via search ──────────────────────────────
try: try:
search_result = await self._search.search( search_result = await self._search.search(
@ -83,8 +141,8 @@ class ChatService:
context_block = _build_context_block(items) context_block = _build_context_block(items)
logger.info( logger.info(
"chat_search query=%r creator=%r cascade_tier=%s source_count=%d", "chat_search query=%r creator=%r cascade_tier=%s source_count=%d cid=%s",
query, creator, cascade_tier, len(sources), query, creator, cascade_tier, len(sources), conversation_id,
) )
# Emit sources event first # Emit sources event first
@ -93,13 +151,19 @@ class ChatService:
# ── 3. Stream LLM completion ──────────────────────────────────── # ── 3. Stream LLM completion ────────────────────────────────────
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block) system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
messages: list[dict[str, str]] = [
{"role": "system", "content": system_prompt},
]
# Inject conversation history between system prompt and current query
messages.extend(history)
messages.append({"role": "user", "content": query})
accumulated_response = ""
try: try:
stream = await self._openai.chat.completions.create( stream = await self._openai.chat.completions.create(
model=self.settings.llm_model, model=self.settings.llm_model,
messages=[ messages=messages,
{"role": "system", "content": system_prompt},
{"role": "user", "content": query},
],
stream=True, stream=True,
temperature=0.3, temperature=0.3,
max_tokens=2048, max_tokens=2048,
@ -108,21 +172,26 @@ class ChatService:
async for chunk in stream: async for chunk in stream:
choice = chunk.choices[0] if chunk.choices else None choice = chunk.choices[0] if chunk.choices else None
if choice and choice.delta and choice.delta.content: if choice and choice.delta and choice.delta.content:
yield _sse("token", choice.delta.content) text = choice.delta.content
accumulated_response += text
yield _sse("token", text)
except Exception: except Exception:
tb = traceback.format_exc() tb = traceback.format_exc()
logger.error("chat_llm_error query=%r\n%s", query, tb) logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb)
yield _sse("error", {"message": "LLM generation failed"}) yield _sse("error", {"message": "LLM generation failed"})
return return
# ── 4. Done event ─────────────────────────────────────────────── # ── 4. Save conversation history ────────────────────────────────
await self._save_history(conversation_id, history, query, accumulated_response)
# ── 5. Done event ───────────────────────────────────────────────
latency_ms = (time.monotonic() - start) * 1000 latency_ms = (time.monotonic() - start) * 1000
logger.info( logger.info(
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f", "chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s",
query, creator, cascade_tier, len(sources), latency_ms, query, creator, cascade_tier, len(sources), latency_ms, conversation_id,
) )
yield _sse("done", {"cascade_tier": cascade_tier}) yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id})
# ── Helpers ────────────────────────────────────────────────────────────────── # ── Helpers ──────────────────────────────────────────────────────────────────

View file

@ -16,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from chat_service import ChatService from chat_service import ChatService
from config import Settings, get_settings from config import Settings, get_settings
from database import get_session from database import get_session
from redis_client import get_redis
logger = logging.getLogger("chrysopedia.chat.router") logger = logging.getLogger("chrysopedia.chat.router")
@ -27,31 +28,35 @@ class ChatRequest(BaseModel):
query: str = Field(..., min_length=1, max_length=1000) query: str = Field(..., min_length=1, max_length=1000)
creator: str | None = None creator: str | None = None
conversation_id: str | None = None
def _get_chat_service(settings: Settings = Depends(get_settings)) -> ChatService:
"""Build a ChatService from current settings."""
return ChatService(settings)
@router.post("") @router.post("")
async def chat( async def chat(
body: ChatRequest, body: ChatRequest,
db: AsyncSession = Depends(get_session), db: AsyncSession = Depends(get_session),
service: ChatService = Depends(_get_chat_service), settings: Settings = Depends(get_settings),
) -> StreamingResponse: ) -> StreamingResponse:
"""Stream a chat response as Server-Sent Events. """Stream a chat response as Server-Sent Events.
SSE protocol: SSE protocol:
- ``event: sources`` citation metadata array (sent first) - ``event: sources`` citation metadata array (sent first)
- ``event: token`` streamed text chunk (repeated) - ``event: token`` streamed text chunk (repeated)
- ``event: done`` completion metadata with cascade_tier - ``event: done`` completion metadata with cascade_tier, conversation_id
- ``event: error`` error message (on failure) - ``event: error`` error message (on failure)
""" """
logger.info("chat_request query=%r creator=%r", body.query, body.creator) logger.info("chat_request query=%r creator=%r cid=%r", body.query, body.creator, body.conversation_id)
redis = await get_redis()
service = ChatService(settings, redis=redis)
return StreamingResponse( return StreamingResponse(
service.stream_response(query=body.query, db=db, creator=body.creator), service.stream_response(
query=body.query,
db=db,
creator=body.creator,
conversation_id=body.conversation_id,
),
media_type="text/event-stream", media_type="text/event-stream",
headers={ headers={
"Cache-Control": "no-cache", "Cache-Control": "no-cache",

View file

@ -6,6 +6,7 @@ Mocks SearchService.search() and the OpenAI streaming response to verify:
3. Creator param forwarded to search 3. Creator param forwarded to search
4. Empty/invalid query returns 422 4. Empty/invalid query returns 422
5. LLM error produces an SSE error event 5. LLM error produces an SSE error event
6. Multi-turn conversation memory via Redis (load/save/cap/TTL)
These tests use a standalone ASGI client that does NOT require a running These tests use a standalone ASGI client that does NOT require a running
PostgreSQL instance the DB session dependency is overridden with a mock. PostgreSQL instance the DB session dependency is overridden with a mock.
@ -33,8 +34,28 @@ from main import app # noqa: E402
# ── Standalone test client (no DB required) ────────────────────────────────── # ── Standalone test client (no DB required) ──────────────────────────────────
@pytest_asyncio.fixture() @pytest_asyncio.fixture()
async def chat_client(): async def mock_redis():
"""Async HTTP test client that mocks out the DB session entirely.""" """In-memory mock of an async Redis client (get/set)."""
store: dict[str, tuple[str, int | None]] = {}
mock = AsyncMock()
async def _get(key):
entry = store.get(key)
return entry[0] if entry else None
async def _set(key, value, ex=None):
store[key] = (value, ex)
mock.get = AsyncMock(side_effect=_get)
mock.set = AsyncMock(side_effect=_set)
mock._store = store # exposed for test assertions
return mock
@pytest_asyncio.fixture()
async def chat_client(mock_redis):
"""Async HTTP test client that mocks out the DB session and Redis."""
mock_session = AsyncMock() mock_session = AsyncMock()
async def _mock_get_session(): async def _mock_get_session():
@ -44,7 +65,8 @@ async def chat_client():
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://testserver") as ac: async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
yield ac with patch("routers.chat.get_redis", new_callable=AsyncMock, return_value=mock_redis):
yield ac
app.dependency_overrides.pop(get_session, None) app.dependency_overrides.pop(get_session, None)
@ -298,3 +320,246 @@ async def test_chat_llm_error_produces_error_event(chat_client):
error_event = next(e for e in events if e["event"] == "error") error_event = next(e for e in events if e["event"] == "error")
assert "message" in error_event["data"] assert "message" in error_event["data"]
# ── Multi-turn conversation memory tests ─────────────────────────────────────
def _chat_request_with_mocks(query, conversation_id=None, creator=None, token_chunks=None):
"""Helper: build a request JSON and patched mocks for a single chat call."""
if token_chunks is None:
token_chunks = ["ok"]
body: dict[str, Any] = {"query": query}
if conversation_id is not None:
body["conversation_id"] = conversation_id
if creator is not None:
body["creator"] = creator
return body, token_chunks
@pytest.mark.asyncio
async def test_conversation_done_event_includes_conversation_id(chat_client):
"""The done SSE event includes conversation_id — both when provided and auto-generated."""
search_result = _fake_search_result()
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["hello"])
)
# With explicit conversation_id
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "conversation_id": "conv-abc-123"},
)
events = _parse_sse(resp.text)
done_data = next(e for e in events if e["event"] == "done")["data"]
assert done_data["conversation_id"] == "conv-abc-123"
@pytest.mark.asyncio
async def test_conversation_auto_generates_id_when_omitted(chat_client):
"""When conversation_id is not provided, the done event still includes one (auto-generated)."""
search_result = _fake_search_result()
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["world"])
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "no cid"})
events = _parse_sse(resp.text)
done_data = next(e for e in events if e["event"] == "done")["data"]
# Auto-generated UUID format
cid = done_data["conversation_id"]
assert isinstance(cid, str)
assert len(cid) == 36 # UUID4 format
@pytest.mark.asyncio
async def test_conversation_history_saved_to_redis(chat_client, mock_redis):
"""After a successful chat, user+assistant messages are saved to Redis."""
search_result = _fake_search_result()
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["The answer ", "is 42."])
)
cid = "conv-save-test"
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "what is the answer", "conversation_id": cid},
)
assert resp.status_code == 200
# Verify Redis received the history
redis_key = f"chrysopedia:chat:{cid}"
assert redis_key in mock_redis._store
stored_json, ttl = mock_redis._store[redis_key]
history = json.loads(stored_json)
assert len(history) == 2
assert history[0] == {"role": "user", "content": "what is the answer"}
assert history[1] == {"role": "assistant", "content": "The answer is 42."}
assert ttl == 3600
@pytest.mark.asyncio
async def test_conversation_history_injected_into_llm_messages(chat_client, mock_redis):
"""Prior conversation history is injected between system prompt and user message."""
search_result = _fake_search_result()
# Pre-populate Redis with conversation history
cid = "conv-inject-test"
prior_history = [
{"role": "user", "content": "what is reverb"},
{"role": "assistant", "content": "Reverb simulates acoustic spaces."},
]
mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 3600)
captured_messages = []
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_messages.extend(kwargs.get("messages", []))
return _mock_openai_stream(["follow-up answer"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "how do I use it on drums", "conversation_id": cid},
)
assert resp.status_code == 200
# Messages should be: system, prior_user, prior_assistant, current_user
assert len(captured_messages) == 4
assert captured_messages[0]["role"] == "system"
assert captured_messages[1] == {"role": "user", "content": "what is reverb"}
assert captured_messages[2] == {"role": "assistant", "content": "Reverb simulates acoustic spaces."}
assert captured_messages[3] == {"role": "user", "content": "how do I use it on drums"}
@pytest.mark.asyncio
async def test_conversation_history_capped_at_10_pairs(chat_client, mock_redis):
"""History is capped at 10 turn pairs (20 messages). Oldest turns are dropped."""
search_result = _fake_search_result()
cid = "conv-cap-test"
# Pre-populate with 10 turn pairs (20 messages) — at the cap
prior_history = []
for i in range(10):
prior_history.append({"role": "user", "content": f"question {i}"})
prior_history.append({"role": "assistant", "content": f"answer {i}"})
mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 3600)
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["capped reply"])
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "turn 11", "conversation_id": cid},
)
assert resp.status_code == 200
# After adding turn 11, history should still be 20 messages (10 pairs)
redis_key = f"chrysopedia:chat:{cid}"
stored_json, _ = mock_redis._store[redis_key]
history = json.loads(stored_json)
assert len(history) == 20 # 10 pairs
# Oldest pair (question 0 / answer 0) should be dropped
assert history[0] == {"role": "user", "content": "question 1"}
# New pair should be at the end
assert history[-2] == {"role": "user", "content": "turn 11"}
assert history[-1] == {"role": "assistant", "content": "capped reply"}
@pytest.mark.asyncio
async def test_conversation_ttl_refreshed_on_interaction(chat_client, mock_redis):
"""Each interaction refreshes the Redis TTL to 1 hour."""
search_result = _fake_search_result()
cid = "conv-ttl-test"
# Pre-populate with a short simulated TTL
prior_history = [
{"role": "user", "content": "old message"},
{"role": "assistant", "content": "old reply"},
]
mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 100)
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["refreshed"])
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "refresh test", "conversation_id": cid},
)
assert resp.status_code == 200
# TTL should be refreshed to 3600
redis_key = f"chrysopedia:chat:{cid}"
_, ttl = mock_redis._store[redis_key]
assert ttl == 3600
@pytest.mark.asyncio
async def test_single_turn_fallback_no_redis_history(chat_client, mock_redis):
"""When no conversation_id is provided, no history is loaded and behavior matches single-turn."""
search_result = _fake_search_result()
captured_messages = []
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_messages.extend(kwargs.get("messages", []))
return _mock_openai_stream(["standalone answer"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "standalone question"})
assert resp.status_code == 200
# Should only have system + user (no history injected since auto-generated cid is fresh)
assert len(captured_messages) == 2
assert captured_messages[0]["role"] == "system"
assert captured_messages[1]["role"] == "user"