test: Added multi-turn conversation memory with Redis-backed history (1…
- "backend/chat_service.py" - "backend/routers/chat.py" - "backend/tests/test_chat.py" GSD-Task: S04/T01
This commit is contained in:
parent
0098254fdd
commit
29e60bbc99
3 changed files with 365 additions and 26 deletions
|
|
@ -3,6 +3,12 @@
|
||||||
Assembles a numbered context block from search results, then streams
|
Assembles a numbered context block from search results, then streams
|
||||||
completion tokens from an OpenAI-compatible API. Yields SSE-formatted
|
completion tokens from an OpenAI-compatible API. Yields SSE-formatted
|
||||||
events: sources, token, done, and error.
|
events: sources, token, done, and error.
|
||||||
|
|
||||||
|
Multi-turn memory: When a conversation_id is provided, prior messages are
|
||||||
|
loaded from Redis, injected into the LLM messages array, and the new
|
||||||
|
user+assistant turn is appended after streaming completes. History is
|
||||||
|
capped at 10 turn pairs (20 messages) and expires after 1 hour of
|
||||||
|
inactivity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -11,6 +17,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
import uuid
|
||||||
from typing import Any, AsyncIterator
|
from typing import Any, AsyncIterator
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
@ -32,35 +39,86 @@ Sources:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_MAX_CONTEXT_SOURCES = 10
|
_MAX_CONTEXT_SOURCES = 10
|
||||||
|
_MAX_TURN_PAIRS = 10
|
||||||
|
_HISTORY_TTL_SECONDS = 3600 # 1 hour
|
||||||
|
|
||||||
|
|
||||||
|
def _redis_key(conversation_id: str) -> str:
|
||||||
|
return f"chrysopedia:chat:{conversation_id}"
|
||||||
|
|
||||||
|
|
||||||
class ChatService:
|
class ChatService:
|
||||||
"""Retrieve context from search, stream an LLM response with citations."""
|
"""Retrieve context from search, stream an LLM response with citations."""
|
||||||
|
|
||||||
def __init__(self, settings: Settings) -> None:
|
def __init__(self, settings: Settings, redis=None) -> None:
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
self._search = SearchService(settings)
|
self._search = SearchService(settings)
|
||||||
self._openai = openai.AsyncOpenAI(
|
self._openai = openai.AsyncOpenAI(
|
||||||
base_url=settings.llm_api_url,
|
base_url=settings.llm_api_url,
|
||||||
api_key=settings.llm_api_key,
|
api_key=settings.llm_api_key,
|
||||||
)
|
)
|
||||||
|
self._redis = redis
|
||||||
|
|
||||||
|
async def _load_history(self, conversation_id: str) -> list[dict[str, str]]:
|
||||||
|
"""Load conversation history from Redis. Returns empty list on miss."""
|
||||||
|
if not self._redis:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
raw = await self._redis.get(_redis_key(conversation_id))
|
||||||
|
if raw:
|
||||||
|
return json.loads(raw)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("chat_history_load_error cid=%s", conversation_id, exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _save_history(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
history: list[dict[str, str]],
|
||||||
|
user_msg: str,
|
||||||
|
assistant_msg: str,
|
||||||
|
) -> None:
|
||||||
|
"""Append the new turn pair and persist to Redis with TTL refresh."""
|
||||||
|
if not self._redis:
|
||||||
|
return
|
||||||
|
history.append({"role": "user", "content": user_msg})
|
||||||
|
history.append({"role": "assistant", "content": assistant_msg})
|
||||||
|
# Cap at _MAX_TURN_PAIRS (keep most recent)
|
||||||
|
if len(history) > _MAX_TURN_PAIRS * 2:
|
||||||
|
history = history[-_MAX_TURN_PAIRS * 2:]
|
||||||
|
try:
|
||||||
|
await self._redis.set(
|
||||||
|
_redis_key(conversation_id),
|
||||||
|
json.dumps(history),
|
||||||
|
ex=_HISTORY_TTL_SECONDS,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True)
|
||||||
|
|
||||||
async def stream_response(
|
async def stream_response(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
creator: str | None = None,
|
creator: str | None = None,
|
||||||
|
conversation_id: str | None = None,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
"""Yield SSE-formatted events for a chat query.
|
"""Yield SSE-formatted events for a chat query.
|
||||||
|
|
||||||
Protocol:
|
Protocol:
|
||||||
1. ``event: sources\ndata: <json array of citation metadata>\n\n``
|
1. ``event: sources\ndata: <json array of citation metadata>\n\n``
|
||||||
2. ``event: token\ndata: <text chunk>\n\n`` (repeated)
|
2. ``event: token\ndata: <text chunk>\n\n`` (repeated)
|
||||||
3. ``event: done\ndata: <json with cascade_tier>\n\n``
|
3. ``event: done\ndata: <json with cascade_tier, conversation_id>\n\n``
|
||||||
On error: ``event: error\ndata: <json with message>\n\n``
|
On error: ``event: error\ndata: <json with message>\n\n``
|
||||||
"""
|
"""
|
||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
|
|
||||||
|
# Assign conversation_id if not provided (single-turn becomes trackable)
|
||||||
|
if conversation_id is None:
|
||||||
|
conversation_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# ── 0. Load conversation history ────────────────────────────────
|
||||||
|
history = await self._load_history(conversation_id)
|
||||||
|
|
||||||
# ── 1. Retrieve context via search ──────────────────────────────
|
# ── 1. Retrieve context via search ──────────────────────────────
|
||||||
try:
|
try:
|
||||||
search_result = await self._search.search(
|
search_result = await self._search.search(
|
||||||
|
|
@ -83,8 +141,8 @@ class ChatService:
|
||||||
context_block = _build_context_block(items)
|
context_block = _build_context_block(items)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"chat_search query=%r creator=%r cascade_tier=%s source_count=%d",
|
"chat_search query=%r creator=%r cascade_tier=%s source_count=%d cid=%s",
|
||||||
query, creator, cascade_tier, len(sources),
|
query, creator, cascade_tier, len(sources), conversation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Emit sources event first
|
# Emit sources event first
|
||||||
|
|
@ -93,13 +151,19 @@ class ChatService:
|
||||||
# ── 3. Stream LLM completion ────────────────────────────────────
|
# ── 3. Stream LLM completion ────────────────────────────────────
|
||||||
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
|
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
|
||||||
|
|
||||||
|
messages: list[dict[str, str]] = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
]
|
||||||
|
# Inject conversation history between system prompt and current query
|
||||||
|
messages.extend(history)
|
||||||
|
messages.append({"role": "user", "content": query})
|
||||||
|
|
||||||
|
accumulated_response = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stream = await self._openai.chat.completions.create(
|
stream = await self._openai.chat.completions.create(
|
||||||
model=self.settings.llm_model,
|
model=self.settings.llm_model,
|
||||||
messages=[
|
messages=messages,
|
||||||
{"role": "system", "content": system_prompt},
|
|
||||||
{"role": "user", "content": query},
|
|
||||||
],
|
|
||||||
stream=True,
|
stream=True,
|
||||||
temperature=0.3,
|
temperature=0.3,
|
||||||
max_tokens=2048,
|
max_tokens=2048,
|
||||||
|
|
@ -108,21 +172,26 @@ class ChatService:
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
choice = chunk.choices[0] if chunk.choices else None
|
choice = chunk.choices[0] if chunk.choices else None
|
||||||
if choice and choice.delta and choice.delta.content:
|
if choice and choice.delta and choice.delta.content:
|
||||||
yield _sse("token", choice.delta.content)
|
text = choice.delta.content
|
||||||
|
accumulated_response += text
|
||||||
|
yield _sse("token", text)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
tb = traceback.format_exc()
|
tb = traceback.format_exc()
|
||||||
logger.error("chat_llm_error query=%r\n%s", query, tb)
|
logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb)
|
||||||
yield _sse("error", {"message": "LLM generation failed"})
|
yield _sse("error", {"message": "LLM generation failed"})
|
||||||
return
|
return
|
||||||
|
|
||||||
# ── 4. Done event ───────────────────────────────────────────────
|
# ── 4. Save conversation history ────────────────────────────────
|
||||||
|
await self._save_history(conversation_id, history, query, accumulated_response)
|
||||||
|
|
||||||
|
# ── 5. Done event ───────────────────────────────────────────────
|
||||||
latency_ms = (time.monotonic() - start) * 1000
|
latency_ms = (time.monotonic() - start) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f",
|
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s",
|
||||||
query, creator, cascade_tier, len(sources), latency_ms,
|
query, creator, cascade_tier, len(sources), latency_ms, conversation_id,
|
||||||
)
|
)
|
||||||
yield _sse("done", {"cascade_tier": cascade_tier})
|
yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id})
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from chat_service import ChatService
|
from chat_service import ChatService
|
||||||
from config import Settings, get_settings
|
from config import Settings, get_settings
|
||||||
from database import get_session
|
from database import get_session
|
||||||
|
from redis_client import get_redis
|
||||||
|
|
||||||
logger = logging.getLogger("chrysopedia.chat.router")
|
logger = logging.getLogger("chrysopedia.chat.router")
|
||||||
|
|
||||||
|
|
@ -27,31 +28,35 @@ class ChatRequest(BaseModel):
|
||||||
|
|
||||||
query: str = Field(..., min_length=1, max_length=1000)
|
query: str = Field(..., min_length=1, max_length=1000)
|
||||||
creator: str | None = None
|
creator: str | None = None
|
||||||
|
conversation_id: str | None = None
|
||||||
|
|
||||||
def _get_chat_service(settings: Settings = Depends(get_settings)) -> ChatService:
|
|
||||||
"""Build a ChatService from current settings."""
|
|
||||||
return ChatService(settings)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def chat(
|
async def chat(
|
||||||
body: ChatRequest,
|
body: ChatRequest,
|
||||||
db: AsyncSession = Depends(get_session),
|
db: AsyncSession = Depends(get_session),
|
||||||
service: ChatService = Depends(_get_chat_service),
|
settings: Settings = Depends(get_settings),
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
"""Stream a chat response as Server-Sent Events.
|
"""Stream a chat response as Server-Sent Events.
|
||||||
|
|
||||||
SSE protocol:
|
SSE protocol:
|
||||||
- ``event: sources`` — citation metadata array (sent first)
|
- ``event: sources`` — citation metadata array (sent first)
|
||||||
- ``event: token`` — streamed text chunk (repeated)
|
- ``event: token`` — streamed text chunk (repeated)
|
||||||
- ``event: done`` — completion metadata with cascade_tier
|
- ``event: done`` — completion metadata with cascade_tier, conversation_id
|
||||||
- ``event: error`` — error message (on failure)
|
- ``event: error`` — error message (on failure)
|
||||||
"""
|
"""
|
||||||
logger.info("chat_request query=%r creator=%r", body.query, body.creator)
|
logger.info("chat_request query=%r creator=%r cid=%r", body.query, body.creator, body.conversation_id)
|
||||||
|
|
||||||
|
redis = await get_redis()
|
||||||
|
service = ChatService(settings, redis=redis)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
service.stream_response(query=body.query, db=db, creator=body.creator),
|
service.stream_response(
|
||||||
|
query=body.query,
|
||||||
|
db=db,
|
||||||
|
creator=body.creator,
|
||||||
|
conversation_id=body.conversation_id,
|
||||||
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ Mocks SearchService.search() and the OpenAI streaming response to verify:
|
||||||
3. Creator param forwarded to search
|
3. Creator param forwarded to search
|
||||||
4. Empty/invalid query returns 422
|
4. Empty/invalid query returns 422
|
||||||
5. LLM error produces an SSE error event
|
5. LLM error produces an SSE error event
|
||||||
|
6. Multi-turn conversation memory via Redis (load/save/cap/TTL)
|
||||||
|
|
||||||
These tests use a standalone ASGI client that does NOT require a running
|
These tests use a standalone ASGI client that does NOT require a running
|
||||||
PostgreSQL instance — the DB session dependency is overridden with a mock.
|
PostgreSQL instance — the DB session dependency is overridden with a mock.
|
||||||
|
|
@ -33,8 +34,28 @@ from main import app # noqa: E402
|
||||||
# ── Standalone test client (no DB required) ──────────────────────────────────
|
# ── Standalone test client (no DB required) ──────────────────────────────────
|
||||||
|
|
||||||
@pytest_asyncio.fixture()
|
@pytest_asyncio.fixture()
|
||||||
async def chat_client():
|
async def mock_redis():
|
||||||
"""Async HTTP test client that mocks out the DB session entirely."""
|
"""In-memory mock of an async Redis client (get/set)."""
|
||||||
|
store: dict[str, tuple[str, int | None]] = {}
|
||||||
|
|
||||||
|
mock = AsyncMock()
|
||||||
|
|
||||||
|
async def _get(key):
|
||||||
|
entry = store.get(key)
|
||||||
|
return entry[0] if entry else None
|
||||||
|
|
||||||
|
async def _set(key, value, ex=None):
|
||||||
|
store[key] = (value, ex)
|
||||||
|
|
||||||
|
mock.get = AsyncMock(side_effect=_get)
|
||||||
|
mock.set = AsyncMock(side_effect=_set)
|
||||||
|
mock._store = store # exposed for test assertions
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture()
|
||||||
|
async def chat_client(mock_redis):
|
||||||
|
"""Async HTTP test client that mocks out the DB session and Redis."""
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
|
|
||||||
async def _mock_get_session():
|
async def _mock_get_session():
|
||||||
|
|
@ -44,7 +65,8 @@ async def chat_client():
|
||||||
|
|
||||||
transport = ASGITransport(app=app)
|
transport = ASGITransport(app=app)
|
||||||
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
|
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
|
||||||
yield ac
|
with patch("routers.chat.get_redis", new_callable=AsyncMock, return_value=mock_redis):
|
||||||
|
yield ac
|
||||||
|
|
||||||
app.dependency_overrides.pop(get_session, None)
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
@ -298,3 +320,246 @@ async def test_chat_llm_error_produces_error_event(chat_client):
|
||||||
|
|
||||||
error_event = next(e for e in events if e["event"] == "error")
|
error_event = next(e for e in events if e["event"] == "error")
|
||||||
assert "message" in error_event["data"]
|
assert "message" in error_event["data"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Multi-turn conversation memory tests ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _chat_request_with_mocks(query, conversation_id=None, creator=None, token_chunks=None):
|
||||||
|
"""Helper: build a request JSON and patched mocks for a single chat call."""
|
||||||
|
if token_chunks is None:
|
||||||
|
token_chunks = ["ok"]
|
||||||
|
body: dict[str, Any] = {"query": query}
|
||||||
|
if conversation_id is not None:
|
||||||
|
body["conversation_id"] = conversation_id
|
||||||
|
if creator is not None:
|
||||||
|
body["creator"] = creator
|
||||||
|
return body, token_chunks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_conversation_done_event_includes_conversation_id(chat_client):
|
||||||
|
"""The done SSE event includes conversation_id — both when provided and auto-generated."""
|
||||||
|
search_result = _fake_search_result()
|
||||||
|
mock_openai_client = MagicMock()
|
||||||
|
mock_openai_client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=_mock_openai_stream(["hello"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# With explicit conversation_id
|
||||||
|
with (
|
||||||
|
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||||
|
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||||
|
):
|
||||||
|
resp = await chat_client.post(
|
||||||
|
"/api/v1/chat",
|
||||||
|
json={"query": "test", "conversation_id": "conv-abc-123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
events = _parse_sse(resp.text)
|
||||||
|
done_data = next(e for e in events if e["event"] == "done")["data"]
|
||||||
|
assert done_data["conversation_id"] == "conv-abc-123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_conversation_auto_generates_id_when_omitted(chat_client):
|
||||||
|
"""When conversation_id is not provided, the done event still includes one (auto-generated)."""
|
||||||
|
search_result = _fake_search_result()
|
||||||
|
mock_openai_client = MagicMock()
|
||||||
|
mock_openai_client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=_mock_openai_stream(["world"])
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||||
|
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||||
|
):
|
||||||
|
resp = await chat_client.post("/api/v1/chat", json={"query": "no cid"})
|
||||||
|
|
||||||
|
events = _parse_sse(resp.text)
|
||||||
|
done_data = next(e for e in events if e["event"] == "done")["data"]
|
||||||
|
# Auto-generated UUID format
|
||||||
|
cid = done_data["conversation_id"]
|
||||||
|
assert isinstance(cid, str)
|
||||||
|
assert len(cid) == 36 # UUID4 format
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_conversation_history_saved_to_redis(chat_client, mock_redis):
|
||||||
|
"""After a successful chat, user+assistant messages are saved to Redis."""
|
||||||
|
search_result = _fake_search_result()
|
||||||
|
mock_openai_client = MagicMock()
|
||||||
|
mock_openai_client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=_mock_openai_stream(["The answer ", "is 42."])
|
||||||
|
)
|
||||||
|
|
||||||
|
cid = "conv-save-test"
|
||||||
|
with (
|
||||||
|
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||||
|
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||||
|
):
|
||||||
|
resp = await chat_client.post(
|
||||||
|
"/api/v1/chat",
|
||||||
|
json={"query": "what is the answer", "conversation_id": cid},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# Verify Redis received the history
|
||||||
|
redis_key = f"chrysopedia:chat:{cid}"
|
||||||
|
assert redis_key in mock_redis._store
|
||||||
|
stored_json, ttl = mock_redis._store[redis_key]
|
||||||
|
history = json.loads(stored_json)
|
||||||
|
|
||||||
|
assert len(history) == 2
|
||||||
|
assert history[0] == {"role": "user", "content": "what is the answer"}
|
||||||
|
assert history[1] == {"role": "assistant", "content": "The answer is 42."}
|
||||||
|
assert ttl == 3600
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_conversation_history_injected_into_llm_messages(chat_client, mock_redis):
|
||||||
|
"""Prior conversation history is injected between system prompt and user message."""
|
||||||
|
search_result = _fake_search_result()
|
||||||
|
|
||||||
|
# Pre-populate Redis with conversation history
|
||||||
|
cid = "conv-inject-test"
|
||||||
|
prior_history = [
|
||||||
|
{"role": "user", "content": "what is reverb"},
|
||||||
|
{"role": "assistant", "content": "Reverb simulates acoustic spaces."},
|
||||||
|
]
|
||||||
|
mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 3600)
|
||||||
|
|
||||||
|
captured_messages = []
|
||||||
|
|
||||||
|
mock_openai_client = MagicMock()
|
||||||
|
|
||||||
|
async def _capture_create(**kwargs):
|
||||||
|
captured_messages.extend(kwargs.get("messages", []))
|
||||||
|
return _mock_openai_stream(["follow-up answer"])
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||||
|
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||||
|
):
|
||||||
|
resp = await chat_client.post(
|
||||||
|
"/api/v1/chat",
|
||||||
|
json={"query": "how do I use it on drums", "conversation_id": cid},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# Messages should be: system, prior_user, prior_assistant, current_user
|
||||||
|
assert len(captured_messages) == 4
|
||||||
|
assert captured_messages[0]["role"] == "system"
|
||||||
|
assert captured_messages[1] == {"role": "user", "content": "what is reverb"}
|
||||||
|
assert captured_messages[2] == {"role": "assistant", "content": "Reverb simulates acoustic spaces."}
|
||||||
|
assert captured_messages[3] == {"role": "user", "content": "how do I use it on drums"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_conversation_history_capped_at_10_pairs(chat_client, mock_redis):
|
||||||
|
"""History is capped at 10 turn pairs (20 messages). Oldest turns are dropped."""
|
||||||
|
search_result = _fake_search_result()
|
||||||
|
cid = "conv-cap-test"
|
||||||
|
|
||||||
|
# Pre-populate with 10 turn pairs (20 messages) — at the cap
|
||||||
|
prior_history = []
|
||||||
|
for i in range(10):
|
||||||
|
prior_history.append({"role": "user", "content": f"question {i}"})
|
||||||
|
prior_history.append({"role": "assistant", "content": f"answer {i}"})
|
||||||
|
|
||||||
|
mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 3600)
|
||||||
|
|
||||||
|
mock_openai_client = MagicMock()
|
||||||
|
mock_openai_client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=_mock_openai_stream(["capped reply"])
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||||
|
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||||
|
):
|
||||||
|
resp = await chat_client.post(
|
||||||
|
"/api/v1/chat",
|
||||||
|
json={"query": "turn 11", "conversation_id": cid},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# After adding turn 11, history should still be 20 messages (10 pairs)
|
||||||
|
redis_key = f"chrysopedia:chat:{cid}"
|
||||||
|
stored_json, _ = mock_redis._store[redis_key]
|
||||||
|
history = json.loads(stored_json)
|
||||||
|
|
||||||
|
assert len(history) == 20 # 10 pairs
|
||||||
|
# Oldest pair (question 0 / answer 0) should be dropped
|
||||||
|
assert history[0] == {"role": "user", "content": "question 1"}
|
||||||
|
# New pair should be at the end
|
||||||
|
assert history[-2] == {"role": "user", "content": "turn 11"}
|
||||||
|
assert history[-1] == {"role": "assistant", "content": "capped reply"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_conversation_ttl_refreshed_on_interaction(chat_client, mock_redis):
|
||||||
|
"""Each interaction refreshes the Redis TTL to 1 hour."""
|
||||||
|
search_result = _fake_search_result()
|
||||||
|
cid = "conv-ttl-test"
|
||||||
|
|
||||||
|
# Pre-populate with a short simulated TTL
|
||||||
|
prior_history = [
|
||||||
|
{"role": "user", "content": "old message"},
|
||||||
|
{"role": "assistant", "content": "old reply"},
|
||||||
|
]
|
||||||
|
mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 100)
|
||||||
|
|
||||||
|
mock_openai_client = MagicMock()
|
||||||
|
mock_openai_client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=_mock_openai_stream(["refreshed"])
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||||
|
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||||
|
):
|
||||||
|
resp = await chat_client.post(
|
||||||
|
"/api/v1/chat",
|
||||||
|
json={"query": "refresh test", "conversation_id": cid},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# TTL should be refreshed to 3600
|
||||||
|
redis_key = f"chrysopedia:chat:{cid}"
|
||||||
|
_, ttl = mock_redis._store[redis_key]
|
||||||
|
assert ttl == 3600
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_turn_fallback_no_redis_history(chat_client, mock_redis):
|
||||||
|
"""When no conversation_id is provided, no history is loaded and behavior matches single-turn."""
|
||||||
|
search_result = _fake_search_result()
|
||||||
|
|
||||||
|
captured_messages = []
|
||||||
|
mock_openai_client = MagicMock()
|
||||||
|
|
||||||
|
async def _capture_create(**kwargs):
|
||||||
|
captured_messages.extend(kwargs.get("messages", []))
|
||||||
|
return _mock_openai_stream(["standalone answer"])
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||||
|
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||||
|
):
|
||||||
|
resp = await chat_client.post("/api/v1/chat", json={"query": "standalone question"})
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# Should only have system + user (no history injected since auto-generated cid is fresh)
|
||||||
|
assert len(captured_messages) == 2
|
||||||
|
assert captured_messages[0]["role"] == "system"
|
||||||
|
assert captured_messages[1]["role"] == "user"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue