diff --git a/backend/chat_service.py b/backend/chat_service.py index 05fd926..4cf0ba5 100644 --- a/backend/chat_service.py +++ b/backend/chat_service.py @@ -3,6 +3,12 @@ Assembles a numbered context block from search results, then streams completion tokens from an OpenAI-compatible API. Yields SSE-formatted events: sources, token, done, and error. + +Multi-turn memory: When a conversation_id is provided, prior messages are +loaded from Redis, injected into the LLM messages array, and the new +user+assistant turn is appended after streaming completes. History is +capped at 10 turn pairs (20 messages) and expires after 1 hour of +inactivity. """ from __future__ import annotations @@ -11,6 +17,7 @@ import json import logging import time import traceback +import uuid from typing import Any, AsyncIterator import openai @@ -32,35 +39,86 @@ Sources: """ _MAX_CONTEXT_SOURCES = 10 +_MAX_TURN_PAIRS = 10 +_HISTORY_TTL_SECONDS = 3600 # 1 hour + + +def _redis_key(conversation_id: str) -> str: + return f"chrysopedia:chat:{conversation_id}" class ChatService: """Retrieve context from search, stream an LLM response with citations.""" - def __init__(self, settings: Settings) -> None: + def __init__(self, settings: Settings, redis=None) -> None: self.settings = settings self._search = SearchService(settings) self._openai = openai.AsyncOpenAI( base_url=settings.llm_api_url, api_key=settings.llm_api_key, ) + self._redis = redis + + async def _load_history(self, conversation_id: str) -> list[dict[str, str]]: + """Load conversation history from Redis. Returns empty list on miss.""" + if not self._redis: + return [] + try: + raw = await self._redis.get(_redis_key(conversation_id)) + if raw: + return json.loads(raw) + except Exception: + logger.warning("chat_history_load_error cid=%s", conversation_id, exc_info=True) + return [] + + async def _save_history( + self, + conversation_id: str, + history: list[dict[str, str]], + user_msg: str, + assistant_msg: str, + ) -> None: + """Append the new turn pair and persist to Redis with TTL refresh.""" + if not self._redis: + return + history.append({"role": "user", "content": user_msg}) + history.append({"role": "assistant", "content": assistant_msg}) + # Cap at _MAX_TURN_PAIRS (keep most recent) + if len(history) > _MAX_TURN_PAIRS * 2: + history = history[-_MAX_TURN_PAIRS * 2:] + try: + await self._redis.set( + _redis_key(conversation_id), + json.dumps(history), + ex=_HISTORY_TTL_SECONDS, + ) + except Exception: + logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True) async def stream_response( self, query: str, db: AsyncSession, creator: str | None = None, + conversation_id: str | None = None, ) -> AsyncIterator[str]: """Yield SSE-formatted events for a chat query. Protocol: 1. ``event: sources\ndata: \n\n`` 2. ``event: token\ndata: \n\n`` (repeated) - 3. ``event: done\ndata: \n\n`` + 3. ``event: done\ndata: \n\n`` On error: ``event: error\ndata: \n\n`` """ start = time.monotonic() + # Assign conversation_id if not provided (single-turn becomes trackable) + if conversation_id is None: + conversation_id = str(uuid.uuid4()) + + # ── 0. Load conversation history ──────────────────────────────── + history = await self._load_history(conversation_id) + # ── 1. Retrieve context via search ────────────────────────────── try: search_result = await self._search.search( @@ -83,8 +141,8 @@ class ChatService: context_block = _build_context_block(items) logger.info( - "chat_search query=%r creator=%r cascade_tier=%s source_count=%d", - query, creator, cascade_tier, len(sources), + "chat_search query=%r creator=%r cascade_tier=%s source_count=%d cid=%s", + query, creator, cascade_tier, len(sources), conversation_id, ) # Emit sources event first @@ -93,13 +151,19 @@ class ChatService: # ── 3. Stream LLM completion ──────────────────────────────────── system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block) + messages: list[dict[str, str]] = [ + {"role": "system", "content": system_prompt}, + ] + # Inject conversation history between system prompt and current query + messages.extend(history) + messages.append({"role": "user", "content": query}) + + accumulated_response = "" + try: stream = await self._openai.chat.completions.create( model=self.settings.llm_model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": query}, - ], + messages=messages, stream=True, temperature=0.3, max_tokens=2048, @@ -108,21 +172,26 @@ class ChatService: async for chunk in stream: choice = chunk.choices[0] if chunk.choices else None if choice and choice.delta and choice.delta.content: - yield _sse("token", choice.delta.content) + text = choice.delta.content + accumulated_response += text + yield _sse("token", text) except Exception: tb = traceback.format_exc() - logger.error("chat_llm_error query=%r\n%s", query, tb) + logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb) yield _sse("error", {"message": "LLM generation failed"}) return - # ── 4. Done event ─────────────────────────────────────────────── + # ── 4. Save conversation history ──────────────────────────────── + await self._save_history(conversation_id, history, query, accumulated_response) + + # ── 5. Done event ─────────────────────────────────────────────── latency_ms = (time.monotonic() - start) * 1000 logger.info( - "chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f", - query, creator, cascade_tier, len(sources), latency_ms, + "chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s", + query, creator, cascade_tier, len(sources), latency_ms, conversation_id, ) - yield _sse("done", {"cascade_tier": cascade_tier}) + yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id}) # ── Helpers ────────────────────────────────────────────────────────────────── diff --git a/backend/routers/chat.py b/backend/routers/chat.py index b38365f..fd7ad87 100644 --- a/backend/routers/chat.py +++ b/backend/routers/chat.py @@ -16,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from chat_service import ChatService from config import Settings, get_settings from database import get_session +from redis_client import get_redis logger = logging.getLogger("chrysopedia.chat.router") @@ -27,31 +28,35 @@ class ChatRequest(BaseModel): query: str = Field(..., min_length=1, max_length=1000) creator: str | None = None - - -def _get_chat_service(settings: Settings = Depends(get_settings)) -> ChatService: - """Build a ChatService from current settings.""" - return ChatService(settings) + conversation_id: str | None = None @router.post("") async def chat( body: ChatRequest, db: AsyncSession = Depends(get_session), - service: ChatService = Depends(_get_chat_service), + settings: Settings = Depends(get_settings), ) -> StreamingResponse: """Stream a chat response as Server-Sent Events. SSE protocol: - ``event: sources`` — citation metadata array (sent first) - ``event: token`` — streamed text chunk (repeated) - - ``event: done`` — completion metadata with cascade_tier + - ``event: done`` — completion metadata with cascade_tier, conversation_id - ``event: error`` — error message (on failure) """ - logger.info("chat_request query=%r creator=%r", body.query, body.creator) + logger.info("chat_request query=%r creator=%r cid=%r", body.query, body.creator, body.conversation_id) + + redis = await get_redis() + service = ChatService(settings, redis=redis) return StreamingResponse( - service.stream_response(query=body.query, db=db, creator=body.creator), + service.stream_response( + query=body.query, + db=db, + creator=body.creator, + conversation_id=body.conversation_id, + ), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", diff --git a/backend/tests/test_chat.py b/backend/tests/test_chat.py index 985e16c..d45d68f 100644 --- a/backend/tests/test_chat.py +++ b/backend/tests/test_chat.py @@ -6,6 +6,7 @@ Mocks SearchService.search() and the OpenAI streaming response to verify: 3. Creator param forwarded to search 4. Empty/invalid query returns 422 5. LLM error produces an SSE error event +6. Multi-turn conversation memory via Redis (load/save/cap/TTL) These tests use a standalone ASGI client that does NOT require a running PostgreSQL instance — the DB session dependency is overridden with a mock. @@ -33,8 +34,28 @@ from main import app # noqa: E402 # ── Standalone test client (no DB required) ────────────────────────────────── @pytest_asyncio.fixture() -async def chat_client(): - """Async HTTP test client that mocks out the DB session entirely.""" +async def mock_redis(): + """In-memory mock of an async Redis client (get/set).""" + store: dict[str, tuple[str, int | None]] = {} + + mock = AsyncMock() + + async def _get(key): + entry = store.get(key) + return entry[0] if entry else None + + async def _set(key, value, ex=None): + store[key] = (value, ex) + + mock.get = AsyncMock(side_effect=_get) + mock.set = AsyncMock(side_effect=_set) + mock._store = store # exposed for test assertions + return mock + + +@pytest_asyncio.fixture() +async def chat_client(mock_redis): + """Async HTTP test client that mocks out the DB session and Redis.""" mock_session = AsyncMock() async def _mock_get_session(): @@ -44,7 +65,8 @@ async def chat_client(): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://testserver") as ac: - yield ac + with patch("routers.chat.get_redis", new_callable=AsyncMock, return_value=mock_redis): + yield ac app.dependency_overrides.pop(get_session, None) @@ -298,3 +320,246 @@ async def test_chat_llm_error_produces_error_event(chat_client): error_event = next(e for e in events if e["event"] == "error") assert "message" in error_event["data"] + + +# ── Multi-turn conversation memory tests ───────────────────────────────────── + + +def _chat_request_with_mocks(query, conversation_id=None, creator=None, token_chunks=None): + """Helper: build a request JSON and patched mocks for a single chat call.""" + if token_chunks is None: + token_chunks = ["ok"] + body: dict[str, Any] = {"query": query} + if conversation_id is not None: + body["conversation_id"] = conversation_id + if creator is not None: + body["creator"] = creator + return body, token_chunks + + +@pytest.mark.asyncio +async def test_conversation_done_event_includes_conversation_id(chat_client): + """The done SSE event includes conversation_id — both when provided and auto-generated.""" + search_result = _fake_search_result() + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create = AsyncMock( + return_value=_mock_openai_stream(["hello"]) + ) + + # With explicit conversation_id + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "test", "conversation_id": "conv-abc-123"}, + ) + + events = _parse_sse(resp.text) + done_data = next(e for e in events if e["event"] == "done")["data"] + assert done_data["conversation_id"] == "conv-abc-123" + + +@pytest.mark.asyncio +async def test_conversation_auto_generates_id_when_omitted(chat_client): + """When conversation_id is not provided, the done event still includes one (auto-generated).""" + search_result = _fake_search_result() + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create = AsyncMock( + return_value=_mock_openai_stream(["world"]) + ) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + resp = await chat_client.post("/api/v1/chat", json={"query": "no cid"}) + + events = _parse_sse(resp.text) + done_data = next(e for e in events if e["event"] == "done")["data"] + # Auto-generated UUID format + cid = done_data["conversation_id"] + assert isinstance(cid, str) + assert len(cid) == 36 # UUID4 format + + +@pytest.mark.asyncio +async def test_conversation_history_saved_to_redis(chat_client, mock_redis): + """After a successful chat, user+assistant messages are saved to Redis.""" + search_result = _fake_search_result() + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create = AsyncMock( + return_value=_mock_openai_stream(["The answer ", "is 42."]) + ) + + cid = "conv-save-test" + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "what is the answer", "conversation_id": cid}, + ) + + assert resp.status_code == 200 + + # Verify Redis received the history + redis_key = f"chrysopedia:chat:{cid}" + assert redis_key in mock_redis._store + stored_json, ttl = mock_redis._store[redis_key] + history = json.loads(stored_json) + + assert len(history) == 2 + assert history[0] == {"role": "user", "content": "what is the answer"} + assert history[1] == {"role": "assistant", "content": "The answer is 42."} + assert ttl == 3600 + + +@pytest.mark.asyncio +async def test_conversation_history_injected_into_llm_messages(chat_client, mock_redis): + """Prior conversation history is injected between system prompt and user message.""" + search_result = _fake_search_result() + + # Pre-populate Redis with conversation history + cid = "conv-inject-test" + prior_history = [ + {"role": "user", "content": "what is reverb"}, + {"role": "assistant", "content": "Reverb simulates acoustic spaces."}, + ] + mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 3600) + + captured_messages = [] + + mock_openai_client = MagicMock() + + async def _capture_create(**kwargs): + captured_messages.extend(kwargs.get("messages", [])) + return _mock_openai_stream(["follow-up answer"]) + + mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "how do I use it on drums", "conversation_id": cid}, + ) + + assert resp.status_code == 200 + + # Messages should be: system, prior_user, prior_assistant, current_user + assert len(captured_messages) == 4 + assert captured_messages[0]["role"] == "system" + assert captured_messages[1] == {"role": "user", "content": "what is reverb"} + assert captured_messages[2] == {"role": "assistant", "content": "Reverb simulates acoustic spaces."} + assert captured_messages[3] == {"role": "user", "content": "how do I use it on drums"} + + +@pytest.mark.asyncio +async def test_conversation_history_capped_at_10_pairs(chat_client, mock_redis): + """History is capped at 10 turn pairs (20 messages). Oldest turns are dropped.""" + search_result = _fake_search_result() + cid = "conv-cap-test" + + # Pre-populate with 10 turn pairs (20 messages) — at the cap + prior_history = [] + for i in range(10): + prior_history.append({"role": "user", "content": f"question {i}"}) + prior_history.append({"role": "assistant", "content": f"answer {i}"}) + + mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 3600) + + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create = AsyncMock( + return_value=_mock_openai_stream(["capped reply"]) + ) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "turn 11", "conversation_id": cid}, + ) + + assert resp.status_code == 200 + + # After adding turn 11, history should still be 20 messages (10 pairs) + redis_key = f"chrysopedia:chat:{cid}" + stored_json, _ = mock_redis._store[redis_key] + history = json.loads(stored_json) + + assert len(history) == 20 # 10 pairs + # Oldest pair (question 0 / answer 0) should be dropped + assert history[0] == {"role": "user", "content": "question 1"} + # New pair should be at the end + assert history[-2] == {"role": "user", "content": "turn 11"} + assert history[-1] == {"role": "assistant", "content": "capped reply"} + + +@pytest.mark.asyncio +async def test_conversation_ttl_refreshed_on_interaction(chat_client, mock_redis): + """Each interaction refreshes the Redis TTL to 1 hour.""" + search_result = _fake_search_result() + cid = "conv-ttl-test" + + # Pre-populate with a short simulated TTL + prior_history = [ + {"role": "user", "content": "old message"}, + {"role": "assistant", "content": "old reply"}, + ] + mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 100) + + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create = AsyncMock( + return_value=_mock_openai_stream(["refreshed"]) + ) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "refresh test", "conversation_id": cid}, + ) + + assert resp.status_code == 200 + + # TTL should be refreshed to 3600 + redis_key = f"chrysopedia:chat:{cid}" + _, ttl = mock_redis._store[redis_key] + assert ttl == 3600 + + +@pytest.mark.asyncio +async def test_single_turn_fallback_no_redis_history(chat_client, mock_redis): + """When no conversation_id is provided, no history is loaded and behavior matches single-turn.""" + search_result = _fake_search_result() + + captured_messages = [] + mock_openai_client = MagicMock() + + async def _capture_create(**kwargs): + captured_messages.extend(kwargs.get("messages", [])) + return _mock_openai_stream(["standalone answer"]) + + mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + resp = await chat_client.post("/api/v1/chat", json={"query": "standalone question"}) + + assert resp.status_code == 200 + + # Should only have system + user (no history injected since auto-generated cid is fresh) + assert len(captured_messages) == 2 + assert captured_messages[0]["role"] == "system" + assert captured_messages[1]["role"] == "user"