test: Added automatic primary→fallback LLM endpoint switching in ChatSe…
- "backend/chat_service.py" - "backend/tests/test_chat.py" - "docker-compose.yml" GSD-Task: S08/T01
This commit is contained in:
parent
f2edb1f375
commit
7b048ccbaf
3 changed files with 165 additions and 2 deletions
|
|
@ -59,6 +59,10 @@ class ChatService:
|
|||
base_url=settings.llm_api_url,
|
||||
api_key=settings.llm_api_key,
|
||||
)
|
||||
self._fallback_openai = openai.AsyncOpenAI(
|
||||
base_url=settings.llm_fallback_url,
|
||||
api_key=settings.llm_api_key,
|
||||
)
|
||||
self._redis = redis
|
||||
|
||||
async def _load_history(self, conversation_id: str) -> list[dict[str, str]]:
|
||||
|
|
@ -244,6 +248,7 @@ class ChatService:
|
|||
|
||||
accumulated_response = ""
|
||||
usage_data: dict[str, int] | None = None
|
||||
fallback_used = False
|
||||
|
||||
try:
|
||||
stream = await self._openai.chat.completions.create(
|
||||
|
|
@ -269,6 +274,44 @@ class ChatService:
|
|||
accumulated_response += text
|
||||
yield _sse("token", text)
|
||||
|
||||
except (openai.APIConnectionError, openai.APITimeoutError, openai.InternalServerError) as exc:
|
||||
logger.warning(
|
||||
"chat_llm_fallback primary failed (%s: %s), retrying with fallback at %s",
|
||||
type(exc).__name__, exc, self.settings.llm_fallback_url,
|
||||
)
|
||||
fallback_used = True
|
||||
accumulated_response = ""
|
||||
usage_data = None
|
||||
|
||||
try:
|
||||
stream = await self._fallback_openai.chat.completions.create(
|
||||
model=self.settings.llm_fallback_model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
temperature=temperature,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
usage_data = {
|
||||
"prompt_tokens": chunk.usage.prompt_tokens or 0,
|
||||
"completion_tokens": chunk.usage.completion_tokens or 0,
|
||||
"total_tokens": chunk.usage.total_tokens or 0,
|
||||
}
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if choice and choice.delta and choice.delta.content:
|
||||
text = choice.delta.content
|
||||
accumulated_response += text
|
||||
yield _sse("token", text)
|
||||
|
||||
except Exception:
|
||||
tb = traceback.format_exc()
|
||||
logger.error("chat_llm_error fallback also failed query=%r cid=%s\n%s", query, conversation_id, tb)
|
||||
yield _sse("error", {"message": "LLM generation failed"})
|
||||
return
|
||||
|
||||
except Exception:
|
||||
tb = traceback.format_exc()
|
||||
logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb)
|
||||
|
|
@ -301,7 +344,7 @@ class ChatService:
|
|||
query=query,
|
||||
usage=usage_data,
|
||||
cascade_tier=cascade_tier,
|
||||
model=self.settings.llm_model,
|
||||
model=self.settings.llm_fallback_model if fallback_used else self.settings.llm_model,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
|
|
@ -311,7 +354,7 @@ class ChatService:
|
|||
query, creator, cascade_tier, len(sources), latency_ms, conversation_id,
|
||||
usage_data.get("total_tokens", 0),
|
||||
)
|
||||
yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id})
|
||||
yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id, "fallback_used": fallback_used})
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import openai
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
# Ensure backend/ is on sys.path
|
||||
|
|
@ -958,3 +959,120 @@ async def test_personality_weight_string_returns_422(chat_client):
|
|||
json={"query": "test", "personality_weight": "high"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ── LLM fallback tests ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_fallback_on_connection_error(chat_client):
|
||||
"""When primary LLM raises APIConnectionError, fallback client serves the response."""
|
||||
search_result = _fake_search_result()
|
||||
|
||||
# Primary client raises on create()
|
||||
mock_primary = MagicMock()
|
||||
mock_primary.chat.completions.create = AsyncMock(
|
||||
side_effect=openai.APIConnectionError(request=MagicMock()),
|
||||
)
|
||||
|
||||
# Fallback client succeeds
|
||||
mock_fallback = MagicMock()
|
||||
mock_fallback.chat.completions.create = AsyncMock(
|
||||
return_value=_mock_openai_stream(["fallback ", "answer"]),
|
||||
)
|
||||
|
||||
# AsyncOpenAI is called 3 times in ChatService.__init__:
|
||||
# 1. SearchService (irrelevant, search is mocked)
|
||||
# 2. self._openai (primary)
|
||||
# 3. self._fallback_openai (fallback)
|
||||
call_count = 0
|
||||
|
||||
def _make_client(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 2:
|
||||
return mock_primary
|
||||
if call_count == 3:
|
||||
return mock_fallback
|
||||
return MagicMock()
|
||||
|
||||
with (
|
||||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||
patch("chat_service.openai.AsyncOpenAI", side_effect=_make_client),
|
||||
):
|
||||
resp = await chat_client.post("/api/v1/chat", json={"query": "test fallback"})
|
||||
|
||||
assert resp.status_code == 200
|
||||
events = _parse_sse(resp.text)
|
||||
event_types = [e["event"] for e in events]
|
||||
|
||||
assert "sources" in event_types
|
||||
assert "token" in event_types
|
||||
assert "done" in event_types
|
||||
assert "error" not in event_types
|
||||
|
||||
# Verify tokens came from fallback
|
||||
token_texts = [e["data"] for e in events if e["event"] == "token"]
|
||||
combined = "".join(token_texts)
|
||||
assert "fallback answer" in combined
|
||||
|
||||
# Done event should have fallback_used=True
|
||||
done_data = next(e for e in events if e["event"] == "done")["data"]
|
||||
assert done_data["fallback_used"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_fallback_on_internal_server_error(chat_client):
|
||||
"""When primary LLM raises InternalServerError, fallback client serves the response."""
|
||||
search_result = _fake_search_result()
|
||||
|
||||
# Primary client raises InternalServerError on create()
|
||||
mock_primary = MagicMock()
|
||||
mock_primary.chat.completions.create = AsyncMock(
|
||||
side_effect=openai.InternalServerError(
|
||||
message="GPU OOM",
|
||||
response=MagicMock(status_code=500),
|
||||
body=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Fallback client succeeds
|
||||
mock_fallback = MagicMock()
|
||||
mock_fallback.chat.completions.create = AsyncMock(
|
||||
return_value=_mock_openai_stream(["recovered ", "response"]),
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def _make_client(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 2:
|
||||
return mock_primary
|
||||
if call_count == 3:
|
||||
return mock_fallback
|
||||
return MagicMock()
|
||||
|
||||
with (
|
||||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||
patch("chat_service.openai.AsyncOpenAI", side_effect=_make_client),
|
||||
):
|
||||
resp = await chat_client.post("/api/v1/chat", json={"query": "test ise fallback"})
|
||||
|
||||
assert resp.status_code == 200
|
||||
events = _parse_sse(resp.text)
|
||||
event_types = [e["event"] for e in events]
|
||||
|
||||
assert "sources" in event_types
|
||||
assert "token" in event_types
|
||||
assert "done" in event_types
|
||||
assert "error" not in event_types
|
||||
|
||||
# Verify tokens from fallback
|
||||
token_texts = [e["data"] for e in events if e["event"] == "token"]
|
||||
combined = "".join(token_texts)
|
||||
assert "recovered response" in combined
|
||||
|
||||
# Done event should have fallback_used=True
|
||||
done_data = next(e for e in events if e["event"] == "done")["data"]
|
||||
assert done_data["fallback_used"] is True
|
||||
|
|
|
|||
|
|
@ -121,6 +121,8 @@ services:
|
|||
REDIS_URL: redis://chrysopedia-redis:6379/0
|
||||
QDRANT_URL: http://chrysopedia-qdrant:6333
|
||||
EMBEDDING_API_URL: http://chrysopedia-ollama:11434/v1
|
||||
LLM_FALLBACK_URL: http://chrysopedia-ollama:11434/v1
|
||||
LLM_FALLBACK_MODEL: fyn-llm-agent-chat
|
||||
PROMPTS_PATH: /prompts
|
||||
volumes:
|
||||
- /vmPool/r/services/chrysopedia_data:/data
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue