feat: Added personality_weight (0.0–1.0) to chat API; modulates system…
- "backend/routers/chat.py" - "backend/chat_service.py" - "backend/tests/test_chat.py" GSD-Task: S02/T01
This commit is contained in:
parent
04630764a6
commit
0856827b59
3 changed files with 394 additions and 2 deletions
|
|
@ -21,9 +21,11 @@ import uuid
|
|||
from typing import Any, AsyncIterator
|
||||
|
||||
import openai
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from config import Settings
|
||||
from models import Creator
|
||||
from search_service import SearchService
|
||||
|
||||
logger = logging.getLogger("chrysopedia.chat")
|
||||
|
|
@ -95,12 +97,43 @@ class ChatService:
|
|||
except Exception:
|
||||
logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True)
|
||||
|
||||
async def _inject_personality(
|
||||
self,
|
||||
system_prompt: str,
|
||||
db: AsyncSession,
|
||||
creator_name: str,
|
||||
weight: float,
|
||||
) -> str:
|
||||
"""Query creator personality_profile and append a voice block to the system prompt.
|
||||
|
||||
Falls back to the unmodified prompt on DB error, missing creator, or null profile.
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Creator).where(Creator.name == creator_name)
|
||||
)
|
||||
creator_row = result.scalars().first()
|
||||
except Exception:
|
||||
logger.warning("chat_personality_db_error creator=%r", creator_name, exc_info=True)
|
||||
return system_prompt
|
||||
|
||||
if creator_row is None or creator_row.personality_profile is None:
|
||||
logger.debug("chat_personality_skip creator=%r reason=%s",
|
||||
creator_name,
|
||||
"not_found" if creator_row is None else "null_profile")
|
||||
return system_prompt
|
||||
|
||||
profile = creator_row.personality_profile
|
||||
voice_block = _build_personality_block(creator_name, profile, weight)
|
||||
return system_prompt + "\n\n" + voice_block
|
||||
|
||||
async def stream_response(
|
||||
self,
|
||||
query: str,
|
||||
db: AsyncSession,
|
||||
creator: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
personality_weight: float = 0.0,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Yield SSE-formatted events for a chat query.
|
||||
|
||||
|
|
@ -151,6 +184,15 @@ class ChatService:
|
|||
# ── 3. Stream LLM completion ────────────────────────────────────
|
||||
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
|
||||
|
||||
# Inject creator personality voice when weight > 0
|
||||
if personality_weight > 0 and creator:
|
||||
system_prompt = await self._inject_personality(
|
||||
system_prompt, db, creator, personality_weight,
|
||||
)
|
||||
|
||||
# Scale temperature with personality weight: 0.3 (encyclopedic) → 0.5 (full personality)
|
||||
temperature = 0.3 + (personality_weight * 0.2)
|
||||
|
||||
messages: list[dict[str, str]] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
]
|
||||
|
|
@ -165,7 +207,7 @@ class ChatService:
|
|||
model=self.settings.llm_model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
temperature=0.3,
|
||||
temperature=temperature,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
|
|
@ -245,3 +287,47 @@ def _build_context_block(items: list[dict[str, Any]]) -> str:
|
|||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _build_personality_block(creator_name: str, profile: dict[str, Any], weight: float) -> str:
|
||||
"""Build a personality voice injection block from a creator's personality_profile JSONB.
|
||||
|
||||
The ``weight`` (0.0–1.0) determines how strongly the personality should
|
||||
come through. At low weights the instruction is softer ("subtly adopt");
|
||||
at high weights it is emphatic ("fully embody").
|
||||
"""
|
||||
vocab = profile.get("vocabulary", {})
|
||||
tone = profile.get("tone", {})
|
||||
style = profile.get("style_markers", {})
|
||||
|
||||
phrases = vocab.get("signature_phrases", [])
|
||||
descriptors = tone.get("descriptors", [])
|
||||
teaching_style = tone.get("teaching_style", "")
|
||||
energy = tone.get("energy", "moderate")
|
||||
formality = tone.get("formality", "conversational")
|
||||
|
||||
parts: list[str] = []
|
||||
|
||||
# Intensity qualifier
|
||||
if weight >= 0.8:
|
||||
parts.append(f"Fully embody {creator_name}'s voice and style.")
|
||||
elif weight >= 0.4:
|
||||
parts.append(f"Respond in {creator_name}'s voice.")
|
||||
else:
|
||||
parts.append(f"Subtly adopt {creator_name}'s communication style.")
|
||||
|
||||
if teaching_style:
|
||||
parts.append(f"Teaching style: {teaching_style}.")
|
||||
if descriptors:
|
||||
parts.append(f"Tone: {', '.join(descriptors[:5])}.")
|
||||
if phrases:
|
||||
parts.append(f"Use their signature phrases: {', '.join(phrases[:6])}.")
|
||||
parts.append(f"Match their {formality} {energy} tone.")
|
||||
|
||||
# Style markers
|
||||
if style.get("uses_analogies"):
|
||||
parts.append("Use analogies when helpful.")
|
||||
if style.get("audience_engagement"):
|
||||
parts.append(f"Audience engagement: {style['audience_engagement']}.")
|
||||
|
||||
return " ".join(parts)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ class ChatRequest(BaseModel):
|
|||
query: str = Field(..., min_length=1, max_length=1000)
|
||||
creator: str | None = None
|
||||
conversation_id: str | None = None
|
||||
personality_weight: float = Field(default=0.0, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
@router.post("")
|
||||
|
|
@ -45,7 +46,7 @@ async def chat(
|
|||
- ``event: done`` — completion metadata with cascade_tier, conversation_id
|
||||
- ``event: error`` — error message (on failure)
|
||||
"""
|
||||
logger.info("chat_request query=%r creator=%r cid=%r", body.query, body.creator, body.conversation_id)
|
||||
logger.info("chat_request query=%r creator=%r cid=%r weight=%.2f", body.query, body.creator, body.conversation_id, body.personality_weight)
|
||||
|
||||
redis = await get_redis()
|
||||
service = ChatService(settings, redis=redis)
|
||||
|
|
@ -56,6 +57,7 @@ async def chat(
|
|||
db=db,
|
||||
creator=body.creator,
|
||||
conversation_id=body.conversation_id,
|
||||
personality_weight=body.personality_weight,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
|
|||
|
|
@ -563,3 +563,307 @@ async def test_single_turn_fallback_no_redis_history(chat_client, mock_redis):
|
|||
assert len(captured_messages) == 2
|
||||
assert captured_messages[0]["role"] == "system"
|
||||
assert captured_messages[1]["role"] == "user"
|
||||
|
||||
|
||||
# ── Personality weight tests ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
_FAKE_PERSONALITY_PROFILE = {
|
||||
"vocabulary": {
|
||||
"signature_phrases": ["let's gooo", "that's fire"],
|
||||
"jargon_level": "mixed",
|
||||
"filler_words": [],
|
||||
"distinctive_terms": ["sauce", "vibes"],
|
||||
"sound_descriptions": ["crispy", "punchy"],
|
||||
},
|
||||
"tone": {
|
||||
"formality": "casual",
|
||||
"energy": "high",
|
||||
"humor": "occasional",
|
||||
"teaching_style": "hands-on demo-driven",
|
||||
"descriptors": ["enthusiastic", "direct", "encouraging"],
|
||||
},
|
||||
"style_markers": {
|
||||
"explanation_approach": "example-first",
|
||||
"uses_analogies": True,
|
||||
"analogy_examples": ["like cooking a steak"],
|
||||
"sound_words": ["brrr", "thwack"],
|
||||
"self_references": "I always",
|
||||
"audience_engagement": "asks rhetorical questions",
|
||||
"pacing": "fast",
|
||||
},
|
||||
"summary": "High-energy producer who teaches by doing.",
|
||||
}
|
||||
|
||||
|
||||
def _mock_creator_row(name: str, profile: dict | None):
|
||||
"""Build a mock Creator ORM row with just the fields personality injection needs."""
|
||||
row = MagicMock()
|
||||
row.name = name
|
||||
row.personality_profile = profile
|
||||
return row
|
||||
|
||||
|
||||
def _mock_db_execute(creator_row):
|
||||
"""Return a mock db.execute that yields a scalars().first() result."""
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.first.return_value = creator_row
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
return AsyncMock(return_value=mock_result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_personality_weight_accepted_and_forwarded(chat_client):
|
||||
"""personality_weight is accepted in the request and forwarded to stream_response."""
|
||||
search_result = _fake_search_result()
|
||||
|
||||
captured_kwargs = {}
|
||||
mock_openai_client = MagicMock()
|
||||
|
||||
async def _capture_create(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return _mock_openai_stream(["ok"])
|
||||
|
||||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||||
|
||||
with (
|
||||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||
):
|
||||
resp = await chat_client.post(
|
||||
"/api/v1/chat",
|
||||
json={"query": "test", "creator": "Keota", "personality_weight": 0.7},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
events = _parse_sse(resp.text)
|
||||
event_types = [e["event"] for e in events]
|
||||
assert "done" in event_types
|
||||
# Temperature should reflect the weight: 0.3 + 0.7*0.2 = 0.44
|
||||
assert captured_kwargs.get("temperature") == pytest.approx(0.44)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_personality_prompt_injected_when_weight_and_profile(chat_client):
|
||||
"""System prompt includes personality context when weight > 0 and profile exists."""
|
||||
search_result = _fake_search_result()
|
||||
creator_row = _mock_creator_row("Keota", _FAKE_PERSONALITY_PROFILE)
|
||||
|
||||
captured_messages = []
|
||||
mock_openai_client = MagicMock()
|
||||
|
||||
async def _capture_create(**kwargs):
|
||||
captured_messages.extend(kwargs.get("messages", []))
|
||||
return _mock_openai_stream(["personality answer"])
|
||||
|
||||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||||
|
||||
with (
|
||||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||
):
|
||||
# We need to mock db.execute inside the service — override the session
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = _mock_db_execute(creator_row)
|
||||
|
||||
async def _mock_get_session():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_session] = _mock_get_session
|
||||
|
||||
resp = await chat_client.post(
|
||||
"/api/v1/chat",
|
||||
json={"query": "snare tips", "creator": "Keota", "personality_weight": 0.7},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert len(captured_messages) >= 2
|
||||
|
||||
system_prompt = captured_messages[0]["content"]
|
||||
# Personality block should be appended
|
||||
assert "Keota" in system_prompt
|
||||
assert "let's gooo" in system_prompt
|
||||
assert "hands-on demo-driven" in system_prompt
|
||||
assert "casual" in system_prompt
|
||||
assert "high" in system_prompt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_personality_encyclopedic_fallback_null_profile(chat_client):
|
||||
"""When weight > 0 but personality_profile is null, falls back to encyclopedic prompt."""
|
||||
search_result = _fake_search_result()
|
||||
creator_row = _mock_creator_row("NullCreator", None)
|
||||
|
||||
captured_messages = []
|
||||
mock_openai_client = MagicMock()
|
||||
|
||||
async def _capture_create(**kwargs):
|
||||
captured_messages.extend(kwargs.get("messages", []))
|
||||
return _mock_openai_stream(["encyclopedic answer"])
|
||||
|
||||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||||
|
||||
with (
|
||||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = _mock_db_execute(creator_row)
|
||||
|
||||
async def _mock_get_session():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_session] = _mock_get_session
|
||||
|
||||
resp = await chat_client.post(
|
||||
"/api/v1/chat",
|
||||
json={"query": "reverb tips", "creator": "NullCreator", "personality_weight": 0.5},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
system_prompt = captured_messages[0]["content"]
|
||||
# Should be the standard encyclopedic prompt, no personality injection
|
||||
assert "Chrysopedia" in system_prompt
|
||||
assert "NullCreator" not in system_prompt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_personality_encyclopedic_fallback_missing_creator(chat_client):
|
||||
"""When weight > 0 but creator doesn't exist in DB, falls back to encyclopedic prompt."""
|
||||
search_result = _fake_search_result()
|
||||
|
||||
captured_messages = []
|
||||
mock_openai_client = MagicMock()
|
||||
|
||||
async def _capture_create(**kwargs):
|
||||
captured_messages.extend(kwargs.get("messages", []))
|
||||
return _mock_openai_stream(["encyclopedic answer"])
|
||||
|
||||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||||
|
||||
with (
|
||||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = _mock_db_execute(None) # No creator found
|
||||
|
||||
async def _mock_get_session():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_session] = _mock_get_session
|
||||
|
||||
resp = await chat_client.post(
|
||||
"/api/v1/chat",
|
||||
json={"query": "bass tips", "creator": "GhostCreator", "personality_weight": 0.8},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
system_prompt = captured_messages[0]["content"]
|
||||
assert "Chrysopedia" in system_prompt
|
||||
assert "GhostCreator" not in system_prompt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_personality_weight_zero_skips_profile_query(chat_client):
|
||||
"""When weight is 0.0, no Creator query is made even if creator is set."""
|
||||
search_result = _fake_search_result()
|
||||
|
||||
captured_kwargs = {}
|
||||
mock_openai_client = MagicMock()
|
||||
|
||||
async def _capture_create(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return _mock_openai_stream(["ok"])
|
||||
|
||||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||||
|
||||
with (
|
||||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock() # Should NOT be called
|
||||
|
||||
async def _mock_get_session():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_session] = _mock_get_session
|
||||
|
||||
resp = await chat_client.post(
|
||||
"/api/v1/chat",
|
||||
json={"query": "test", "creator": "Keota", "personality_weight": 0.0},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
# DB execute should not have been called for Creator lookup
|
||||
mock_session.execute.assert_not_called()
|
||||
# Temperature should be 0.3 (base)
|
||||
assert captured_kwargs.get("temperature") == pytest.approx(0.3)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_personality_temperature_scales_with_weight(chat_client):
|
||||
"""Temperature scales: 0.3 at weight=0.0, 0.5 at weight=1.0."""
|
||||
search_result = _fake_search_result()
|
||||
creator_row = _mock_creator_row("Keota", _FAKE_PERSONALITY_PROFILE)
|
||||
|
||||
captured_kwargs = {}
|
||||
mock_openai_client = MagicMock()
|
||||
|
||||
async def _capture_create(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return _mock_openai_stream(["warm"])
|
||||
|
||||
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
|
||||
|
||||
with (
|
||||
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
|
||||
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = _mock_db_execute(creator_row)
|
||||
|
||||
async def _mock_get_session():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_session] = _mock_get_session
|
||||
|
||||
resp = await chat_client.post(
|
||||
"/api/v1/chat",
|
||||
json={"query": "test", "creator": "Keota", "personality_weight": 1.0},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert captured_kwargs.get("temperature") == pytest.approx(0.5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_personality_weight_above_1_returns_422(chat_client):
|
||||
"""personality_weight > 1.0 fails Pydantic validation with 422."""
|
||||
resp = await chat_client.post(
|
||||
"/api/v1/chat",
|
||||
json={"query": "test", "personality_weight": 1.5},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_personality_weight_below_0_returns_422(chat_client):
|
||||
"""personality_weight < 0.0 fails Pydantic validation with 422."""
|
||||
resp = await chat_client.post(
|
||||
"/api/v1/chat",
|
||||
json={"query": "test", "personality_weight": -0.1},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_personality_weight_string_returns_422(chat_client):
|
||||
"""personality_weight as a non-numeric string fails validation with 422."""
|
||||
resp = await chat_client.post(
|
||||
"/api/v1/chat",
|
||||
json={"query": "test", "personality_weight": "high"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue