feat: Added personality_weight (0.0–1.0) to chat API; modulates system…

- "backend/routers/chat.py"
- "backend/chat_service.py"
- "backend/tests/test_chat.py"

GSD-Task: S02/T01
This commit is contained in:
jlightner 2026-04-04 09:28:35 +00:00
parent 04630764a6
commit 0856827b59
3 changed files with 394 additions and 2 deletions

View file

@ -21,9 +21,11 @@ import uuid
from typing import Any, AsyncIterator from typing import Any, AsyncIterator
import openai import openai
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from config import Settings from config import Settings
from models import Creator
from search_service import SearchService from search_service import SearchService
logger = logging.getLogger("chrysopedia.chat") logger = logging.getLogger("chrysopedia.chat")
@ -95,12 +97,43 @@ class ChatService:
except Exception: except Exception:
logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True) logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True)
async def _inject_personality(
self,
system_prompt: str,
db: AsyncSession,
creator_name: str,
weight: float,
) -> str:
"""Query creator personality_profile and append a voice block to the system prompt.
Falls back to the unmodified prompt on DB error, missing creator, or null profile.
"""
try:
result = await db.execute(
select(Creator).where(Creator.name == creator_name)
)
creator_row = result.scalars().first()
except Exception:
logger.warning("chat_personality_db_error creator=%r", creator_name, exc_info=True)
return system_prompt
if creator_row is None or creator_row.personality_profile is None:
logger.debug("chat_personality_skip creator=%r reason=%s",
creator_name,
"not_found" if creator_row is None else "null_profile")
return system_prompt
profile = creator_row.personality_profile
voice_block = _build_personality_block(creator_name, profile, weight)
return system_prompt + "\n\n" + voice_block
async def stream_response( async def stream_response(
self, self,
query: str, query: str,
db: AsyncSession, db: AsyncSession,
creator: str | None = None, creator: str | None = None,
conversation_id: str | None = None, conversation_id: str | None = None,
personality_weight: float = 0.0,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Yield SSE-formatted events for a chat query. """Yield SSE-formatted events for a chat query.
@ -151,6 +184,15 @@ class ChatService:
# ── 3. Stream LLM completion ──────────────────────────────────── # ── 3. Stream LLM completion ────────────────────────────────────
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block) system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
# Inject creator personality voice when weight > 0
if personality_weight > 0 and creator:
system_prompt = await self._inject_personality(
system_prompt, db, creator, personality_weight,
)
# Scale temperature with personality weight: 0.3 (encyclopedic) → 0.5 (full personality)
temperature = 0.3 + (personality_weight * 0.2)
messages: list[dict[str, str]] = [ messages: list[dict[str, str]] = [
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
] ]
@ -165,7 +207,7 @@ class ChatService:
model=self.settings.llm_model, model=self.settings.llm_model,
messages=messages, messages=messages,
stream=True, stream=True,
temperature=0.3, temperature=temperature,
max_tokens=2048, max_tokens=2048,
) )
@ -245,3 +287,47 @@ def _build_context_block(items: list[dict[str, Any]]) -> str:
lines.append("") lines.append("")
return "\n".join(lines) return "\n".join(lines)
def _build_personality_block(creator_name: str, profile: dict[str, Any], weight: float) -> str:
"""Build a personality voice injection block from a creator's personality_profile JSONB.
The ``weight`` (0.01.0) determines how strongly the personality should
come through. At low weights the instruction is softer ("subtly adopt");
at high weights it is emphatic ("fully embody").
"""
vocab = profile.get("vocabulary", {})
tone = profile.get("tone", {})
style = profile.get("style_markers", {})
phrases = vocab.get("signature_phrases", [])
descriptors = tone.get("descriptors", [])
teaching_style = tone.get("teaching_style", "")
energy = tone.get("energy", "moderate")
formality = tone.get("formality", "conversational")
parts: list[str] = []
# Intensity qualifier
if weight >= 0.8:
parts.append(f"Fully embody {creator_name}'s voice and style.")
elif weight >= 0.4:
parts.append(f"Respond in {creator_name}'s voice.")
else:
parts.append(f"Subtly adopt {creator_name}'s communication style.")
if teaching_style:
parts.append(f"Teaching style: {teaching_style}.")
if descriptors:
parts.append(f"Tone: {', '.join(descriptors[:5])}.")
if phrases:
parts.append(f"Use their signature phrases: {', '.join(phrases[:6])}.")
parts.append(f"Match their {formality} {energy} tone.")
# Style markers
if style.get("uses_analogies"):
parts.append("Use analogies when helpful.")
if style.get("audience_engagement"):
parts.append(f"Audience engagement: {style['audience_engagement']}.")
return " ".join(parts)

View file

@ -29,6 +29,7 @@ class ChatRequest(BaseModel):
query: str = Field(..., min_length=1, max_length=1000) query: str = Field(..., min_length=1, max_length=1000)
creator: str | None = None creator: str | None = None
conversation_id: str | None = None conversation_id: str | None = None
personality_weight: float = Field(default=0.0, ge=0.0, le=1.0)
@router.post("") @router.post("")
@ -45,7 +46,7 @@ async def chat(
- ``event: done`` completion metadata with cascade_tier, conversation_id - ``event: done`` completion metadata with cascade_tier, conversation_id
- ``event: error`` error message (on failure) - ``event: error`` error message (on failure)
""" """
logger.info("chat_request query=%r creator=%r cid=%r", body.query, body.creator, body.conversation_id) logger.info("chat_request query=%r creator=%r cid=%r weight=%.2f", body.query, body.creator, body.conversation_id, body.personality_weight)
redis = await get_redis() redis = await get_redis()
service = ChatService(settings, redis=redis) service = ChatService(settings, redis=redis)
@ -56,6 +57,7 @@ async def chat(
db=db, db=db,
creator=body.creator, creator=body.creator,
conversation_id=body.conversation_id, conversation_id=body.conversation_id,
personality_weight=body.personality_weight,
), ),
media_type="text/event-stream", media_type="text/event-stream",
headers={ headers={

View file

@ -563,3 +563,307 @@ async def test_single_turn_fallback_no_redis_history(chat_client, mock_redis):
assert len(captured_messages) == 2 assert len(captured_messages) == 2
assert captured_messages[0]["role"] == "system" assert captured_messages[0]["role"] == "system"
assert captured_messages[1]["role"] == "user" assert captured_messages[1]["role"] == "user"
# ── Personality weight tests ─────────────────────────────────────────────────
_FAKE_PERSONALITY_PROFILE = {
"vocabulary": {
"signature_phrases": ["let's gooo", "that's fire"],
"jargon_level": "mixed",
"filler_words": [],
"distinctive_terms": ["sauce", "vibes"],
"sound_descriptions": ["crispy", "punchy"],
},
"tone": {
"formality": "casual",
"energy": "high",
"humor": "occasional",
"teaching_style": "hands-on demo-driven",
"descriptors": ["enthusiastic", "direct", "encouraging"],
},
"style_markers": {
"explanation_approach": "example-first",
"uses_analogies": True,
"analogy_examples": ["like cooking a steak"],
"sound_words": ["brrr", "thwack"],
"self_references": "I always",
"audience_engagement": "asks rhetorical questions",
"pacing": "fast",
},
"summary": "High-energy producer who teaches by doing.",
}
def _mock_creator_row(name: str, profile: dict | None):
"""Build a mock Creator ORM row with just the fields personality injection needs."""
row = MagicMock()
row.name = name
row.personality_profile = profile
return row
def _mock_db_execute(creator_row):
"""Return a mock db.execute that yields a scalars().first() result."""
mock_scalars = MagicMock()
mock_scalars.first.return_value = creator_row
mock_result = MagicMock()
mock_result.scalars.return_value = mock_scalars
return AsyncMock(return_value=mock_result)
@pytest.mark.asyncio
async def test_personality_weight_accepted_and_forwarded(chat_client):
"""personality_weight is accepted in the request and forwarded to stream_response."""
search_result = _fake_search_result()
captured_kwargs = {}
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_kwargs.update(kwargs)
return _mock_openai_stream(["ok"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "creator": "Keota", "personality_weight": 0.7},
)
assert resp.status_code == 200
events = _parse_sse(resp.text)
event_types = [e["event"] for e in events]
assert "done" in event_types
# Temperature should reflect the weight: 0.3 + 0.7*0.2 = 0.44
assert captured_kwargs.get("temperature") == pytest.approx(0.44)
@pytest.mark.asyncio
async def test_personality_prompt_injected_when_weight_and_profile(chat_client):
"""System prompt includes personality context when weight > 0 and profile exists."""
search_result = _fake_search_result()
creator_row = _mock_creator_row("Keota", _FAKE_PERSONALITY_PROFILE)
captured_messages = []
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_messages.extend(kwargs.get("messages", []))
return _mock_openai_stream(["personality answer"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
# We need to mock db.execute inside the service — override the session
mock_session = AsyncMock()
mock_session.execute = _mock_db_execute(creator_row)
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "snare tips", "creator": "Keota", "personality_weight": 0.7},
)
assert resp.status_code == 200
assert len(captured_messages) >= 2
system_prompt = captured_messages[0]["content"]
# Personality block should be appended
assert "Keota" in system_prompt
assert "let's gooo" in system_prompt
assert "hands-on demo-driven" in system_prompt
assert "casual" in system_prompt
assert "high" in system_prompt
@pytest.mark.asyncio
async def test_personality_encyclopedic_fallback_null_profile(chat_client):
"""When weight > 0 but personality_profile is null, falls back to encyclopedic prompt."""
search_result = _fake_search_result()
creator_row = _mock_creator_row("NullCreator", None)
captured_messages = []
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_messages.extend(kwargs.get("messages", []))
return _mock_openai_stream(["encyclopedic answer"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
mock_session = AsyncMock()
mock_session.execute = _mock_db_execute(creator_row)
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "reverb tips", "creator": "NullCreator", "personality_weight": 0.5},
)
assert resp.status_code == 200
system_prompt = captured_messages[0]["content"]
# Should be the standard encyclopedic prompt, no personality injection
assert "Chrysopedia" in system_prompt
assert "NullCreator" not in system_prompt
@pytest.mark.asyncio
async def test_personality_encyclopedic_fallback_missing_creator(chat_client):
"""When weight > 0 but creator doesn't exist in DB, falls back to encyclopedic prompt."""
search_result = _fake_search_result()
captured_messages = []
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_messages.extend(kwargs.get("messages", []))
return _mock_openai_stream(["encyclopedic answer"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
mock_session = AsyncMock()
mock_session.execute = _mock_db_execute(None) # No creator found
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "bass tips", "creator": "GhostCreator", "personality_weight": 0.8},
)
assert resp.status_code == 200
system_prompt = captured_messages[0]["content"]
assert "Chrysopedia" in system_prompt
assert "GhostCreator" not in system_prompt
@pytest.mark.asyncio
async def test_personality_weight_zero_skips_profile_query(chat_client):
"""When weight is 0.0, no Creator query is made even if creator is set."""
search_result = _fake_search_result()
captured_kwargs = {}
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_kwargs.update(kwargs)
return _mock_openai_stream(["ok"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
mock_session = AsyncMock()
mock_session.execute = AsyncMock() # Should NOT be called
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "creator": "Keota", "personality_weight": 0.0},
)
assert resp.status_code == 200
# DB execute should not have been called for Creator lookup
mock_session.execute.assert_not_called()
# Temperature should be 0.3 (base)
assert captured_kwargs.get("temperature") == pytest.approx(0.3)
@pytest.mark.asyncio
async def test_personality_temperature_scales_with_weight(chat_client):
"""Temperature scales: 0.3 at weight=0.0, 0.5 at weight=1.0."""
search_result = _fake_search_result()
creator_row = _mock_creator_row("Keota", _FAKE_PERSONALITY_PROFILE)
captured_kwargs = {}
mock_openai_client = MagicMock()
async def _capture_create(**kwargs):
captured_kwargs.update(kwargs)
return _mock_openai_stream(["warm"])
mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create)
with (
patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result),
patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client),
):
mock_session = AsyncMock()
mock_session.execute = _mock_db_execute(creator_row)
async def _mock_get_session():
yield mock_session
app.dependency_overrides[get_session] = _mock_get_session
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "creator": "Keota", "personality_weight": 1.0},
)
assert resp.status_code == 200
assert captured_kwargs.get("temperature") == pytest.approx(0.5)
@pytest.mark.asyncio
async def test_personality_weight_above_1_returns_422(chat_client):
"""personality_weight > 1.0 fails Pydantic validation with 422."""
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "personality_weight": 1.5},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_personality_weight_below_0_returns_422(chat_client):
"""personality_weight < 0.0 fails Pydantic validation with 422."""
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "personality_weight": -0.1},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_personality_weight_string_returns_422(chat_client):
"""personality_weight as a non-numeric string fails validation with 422."""
resp = await chat_client.post(
"/api/v1/chat",
json={"query": "test", "personality_weight": "high"},
)
assert resp.status_code == 422