feat: Added personality_weight (0.0–1.0) to chat API; modulates system…

- "backend/routers/chat.py"
- "backend/chat_service.py"
- "backend/tests/test_chat.py"

GSD-Task: S02/T01
This commit is contained in:
jlightner 2026-04-04 09:28:35 +00:00
parent 04630764a6
commit 0856827b59
3 changed files with 394 additions and 2 deletions

View file

@ -21,9 +21,11 @@ import uuid
from typing import Any, AsyncIterator
import openai
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from config import Settings
from models import Creator
from search_service import SearchService
logger = logging.getLogger("chrysopedia.chat")
@ -95,12 +97,43 @@ class ChatService:
except Exception:
logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True)
async def _inject_personality(
self,
system_prompt: str,
db: AsyncSession,
creator_name: str,
weight: float,
) -> str:
"""Query creator personality_profile and append a voice block to the system prompt.
Falls back to the unmodified prompt on DB error, missing creator, or null profile.
"""
try:
result = await db.execute(
select(Creator).where(Creator.name == creator_name)
)
creator_row = result.scalars().first()
except Exception:
logger.warning("chat_personality_db_error creator=%r", creator_name, exc_info=True)
return system_prompt
if creator_row is None or creator_row.personality_profile is None:
logger.debug("chat_personality_skip creator=%r reason=%s",
creator_name,
"not_found" if creator_row is None else "null_profile")
return system_prompt
profile = creator_row.personality_profile
voice_block = _build_personality_block(creator_name, profile, weight)
return system_prompt + "\n\n" + voice_block
async def stream_response(
self,
query: str,
db: AsyncSession,
creator: str | None = None,
conversation_id: str | None = None,
personality_weight: float = 0.0,
) -> AsyncIterator[str]:
"""Yield SSE-formatted events for a chat query.
@ -151,6 +184,15 @@ class ChatService:
# ── 3. Stream LLM completion ────────────────────────────────────
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
# Inject creator personality voice when weight > 0
if personality_weight > 0 and creator:
system_prompt = await self._inject_personality(
system_prompt, db, creator, personality_weight,
)
# Scale temperature with personality weight: 0.3 (encyclopedic) → 0.5 (full personality)
temperature = 0.3 + (personality_weight * 0.2)
messages: list[dict[str, str]] = [
{"role": "system", "content": system_prompt},
]
@ -165,7 +207,7 @@ class ChatService:
model=self.settings.llm_model,
messages=messages,
stream=True,
temperature=0.3,
temperature=temperature,
max_tokens=2048,
)
@ -245,3 +287,47 @@ def _build_context_block(items: list[dict[str, Any]]) -> str:
lines.append("")
return "\n".join(lines)
def _build_personality_block(creator_name: str, profile: dict[str, Any], weight: float) -> str:
"""Build a personality voice injection block from a creator's personality_profile JSONB.
The ``weight`` (0.01.0) determines how strongly the personality should
come through. At low weights the instruction is softer ("subtly adopt");
at high weights it is emphatic ("fully embody").
"""
vocab = profile.get("vocabulary", {})
tone = profile.get("tone", {})
style = profile.get("style_markers", {})
phrases = vocab.get("signature_phrases", [])
descriptors = tone.get("descriptors", [])
teaching_style = tone.get("teaching_style", "")
energy = tone.get("energy", "moderate")
formality = tone.get("formality", "conversational")
parts: list[str] = []
# Intensity qualifier
if weight >= 0.8:
parts.append(f"Fully embody {creator_name}'s voice and style.")
elif weight >= 0.4:
parts.append(f"Respond in {creator_name}'s voice.")
else:
parts.append(f"Subtly adopt {creator_name}'s communication style.")
if teaching_style:
parts.append(f"Teaching style: {teaching_style}.")
if descriptors:
parts.append(f"Tone: {', '.join(descriptors[:5])}.")
if phrases:
parts.append(f"Use their signature phrases: {', '.join(phrases[:6])}.")
parts.append(f"Match their {formality} {energy} tone.")
# Style markers
if style.get("uses_analogies"):
parts.append("Use analogies when helpful.")
if style.get("audience_engagement"):
parts.append(f"Audience engagement: {style['audience_engagement']}.")
return " ".join(parts)

View file

@ -29,6 +29,7 @@ class ChatRequest(BaseModel):
query: str = Field(..., min_length=1, max_length=1000)
creator: str | None = None
conversation_id: str | None = None
personality_weight: float = Field(default=0.0, ge=0.0, le=1.0)
@router.post("")
@ -45,7 +46,7 @@ async def chat(
- ``event: done`` completion metadata with cascade_tier, conversation_id
- ``event: error`` error message (on failure)
"""
logger.info("chat_request query=%r creator=%r cid=%r", body.query, body.creator, body.conversation_id)
logger.info("chat_request query=%r creator=%r cid=%r weight=%.2f", body.query, body.creator, body.conversation_id, body.personality_weight)
redis = await get_redis()
service = ChatService(settings, redis=redis)
@ -56,6 +57,7 @@ async def chat(
db=db,
creator=body.creator,
conversation_id=body.conversation_id,
personality_weight=body.personality_weight,
),
media_type="text/event-stream",
headers={

View file

@ -563,3 +563,307 @@ async def test_single_turn_fallback_no_redis_history(chat_client, mock_redis):
assert len(captured_messages) == 2
assert captured_messages[0]["role"] == "system"
assert captured_messages[1]["role"] == "user"
# ── Personality weight tests ─────────────────────────────────────────────────
_FAKE_PERSONALITY_PROFILE = {
"vocabulary": {
"signature_phrases": ["let's gooo", "that's fire"],
"jargon_level": "mixed",
"filler_words": [],
"distinctive_terms": ["sauce", "vibes"],
"sound_descriptions": ["crispy", "punchy"],
},
"tone": {
"formality": "casual",
"energy": "high",
"humor": "occasional",
"teaching_style": "hands-on demo-driven",
"descriptors": ["enthusiastic", "direct", "encouraging"],
},
"style_markers": {
"explanation_approach": "example-first",
"uses_analogies": True,
"analogy_examples": ["like cooking a steak"],
"sound_words": ["brrr", "thwack"],
"self_references": "I always",
"audience_engagement": "asks rhetorical questions",
"pacing": "fast",
},
"summary": "High-energy producer who teaches by doing.",
}
def _mock_creator_row(name: str, profile: dict | None):
"""Build a mock Creator ORM row with just the fields personality injection needs."""
row = MagicMock()
row.name = name
row.personality_profile = profile
return row
def _mock_db_execute(creator_row):
"""Return a mock db.execute that yields a scalars().first() result."""
mock_scalars = MagicMock()
mock_scalars.first.return_value = creator_row
mock_result = MagicMock()
mock_result.scalars.return_value = mock_scalars
return AsyncMock(return_value=mock_result)
@pytest.mark.asyncio
async def test_personality_weight_accepted_and_forwarded(chat_client):
"""personality_weight is accepted in the request and forwarded to stream_response."""
search_result = _fake_search_result()
captured_kwargs = {}
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_kwargs.update(kwargs)
return _mock_openai_stream(["ok"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "creator": "Keota", "personality_weight": 0.7},
)
assert resp.status_code == 200
events = _parse_sse(resp.text)
event_types = [e["event"] for e in events]
assert "done" in event_types
# Temperature should reflect the weight: 0.3 + 0.7*0.2 = 0.44
assert captured_kwargs.get("temperature") == pytest.approx(0.44)
@pytest.mark.asyncio
async def test_personality_prompt_injected_when_weight_and_profile(chat_client):
"""System prompt includes personality context when weight > 0 and profile exists."""
search_result = _fake_search_result()
creator_row = _mock_creator_row("Keota", _FAKE_PERSONALITY_PROFILE)
captured_messages = []
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_messages.extend(kwargs.get("messages", []))
return _mock_openai_stream(["personality answer"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
# We need to mock db.execute inside the service — override the session
mock_session = AsyncMock()
mock_session.execute = _mock_db_execute(creator_row)
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "snare tips", "creator": "Keota", "personality_weight": 0.7},
)
assert resp.status_code == 200
assert len(captured_messages) >= 2
system_prompt = captured_messages[0]["content"]
# Personality block should be appended
assert "Keota" in system_prompt
assert "let's gooo" in system_prompt
assert "hands-on demo-driven" in system_prompt
assert "casual" in system_prompt
assert "high" in system_prompt
@pytest.mark.asyncio
async def test_personality_encyclopedic_fallback_null_profile(chat_client):
"""When weight > 0 but personality_profile is null, falls back to encyclopedic prompt."""
search_result = _fake_search_result()
creator_row = _mock_creator_row("NullCreator", None)
captured_messages = []
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_messages.extend(kwargs.get("messages", []))
return _mock_openai_stream(["encyclopedic answer"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
mock_session = AsyncMock()
mock_session.execute = _mock_db_execute(creator_row)
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "reverb tips", "creator": "NullCreator", "personality_weight": 0.5},
)
assert resp.status_code == 200
system_prompt = captured_messages[0]["content"]
# Should be the standard encyclopedic prompt, no personality injection
assert "Chrysopedia" in system_prompt
assert "NullCreator" not in system_prompt
@pytest.mark.asyncio
async def test_personality_encyclopedic_fallback_missing_creator(chat_client):
"""When weight > 0 but creator doesn't exist in DB, falls back to encyclopedic prompt."""
search_result = _fake_search_result()
captured_messages = []
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_messages.extend(kwargs.get("messages", []))
return _mock_openai_stream(["encyclopedic answer"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
mock_session = AsyncMock()
mock_session.execute = _mock_db_execute(None) # No creator found
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "bass tips", "creator": "GhostCreator", "personality_weight": 0.8},
)
assert resp.status_code == 200
system_prompt = captured_messages[0]["content"]
assert "Chrysopedia" in system_prompt
assert "GhostCreator" not in system_prompt
@pytest.mark.asyncio
async def test_personality_weight_zero_skips_profile_query(chat_client):
"""When weight is 0.0, no Creator query is made even if creator is set."""
search_result = _fake_search_result()
captured_kwargs = {}
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_kwargs.update(kwargs)
return _mock_openai_stream(["ok"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
mock_session = AsyncMock()
mock_session.execute = AsyncMock() # Should NOT be called
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "creator": "Keota", "personality_weight": 0.0},
)
assert resp.status_code == 200
# DB execute should not have been called for Creator lookup
mock_session.execute.assert_not_called()
# Temperature should be 0.3 (base)
assert captured_kwargs.get("temperature") == pytest.approx(0.3)
@pytest.mark.asyncio
async def test_personality_temperature_scales_with_weight(chat_client):
"""Temperature scales: 0.3 at weight=0.0, 0.5 at weight=1.0."""
search_result = _fake_search_result()
creator_row = _mock_creator_row("Keota", _FAKE_PERSONALITY_PROFILE)
captured_kwargs = {}
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_kwargs.update(kwargs)
return _mock_openai_stream(["warm"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
mock_session = AsyncMock()
mock_session.execute = _mock_db_execute(creator_row)
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "creator": "Keota", "personality_weight": 1.0},
)
assert resp.status_code == 200
assert captured_kwargs.get("temperature") == pytest.approx(0.5)
@pytest.mark.asyncio
async def test_personality_weight_above_1_returns_422(chat_client):
"""personality_weight > 1.0 fails Pydantic validation with 422."""
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "personality_weight": 1.5},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_personality_weight_below_0_returns_422(chat_client):
"""personality_weight < 0.0 fails Pydantic validation with 422."""
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "personality_weight": -0.1},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_personality_weight_string_returns_422(chat_client):
"""personality_weight as a non-numeric string fails validation with 422."""
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "personality_weight": "high"},
)
assert resp.status_code == 422