- "backend/chat_service.py" - "backend/tests/test_chat.py" - "docker-compose.yml" GSD-Task: S08/T01
1078 lines
38 KiB
Python
1078 lines
38 KiB
Python
"""Integration tests for the chat SSE endpoint.
|
||
|
||
Mocks SearchService.search() and the OpenAI streaming response to verify:
|
||
1. Valid SSE format with sources, token, and done events
|
||
2. Citation numbering matches the sources array
|
||
3. Creator param forwarded to search
|
||
4. Empty/invalid query returns 422
|
||
5. LLM error produces an SSE error event
|
||
6. Multi-turn conversation memory via Redis (load/save/cap/TTL)
|
||
|
||
These tests use a standalone ASGI client that does NOT require a running
|
||
PostgreSQL instance — the DB session dependency is overridden with a mock.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from typing import Any
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
import pytest_asyncio
|
||
import openai
|
||
from httpx import ASGITransport, AsyncClient
|
||
|
||
# Ensure backend/ is on sys.path
|
||
import pathlib
|
||
import sys
|
||
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent))
|
||
|
||
from database import get_session # noqa: E402
|
||
from main import app # noqa: E402
|
||
|
||
|
||
# ── Standalone test client (no DB required) ──────────────────────────────────
|
||
|
||
@pytest_asyncio.fixture()
|
||
async def mock_redis():
|
||
"""In-memory mock of an async Redis client (get/set)."""
|
||
store: dict[str, tuple[str, int | None]] = {}
|
||
|
||
mock = AsyncMock()
|
||
|
||
async def _get(key):
|
||
entry = store.get(key)
|
||
return entry[0] if entry else None
|
||
|
||
async def _set(key, value, ex=None):
|
||
store[key] = (value, ex)
|
||
|
||
mock.get = AsyncMock(side_effect=_get)
|
||
mock.set = AsyncMock(side_effect=_set)
|
||
mock._store = store # exposed for test assertions
|
||
return mock
|
||
|
||
|
||
@pytest_asyncio.fixture()
|
||
async def chat_client(mock_redis):
|
||
"""Async HTTP test client that mocks out the DB session and Redis."""
|
||
mock_session = AsyncMock()
|
||
|
||
async def _mock_get_session():
|
||
yield mock_session
|
||
|
||
app.dependency_overrides[get_session] = _mock_get_session
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
|
||
with patch("routers.chat.get_redis", new_callable=AsyncMock, return_value=mock_redis):
|
||
yield ac
|
||
|
||
app.dependency_overrides.pop(get_session, None)
|
||
|
||
|
||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||
|
||
|
||
def _parse_sse(body: str) -> list[dict[str, Any]]:
|
||
"""Parse SSE text into a list of {event, data} dicts."""
|
||
events: list[dict[str, Any]] = []
|
||
current_event: str | None = None
|
||
current_data: str | None = None
|
||
|
||
for line in body.split("\n"):
|
||
if line.startswith("event: "):
|
||
current_event = line[len("event: "):]
|
||
elif line.startswith("data: "):
|
||
current_data = line[len("data: "):]
|
||
elif line == "" and current_event is not None and current_data is not None:
|
||
try:
|
||
parsed = json.loads(current_data)
|
||
except json.JSONDecodeError:
|
||
parsed = current_data
|
||
events.append({"event": current_event, "data": parsed})
|
||
current_event = None
|
||
current_data = None
|
||
|
||
return events
|
||
|
||
|
||
def _fake_search_result(
|
||
items: list[dict[str, Any]] | None = None,
|
||
cascade_tier: str = "global",
|
||
) -> dict[str, Any]:
|
||
"""Build a fake SearchService.search() return value."""
|
||
if items is None:
|
||
items = [
|
||
{
|
||
"title": "Snare Compression",
|
||
"slug": "snare-compression",
|
||
"technique_page_slug": "snare-compression",
|
||
"creator_name": "Keota",
|
||
"topic_category": "Mixing",
|
||
"summary": "How to compress a snare drum for punch and presence.",
|
||
"section_anchor": "",
|
||
"section_heading": "",
|
||
"type": "technique_page",
|
||
"score": 0.9,
|
||
},
|
||
{
|
||
"title": "Parallel Processing",
|
||
"slug": "parallel-processing",
|
||
"technique_page_slug": "parallel-processing",
|
||
"creator_name": "Skope",
|
||
"topic_category": "Mixing",
|
||
"summary": "Using parallel compression for dynamics control.",
|
||
"section_anchor": "bus-setup",
|
||
"section_heading": "Bus Setup",
|
||
"type": "technique_page",
|
||
"score": 0.85,
|
||
},
|
||
]
|
||
return {
|
||
"items": items,
|
||
"partial_matches": [],
|
||
"total": len(items),
|
||
"query": "snare compression",
|
||
"fallback_used": False,
|
||
"cascade_tier": cascade_tier,
|
||
}
|
||
|
||
|
||
def _mock_openai_stream(chunks: list[str]):
|
||
"""Create a mock async iterator that yields OpenAI-style stream chunks."""
|
||
|
||
class FakeChoice:
|
||
def __init__(self, text: str | None):
|
||
self.delta = MagicMock()
|
||
self.delta.content = text
|
||
|
||
class FakeChunk:
|
||
def __init__(self, text: str | None):
|
||
self.choices = [FakeChoice(text)]
|
||
|
||
class FakeStream:
|
||
def __init__(self, chunks: list[str]):
|
||
self._chunks = chunks
|
||
self._index = 0
|
||
|
||
def __aiter__(self):
|
||
return self
|
||
|
||
async def __anext__(self):
|
||
if self._index >= len(self._chunks):
|
||
raise StopAsyncIteration
|
||
chunk = FakeChunk(self._chunks[self._index])
|
||
self._index += 1
|
||
return chunk
|
||
|
||
return FakeStream(chunks)
|
||
|
||
|
||
def _mock_openai_stream_error():
|
||
"""Create a mock async iterator that raises mid-stream."""
|
||
|
||
class FakeStream:
|
||
def __init__(self):
|
||
self._yielded = False
|
||
|
||
def __aiter__(self):
|
||
return self
|
||
|
||
async def __anext__(self):
|
||
if not self._yielded:
|
||
self._yielded = True
|
||
raise RuntimeError("LLM connection lost")
|
||
raise StopAsyncIteration
|
||
|
||
return FakeStream()
|
||
|
||
|
||
# ── Tests ────────────────────────────────────────────────────────────────────
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_sse_format_and_events(chat_client):
|
||
"""SSE stream contains sources, token(s), and done events in order."""
|
||
search_result = _fake_search_result()
|
||
token_chunks = ["Snare compression ", "uses [1] to add ", "punch. See also [2]."]
|
||
|
||
mock_openai_client = MagicMock()
|
||
mock_openai_client.chat.completions.create = AsyncMock(
|
||
return_value=_mock_openai_stream(token_chunks)
|
||
)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
resp = await chat_client.post("/api/v1/chat", json={"query": "snare compression"})
|
||
|
||
assert resp.status_code == 200
|
||
assert "text/event-stream" in resp.headers.get("content-type", "")
|
||
|
||
events = _parse_sse(resp.text)
|
||
event_types = [e["event"] for e in events]
|
||
|
||
# Must have sources first, then tokens, then done
|
||
assert event_types[0] == "sources"
|
||
assert "token" in event_types
|
||
assert event_types[-1] == "done"
|
||
|
||
# Sources event is a list
|
||
sources_data = events[0]["data"]
|
||
assert isinstance(sources_data, list)
|
||
assert len(sources_data) == 2
|
||
|
||
# Done event has cascade_tier
|
||
done_data = events[-1]["data"]
|
||
assert "cascade_tier" in done_data
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_citation_numbering(chat_client):
|
||
"""Citation numbers in sources array match 1-based indexing."""
|
||
search_result = _fake_search_result()
|
||
|
||
mock_openai_client = MagicMock()
|
||
mock_openai_client.chat.completions.create = AsyncMock(
|
||
return_value=_mock_openai_stream(["hello"])
|
||
)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
resp = await chat_client.post("/api/v1/chat", json={"query": "compression"})
|
||
|
||
events = _parse_sse(resp.text)
|
||
sources = events[0]["data"]
|
||
|
||
assert sources[0]["number"] == 1
|
||
assert sources[0]["title"] == "Snare Compression"
|
||
assert sources[1]["number"] == 2
|
||
assert sources[1]["title"] == "Parallel Processing"
|
||
assert sources[1]["section_anchor"] == "bus-setup"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_creator_forwarded_to_search(chat_client):
|
||
"""Creator parameter is passed through to SearchService.search()."""
|
||
search_result = _fake_search_result()
|
||
|
||
mock_openai_client = MagicMock()
|
||
mock_openai_client.chat.completions.create = AsyncMock(
|
||
return_value=_mock_openai_stream(["ok"])
|
||
)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result) as mock_search,
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "drum mixing", "creator": "keota"},
|
||
)
|
||
|
||
assert resp.status_code == 200
|
||
mock_search.assert_called_once()
|
||
call_kwargs = mock_search.call_args.kwargs
|
||
assert call_kwargs.get("creator") == "keota"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_empty_query_returns_422(chat_client):
|
||
"""An empty query string should fail Pydantic validation with 422."""
|
||
resp = await chat_client.post("/api/v1/chat", json={"query": ""})
|
||
assert resp.status_code == 422
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_missing_query_returns_422(chat_client):
|
||
"""Missing query field should fail with 422."""
|
||
resp = await chat_client.post("/api/v1/chat", json={})
|
||
assert resp.status_code == 422
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_llm_error_produces_error_event(chat_client):
|
||
"""When the LLM raises mid-stream, an error SSE event is emitted."""
|
||
search_result = _fake_search_result()
|
||
|
||
mock_openai_client = MagicMock()
|
||
mock_openai_client.chat.completions.create = AsyncMock(
|
||
return_value=_mock_openai_stream_error()
|
||
)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
resp = await chat_client.post("/api/v1/chat", json={"query": "test error"})
|
||
|
||
assert resp.status_code == 200 # SSE streams always return 200
|
||
|
||
events = _parse_sse(resp.text)
|
||
event_types = [e["event"] for e in events]
|
||
|
||
assert "sources" in event_types
|
||
assert "error" in event_types
|
||
|
||
error_event = next(e for e in events if e["event"] == "error")
|
||
assert "message" in error_event["data"]
|
||
|
||
|
||
# ── Multi-turn conversation memory tests ─────────────────────────────────────
|
||
|
||
|
||
def _chat_request_with_mocks(query, conversation_id=None, creator=None, token_chunks=None):
|
||
"""Helper: build a request JSON and patched mocks for a single chat call."""
|
||
if token_chunks is None:
|
||
token_chunks = ["ok"]
|
||
body: dict[str, Any] = {"query": query}
|
||
if conversation_id is not None:
|
||
body["conversation_id"] = conversation_id
|
||
if creator is not None:
|
||
body["creator"] = creator
|
||
return body, token_chunks
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_conversation_done_event_includes_conversation_id(chat_client):
|
||
"""The done SSE event includes conversation_id — both when provided and auto-generated."""
|
||
search_result = _fake_search_result()
|
||
mock_openai_client = MagicMock()
|
||
mock_openai_client.chat.completions.create = AsyncMock(
|
||
return_value=_mock_openai_stream(["hello"])
|
||
)
|
||
|
||
# With explicit conversation_id
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "test", "conversation_id": "conv-abc-123"},
|
||
)
|
||
|
||
events = _parse_sse(resp.text)
|
||
done_data = next(e for e in events if e["event"] == "done")["data"]
|
||
assert done_data["conversation_id"] == "conv-abc-123"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_conversation_auto_generates_id_when_omitted(chat_client):
|
||
"""When conversation_id is not provided, the done event still includes one (auto-generated)."""
|
||
search_result = _fake_search_result()
|
||
mock_openai_client = MagicMock()
|
||
mock_openai_client.chat.completions.create = AsyncMock(
|
||
return_value=_mock_openai_stream(["world"])
|
||
)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
resp = await chat_client.post("/api/v1/chat", json={"query": "no cid"})
|
||
|
||
events = _parse_sse(resp.text)
|
||
done_data = next(e for e in events if e["event"] == "done")["data"]
|
||
# Auto-generated UUID format
|
||
cid = done_data["conversation_id"]
|
||
assert isinstance(cid, str)
|
||
assert len(cid) == 36 # UUID4 format
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_conversation_history_saved_to_redis(chat_client, mock_redis):
|
||
"""After a successful chat, user+assistant messages are saved to Redis."""
|
||
search_result = _fake_search_result()
|
||
mock_openai_client = MagicMock()
|
||
mock_openai_client.chat.completions.create = AsyncMock(
|
||
return_value=_mock_openai_stream(["The answer ", "is 42."])
|
||
)
|
||
|
||
cid = "conv-save-test"
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "what is the answer", "conversation_id": cid},
|
||
)
|
||
|
||
assert resp.status_code == 200
|
||
|
||
# Verify Redis received the history
|
||
redis_key = f"chrysopedia:chat:{cid}"
|
||
assert redis_key in mock_redis._store
|
||
stored_json, ttl = mock_redis._store[redis_key]
|
||
history = json.loads(stored_json)
|
||
|
||
assert len(history) == 2
|
||
assert history[0] == {"role": "user", "content": "what is the answer"}
|
||
assert history[1] == {"role": "assistant", "content": "The answer is 42."}
|
||
assert ttl == 3600
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_conversation_history_injected_into_llm_messages(chat_client, mock_redis):
|
||
"""Prior conversation history is injected between system prompt and user message."""
|
||
search_result = _fake_search_result()
|
||
|
||
# Pre-populate Redis with conversation history
|
||
cid = "conv-inject-test"
|
||
prior_history = [
|
||
{"role": "user", "content": "what is reverb"},
|
||
{"role": "assistant", "content": "Reverb simulates acoustic spaces."},
|
||
]
|
||
mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 3600)
|
||
|
||
captured_messages = []
|
||
|
||
mock_openai_client = MagicMock()
|
||
|
||
async def _capture_create(**kwargs):
|
||
captured_messages.extend(kwargs.get("messages", []))
|
||
return _mock_openai_stream(["follow-up answer"])
|
||
|
||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "how do I use it on drums", "conversation_id": cid},
|
||
)
|
||
|
||
assert resp.status_code == 200
|
||
|
||
# Messages should be: system, prior_user, prior_assistant, current_user
|
||
assert len(captured_messages) == 4
|
||
assert captured_messages[0]["role"] == "system"
|
||
assert captured_messages[1] == {"role": "user", "content": "what is reverb"}
|
||
assert captured_messages[2] == {"role": "assistant", "content": "Reverb simulates acoustic spaces."}
|
||
assert captured_messages[3] == {"role": "user", "content": "how do I use it on drums"}
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_conversation_history_capped_at_10_pairs(chat_client, mock_redis):
|
||
"""History is capped at 10 turn pairs (20 messages). Oldest turns are dropped."""
|
||
search_result = _fake_search_result()
|
||
cid = "conv-cap-test"
|
||
|
||
# Pre-populate with 10 turn pairs (20 messages) — at the cap
|
||
prior_history = []
|
||
for i in range(10):
|
||
prior_history.append({"role": "user", "content": f"question {i}"})
|
||
prior_history.append({"role": "assistant", "content": f"answer {i}"})
|
||
|
||
mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 3600)
|
||
|
||
mock_openai_client = MagicMock()
|
||
mock_openai_client.chat.completions.create = AsyncMock(
|
||
return_value=_mock_openai_stream(["capped reply"])
|
||
)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "turn 11", "conversation_id": cid},
|
||
)
|
||
|
||
assert resp.status_code == 200
|
||
|
||
# After adding turn 11, history should still be 20 messages (10 pairs)
|
||
redis_key = f"chrysopedia:chat:{cid}"
|
||
stored_json, _ = mock_redis._store[redis_key]
|
||
history = json.loads(stored_json)
|
||
|
||
assert len(history) == 20 # 10 pairs
|
||
# Oldest pair (question 0 / answer 0) should be dropped
|
||
assert history[0] == {"role": "user", "content": "question 1"}
|
||
# New pair should be at the end
|
||
assert history[-2] == {"role": "user", "content": "turn 11"}
|
||
assert history[-1] == {"role": "assistant", "content": "capped reply"}
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_conversation_ttl_refreshed_on_interaction(chat_client, mock_redis):
|
||
"""Each interaction refreshes the Redis TTL to 1 hour."""
|
||
search_result = _fake_search_result()
|
||
cid = "conv-ttl-test"
|
||
|
||
# Pre-populate with a short simulated TTL
|
||
prior_history = [
|
||
{"role": "user", "content": "old message"},
|
||
{"role": "assistant", "content": "old reply"},
|
||
]
|
||
mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 100)
|
||
|
||
mock_openai_client = MagicMock()
|
||
mock_openai_client.chat.completions.create = AsyncMock(
|
||
return_value=_mock_openai_stream(["refreshed"])
|
||
)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "refresh test", "conversation_id": cid},
|
||
)
|
||
|
||
assert resp.status_code == 200
|
||
|
||
# TTL should be refreshed to 3600
|
||
redis_key = f"chrysopedia:chat:{cid}"
|
||
_, ttl = mock_redis._store[redis_key]
|
||
assert ttl == 3600
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_single_turn_fallback_no_redis_history(chat_client, mock_redis):
|
||
"""When no conversation_id is provided, no history is loaded and behavior matches single-turn."""
|
||
search_result = _fake_search_result()
|
||
|
||
captured_messages = []
|
||
mock_openai_client = MagicMock()
|
||
|
||
async def _capture_create(**kwargs):
|
||
captured_messages.extend(kwargs.get("messages", []))
|
||
return _mock_openai_stream(["standalone answer"])
|
||
|
||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
resp = await chat_client.post("/api/v1/chat", json={"query": "standalone question"})
|
||
|
||
assert resp.status_code == 200
|
||
|
||
# Should only have system + user (no history injected since auto-generated cid is fresh)
|
||
assert len(captured_messages) == 2
|
||
assert captured_messages[0]["role"] == "system"
|
||
assert captured_messages[1]["role"] == "user"
|
||
|
||
|
||
# ── Personality weight tests ─────────────────────────────────────────────────
|
||
|
||
|
||
_FAKE_PERSONALITY_PROFILE = {
|
||
"vocabulary": {
|
||
"signature_phrases": ["let's gooo", "that's fire"],
|
||
"jargon_level": "mixed",
|
||
"filler_words": [],
|
||
"distinctive_terms": ["sauce", "vibes"],
|
||
"sound_descriptions": ["crispy", "punchy"],
|
||
},
|
||
"tone": {
|
||
"formality": "casual",
|
||
"energy": "high",
|
||
"humor": "occasional",
|
||
"teaching_style": "hands-on demo-driven",
|
||
"descriptors": ["enthusiastic", "direct", "encouraging"],
|
||
},
|
||
"style_markers": {
|
||
"explanation_approach": "example-first",
|
||
"uses_analogies": True,
|
||
"analogy_examples": ["like cooking a steak"],
|
||
"sound_words": ["brrr", "thwack"],
|
||
"self_references": "I always",
|
||
"audience_engagement": "asks rhetorical questions",
|
||
"pacing": "fast",
|
||
},
|
||
"summary": "High-energy producer who teaches by doing.",
|
||
}
|
||
|
||
|
||
def _mock_creator_row(name: str, profile: dict | None):
|
||
"""Build a mock Creator ORM row with just the fields personality injection needs."""
|
||
row = MagicMock()
|
||
row.name = name
|
||
row.personality_profile = profile
|
||
return row
|
||
|
||
|
||
def _mock_db_execute(creator_row):
|
||
"""Return a mock db.execute that yields a scalars().first() result."""
|
||
mock_scalars = MagicMock()
|
||
mock_scalars.first.return_value = creator_row
|
||
mock_result = MagicMock()
|
||
mock_result.scalars.return_value = mock_scalars
|
||
return AsyncMock(return_value=mock_result)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_personality_weight_accepted_and_forwarded(chat_client):
|
||
"""personality_weight is accepted in the request and forwarded to stream_response."""
|
||
search_result = _fake_search_result()
|
||
|
||
captured_kwargs = {}
|
||
mock_openai_client = MagicMock()
|
||
|
||
async def _capture_create(**kwargs):
|
||
captured_kwargs.update(kwargs)
|
||
return _mock_openai_stream(["ok"])
|
||
|
||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "test", "creator": "Keota", "personality_weight": 0.7},
|
||
)
|
||
|
||
assert resp.status_code == 200
|
||
events = _parse_sse(resp.text)
|
||
event_types = [e["event"] for e in events]
|
||
assert "done" in event_types
|
||
# Temperature should reflect the weight: 0.3 + 0.7*0.2 = 0.44
|
||
assert captured_kwargs.get("temperature") == pytest.approx(0.44)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_personality_prompt_injected_when_weight_and_profile(chat_client):
|
||
"""System prompt includes personality context when weight > 0 and profile exists."""
|
||
search_result = _fake_search_result()
|
||
creator_row = _mock_creator_row("Keota", _FAKE_PERSONALITY_PROFILE)
|
||
|
||
captured_messages = []
|
||
mock_openai_client = MagicMock()
|
||
|
||
async def _capture_create(**kwargs):
|
||
captured_messages.extend(kwargs.get("messages", []))
|
||
return _mock_openai_stream(["personality answer"])
|
||
|
||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
# We need to mock db.execute inside the service — override the session
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = _mock_db_execute(creator_row)
|
||
|
||
async def _mock_get_session():
|
||
yield mock_session
|
||
|
||
app.dependency_overrides[get_session] = _mock_get_session
|
||
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "snare tips", "creator": "Keota", "personality_weight": 0.7},
|
||
)
|
||
|
||
assert resp.status_code == 200
|
||
assert len(captured_messages) >= 2
|
||
|
||
system_prompt = captured_messages[0]["content"]
|
||
# weight=0.7 → tier 3: signature phrases YES, distinctive_terms NO
|
||
assert "Keota" in system_prompt
|
||
assert "Respond as Keota would" in system_prompt
|
||
assert "hands-on demo-driven" in system_prompt
|
||
assert "casual" in system_prompt
|
||
assert "high" in system_prompt
|
||
assert "let's gooo" in system_prompt # signature phrases included at 0.6+
|
||
assert "enthusiastic" in system_prompt # descriptors at 0.4+
|
||
assert "example-first" in system_prompt # explanation_approach at 0.4+
|
||
# Tier 4 fields (0.8+) should NOT be present
|
||
assert "sauce" not in system_prompt # distinctive_terms
|
||
assert "crispy" not in system_prompt # sound_descriptions
|
||
assert "brrr" not in system_prompt # sound_words
|
||
|
||
|
||
def test_personality_block_continuous_interpolation_tiers():
|
||
"""Progressive field inclusion across 5 interpolation tiers."""
|
||
from chat_service import _build_personality_block
|
||
|
||
profile = _FAKE_PERSONALITY_PROFILE
|
||
|
||
# weight < 0.2: empty
|
||
for w in (0.0, 0.1, 0.15, 0.19):
|
||
result = _build_personality_block("Keota", profile, w)
|
||
assert result == "", f"weight={w} should produce empty block"
|
||
|
||
# weight 0.2–0.39: basic tone only
|
||
for w in (0.2, 0.3, 0.39):
|
||
result = _build_personality_block("Keota", profile, w)
|
||
assert "subtly reference Keota" in result
|
||
assert "hands-on demo-driven" in result
|
||
assert "casual" in result and "high" in result
|
||
# Should NOT include descriptors, explanation_approach, phrases
|
||
assert "enthusiastic" not in result
|
||
assert "example-first" not in result
|
||
assert "let's gooo" not in result
|
||
|
||
# weight 0.4–0.59: + descriptors, explanation_approach
|
||
for w in (0.4, 0.5, 0.59):
|
||
result = _build_personality_block("Keota", profile, w)
|
||
assert "Adopt Keota" in result
|
||
assert "enthusiastic" in result # descriptors
|
||
assert "example-first" in result # explanation_approach
|
||
assert "analogies" in result # uses_analogies
|
||
# Should NOT include phrases or tier-4 fields
|
||
assert "let's gooo" not in result
|
||
assert "sauce" not in result
|
||
|
||
# weight 0.6–0.79: + signature phrases
|
||
for w in (0.6, 0.7, 0.79):
|
||
result = _build_personality_block("Keota", profile, w)
|
||
assert "Respond as Keota would" in result
|
||
assert "let's gooo" in result # signature phrases
|
||
assert "enthusiastic" in result # still has descriptors
|
||
# Should NOT include tier-4 fields
|
||
assert "sauce" not in result
|
||
assert "crispy" not in result
|
||
|
||
# weight 0.8–0.89: + distinctive_terms, sound_descriptions, etc.
|
||
for w in (0.8, 0.85, 0.89):
|
||
result = _build_personality_block("Keota", profile, w)
|
||
assert "Fully embody Keota" in result
|
||
assert "sauce" in result # distinctive_terms
|
||
assert "crispy" in result # sound_descriptions
|
||
assert "brrr" in result # sound_words
|
||
assert "I always" in result # self_references
|
||
assert "fast" in result # pacing
|
||
# Should NOT include summary
|
||
assert "High-energy producer" not in result
|
||
|
||
# weight >= 0.9: + summary
|
||
for w in (0.9, 0.95, 1.0):
|
||
result = _build_personality_block("Keota", profile, w)
|
||
assert "Fully embody Keota" in result
|
||
assert "High-energy producer" in result # summary
|
||
|
||
|
||
def test_personality_block_phrase_count_scales_with_weight():
|
||
"""Signature phrase count = max(2, round(weight * len(phrases)))."""
|
||
from chat_service import _build_personality_block
|
||
|
||
# Profile with 6 phrases to make scaling visible
|
||
profile = {
|
||
"vocabulary": {
|
||
"signature_phrases": ["p1", "p2", "p3", "p4", "p5", "p6"],
|
||
},
|
||
"tone": {},
|
||
"style_markers": {},
|
||
}
|
||
# weight=0.6: max(2, round(0.6*6)) = max(2,4) = 4 → first 4 phrases
|
||
result = _build_personality_block("Test", profile, 0.6)
|
||
assert "p4" in result
|
||
assert "p5" not in result
|
||
|
||
# weight=1.0: max(2, round(1.0*6)) = 6 → all phrases
|
||
result = _build_personality_block("Test", profile, 1.0)
|
||
assert "p6" in result
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_personality_encyclopedic_fallback_null_profile(chat_client):
|
||
"""When weight > 0 but personality_profile is null, falls back to encyclopedic prompt."""
|
||
search_result = _fake_search_result()
|
||
creator_row = _mock_creator_row("NullCreator", None)
|
||
|
||
captured_messages = []
|
||
mock_openai_client = MagicMock()
|
||
|
||
async def _capture_create(**kwargs):
|
||
captured_messages.extend(kwargs.get("messages", []))
|
||
return _mock_openai_stream(["encyclopedic answer"])
|
||
|
||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = _mock_db_execute(creator_row)
|
||
|
||
async def _mock_get_session():
|
||
yield mock_session
|
||
|
||
app.dependency_overrides[get_session] = _mock_get_session
|
||
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "reverb tips", "creator": "NullCreator", "personality_weight": 0.5},
|
||
)
|
||
|
||
assert resp.status_code == 200
|
||
system_prompt = captured_messages[0]["content"]
|
||
# Should be the standard encyclopedic prompt, no personality injection
|
||
assert "Chrysopedia" in system_prompt
|
||
assert "NullCreator" not in system_prompt
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_personality_encyclopedic_fallback_missing_creator(chat_client):
|
||
"""When weight > 0 but creator doesn't exist in DB, falls back to encyclopedic prompt."""
|
||
search_result = _fake_search_result()
|
||
|
||
captured_messages = []
|
||
mock_openai_client = MagicMock()
|
||
|
||
async def _capture_create(**kwargs):
|
||
captured_messages.extend(kwargs.get("messages", []))
|
||
return _mock_openai_stream(["encyclopedic answer"])
|
||
|
||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = _mock_db_execute(None) # No creator found
|
||
|
||
async def _mock_get_session():
|
||
yield mock_session
|
||
|
||
app.dependency_overrides[get_session] = _mock_get_session
|
||
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "bass tips", "creator": "GhostCreator", "personality_weight": 0.8},
|
||
)
|
||
|
||
assert resp.status_code == 200
|
||
system_prompt = captured_messages[0]["content"]
|
||
assert "Chrysopedia" in system_prompt
|
||
assert "GhostCreator" not in system_prompt
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_personality_weight_zero_skips_profile_query(chat_client):
|
||
"""When weight is 0.0, no Creator query is made even if creator is set."""
|
||
search_result = _fake_search_result()
|
||
|
||
captured_kwargs = {}
|
||
mock_openai_client = MagicMock()
|
||
|
||
async def _capture_create(**kwargs):
|
||
captured_kwargs.update(kwargs)
|
||
return _mock_openai_stream(["ok"])
|
||
|
||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = AsyncMock() # Should NOT be called
|
||
|
||
async def _mock_get_session():
|
||
yield mock_session
|
||
|
||
app.dependency_overrides[get_session] = _mock_get_session
|
||
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "test", "creator": "Keota", "personality_weight": 0.0},
|
||
)
|
||
|
||
assert resp.status_code == 200
|
||
# DB execute should not have been called for Creator lookup
|
||
mock_session.execute.assert_not_called()
|
||
# Temperature should be 0.3 (base)
|
||
assert captured_kwargs.get("temperature") == pytest.approx(0.3)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_personality_temperature_scales_with_weight(chat_client):
|
||
"""Temperature scales: 0.3 at weight=0.0, 0.5 at weight=1.0."""
|
||
search_result = _fake_search_result()
|
||
creator_row = _mock_creator_row("Keota", _FAKE_PERSONALITY_PROFILE)
|
||
|
||
captured_kwargs = {}
|
||
mock_openai_client = MagicMock()
|
||
|
||
async def _capture_create(**kwargs):
|
||
captured_kwargs.update(kwargs)
|
||
return _mock_openai_stream(["warm"])
|
||
|
||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||
):
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = _mock_db_execute(creator_row)
|
||
|
||
async def _mock_get_session():
|
||
yield mock_session
|
||
|
||
app.dependency_overrides[get_session] = _mock_get_session
|
||
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "test", "creator": "Keota", "personality_weight": 1.0},
|
||
)
|
||
|
||
assert resp.status_code == 200
|
||
assert captured_kwargs.get("temperature") == pytest.approx(0.5)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_personality_weight_above_1_returns_422(chat_client):
|
||
"""personality_weight > 1.0 fails Pydantic validation with 422."""
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "test", "personality_weight": 1.5},
|
||
)
|
||
assert resp.status_code == 422
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_personality_weight_below_0_returns_422(chat_client):
|
||
"""personality_weight < 0.0 fails Pydantic validation with 422."""
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "test", "personality_weight": -0.1},
|
||
)
|
||
assert resp.status_code == 422
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_personality_weight_string_returns_422(chat_client):
|
||
"""personality_weight as a non-numeric string fails validation with 422."""
|
||
resp = await chat_client.post(
|
||
"/api/v1/chat",
|
||
json={"query": "test", "personality_weight": "high"},
|
||
)
|
||
assert resp.status_code == 422
|
||
|
||
|
||
# ── LLM fallback tests ──────────────────────────────────────────────────────
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_fallback_on_connection_error(chat_client):
|
||
"""When primary LLM raises APIConnectionError, fallback client serves the response."""
|
||
search_result = _fake_search_result()
|
||
|
||
# Primary client raises on create()
|
||
mock_primary = MagicMock()
|
||
mock_primary.chat.completions.create = AsyncMock(
|
||
side_effect=openai.APIConnectionError(request=MagicMock()),
|
||
)
|
||
|
||
# Fallback client succeeds
|
||
mock_fallback = MagicMock()
|
||
mock_fallback.chat.completions.create = AsyncMock(
|
||
return_value=_mock_openai_stream(["fallback ", "answer"]),
|
||
)
|
||
|
||
# AsyncOpenAI is called 3 times in ChatService.__init__:
|
||
# 1. SearchService (irrelevant, search is mocked)
|
||
# 2. self._openai (primary)
|
||
# 3. self._fallback_openai (fallback)
|
||
call_count = 0
|
||
|
||
def _make_client(**kwargs):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
if call_count == 2:
|
||
return mock_primary
|
||
if call_count == 3:
|
||
return mock_fallback
|
||
return MagicMock()
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", side_effect=_make_client),
|
||
):
|
||
resp = await chat_client.post("/api/v1/chat", json={"query": "test fallback"})
|
||
|
||
assert resp.status_code == 200
|
||
events = _parse_sse(resp.text)
|
||
event_types = [e["event"] for e in events]
|
||
|
||
assert "sources" in event_types
|
||
assert "token" in event_types
|
||
assert "done" in event_types
|
||
assert "error" not in event_types
|
||
|
||
# Verify tokens came from fallback
|
||
token_texts = [e["data"] for e in events if e["event"] == "token"]
|
||
combined = "".join(token_texts)
|
||
assert "fallback answer" in combined
|
||
|
||
# Done event should have fallback_used=True
|
||
done_data = next(e for e in events if e["event"] == "done")["data"]
|
||
assert done_data["fallback_used"] is True
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chat_fallback_on_internal_server_error(chat_client):
|
||
"""When primary LLM raises InternalServerError, fallback client serves the response."""
|
||
search_result = _fake_search_result()
|
||
|
||
# Primary client raises InternalServerError on create()
|
||
mock_primary = MagicMock()
|
||
mock_primary.chat.completions.create = AsyncMock(
|
||
side_effect=openai.InternalServerError(
|
||
message="GPU OOM",
|
||
response=MagicMock(status_code=500),
|
||
body=None,
|
||
),
|
||
)
|
||
|
||
# Fallback client succeeds
|
||
mock_fallback = MagicMock()
|
||
mock_fallback.chat.completions.create = AsyncMock(
|
||
return_value=_mock_openai_stream(["recovered ", "response"]),
|
||
)
|
||
|
||
call_count = 0
|
||
|
||
def _make_client(**kwargs):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
if call_count == 2:
|
||
return mock_primary
|
||
if call_count == 3:
|
||
return mock_fallback
|
||
return MagicMock()
|
||
|
||
with (
|
||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||
patch("chat_service.openai.AsyncOpenAI", side_effect=_make_client),
|
||
):
|
||
resp = await chat_client.post("/api/v1/chat", json={"query": "test ise fallback"})
|
||
|
||
assert resp.status_code == 200
|
||
events = _parse_sse(resp.text)
|
||
event_types = [e["event"] for e in events]
|
||
|
||
assert "sources" in event_types
|
||
assert "token" in event_types
|
||
assert "done" in event_types
|
||
assert "error" not in event_types
|
||
|
||
# Verify tokens from fallback
|
||
token_texts = [e["data"] for e in events if e["event"] == "token"]
|
||
combined = "".join(token_texts)
|
||
assert "recovered response" in combined
|
||
|
||
# Done event should have fallback_used=True
|
||
done_data = next(e for e in events if e["event"] == "done")["data"]
|
||
assert done_data["fallback_used"] is True
|