chrysopedia/backend/tests/test_chat.py
jlightner 899ab742a8 test: Added automatic primary→fallback LLM endpoint switching in ChatSe…
- "backend/chat_service.py"
- "backend/tests/test_chat.py"
- "docker-compose.yml"

GSD-Task: S08/T01
2026-04-04 14:31:28 +00:00

1078 lines
38 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Integration tests for the chat SSE endpoint.
Mocks SearchService.search() and the OpenAI streaming response to verify:
1. Valid SSE format with sources, token, and done events
2. Citation numbering matches the sources array
3. Creator param forwarded to search
4. Empty/invalid query returns 422
5. LLM error produces an SSE error event
6. Multi-turn conversation memory via Redis (load/save/cap/TTL)
These tests use a standalone ASGI client that does NOT require a running
PostgreSQL instance — the DB session dependency is overridden with a mock.
"""
from __future__ import annotations
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
import openai
from httpx import ASGITransport, AsyncClient
# Ensure backend/ is on sys.path
import pathlib
import sys
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent))
from database import get_session # noqa: E402
from main import app # noqa: E402
# ── Standalone test client (no DB required) ──────────────────────────────────
@pytest_asyncio.fixture()
async def mock_redis():
"""In-memory mock of an async Redis client (get/set)."""
store: dict[str, tuple[str, int | None]] = {}
mock = AsyncMock()
async def _get(key):
entry = store.get(key)
return entry[0] if entry else None
async def _set(key, value, ex=None):
store[key] = (value, ex)
mock.get = AsyncMock(side_effect=_get)
mock.set = AsyncMock(side_effect=_set)
mock._store = store # exposed for test assertions
return mock
@pytest_asyncio.fixture()
async def chat_client(mock_redis):
"""Async HTTP test client that mocks out the DB session and Redis."""
mock_session = AsyncMock()
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
with patch("routers.chat.get_redis", new_callable=AsyncMock, return_value=mock_redis):
yield ac
app.dependency_overrides.pop(get_session, None)
# ── Helpers ──────────────────────────────────────────────────────────────────
def _parse_sse(body: str) -> list[dict[str, Any]]:
"""Parse SSE text into a list of {event, data} dicts."""
events: list[dict[str, Any]] = []
current_event: str | None = None
current_data: str | None = None
for line in body.split("\n"):
if line.startswith("event: "):
current_event = line[len("event: "):]
elif line.startswith("data: "):
current_data = line[len("data: "):]
elif line == "" and current_event is not None and current_data is not None:
try:
parsed = json.loads(current_data)
except json.JSONDecodeError:
parsed = current_data
events.append({"event": current_event, "data": parsed})
current_event = None
current_data = None
return events
def _fake_search_result(
items: list[dict[str, Any]] | None = None,
cascade_tier: str = "global",
) -> dict[str, Any]:
"""Build a fake SearchService.search() return value."""
if items is None:
items = [
{
"title": "Snare Compression",
"slug": "snare-compression",
"technique_page_slug": "snare-compression",
"creator_name": "Keota",
"topic_category": "Mixing",
"summary": "How to compress a snare drum for punch and presence.",
"section_anchor": "",
"section_heading": "",
"type": "technique_page",
"score": 0.9,
},
{
"title": "Parallel Processing",
"slug": "parallel-processing",
"technique_page_slug": "parallel-processing",
"creator_name": "Skope",
"topic_category": "Mixing",
"summary": "Using parallel compression for dynamics control.",
"section_anchor": "bus-setup",
"section_heading": "Bus Setup",
"type": "technique_page",
"score": 0.85,
},
]
return {
"items": items,
"partial_matches": [],
"total": len(items),
"query": "snare compression",
"fallback_used": False,
"cascade_tier": cascade_tier,
}
def _mock_openai_stream(chunks: list[str]):
"""Create a mock async iterator that yields OpenAI-style stream chunks."""
class FakeChoice:
def __init__(self, text: str | None):
self.delta = MagicMock()
self.delta.content = text
class FakeChunk:
def __init__(self, text: str | None):
self.choices = [FakeChoice(text)]
class FakeStream:
def __init__(self, chunks: list[str]):
self._chunks = chunks
self._index = 0
def __aiter__(self):
return self
async def __anext__(self):
if self._index >= len(self._chunks):
raise StopAsyncIteration
chunk = FakeChunk(self._chunks[self._index])
self._index += 1
return chunk
return FakeStream(chunks)
def _mock_openai_stream_error():
"""Create a mock async iterator that raises mid-stream."""
class FakeStream:
def __init__(self):
self._yielded = False
def __aiter__(self):
return self
async def __anext__(self):
if not self._yielded:
self._yielded = True
raise RuntimeError("LLM connection lost")
raise StopAsyncIteration
return FakeStream()
# ── Tests ────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_chat_sse_format_and_events(chat_client):
"""SSE stream contains sources, token(s), and done events in order."""
search_result = _fake_search_result()
token_chunks = ["Snare compression ", "uses [1] to add ", "punch. See also [2]."]
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(token_chunks)
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "snare compression"})
assert resp.status_code == 200
assert "text/event-stream" in resp.headers.get("content-type", "")
events = _parse_sse(resp.text)
event_types = [e["event"] for e in events]
# Must have sources first, then tokens, then done
assert event_types[0] == "sources"
assert "token" in event_types
assert event_types[-1] == "done"
# Sources event is a list
sources_data = events[0]["data"]
assert isinstance(sources_data, list)
assert len(sources_data) == 2
# Done event has cascade_tier
done_data = events[-1]["data"]
assert "cascade_tier" in done_data
@pytest.mark.asyncio
async def test_chat_citation_numbering(chat_client):
"""Citation numbers in sources array match 1-based indexing."""
search_result = _fake_search_result()
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["hello"])
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "compression"})
events = _parse_sse(resp.text)
sources = events[0]["data"]
assert sources[0]["number"] == 1
assert sources[0]["title"] == "Snare Compression"
assert sources[1]["number"] == 2
assert sources[1]["title"] == "Parallel Processing"
assert sources[1]["section_anchor"] == "bus-setup"
@pytest.mark.asyncio
async def test_chat_creator_forwarded_to_search(chat_client):
"""Creator parameter is passed through to SearchService.search()."""
search_result = _fake_search_result()
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["ok"])
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result) as mock_search,
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "drum mixing", "creator": "keota"},
)
assert resp.status_code == 200
mock_search.assert_called_once()
call_kwargs = mock_search.call_args.kwargs
assert call_kwargs.get("creator") == "keota"
@pytest.mark.asyncio
async def test_chat_empty_query_returns_422(chat_client):
"""An empty query string should fail Pydantic validation with 422."""
resp = await chat_client.post("/api/v1/chat", json={"query": ""})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_missing_query_returns_422(chat_client):
"""Missing query field should fail with 422."""
resp = await chat_client.post("/api/v1/chat", json={})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_chat_llm_error_produces_error_event(chat_client):
"""When the LLM raises mid-stream, an error SSE event is emitted."""
search_result = _fake_search_result()
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream_error()
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "test error"})
assert resp.status_code == 200 # SSE streams always return 200
events = _parse_sse(resp.text)
event_types = [e["event"] for e in events]
assert "sources" in event_types
assert "error" in event_types
error_event = next(e for e in events if e["event"] == "error")
assert "message" in error_event["data"]
# ── Multi-turn conversation memory tests ─────────────────────────────────────
def _chat_request_with_mocks(query, conversation_id=None, creator=None, token_chunks=None):
"""Helper: build a request JSON and patched mocks for a single chat call."""
if token_chunks is None:
token_chunks = ["ok"]
body: dict[str, Any] = {"query": query}
if conversation_id is not None:
body["conversation_id"] = conversation_id
if creator is not None:
body["creator"] = creator
return body, token_chunks
@pytest.mark.asyncio
async def test_conversation_done_event_includes_conversation_id(chat_client):
"""The done SSE event includes conversation_id — both when provided and auto-generated."""
search_result = _fake_search_result()
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["hello"])
)
# With explicit conversation_id
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "conversation_id": "conv-abc-123"},
)
events = _parse_sse(resp.text)
done_data = next(e for e in events if e["event"] == "done")["data"]
assert done_data["conversation_id"] == "conv-abc-123"
@pytest.mark.asyncio
async def test_conversation_auto_generates_id_when_omitted(chat_client):
"""When conversation_id is not provided, the done event still includes one (auto-generated)."""
search_result = _fake_search_result()
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["world"])
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "no cid"})
events = _parse_sse(resp.text)
done_data = next(e for e in events if e["event"] == "done")["data"]
# Auto-generated UUID format
cid = done_data["conversation_id"]
assert isinstance(cid, str)
assert len(cid) == 36 # UUID4 format
@pytest.mark.asyncio
async def test_conversation_history_saved_to_redis(chat_client, mock_redis):
"""After a successful chat, user+assistant messages are saved to Redis."""
search_result = _fake_search_result()
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["The answer ", "is 42."])
)
cid = "conv-save-test"
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "what is the answer", "conversation_id": cid},
)
assert resp.status_code == 200
# Verify Redis received the history
redis_key = f"chrysopedia:chat:{cid}"
assert redis_key in mock_redis._store
stored_json, ttl = mock_redis._store[redis_key]
history = json.loads(stored_json)
assert len(history) == 2
assert history[0] == {"role": "user", "content": "what is the answer"}
assert history[1] == {"role": "assistant", "content": "The answer is 42."}
assert ttl == 3600
@pytest.mark.asyncio
async def test_conversation_history_injected_into_llm_messages(chat_client, mock_redis):
"""Prior conversation history is injected between system prompt and user message."""
search_result = _fake_search_result()
# Pre-populate Redis with conversation history
cid = "conv-inject-test"
prior_history = [
{"role": "user", "content": "what is reverb"},
{"role": "assistant", "content": "Reverb simulates acoustic spaces."},
]
mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 3600)
captured_messages = []
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_messages.extend(kwargs.get("messages", []))
return _mock_openai_stream(["follow-up answer"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "how do I use it on drums", "conversation_id": cid},
)
assert resp.status_code == 200
# Messages should be: system, prior_user, prior_assistant, current_user
assert len(captured_messages) == 4
assert captured_messages[0]["role"] == "system"
assert captured_messages[1] == {"role": "user", "content": "what is reverb"}
assert captured_messages[2] == {"role": "assistant", "content": "Reverb simulates acoustic spaces."}
assert captured_messages[3] == {"role": "user", "content": "how do I use it on drums"}
@pytest.mark.asyncio
async def test_conversation_history_capped_at_10_pairs(chat_client, mock_redis):
"""History is capped at 10 turn pairs (20 messages). Oldest turns are dropped."""
search_result = _fake_search_result()
cid = "conv-cap-test"
# Pre-populate with 10 turn pairs (20 messages) — at the cap
prior_history = []
for i in range(10):
prior_history.append({"role": "user", "content": f"question {i}"})
prior_history.append({"role": "assistant", "content": f"answer {i}"})
mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 3600)
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["capped reply"])
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "turn 11", "conversation_id": cid},
)
assert resp.status_code == 200
# After adding turn 11, history should still be 20 messages (10 pairs)
redis_key = f"chrysopedia:chat:{cid}"
stored_json, _ = mock_redis._store[redis_key]
history = json.loads(stored_json)
assert len(history) == 20 # 10 pairs
# Oldest pair (question 0 / answer 0) should be dropped
assert history[0] == {"role": "user", "content": "question 1"}
# New pair should be at the end
assert history[-2] == {"role": "user", "content": "turn 11"}
assert history[-1] == {"role": "assistant", "content": "capped reply"}
@pytest.mark.asyncio
async def test_conversation_ttl_refreshed_on_interaction(chat_client, mock_redis):
"""Each interaction refreshes the Redis TTL to 1 hour."""
search_result = _fake_search_result()
cid = "conv-ttl-test"
# Pre-populate with a short simulated TTL
prior_history = [
{"role": "user", "content": "old message"},
{"role": "assistant", "content": "old reply"},
]
mock_redis._store[f"chrysopedia:chat:{cid}"] = (json.dumps(prior_history), 100)
mock_openai_client = MagicMock()
mock_openai_client.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["refreshed"])
)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "refresh test", "conversation_id": cid},
)
assert resp.status_code == 200
# TTL should be refreshed to 3600
redis_key = f"chrysopedia:chat:{cid}"
_, ttl = mock_redis._store[redis_key]
assert ttl == 3600
@pytest.mark.asyncio
async def test_single_turn_fallback_no_redis_history(chat_client, mock_redis):
"""When no conversation_id is provided, no history is loaded and behavior matches single-turn."""
search_result = _fake_search_result()
captured_messages = []
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_messages.extend(kwargs.get("messages", []))
return _mock_openai_stream(["standalone answer"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "standalone question"})
assert resp.status_code == 200
# Should only have system + user (no history injected since auto-generated cid is fresh)
assert len(captured_messages) == 2
assert captured_messages[0]["role"] == "system"
assert captured_messages[1]["role"] == "user"
# ── Personality weight tests ─────────────────────────────────────────────────
_FAKE_PERSONALITY_PROFILE = {
"vocabulary": {
"signature_phrases": ["let's gooo", "that's fire"],
"jargon_level": "mixed",
"filler_words": [],
"distinctive_terms": ["sauce", "vibes"],
"sound_descriptions": ["crispy", "punchy"],
},
"tone": {
"formality": "casual",
"energy": "high",
"humor": "occasional",
"teaching_style": "hands-on demo-driven",
"descriptors": ["enthusiastic", "direct", "encouraging"],
},
"style_markers": {
"explanation_approach": "example-first",
"uses_analogies": True,
"analogy_examples": ["like cooking a steak"],
"sound_words": ["brrr", "thwack"],
"self_references": "I always",
"audience_engagement": "asks rhetorical questions",
"pacing": "fast",
},
"summary": "High-energy producer who teaches by doing.",
}
def _mock_creator_row(name: str, profile: dict | None):
"""Build a mock Creator ORM row with just the fields personality injection needs."""
row = MagicMock()
row.name = name
row.personality_profile = profile
return row
def _mock_db_execute(creator_row):
"""Return a mock db.execute that yields a scalars().first() result."""
mock_scalars = MagicMock()
mock_scalars.first.return_value = creator_row
mock_result = MagicMock()
mock_result.scalars.return_value = mock_scalars
return AsyncMock(return_value=mock_result)
@pytest.mark.asyncio
async def test_personality_weight_accepted_and_forwarded(chat_client):
"""personality_weight is accepted in the request and forwarded to stream_response."""
search_result = _fake_search_result()
captured_kwargs = {}
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_kwargs.update(kwargs)
return _mock_openai_stream(["ok"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "creator": "Keota", "personality_weight": 0.7},
)
assert resp.status_code == 200
events = _parse_sse(resp.text)
event_types = [e["event"] for e in events]
assert "done" in event_types
# Temperature should reflect the weight: 0.3 + 0.7*0.2 = 0.44
assert captured_kwargs.get("temperature") == pytest.approx(0.44)
@pytest.mark.asyncio
async def test_personality_prompt_injected_when_weight_and_profile(chat_client):
"""System prompt includes personality context when weight > 0 and profile exists."""
search_result = _fake_search_result()
creator_row = _mock_creator_row("Keota", _FAKE_PERSONALITY_PROFILE)
captured_messages = []
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_messages.extend(kwargs.get("messages", []))
return _mock_openai_stream(["personality answer"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
# We need to mock db.execute inside the service — override the session
mock_session = AsyncMock()
mock_session.execute = _mock_db_execute(creator_row)
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "snare tips", "creator": "Keota", "personality_weight": 0.7},
)
assert resp.status_code == 200
assert len(captured_messages) >= 2
system_prompt = captured_messages[0]["content"]
# weight=0.7 → tier 3: signature phrases YES, distinctive_terms NO
assert "Keota" in system_prompt
assert "Respond as Keota would" in system_prompt
assert "hands-on demo-driven" in system_prompt
assert "casual" in system_prompt
assert "high" in system_prompt
assert "let's gooo" in system_prompt # signature phrases included at 0.6+
assert "enthusiastic" in system_prompt # descriptors at 0.4+
assert "example-first" in system_prompt # explanation_approach at 0.4+
# Tier 4 fields (0.8+) should NOT be present
assert "sauce" not in system_prompt # distinctive_terms
assert "crispy" not in system_prompt # sound_descriptions
assert "brrr" not in system_prompt # sound_words
def test_personality_block_continuous_interpolation_tiers():
"""Progressive field inclusion across 5 interpolation tiers."""
from chat_service import _build_personality_block
profile = _FAKE_PERSONALITY_PROFILE
# weight < 0.2: empty
for w in (0.0, 0.1, 0.15, 0.19):
result = _build_personality_block("Keota", profile, w)
assert result == "", f"weight={w} should produce empty block"
# weight 0.20.39: basic tone only
for w in (0.2, 0.3, 0.39):
result = _build_personality_block("Keota", profile, w)
assert "subtly reference Keota" in result
assert "hands-on demo-driven" in result
assert "casual" in result and "high" in result
# Should NOT include descriptors, explanation_approach, phrases
assert "enthusiastic" not in result
assert "example-first" not in result
assert "let's gooo" not in result
# weight 0.40.59: + descriptors, explanation_approach
for w in (0.4, 0.5, 0.59):
result = _build_personality_block("Keota", profile, w)
assert "Adopt Keota" in result
assert "enthusiastic" in result # descriptors
assert "example-first" in result # explanation_approach
assert "analogies" in result # uses_analogies
# Should NOT include phrases or tier-4 fields
assert "let's gooo" not in result
assert "sauce" not in result
# weight 0.60.79: + signature phrases
for w in (0.6, 0.7, 0.79):
result = _build_personality_block("Keota", profile, w)
assert "Respond as Keota would" in result
assert "let's gooo" in result # signature phrases
assert "enthusiastic" in result # still has descriptors
# Should NOT include tier-4 fields
assert "sauce" not in result
assert "crispy" not in result
# weight 0.80.89: + distinctive_terms, sound_descriptions, etc.
for w in (0.8, 0.85, 0.89):
result = _build_personality_block("Keota", profile, w)
assert "Fully embody Keota" in result
assert "sauce" in result # distinctive_terms
assert "crispy" in result # sound_descriptions
assert "brrr" in result # sound_words
assert "I always" in result # self_references
assert "fast" in result # pacing
# Should NOT include summary
assert "High-energy producer" not in result
# weight >= 0.9: + summary
for w in (0.9, 0.95, 1.0):
result = _build_personality_block("Keota", profile, w)
assert "Fully embody Keota" in result
assert "High-energy producer" in result # summary
def test_personality_block_phrase_count_scales_with_weight():
"""Signature phrase count = max(2, round(weight * len(phrases)))."""
from chat_service import _build_personality_block
# Profile with 6 phrases to make scaling visible
profile = {
"vocabulary": {
"signature_phrases": ["p1", "p2", "p3", "p4", "p5", "p6"],
},
"tone": {},
"style_markers": {},
}
# weight=0.6: max(2, round(0.6*6)) = max(2,4) = 4 → first 4 phrases
result = _build_personality_block("Test", profile, 0.6)
assert "p4" in result
assert "p5" not in result
# weight=1.0: max(2, round(1.0*6)) = 6 → all phrases
result = _build_personality_block("Test", profile, 1.0)
assert "p6" in result
@pytest.mark.asyncio
async def test_personality_encyclopedic_fallback_null_profile(chat_client):
"""When weight > 0 but personality_profile is null, falls back to encyclopedic prompt."""
search_result = _fake_search_result()
creator_row = _mock_creator_row("NullCreator", None)
captured_messages = []
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_messages.extend(kwargs.get("messages", []))
return _mock_openai_stream(["encyclopedic answer"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
mock_session = AsyncMock()
mock_session.execute = _mock_db_execute(creator_row)
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "reverb tips", "creator": "NullCreator", "personality_weight": 0.5},
)
assert resp.status_code == 200
system_prompt = captured_messages[0]["content"]
# Should be the standard encyclopedic prompt, no personality injection
assert "Chrysopedia" in system_prompt
assert "NullCreator" not in system_prompt
@pytest.mark.asyncio
async def test_personality_encyclopedic_fallback_missing_creator(chat_client):
"""When weight > 0 but creator doesn't exist in DB, falls back to encyclopedic prompt."""
search_result = _fake_search_result()
captured_messages = []
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_messages.extend(kwargs.get("messages", []))
return _mock_openai_stream(["encyclopedic answer"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
mock_session = AsyncMock()
mock_session.execute = _mock_db_execute(None) # No creator found
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "bass tips", "creator": "GhostCreator", "personality_weight": 0.8},
)
assert resp.status_code == 200
system_prompt = captured_messages[0]["content"]
assert "Chrysopedia" in system_prompt
assert "GhostCreator" not in system_prompt
@pytest.mark.asyncio
async def test_personality_weight_zero_skips_profile_query(chat_client):
"""When weight is 0.0, no Creator query is made even if creator is set."""
search_result = _fake_search_result()
captured_kwargs = {}
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_kwargs.update(kwargs)
return _mock_openai_stream(["ok"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
mock_session = AsyncMock()
mock_session.execute = AsyncMock() # Should NOT be called
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "creator": "Keota", "personality_weight": 0.0},
)
assert resp.status_code == 200
# DB execute should not have been called for Creator lookup
mock_session.execute.assert_not_called()
# Temperature should be 0.3 (base)
assert captured_kwargs.get("temperature") == pytest.approx(0.3)
@pytest.mark.asyncio
async def test_personality_temperature_scales_with_weight(chat_client):
"""Temperature scales: 0.3 at weight=0.0, 0.5 at weight=1.0."""
search_result = _fake_search_result()
creator_row = _mock_creator_row("Keota", _FAKE_PERSONALITY_PROFILE)
captured_kwargs = {}
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_kwargs.update(kwargs)
return _mock_openai_stream(["warm"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
mock_session = AsyncMock()
mock_session.execute = _mock_db_execute(creator_row)
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "creator": "Keota", "personality_weight": 1.0},
)
assert resp.status_code == 200
assert captured_kwargs.get("temperature") == pytest.approx(0.5)
@pytest.mark.asyncio
async def test_personality_weight_above_1_returns_422(chat_client):
"""personality_weight > 1.0 fails Pydantic validation with 422."""
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "personality_weight": 1.5},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_personality_weight_below_0_returns_422(chat_client):
"""personality_weight < 0.0 fails Pydantic validation with 422."""
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "personality_weight": -0.1},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_personality_weight_string_returns_422(chat_client):
"""personality_weight as a non-numeric string fails validation with 422."""
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "personality_weight": "high"},
)
assert resp.status_code == 422
# ── LLM fallback tests ──────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_chat_fallback_on_connection_error(chat_client):
"""When primary LLM raises APIConnectionError, fallback client serves the response."""
search_result = _fake_search_result()
# Primary client raises on create()
mock_primary = MagicMock()
mock_primary.chat.completions.create = AsyncMock(
side_effect=openai.APIConnectionError(request=MagicMock()),
)
# Fallback client succeeds
mock_fallback = MagicMock()
mock_fallback.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["fallback ", "answer"]),
)
# AsyncOpenAI is called 3 times in ChatService.__init__:
# 1. SearchService (irrelevant, search is mocked)
# 2. self._openai (primary)
# 3. self._fallback_openai (fallback)
call_count = 0
def _make_client(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 2:
return mock_primary
if call_count == 3:
return mock_fallback
return MagicMock()
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", side_effect=_make_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "test fallback"})
assert resp.status_code == 200
events = _parse_sse(resp.text)
event_types = [e["event"] for e in events]
assert "sources" in event_types
assert "token" in event_types
assert "done" in event_types
assert "error" not in event_types
# Verify tokens came from fallback
token_texts = [e["data"] for e in events if e["event"] == "token"]
combined = "".join(token_texts)
assert "fallback answer" in combined
# Done event should have fallback_used=True
done_data = next(e for e in events if e["event"] == "done")["data"]
assert done_data["fallback_used"] is True
@pytest.mark.asyncio
async def test_chat_fallback_on_internal_server_error(chat_client):
"""When primary LLM raises InternalServerError, fallback client serves the response."""
search_result = _fake_search_result()
# Primary client raises InternalServerError on create()
mock_primary = MagicMock()
mock_primary.chat.completions.create = AsyncMock(
side_effect=openai.InternalServerError(
message="GPU OOM",
response=MagicMock(status_code=500),
body=None,
),
)
# Fallback client succeeds
mock_fallback = MagicMock()
mock_fallback.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["recovered ", "response"]),
)
call_count = 0
def _make_client(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 2:
return mock_primary
if call_count == 3:
return mock_fallback
return MagicMock()
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", side_effect=_make_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "test ise fallback"})
assert resp.status_code == 200
events = _parse_sse(resp.text)
event_types = [e["event"] for e in events]
assert "sources" in event_types
assert "token" in event_types
assert "done" in event_types
assert "error" not in event_types
# Verify tokens from fallback
token_texts = [e["data"] for e in events if e["event"] == "token"]
combined = "".join(token_texts)
assert "recovered response" in combined
# Done event should have fallback_used=True
done_data = next(e for e in events if e["event"] == "done")["data"]
assert done_data["fallback_used"] is True