test: Added automatic primary→fallback LLM endpoint switching in ChatSe…

- "backend/chat_service.py"
- "backend/tests/test_chat.py"
- "docker-compose.yml"

GSD-Task: S08/T01
This commit is contained in:
jlightner 2026-04-04 14:31:28 +00:00
parent f2edb1f375
commit 7b048ccbaf
3 changed files with 165 additions and 2 deletions

View file

@ -59,6 +59,10 @@ class ChatService:
base_url=settings.llm_api_url, base_url=settings.llm_api_url,
api_key=settings.llm_api_key, api_key=settings.llm_api_key,
) )
self._fallback_openai = openai.AsyncOpenAI(
base_url=settings.llm_fallback_url,
api_key=settings.llm_api_key,
)
self._redis = redis self._redis = redis
async def _load_history(self, conversation_id: str) -> list[dict[str, str]]: async def _load_history(self, conversation_id: str) -> list[dict[str, str]]:
@ -244,6 +248,7 @@ class ChatService:
accumulated_response = "" accumulated_response = ""
usage_data: dict[str, int] | None = None usage_data: dict[str, int] | None = None
fallback_used = False
try: try:
stream = await self._openai.chat.completions.create( stream = await self._openai.chat.completions.create(
@ -269,6 +274,44 @@ class ChatService:
accumulated_response += text accumulated_response += text
yield _sse("token", text) yield _sse("token", text)
except (openai.APIConnectionError, openai.APITimeoutError, openai.InternalServerError) as exc:
logger.warning(
"chat_llm_fallback primary failed (%s: %s), retrying with fallback at %s",
type(exc).__name__, exc, self.settings.llm_fallback_url,
)
fallback_used = True
accumulated_response = ""
usage_data = None
try:
stream = await self._fallback_openai.chat.completions.create(
model=self.settings.llm_fallback_model,
messages=messages,
stream=True,
stream_options={"include_usage": True},
temperature=temperature,
max_tokens=2048,
)
async for chunk in stream:
if hasattr(chunk, "usage") and chunk.usage is not None:
usage_data = {
"prompt_tokens": chunk.usage.prompt_tokens or 0,
"completion_tokens": chunk.usage.completion_tokens or 0,
"total_tokens": chunk.usage.total_tokens or 0,
}
choice = chunk.choices[0] if chunk.choices else None
if choice and choice.delta and choice.delta.content:
text = choice.delta.content
accumulated_response += text
yield _sse("token", text)
except Exception:
tb = traceback.format_exc()
logger.error("chat_llm_error fallback also failed query=%r cid=%s\n%s", query, conversation_id, tb)
yield _sse("error", {"message": "LLM generation failed"})
return
except Exception: except Exception:
tb = traceback.format_exc() tb = traceback.format_exc()
logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb) logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb)
@ -301,7 +344,7 @@ class ChatService:
query=query, query=query,
usage=usage_data, usage=usage_data,
cascade_tier=cascade_tier, cascade_tier=cascade_tier,
model=self.settings.llm_model, model=self.settings.llm_fallback_model if fallback_used else self.settings.llm_model,
latency_ms=latency_ms, latency_ms=latency_ms,
) )
@ -311,7 +354,7 @@ class ChatService:
query, creator, cascade_tier, len(sources), latency_ms, conversation_id, query, creator, cascade_tier, len(sources), latency_ms, conversation_id,
usage_data.get("total_tokens", 0), usage_data.get("total_tokens", 0),
) )
yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id}) yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id, "fallback_used": fallback_used})
# ── Helpers ────────────────────────────────────────────────────────────────── # ── Helpers ──────────────────────────────────────────────────────────────────

View file

@ -20,6 +20,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import openai
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
# Ensure backend/ is on sys.path # Ensure backend/ is on sys.path
@ -958,3 +959,120 @@ async def test_personality_weight_string_returns_422(chat_client):
json={"query": "test", "personality_weight": "high"}, json={"query": "test", "personality_weight": "high"},
) )
assert resp.status_code == 422 assert resp.status_code == 422
# ── LLM fallback tests ──────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_chat_fallback_on_connection_error(chat_client):
"""When primary LLM raises APIConnectionError, fallback client serves the response."""
search_result = _fake_search_result()
# Primary client raises on create()
mock_primary = MagicMock()
mock_primary.chat.completions.create = AsyncMock(
side_effect=openai.APIConnectionError(request=MagicMock()),
)
# Fallback client succeeds
mock_fallback = MagicMock()
mock_fallback.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["fallback ", "answer"]),
)
# AsyncOpenAI is called 3 times in ChatService.__init__:
# 1. SearchService (irrelevant, search is mocked)
# 2. self._openai (primary)
# 3. self._fallback_openai (fallback)
call_count = 0
def _make_client(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 2:
return mock_primary
if call_count == 3:
return mock_fallback
return MagicMock()
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", side_effect=_make_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "test fallback"})
assert resp.status_code == 200
events = _parse_sse(resp.text)
event_types = [e["event"] for e in events]
assert "sources" in event_types
assert "token" in event_types
assert "done" in event_types
assert "error" not in event_types
# Verify tokens came from fallback
token_texts = [e["data"] for e in events if e["event"] == "token"]
combined = "".join(token_texts)
assert "fallback answer" in combined
# Done event should have fallback_used=True
done_data = next(e for e in events if e["event"] == "done")["data"]
assert done_data["fallback_used"] is True
@pytest.mark.asyncio
async def test_chat_fallback_on_internal_server_error(chat_client):
"""When primary LLM raises InternalServerError, fallback client serves the response."""
search_result = _fake_search_result()
# Primary client raises InternalServerError on create()
mock_primary = MagicMock()
mock_primary.chat.completions.create = AsyncMock(
side_effect=openai.InternalServerError(
message="GPU OOM",
response=MagicMock(status_code=500),
body=None,
),
)
# Fallback client succeeds
mock_fallback = MagicMock()
mock_fallback.chat.completions.create = AsyncMock(
return_value=_mock_openai_stream(["recovered ", "response"]),
)
call_count = 0
def _make_client(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 2:
return mock_primary
if call_count == 3:
return mock_fallback
return MagicMock()
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", side_effect=_make_client),
):
resp = await chat_client.post("/api/v1/chat", json={"query": "test ise fallback"})
assert resp.status_code == 200
events = _parse_sse(resp.text)
event_types = [e["event"] for e in events]
assert "sources" in event_types
assert "token" in event_types
assert "done" in event_types
assert "error" not in event_types
# Verify tokens from fallback
token_texts = [e["data"] for e in events if e["event"] == "token"]
combined = "".join(token_texts)
assert "recovered response" in combined
# Done event should have fallback_used=True
done_data = next(e for e in events if e["event"] == "done")["data"]
assert done_data["fallback_used"] is True

View file

@ -121,6 +121,8 @@ services:
REDIS_URL: redis://chrysopedia-redis:6379/0 REDIS_URL: redis://chrysopedia-redis:6379/0
QDRANT_URL: http://chrysopedia-qdrant:6333 QDRANT_URL: http://chrysopedia-qdrant:6333
EMBEDDING_API_URL: http://chrysopedia-ollama:11434/v1 EMBEDDING_API_URL: http://chrysopedia-ollama:11434/v1
LLM_FALLBACK_URL: http://chrysopedia-ollama:11434/v1
LLM_FALLBACK_MODEL: fyn-llm-agent-chat
PROMPTS_PATH: /prompts PROMPTS_PATH: /prompts
volumes: volumes:
- /vmPool/r/services/chrysopedia_data:/data - /vmPool/r/services/chrysopedia_data:/data