From 0856827b590efb00802374c56f0efa1b8f84f916 Mon Sep 17 00:00:00 2001 From: jlightner Date: Sat, 4 Apr 2026 09:28:35 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20Added=20personality=5Fweight=20(0.0?= =?UTF-8?q?=E2=80=931.0)=20to=20chat=20API;=20modulates=20system=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - "backend/routers/chat.py" - "backend/chat_service.py" - "backend/tests/test_chat.py" GSD-Task: S02/T01 --- backend/chat_service.py | 88 ++++++++++- backend/routers/chat.py | 4 +- backend/tests/test_chat.py | 304 +++++++++++++++++++++++++++++++++++++ 3 files changed, 394 insertions(+), 2 deletions(-) diff --git a/backend/chat_service.py b/backend/chat_service.py index 4cf0ba5..13dbea9 100644 --- a/backend/chat_service.py +++ b/backend/chat_service.py @@ -21,9 +21,11 @@ import uuid from typing import Any, AsyncIterator import openai +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from config import Settings +from models import Creator from search_service import SearchService logger = logging.getLogger("chrysopedia.chat") @@ -95,12 +97,43 @@ class ChatService: except Exception: logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True) + async def _inject_personality( + self, + system_prompt: str, + db: AsyncSession, + creator_name: str, + weight: float, + ) -> str: + """Query creator personality_profile and append a voice block to the system prompt. + + Falls back to the unmodified prompt on DB error, missing creator, or null profile. + """ + try: + result = await db.execute( + select(Creator).where(Creator.name == creator_name) + ) + creator_row = result.scalars().first() + except Exception: + logger.warning("chat_personality_db_error creator=%r", creator_name, exc_info=True) + return system_prompt + + if creator_row is None or creator_row.personality_profile is None: + logger.debug("chat_personality_skip creator=%r reason=%s", + creator_name, + "not_found" if creator_row is None else "null_profile") + return system_prompt + + profile = creator_row.personality_profile + voice_block = _build_personality_block(creator_name, profile, weight) + return system_prompt + "\n\n" + voice_block + async def stream_response( self, query: str, db: AsyncSession, creator: str | None = None, conversation_id: str | None = None, + personality_weight: float = 0.0, ) -> AsyncIterator[str]: """Yield SSE-formatted events for a chat query. @@ -151,6 +184,15 @@ class ChatService: # ── 3. Stream LLM completion ──────────────────────────────────── system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block) + # Inject creator personality voice when weight > 0 + if personality_weight > 0 and creator: + system_prompt = await self._inject_personality( + system_prompt, db, creator, personality_weight, + ) + + # Scale temperature with personality weight: 0.3 (encyclopedic) → 0.5 (full personality) + temperature = 0.3 + (personality_weight * 0.2) + messages: list[dict[str, str]] = [ {"role": "system", "content": system_prompt}, ] @@ -165,7 +207,7 @@ class ChatService: model=self.settings.llm_model, messages=messages, stream=True, - temperature=0.3, + temperature=temperature, max_tokens=2048, ) @@ -245,3 +287,47 @@ def _build_context_block(items: list[dict[str, Any]]) -> str: lines.append("") return "\n".join(lines) + + +def _build_personality_block(creator_name: str, profile: dict[str, Any], weight: float) -> str: + """Build a personality voice injection block from a creator's personality_profile JSONB. + + The ``weight`` (0.0–1.0) determines how strongly the personality should + come through. At low weights the instruction is softer ("subtly adopt"); + at high weights it is emphatic ("fully embody"). + """ + vocab = profile.get("vocabulary", {}) + tone = profile.get("tone", {}) + style = profile.get("style_markers", {}) + + phrases = vocab.get("signature_phrases", []) + descriptors = tone.get("descriptors", []) + teaching_style = tone.get("teaching_style", "") + energy = tone.get("energy", "moderate") + formality = tone.get("formality", "conversational") + + parts: list[str] = [] + + # Intensity qualifier + if weight >= 0.8: + parts.append(f"Fully embody {creator_name}'s voice and style.") + elif weight >= 0.4: + parts.append(f"Respond in {creator_name}'s voice.") + else: + parts.append(f"Subtly adopt {creator_name}'s communication style.") + + if teaching_style: + parts.append(f"Teaching style: {teaching_style}.") + if descriptors: + parts.append(f"Tone: {', '.join(descriptors[:5])}.") + if phrases: + parts.append(f"Use their signature phrases: {', '.join(phrases[:6])}.") + parts.append(f"Match their {formality} {energy} tone.") + + # Style markers + if style.get("uses_analogies"): + parts.append("Use analogies when helpful.") + if style.get("audience_engagement"): + parts.append(f"Audience engagement: {style['audience_engagement']}.") + + return " ".join(parts) diff --git a/backend/routers/chat.py b/backend/routers/chat.py index fd7ad87..55f93d6 100644 --- a/backend/routers/chat.py +++ b/backend/routers/chat.py @@ -29,6 +29,7 @@ class ChatRequest(BaseModel): query: str = Field(..., min_length=1, max_length=1000) creator: str | None = None conversation_id: str | None = None + personality_weight: float = Field(default=0.0, ge=0.0, le=1.0) @router.post("") @@ -45,7 +46,7 @@ async def chat( - ``event: done`` — completion metadata with cascade_tier, conversation_id - ``event: error`` — error message (on failure) """ - logger.info("chat_request query=%r creator=%r cid=%r", body.query, body.creator, body.conversation_id) + logger.info("chat_request query=%r creator=%r cid=%r weight=%.2f", body.query, body.creator, body.conversation_id, body.personality_weight) redis = await get_redis() service = ChatService(settings, redis=redis) @@ -56,6 +57,7 @@ async def chat( db=db, creator=body.creator, conversation_id=body.conversation_id, + personality_weight=body.personality_weight, ), media_type="text/event-stream", headers={ diff --git a/backend/tests/test_chat.py b/backend/tests/test_chat.py index d45d68f..6564655 100644 --- a/backend/tests/test_chat.py +++ b/backend/tests/test_chat.py @@ -563,3 +563,307 @@ async def test_single_turn_fallback_no_redis_history(chat_client, mock_redis): assert len(captured_messages) == 2 assert captured_messages[0]["role"] == "system" assert captured_messages[1]["role"] == "user" + + +# ── Personality weight tests ───────────────────────────────────────────────── + + +_FAKE_PERSONALITY_PROFILE = { + "vocabulary": { + "signature_phrases": ["let's gooo", "that's fire"], + "jargon_level": "mixed", + "filler_words": [], + "distinctive_terms": ["sauce", "vibes"], + "sound_descriptions": ["crispy", "punchy"], + }, + "tone": { + "formality": "casual", + "energy": "high", + "humor": "occasional", + "teaching_style": "hands-on demo-driven", + "descriptors": ["enthusiastic", "direct", "encouraging"], + }, + "style_markers": { + "explanation_approach": "example-first", + "uses_analogies": True, + "analogy_examples": ["like cooking a steak"], + "sound_words": ["brrr", "thwack"], + "self_references": "I always", + "audience_engagement": "asks rhetorical questions", + "pacing": "fast", + }, + "summary": "High-energy producer who teaches by doing.", +} + + +def _mock_creator_row(name: str, profile: dict | None): + """Build a mock Creator ORM row with just the fields personality injection needs.""" + row = MagicMock() + row.name = name + row.personality_profile = profile + return row + + +def _mock_db_execute(creator_row): + """Return a mock db.execute that yields a scalars().first() result.""" + mock_scalars = MagicMock() + mock_scalars.first.return_value = creator_row + mock_result = MagicMock() + mock_result.scalars.return_value = mock_scalars + return AsyncMock(return_value=mock_result) + + +@pytest.mark.asyncio +async def test_personality_weight_accepted_and_forwarded(chat_client): + """personality_weight is accepted in the request and forwarded to stream_response.""" + search_result = _fake_search_result() + + captured_kwargs = {} + mock_openai_client = MagicMock() + + async def _capture_create(**kwargs): + captured_kwargs.update(kwargs) + return _mock_openai_stream(["ok"]) + + mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "test", "creator": "Keota", "personality_weight": 0.7}, + ) + + assert resp.status_code == 200 + events = _parse_sse(resp.text) + event_types = [e["event"] for e in events] + assert "done" in event_types + # Temperature should reflect the weight: 0.3 + 0.7*0.2 = 0.44 + assert captured_kwargs.get("temperature") == pytest.approx(0.44) + + +@pytest.mark.asyncio +async def test_personality_prompt_injected_when_weight_and_profile(chat_client): + """System prompt includes personality context when weight > 0 and profile exists.""" + search_result = _fake_search_result() + creator_row = _mock_creator_row("Keota", _FAKE_PERSONALITY_PROFILE) + + captured_messages = [] + mock_openai_client = MagicMock() + + async def _capture_create(**kwargs): + captured_messages.extend(kwargs.get("messages", [])) + return _mock_openai_stream(["personality answer"]) + + mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + # We need to mock db.execute inside the service — override the session + mock_session = AsyncMock() + mock_session.execute = _mock_db_execute(creator_row) + + async def _mock_get_session(): + yield mock_session + + app.dependency_overrides[get_session] = _mock_get_session + + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "snare tips", "creator": "Keota", "personality_weight": 0.7}, + ) + + assert resp.status_code == 200 + assert len(captured_messages) >= 2 + + system_prompt = captured_messages[0]["content"] + # Personality block should be appended + assert "Keota" in system_prompt + assert "let's gooo" in system_prompt + assert "hands-on demo-driven" in system_prompt + assert "casual" in system_prompt + assert "high" in system_prompt + + +@pytest.mark.asyncio +async def test_personality_encyclopedic_fallback_null_profile(chat_client): + """When weight > 0 but personality_profile is null, falls back to encyclopedic prompt.""" + search_result = _fake_search_result() + creator_row = _mock_creator_row("NullCreator", None) + + captured_messages = [] + mock_openai_client = MagicMock() + + async def _capture_create(**kwargs): + captured_messages.extend(kwargs.get("messages", [])) + return _mock_openai_stream(["encyclopedic answer"]) + + mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + mock_session = AsyncMock() + mock_session.execute = _mock_db_execute(creator_row) + + async def _mock_get_session(): + yield mock_session + + app.dependency_overrides[get_session] = _mock_get_session + + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "reverb tips", "creator": "NullCreator", "personality_weight": 0.5}, + ) + + assert resp.status_code == 200 + system_prompt = captured_messages[0]["content"] + # Should be the standard encyclopedic prompt, no personality injection + assert "Chrysopedia" in system_prompt + assert "NullCreator" not in system_prompt + + +@pytest.mark.asyncio +async def test_personality_encyclopedic_fallback_missing_creator(chat_client): + """When weight > 0 but creator doesn't exist in DB, falls back to encyclopedic prompt.""" + search_result = _fake_search_result() + + captured_messages = [] + mock_openai_client = MagicMock() + + async def _capture_create(**kwargs): + captured_messages.extend(kwargs.get("messages", [])) + return _mock_openai_stream(["encyclopedic answer"]) + + mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + mock_session = AsyncMock() + mock_session.execute = _mock_db_execute(None) # No creator found + + async def _mock_get_session(): + yield mock_session + + app.dependency_overrides[get_session] = _mock_get_session + + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "bass tips", "creator": "GhostCreator", "personality_weight": 0.8}, + ) + + assert resp.status_code == 200 + system_prompt = captured_messages[0]["content"] + assert "Chrysopedia" in system_prompt + assert "GhostCreator" not in system_prompt + + +@pytest.mark.asyncio +async def test_personality_weight_zero_skips_profile_query(chat_client): + """When weight is 0.0, no Creator query is made even if creator is set.""" + search_result = _fake_search_result() + + captured_kwargs = {} + mock_openai_client = MagicMock() + + async def _capture_create(**kwargs): + captured_kwargs.update(kwargs) + return _mock_openai_stream(["ok"]) + + mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + mock_session = AsyncMock() + mock_session.execute = AsyncMock() # Should NOT be called + + async def _mock_get_session(): + yield mock_session + + app.dependency_overrides[get_session] = _mock_get_session + + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "test", "creator": "Keota", "personality_weight": 0.0}, + ) + + assert resp.status_code == 200 + # DB execute should not have been called for Creator lookup + mock_session.execute.assert_not_called() + # Temperature should be 0.3 (base) + assert captured_kwargs.get("temperature") == pytest.approx(0.3) + + +@pytest.mark.asyncio +async def test_personality_temperature_scales_with_weight(chat_client): + """Temperature scales: 0.3 at weight=0.0, 0.5 at weight=1.0.""" + search_result = _fake_search_result() + creator_row = _mock_creator_row("Keota", _FAKE_PERSONALITY_PROFILE) + + captured_kwargs = {} + mock_openai_client = MagicMock() + + async def _capture_create(**kwargs): + captured_kwargs.update(kwargs) + return _mock_openai_stream(["warm"]) + + mock_openai_client.chat.completions.create = AsyncMock(side_effect=_capture_create) + + with ( + patch("chat_service.SearchService.search", new_callable=AsyncMock, return_value=search_result), + patch("chat_service.openai.AsyncOpenAI", return_value=mock_openai_client), + ): + mock_session = AsyncMock() + mock_session.execute = _mock_db_execute(creator_row) + + async def _mock_get_session(): + yield mock_session + + app.dependency_overrides[get_session] = _mock_get_session + + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "test", "creator": "Keota", "personality_weight": 1.0}, + ) + + assert resp.status_code == 200 + assert captured_kwargs.get("temperature") == pytest.approx(0.5) + + +@pytest.mark.asyncio +async def test_personality_weight_above_1_returns_422(chat_client): + """personality_weight > 1.0 fails Pydantic validation with 422.""" + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "test", "personality_weight": 1.5}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_personality_weight_below_0_returns_422(chat_client): + """personality_weight < 0.0 fails Pydantic validation with 422.""" + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "test", "personality_weight": -0.1}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_personality_weight_string_returns_422(chat_client): + """personality_weight as a non-numeric string fails validation with 422.""" + resp = await chat_client.post( + "/api/v1/chat", + json={"query": "test", "personality_weight": "high"}, + ) + assert resp.status_code == 422