Compare commits
328 commits
1013a333b1
...
472e8ba660
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
472e8ba660 | ||
|
|
5184856d50 | ||
|
|
39b5d7c0be | ||
|
|
526fd0a58c | ||
|
|
4a3bb8208a | ||
|
|
8f7763d822 | ||
|
|
90bb90e989 | ||
|
|
183d852f31 | ||
|
|
7b048ccbaf | ||
|
|
f2edb1f375 | ||
|
|
0b8dcf2ccf | ||
|
|
3d16a5d9e8 | ||
|
|
ce4bccf292 | ||
|
|
78da2f6585 | ||
|
|
86e31cfa5c | ||
|
|
638477cc8e | ||
|
|
a0e228d5b4 | ||
|
|
102e00f323 | ||
|
|
c5fb352552 | ||
|
|
e3456f66b0 | ||
|
|
6e038f041f | ||
|
|
f02865b829 | ||
|
|
a3f2c4f332 | ||
|
|
067d7ed332 | ||
|
|
ecfdc76ba6 | ||
|
|
2f9e3272d9 | ||
|
|
28bc15b404 | ||
|
|
e40a00eb4f | ||
|
|
899e57c0e1 | ||
|
|
ddb283cc28 | ||
|
|
cf7c8b26a0 | ||
|
|
86d554a56f | ||
|
|
fa7e4983c7 | ||
|
|
3f75f33c2b | ||
|
|
09177b9d36 | ||
|
|
8e27f994db | ||
|
|
d32864de6a | ||
|
|
c3d1afa2ce | ||
|
|
e0a6458bdc | ||
|
|
11d58e09dd | ||
|
|
e9ab6bd9a7 | ||
|
|
0856827b59 | ||
|
|
04630764a6 | ||
|
|
c3ab0cec74 | ||
|
|
2a63ccdbe5 | ||
|
|
c163037a0f | ||
|
|
73736295c1 | ||
|
|
15299232a8 | ||
|
|
442d0ca48b | ||
|
|
7cc9497f3c | ||
|
|
52df9c0dc2 | ||
|
|
6a6305e8d1 | ||
|
|
29e60bbc99 | ||
|
|
0098254fdd | ||
|
|
19a6ff660c | ||
|
|
c9a2d2efbb | ||
|
|
fc85966679 | ||
|
|
50bcd5d2b9 | ||
|
|
ab62045e94 | ||
|
|
f531a4e5fb | ||
|
|
6447c2ec32 | ||
|
|
8ae93daf48 | ||
|
|
8e3ba00aab | ||
|
|
87f09f0192 | ||
|
|
2abeb391bd | ||
|
|
944e152c6f | ||
|
|
5a39850a35 | ||
|
|
5be499d0ad | ||
|
|
fa972a4fbc | ||
|
|
dec6c0b812 | ||
|
|
ee00f288d9 | ||
|
|
2949c93c86 | ||
|
|
4edb96df2b | ||
|
|
a2372788d5 | ||
|
|
91cdc5e0b1 | ||
|
|
c374165865 | ||
|
|
90c24d8bf9 | ||
|
|
5e0ce753a5 | ||
|
|
9530c85b9c | ||
|
|
2568dc3812 | ||
|
|
b3f52cc301 | ||
|
|
4917fd3a32 | ||
|
|
90dea3bde1 | ||
|
|
2162385c6b | ||
|
|
2e7fa224bc | ||
|
|
e665e82c25 | ||
|
|
0888569639 | ||
|
|
501393ee36 | ||
|
|
fe493d0647 | ||
|
|
51e01f8b7c | ||
|
|
7b9f6785cb | ||
|
|
0c99b1a8b7 | ||
|
|
bd2be703a5 | ||
|
|
ea8ddf74f0 | ||
|
|
ab3c723533 | ||
|
|
4b7511d363 | ||
|
|
58865f5634 | ||
|
|
346c60a239 | ||
|
|
f4020251b9 | ||
|
|
ae62c09881 | ||
|
|
33f68ef4cf | ||
|
|
aca1cc1c91 | ||
|
|
415f454105 | ||
|
|
feda9c1446 | ||
|
|
dc8f7ea6cf | ||
|
|
ef52ef6967 | ||
|
|
371494c2f8 | ||
|
|
f4f21b6c36 | ||
|
|
a9e3572573 | ||
|
|
78eca0d546 | ||
|
|
df0ff28f5e | ||
|
|
3807db8028 | ||
|
|
c5b7fcf4b1 | ||
|
|
ea479f093e | ||
|
|
c0e95df9f8 | ||
|
|
76e3543a99 | ||
|
|
74e89e8a3f | ||
|
|
49e986a142 | ||
|
|
0d91e87232 | ||
|
|
953ab008fb | ||
|
|
f5d35a0162 | ||
|
|
4d29ec12ce | ||
|
|
2142efece3 | ||
|
|
0623fcf430 | ||
|
|
149e4fc147 | ||
|
|
a235fd9e36 | ||
|
|
afd079ece7 | ||
|
|
fada426959 | ||
|
|
1b54a86225 | ||
|
|
596331156a | ||
|
|
c910aca7f2 | ||
|
|
4f771fc1f7 | ||
|
|
9acfa9bc20 | ||
|
|
1a7c11cac1 | ||
|
|
828afcc3e7 | ||
|
|
ee3cbb74f7 | ||
|
|
48e1bd163d | ||
|
|
ab11a00765 | ||
|
|
9e31e2b744 | ||
|
|
489ad88893 | ||
|
|
e52506511d | ||
|
|
b4bea10067 | ||
|
|
82998c6d8d | ||
|
|
cf90dbb9f0 | ||
|
|
b297002679 | ||
|
|
051f1046b0 | ||
|
|
90a4a15dd7 | ||
|
|
6a3e4e4955 | ||
|
|
4fbc77d10d | ||
|
|
dfb10f04b4 | ||
|
|
5cf50e84de | ||
|
|
98107c20f5 | ||
|
|
80ac367e23 | ||
|
|
47fe10f3df | ||
|
|
6f3d5b27f9 | ||
|
|
df93f2655a | ||
|
|
ff0d40a466 | ||
|
|
9a8d2ea5c9 | ||
|
|
df1d6af84e | ||
|
|
7bdba76d50 | ||
|
|
fd683e8266 | ||
|
|
edfabb037a | ||
|
|
bd8a928c95 | ||
|
|
304f3bc069 | ||
|
|
dbf3643662 | ||
|
|
943a5102fe | ||
|
|
66b02dd94e | ||
|
|
ae98e4e30e | ||
|
|
cd2d842477 | ||
|
|
9ee9b01af5 | ||
|
|
3433c48681 | ||
|
|
3cf993c019 | ||
|
|
4c952ed96c | ||
|
|
f320b08e0b | ||
|
|
d04b810289 | ||
|
|
14d3663567 | ||
|
|
41eeb69c2d | ||
|
|
da0b4b5fd6 | ||
|
|
37308fd185 | ||
|
|
7cc07d1b2d | ||
|
|
1ed60f712a | ||
|
|
1e31af9c30 | ||
|
|
7311f25509 | ||
|
|
1b54b51922 | ||
|
|
4f4126e0ce | ||
|
|
1be0deeb76 | ||
|
|
03373f263d | ||
|
|
0d82b2b409 | ||
|
|
0086573af5 | ||
|
|
91cae921a4 | ||
|
|
b1b02a9633 | ||
|
|
2cf9ae9bd6 | ||
|
|
18de2b3065 | ||
|
|
c62c6eb644 | ||
|
|
c7ac4be860 | ||
|
|
5f608b8889 | ||
|
|
94da19c05d | ||
|
|
630d3fa477 | ||
|
|
250d7315af | ||
|
|
c1cdba14f2 | ||
|
|
5a484fb27a | ||
|
|
9c0247c830 | ||
|
|
0d538238a6 | ||
|
|
89c3f4fcc4 | ||
|
|
2ea720af5c | ||
|
|
2a128a3804 | ||
|
|
a216f60093 | ||
|
|
0abd11299e | ||
|
|
564eb2a636 | ||
|
|
604e9711d1 | ||
|
|
d7de26b86b | ||
|
|
af3f1a6663 | ||
|
|
a6f4f36a46 | ||
|
|
63e350a882 | ||
|
|
9497d8f0e4 | ||
|
|
04ae6d0703 | ||
|
|
f3e6a9c885 | ||
|
|
6845f5c349 | ||
|
|
85712c15eb | ||
|
|
fea0afdec0 | ||
|
|
fa1fc82d5a | ||
|
|
b01e5949b6 | ||
|
|
a9b65fcea9 | ||
|
|
df559bbca0 | ||
|
|
9e4f10b0af | ||
|
|
d0bdc6f516 | ||
|
|
9107323a66 | ||
|
|
836fcb2304 | ||
|
|
cf0205bd5c | ||
|
|
c481bafc7b | ||
|
|
c25db471f7 | ||
|
|
6a793c2c9a | ||
|
|
a7a038beea | ||
|
|
348089d635 | ||
|
|
810f96a640 | ||
|
|
5a959cc9b2 | ||
|
|
5f7a8a6f77 | ||
|
|
881629b5c3 | ||
|
|
813920a653 | ||
|
|
80439d43cf | ||
|
|
0c8bbb32d6 | ||
|
|
831489cf4f | ||
|
|
d5144307b8 | ||
|
|
fe8b529ad2 | ||
|
|
9c519f22b8 | ||
|
|
c23bc2314d | ||
|
|
1fa590b9e0 | ||
|
|
325ffe588e | ||
|
|
1db556c89e | ||
|
|
4843d12869 | ||
|
|
5b663c1e62 | ||
|
|
21dcf2f517 | ||
|
|
3c4c2f9439 | ||
|
|
4e5b41c46f | ||
|
|
62a94d51c8 | ||
|
|
2e4531f1d7 | ||
|
|
a980965705 | ||
|
|
eae4da48f8 | ||
|
|
031f6d3d5b | ||
|
|
4186c6e208 | ||
|
|
2fe7687596 | ||
|
|
dcadcbb5e2 | ||
|
|
d218a85e4e | ||
|
|
ff1b24867f | ||
|
|
0c4208bc07 | ||
|
|
878a7a45c6 | ||
|
|
ec9e307538 | ||
|
|
96d7bc6e75 | ||
|
|
66bf79ed4a | ||
|
|
7058cba307 | ||
|
|
d00639a2ea | ||
|
|
cd9dd6d8f9 | ||
|
|
515c88cc04 | ||
|
|
8570abd069 | ||
|
|
94da945e4b | ||
|
|
b8f698067d | ||
|
|
93a643f1b8 | ||
|
|
b0038f21f7 | ||
|
|
171775452e | ||
|
|
e57d42582b | ||
|
|
420e1e8342 | ||
|
|
6b099f3663 | ||
|
|
81f0ed916f | ||
|
|
2a07583d6d | ||
|
|
872ffe0543 | ||
|
|
a4d298502c | ||
|
|
2a5e4d7156 | ||
|
|
bb4a5f93d8 | ||
|
|
769746197c | ||
|
|
8dc7097603 | ||
|
|
b4b650a26d | ||
|
|
2de51fb1a7 | ||
|
|
527216b768 | ||
|
|
b3531fadb6 | ||
|
|
017bffe05d | ||
|
|
ef78401806 | ||
|
|
9bb40f57e4 | ||
|
|
e09147760d | ||
|
|
433ab0c1f9 | ||
|
|
f73e602995 | ||
|
|
cb16c28f69 | ||
|
|
c47b6380b1 | ||
|
|
23cb375164 | ||
|
|
a99aac9e75 | ||
|
|
6e29c27852 | ||
|
|
e489e4686d | ||
|
|
0e65adaf54 | ||
|
|
1af8665a14 | ||
|
|
edeb3fdc44 | ||
|
|
ce3a9067fa | ||
|
|
5906f286b3 | ||
|
|
740fb59d9d | ||
|
|
34733b199d | ||
|
|
5fcbb95c58 | ||
|
|
b43e4a079a | ||
|
|
e27a86518d | ||
|
|
aa2ef4e153 | ||
|
|
f59718f8c7 | ||
|
|
7cb78e7427 | ||
|
|
b68751f0db | ||
|
|
a9de7f97ea | ||
|
|
88170a41f6 | ||
|
|
9bddaed5d1 | ||
|
|
a79282ffa9 | ||
|
|
03b1d53fcb | ||
|
|
b2bb23930c | ||
|
|
c404270f49 | ||
|
|
0904939660 |
400 changed files with 83479 additions and 0 deletions
111
.artifacts/feature-synthesis-chunking.md
Normal file
111
.artifacts/feature-synthesis-chunking.md
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
# Feature: Stage 5 Synthesis Chunking for Large Category Groups
|
||||||
|
|
||||||
|
## Problem
|
||||||
|
|
||||||
|
Stage 5 synthesis sends all key moments for a given `(video, topic_category)` group to the LLM in a single call. When a video produces a large number of moments in one category, the prompt exceeds what the model can process into a valid structured response.
|
||||||
|
|
||||||
|
**Concrete failure:** COPYCATT's "Sound Design - Everything In 2 Hours Speedrun" (2,026 transcript segments) produced 198 moments classified as "Sound design" (175) / "Sound Design" (23 — casing inconsistency). The synthesis prompt for that category was ~42k tokens. The model (`fyn-llm-agent-think`, 128k context) accepted the prompt but returned only 5,407 completion tokens with `finish=stop` — valid JSON that was structurally incomplete, failing Pydantic `SynthesisResult` validation. The pipeline retried and failed identically each time.
|
||||||
|
|
||||||
|
The other 37 videos in the corpus (up to 930 segments, ~60 moments per category max) all synthesized successfully.
|
||||||
|
|
||||||
|
## Root Causes
|
||||||
|
|
||||||
|
Two independent issues compound into this failure:
|
||||||
|
|
||||||
|
### 1. No chunking in stage 5 synthesis
|
||||||
|
|
||||||
|
`stage5_synthesis()` in `backend/pipeline/stages.py` iterates over `groups[category]` and builds one prompt containing ALL moments for that category. There's no upper bound on how many moments go into a single LLM call.
|
||||||
|
|
||||||
|
**Location:** `stages.py` lines ~850-875 — the `for category, moment_group in groups.items()` loop builds the full `moments_text` without splitting.
|
||||||
|
|
||||||
|
### 2. Inconsistent category casing from stage 4
|
||||||
|
|
||||||
|
Stage 4 classification produces `"Sound design"` and `"Sound Design"` as separate categories for the same video. Stage 5 groups by exact string match, so these stay separate — but even independently, 175 moments in one group is too many. The casing issue does inflate the problem by preventing natural splitting across categories.
|
||||||
|
|
||||||
|
**Location:** Classification output stored in Redis at `chrysopedia:classification:{video_id}`. The `topic_category` values come directly from the LLM with no normalization.
|
||||||
|
|
||||||
|
## Proposed Changes
|
||||||
|
|
||||||
|
### Change 1: Chunked synthesis with merge pass
|
||||||
|
|
||||||
|
Split large category groups into chunks before sending to the LLM. Each chunk produces technique pages independently, then a lightweight merge step combines pages with overlapping topics.
|
||||||
|
|
||||||
|
**In `stage5_synthesis()` (`backend/pipeline/stages.py`):**
|
||||||
|
|
||||||
|
1. After grouping moments by category, check each group's size against a configurable threshold (e.g., `SYNTHESIS_CHUNK_SIZE = 30` moments).
|
||||||
|
|
||||||
|
2. Groups at or below the threshold: process as today — single LLM call.
|
||||||
|
|
||||||
|
3. Groups above the threshold: split into chunks of `SYNTHESIS_CHUNK_SIZE` moments, ordered by `start_time` (preserving chronological context). Each chunk gets its own synthesis LLM call, producing its own `SynthesisResult` with 1+ pages.
|
||||||
|
|
||||||
|
4. After all chunks for a category are processed, collect the resulting pages. Pages with the same or very similar slugs (e.g., Levenshtein distance < 3, or shared slug prefix before the creator suffix) should be merged. The merge is a second LLM call with a simpler prompt: "Here are N partial technique pages on the same topic from the same creator. Merge them into a single cohesive page, combining body sections, deduplicating signal chains and plugins, and writing a unified summary." This merge prompt is much smaller than the original 198-moment prompt because it takes synthesized prose as input, not raw moment data.
|
||||||
|
|
||||||
|
5. If no pages share slugs across chunks, keep them all — they represent genuinely distinct sub-topics the LLM identified within the category.
|
||||||
|
|
||||||
|
**New config setting in `backend/config.py`:**
|
||||||
|
```python
|
||||||
|
synthesis_chunk_size: int = 30 # Max moments per synthesis LLM call
|
||||||
|
```
|
||||||
|
|
||||||
|
**New prompt file:** `prompts/stage5_merge.txt` — instructions for combining partial technique pages into a unified page. Much simpler than the full synthesis prompt since it operates on already-synthesized prose rather than raw moments.
|
||||||
|
|
||||||
|
**Token budget consideration:** 30 moments × ~200 tokens each (title + summary + metadata + transcript excerpt) = ~6k tokens of moment data + ~2k system prompt = ~8k input tokens. Well within what the model handles reliably. The merge call takes 2-4 partial pages of prose (~3-5k tokens total) — also very manageable.
|
||||||
|
|
||||||
|
### Change 2: Category casing normalization in stage 4
|
||||||
|
|
||||||
|
Normalize `topic_category` values before storing classification results in Redis.
|
||||||
|
|
||||||
|
**In `stage4_classification()` (`backend/pipeline/stages.py`):**
|
||||||
|
|
||||||
|
After parsing the `ClassificationResult` from the LLM, apply title-case normalization to each moment's `topic_category`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
category = cls_result.topic_category.strip().title()
|
||||||
|
# "Sound design" -> "Sound Design"
|
||||||
|
# "sound design" -> "Sound Design"
|
||||||
|
# "SOUND DESIGN" -> "Sound Design"
|
||||||
|
```
|
||||||
|
|
||||||
|
This is a one-line fix. It prevents the "Sound design" / "Sound Design" split that inflated the group sizes and would reduce the COPYCATT video from 198 → 198 moments in a single normalized "Sound Design" group — still too many without chunking, but it eliminates the class of bug where moments scatter across near-duplicate categories.
|
||||||
|
|
||||||
|
**Also apply in stage 5 as a safety net:** When building the `groups` dict, normalize the category key:
|
||||||
|
```python
|
||||||
|
category = cls_info.get("topic_category", "Uncategorized").strip().title()
|
||||||
|
```
|
||||||
|
|
||||||
|
This handles data already in Redis from prior stage 4 runs without requiring reprocessing.
|
||||||
|
|
||||||
|
### Change 3: Estimated token pre-check before LLM call
|
||||||
|
|
||||||
|
Before making the synthesis LLM call, estimate the total tokens (prompt + expected output) and log a warning if it exceeds a safety threshold. This doesn't block the call — chunking handles the splitting — but it provides observability for tuning `SYNTHESIS_CHUNK_SIZE`.
|
||||||
|
|
||||||
|
**In the synthesis loop, after building `user_prompt`:**
|
||||||
|
```python
|
||||||
|
estimated_input = estimate_tokens(system_prompt) + estimate_tokens(user_prompt)
|
||||||
|
if estimated_input > 15000:
|
||||||
|
logger.warning(
|
||||||
|
"Stage 5: Large synthesis input for category '%s' video_id=%s: "
|
||||||
|
"~%d input tokens, %d moments. Consider reducing SYNTHESIS_CHUNK_SIZE.",
|
||||||
|
category, video_id, estimated_input, len(moment_group),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Files to Modify
|
||||||
|
|
||||||
|
| File | Change |
|
||||||
|
|------|--------|
|
||||||
|
| `backend/pipeline/stages.py` | Chunk logic in `stage5_synthesis()`, casing normalization in `stage4_classification()` and `stage5_synthesis()` grouping |
|
||||||
|
| `backend/pipeline/llm_client.py` | No changes needed — `estimate_max_tokens()` already handles per-call estimation |
|
||||||
|
| `backend/config.py` | Add `synthesis_chunk_size: int = 30` setting |
|
||||||
|
| `prompts/stage5_merge.txt` | New prompt for merging partial technique pages |
|
||||||
|
| `backend/schemas.py` | No changes — `SynthesisResult` schema works for both chunk and merge calls |
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
1. **Unit test:** Mock the LLM and verify that a 90-moment group gets split into 3 chunks of 30, each producing a `SynthesisResult`, followed by a merge call.
|
||||||
|
2. **Integration test:** Retrigger the COPYCATT "Sound Design - Everything In 2 Hours Speedrun" video and confirm it completes stage 5 without `LLMTruncationError`.
|
||||||
|
3. **Regression test:** Retrigger a small video (e.g., Skope "Understanding Waveshapers", 9 moments) and confirm behavior is unchanged — no chunking triggered, same output.
|
||||||
|
|
||||||
|
## Rollback
|
||||||
|
|
||||||
|
`SYNTHESIS_CHUNK_SIZE` can be set very high (e.g., 9999) to effectively disable chunking without a code change. The casing normalization is backward-compatible — it only affects new pipeline runs.
|
||||||
53
.env.example
Normal file
53
.env.example
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
# ─── Chrysopedia Environment Variables ───
|
||||||
|
# Copy to .env and fill in secrets before docker compose up
|
||||||
|
|
||||||
|
# PostgreSQL
|
||||||
|
POSTGRES_USER=chrysopedia
|
||||||
|
POSTGRES_PASSWORD=changeme
|
||||||
|
POSTGRES_DB=chrysopedia
|
||||||
|
|
||||||
|
# Redis (Celery broker) — container-internal, no secret needed
|
||||||
|
REDIS_URL=redis://chrysopedia-redis:6379/0
|
||||||
|
|
||||||
|
# LLM endpoint (OpenAI-compatible — OpenWebUI on FYN DGX)
|
||||||
|
# Use /api (not /api/v1) so calls route through OpenWebUI's tracked proxy for analytics
|
||||||
|
LLM_API_URL=https://chat.forgetyour.name/api
|
||||||
|
LLM_API_KEY=sk-changeme
|
||||||
|
LLM_MODEL=fyn-llm-agent-chat
|
||||||
|
LLM_FALLBACK_URL=https://chat.forgetyour.name/api
|
||||||
|
LLM_FALLBACK_MODEL=fyn-llm-agent-chat
|
||||||
|
|
||||||
|
# Per-stage LLM model overrides (optional — defaults to LLM_MODEL)
|
||||||
|
# Modality: "chat" = standard JSON mode, "thinking" = reasoning model (strips <think> tags)
|
||||||
|
# Stages 2 (segmentation) and 4 (classification) are mechanical — use fast chat model
|
||||||
|
# Stages 3 (extraction) and 5 (synthesis) need reasoning — use thinking model
|
||||||
|
LLM_STAGE2_MODEL=fyn-llm-agent-chat
|
||||||
|
LLM_STAGE2_MODALITY=chat
|
||||||
|
LLM_STAGE3_MODEL=fyn-llm-agent-think
|
||||||
|
LLM_STAGE3_MODALITY=thinking
|
||||||
|
LLM_STAGE4_MODEL=fyn-llm-agent-chat
|
||||||
|
LLM_STAGE4_MODALITY=chat
|
||||||
|
LLM_STAGE5_MODEL=fyn-llm-agent-think
|
||||||
|
LLM_STAGE5_MODALITY=thinking
|
||||||
|
|
||||||
|
# Max tokens for LLM responses (OpenWebUI defaults to 1000 — pipeline needs much more)
|
||||||
|
LLM_MAX_TOKENS=65536
|
||||||
|
|
||||||
|
# Embedding endpoint (Ollama container in the compose stack)
|
||||||
|
EMBEDDING_API_URL=http://chrysopedia-ollama:11434/v1
|
||||||
|
EMBEDDING_MODEL=nomic-embed-text
|
||||||
|
|
||||||
|
# Qdrant (container-internal)
|
||||||
|
QDRANT_URL=http://chrysopedia-qdrant:6333
|
||||||
|
QDRANT_COLLECTION=chrysopedia
|
||||||
|
|
||||||
|
# Application
|
||||||
|
APP_ENV=production
|
||||||
|
APP_LOG_LEVEL=info
|
||||||
|
|
||||||
|
# File storage paths (inside container, bind-mounted to /vmPool/r/services/chrysopedia_data)
|
||||||
|
TRANSCRIPT_STORAGE_PATH=/data/transcripts
|
||||||
|
VIDEO_METADATA_PATH=/data/video_meta
|
||||||
|
|
||||||
|
# Review mode toggle (true = moments require admin review before publishing)
|
||||||
|
REVIEW_MODE=true
|
||||||
27
.gitignore
vendored
27
.gitignore
vendored
|
|
@ -1,2 +1,29 @@
|
||||||
.bg-shell/
|
.bg-shell/
|
||||||
.gsd/
|
.gsd/
|
||||||
|
|
||||||
|
# ── GSD baseline (auto-generated) ──
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.code-workspace
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
!.env.example
|
||||||
|
node_modules/
|
||||||
|
.next/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
target/
|
||||||
|
vendor/
|
||||||
|
*.log
|
||||||
|
coverage/
|
||||||
|
.cache/
|
||||||
|
tmp/
|
||||||
|
|
|
||||||
7
.mcp.json
Normal file
7
.mcp.json
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"chrysopedia": {
|
||||||
|
"url": "http://ub01:8101/mcp"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
143
.planning/M016-ux-brand-reading-experience.md
Normal file
143
.planning/M016-ux-brand-reading-experience.md
Normal file
|
|
@ -0,0 +1,143 @@
|
||||||
|
# M016: UX Polish, Brand & Reading Experience
|
||||||
|
|
||||||
|
> **Stream:** Frontend — intended for a dedicated GSD milestone instance
|
||||||
|
> **Conflict zone:** `frontend/src/` only — no backend Python changes
|
||||||
|
> **Deploy cadence:** commit-build-redeploy after each slice completion
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
Modernize the public site's visual identity and reading experience, fix pipeline admin UI bugs, and establish a brand baseline (logo, favicon, OG tags). Every change in this milestone lives entirely in the frontend — CSS, React components, and static assets.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Slice Breakdown
|
||||||
|
|
||||||
|
### S01: Landing Page Visual Fixes (Quick Wins)
|
||||||
|
**Risk:** Low | **Effort:** Small | **Files:** `App.css`, `Home.tsx`
|
||||||
|
|
||||||
|
Research found 5 concrete bugs/inconsistencies on the homepage:
|
||||||
|
|
||||||
|
| # | Issue | Fix |
|
||||||
|
|---|-------|-----|
|
||||||
|
| 1 | Duplicate `.btn` rule at App.css:3185 overrides CTA sizing (renders 131x38 instead of ~195x48) | Remove or merge the duplicate `.btn` block |
|
||||||
|
| 2 | `.home-featured` uses `border-image` which kills `border-radius` — card renders square | Replace with pseudo-element gradient border technique |
|
||||||
|
| 3 | Three different `max-width` tracks (36rem, 42rem, none) create jagged center column | Unify to 42rem for all content sections |
|
||||||
|
| 4 | Vertical spacing irregular — Random→Featured gap is only 8px vs 24px elsewhere | Normalize section margins to 1.5rem |
|
||||||
|
| 5 | Two `border-radius` values (0.5rem vs 0.625rem) on home cards | Unify to 0.625rem |
|
||||||
|
|
||||||
|
**Verification:** Visual screenshot comparison before/after on desktop (1280px) and mobile (375px).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### S02: Pipeline Admin UI Fixes
|
||||||
|
**Risk:** Low | **Effort:** Small-Medium | **Files:** `AdminPipeline.tsx`, `App.css`
|
||||||
|
|
||||||
|
Four issues identified with root causes already diagnosed:
|
||||||
|
|
||||||
|
| # | Issue | Root Cause | Fix |
|
||||||
|
|---|-------|-----------|-----|
|
||||||
|
| 1 | Most-recent run won't collapse / flickers | `expandedRunId` in `load()` useCallback dependency array (line 729) causes race condition — collapsing sets null, which triggers load recreation, which re-expands | Remove `expandedRunId` from dependency array + use `useRef` for initial-load tracking |
|
||||||
|
| 2 | Mobile job cards show vertical text ("C h e e") | `.pipeline-video__creator` missing overflow rules (App.css:4477) | Add `overflow: hidden; text-overflow: ellipsis; white-space: nowrap` (matches `.pipeline-video__filename` pattern) |
|
||||||
|
| 3 | No stage direction chevrons | Pipeline stages listed without visual flow indicator | Add CSS chevron/arrow between stage indicators using `::after` pseudo-elements or inline SVG |
|
||||||
|
| 4 | Filter text box should be replaced with button group | Current text input for status filter; should be "ALL \| Not Started \| In Progress \| Complete" buttons, end-aligned | Replace `<input>` with `<div className="filter-buttons">` flexbox, `justify-content: flex-end`, verify vertical alignment against adjacent elements |
|
||||||
|
| 5 | Creators dropdown never populates (422 error) | Frontend requests `fetchCreators({ limit: 200 })` but backend validates `le=100` (line 1126) | Change to `limit: 100` |
|
||||||
|
|
||||||
|
**Verification:** Test collapse toggle on most-recent run, resize to 375px and check creator name truncation, confirm chevrons render between stages, confirm filter buttons align right and sit level with row, confirm creator filter dropdown populates.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### S03: Brand Minimum (Favicon, OG Tags, Logo)
|
||||||
|
**Risk:** Low | **Effort:** Small | **Files:** `index.html`, `App.tsx`, `App.css`, new static assets
|
||||||
|
|
||||||
|
The site currently has:
|
||||||
|
- No favicon (browser default icon)
|
||||||
|
- No OG meta tags (no preview image when sharing URL via text/Discord)
|
||||||
|
- No logo next to "Chrysopedia" in the header
|
||||||
|
|
||||||
|
Tasks:
|
||||||
|
1. **Design a simple logo** — something that fits the dark theme + cyan accent aesthetic. Could be a stylized book/page icon, a knowledge/crystal motif matching the "chryso-" (gold) prefix, or an abstract mark. Generate an SVG.
|
||||||
|
2. **Add favicon** — export logo as favicon.ico + apple-touch-icon + 192/512 PNG for PWA manifest
|
||||||
|
3. **Add OG meta tags** — `og:title`, `og:description`, `og:image`, `og:url`, `twitter:card` in index.html. Create a 1200x630 OG image using the logo + brand colors.
|
||||||
|
4. **Place logo in header** — render the SVG inline next to "Chrysopedia" text with appropriate sizing
|
||||||
|
|
||||||
|
**Verification:** Share URL in Discord/iMessage and confirm preview card renders. Check favicon in browser tab. Visual check logo in header.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### S04: ToC Modernization
|
||||||
|
**Risk:** Medium | **Effort:** Medium | **Files:** `TableOfContents.tsx`, `App.css`, `TechniquePage.tsx`
|
||||||
|
|
||||||
|
Research identified these dated elements:
|
||||||
|
- CSS counter numbering ("1.", "1.2") — biggest offender
|
||||||
|
- Boxed card container with solid border
|
||||||
|
- Uppercase "CONTENTS" label
|
||||||
|
- No active-section highlighting
|
||||||
|
- Underline-only hover states
|
||||||
|
|
||||||
|
Modernization plan:
|
||||||
|
1. **Remove numbered counters** — switch to unordered list with clean indentation
|
||||||
|
2. **Replace box border** with left accent bar (`border-left: 2px solid var(--color-accent)`)
|
||||||
|
3. **Change heading** from "CONTENTS" to "On this page" in sentence case
|
||||||
|
4. **Add hover background** (`rgba(34, 211, 238, 0.08)`) instead of underline
|
||||||
|
5. **Add IntersectionObserver** — track which section heading is in the viewport, highlight the corresponding ToC entry with accent left-border + brighter text color
|
||||||
|
6. **Make ToC sticky** — position it at the top of the existing sidebar (above Key Moments), `position: sticky; top: 1.5rem`
|
||||||
|
|
||||||
|
**Verification:** Navigate to a technique page with 4+ sections. Scroll through — ToC should highlight current section. ToC stays visible in sidebar while scrolling. Hover states work. No numbering visible.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### S05: Sticky Reading Header
|
||||||
|
**Risk:** Medium | **Effort:** Medium | **Files:** new `ReadingHeader.tsx`, `App.css`, `TechniquePage.tsx`
|
||||||
|
|
||||||
|
New component that slides in when user scrolls past the article title:
|
||||||
|
- Shows: article title (truncated) + current section name
|
||||||
|
- `position: sticky; top: 0; z-index: 50`
|
||||||
|
- Thin bar (~40px height), `var(--color-bg-header)` background, subtle bottom border
|
||||||
|
- Hidden by default, slides in via `transform: translateY(-100%)` → `translateY(0)` transition
|
||||||
|
- Uses IntersectionObserver on the technique header element as show/hide trigger
|
||||||
|
- Shares the section-tracking observer from S04's ToC work
|
||||||
|
- On mobile: compact single-line with optional dropdown for section jump
|
||||||
|
- Update `scroll-margin-top` values on section anchors to account for new header height
|
||||||
|
|
||||||
|
**Verification:** Open long technique page. Scroll past title — reading header appears. Correct section name updates as you scroll. Works on mobile (375px). Doesn't break existing header.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### S06: Landing Page Personality Pass
|
||||||
|
**Risk:** Low | **Effort:** Small | **Files:** `Home.tsx`, `App.css`
|
||||||
|
|
||||||
|
After S01 fixes the bugs, this slice adds polish:
|
||||||
|
1. **Hero tightening** — reduce hero bottom padding and how-it-works top margin to get content above the fold faster (~40-50px reclaim)
|
||||||
|
2. **Stats scorecard enhancement** — animated count-up on first view (simple `requestAnimationFrame` counter), subtle glow on numbers
|
||||||
|
3. **Random button treatment** — wrap in a small card with "Feeling adventurous?" tagline, or embed as secondary action inside Trending Searches
|
||||||
|
4. **Section heading standardization** — pick one treatment (title-case with left accent bar) and apply consistently to "Recently Added", "Trending Searches", "Popular Topics"
|
||||||
|
5. **Header brand accent** — apply `color: var(--color-accent)` or subtle gradient to "Chrysopedia" text (pairs with S03 logo)
|
||||||
|
|
||||||
|
**Verification:** Visual check desktop + mobile. Stats animate on page load. Section headings consistent. Content peeks above fold on standard viewport.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Dependency Graph
|
||||||
|
|
||||||
|
```
|
||||||
|
S01 (landing fixes) ──→ S06 (personality pass)
|
||||||
|
S02 (pipeline fixes) [independent]
|
||||||
|
S03 (brand minimum) ──→ S06 (header accent uses logo)
|
||||||
|
S04 (ToC modern) ──→ S05 (reading header shares IntersectionObserver pattern)
|
||||||
|
```
|
||||||
|
|
||||||
|
S01, S02, S03, S04 can all start in parallel. S05 depends on S04. S06 depends on S01 + S03.
|
||||||
|
|
||||||
|
Recommended execution order: **S01 → S02 → S03 → S04 → S05 → S06**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Out of Scope (for this milestone)
|
||||||
|
|
||||||
|
- Creator landing page redesign (depends on backend social links API — see backend session workplan)
|
||||||
|
- Auto-avatar images (backend-gated — see backend session workplan)
|
||||||
|
- Embed tab (needs backend investigation first — see backend session workplan)
|
||||||
|
- Any backend Python changes
|
||||||
|
- M015 S04/S05 leftovers (trending searches block, admin dropdown hover) — should be completed by M015's own GSD session first
|
||||||
160
.planning/backend-perf-creator-features.md
Normal file
160
.planning/backend-perf-creator-features.md
Normal file
|
|
@ -0,0 +1,160 @@
|
||||||
|
# Backend Performance & Creator Features — Session Workplan
|
||||||
|
|
||||||
|
> **Stream:** Backend — run in a separate Claude Code session while GSD executes M016
|
||||||
|
> **Conflict zone:** `backend/` only — zero frontend changes
|
||||||
|
> **Deploy cadence:** commit-build-redeploy after each task group
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
Fix critical performance bottlenecks in the admin pipeline API, implement auto-avatar fetching for creators, and lay the backend groundwork for creator landing page improvements. Purely backend Python — no frontend changes (the creators 422 fix moved to M016 S02).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task Groups (execute sequentially)
|
||||||
|
|
||||||
|
### 1. Critical Fixes (do first, immediate impact)
|
||||||
|
|
||||||
|
#### 1a. Fix `worker-status` async event loop blocking
|
||||||
|
**File:** `backend/routers/pipeline.py` lines 1266-1313
|
||||||
|
**Problem:** The endpoint is `async def` but calls three synchronous Celery inspect methods (`inspector.active()`, `.reserved()`, `.stats()`), each with a 1-second timeout. This blocks the entire uvicorn event loop for ~3 seconds, stalling ALL concurrent API requests.
|
||||||
|
**Evidence:** Every parallel API call during page load takes ~3,024ms instead of their natural 6-38ms.
|
||||||
|
**Fix:**
|
||||||
|
```python
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
active = await asyncio.to_thread(inspector.active) or {}
|
||||||
|
reserved = await asyncio.to_thread(inspector.reserved) or {}
|
||||||
|
stats = await asyncio.to_thread(inspector.stats) or {}
|
||||||
|
```
|
||||||
|
Also consider: reduce inspect timeout to 0.5s, add Redis cache with 10-15s TTL to avoid repeated slow calls.
|
||||||
|
**Impact:** Page load drops from ~3s to ~50ms.
|
||||||
|
|
||||||
|
#### ~~1b. Fix creators endpoint 422 error~~ → Moved to M016 S02
|
||||||
|
The frontend `fetchCreators({ limit: 200 })` fix (AdminPipeline.tsx line 1126) is now part of M016's pipeline UI fixes slice, since that slice already owns AdminPipeline.tsx.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. Pipeline API Performance
|
||||||
|
|
||||||
|
#### 2a. Rewrite `stale-pages` to eliminate N+1 queries
|
||||||
|
**File:** `backend/routers/pipeline.py` lines 906-973
|
||||||
|
**Problem:** Loads ALL technique pages, then runs a separate query per page for latest version + another for creator name. Currently 44 extra queries for 22 pages. Fast today (~30ms) but degrades linearly.
|
||||||
|
**Fix:** Single query joining `technique_pages` with a lateral/window subquery for latest version + join to creators.
|
||||||
|
|
||||||
|
#### 2b. Add pagination to videos endpoint
|
||||||
|
**File:** `backend/routers/pipeline.py` lines 72-188
|
||||||
|
**Problem:** Returns all 43 videos (23KB) with no offset/limit. Client-side filtering only.
|
||||||
|
**Fix:** Add `offset`, `limit`, `status`, `creator_id` query params. Return paginated response with `total` count. Frontend can adopt server-side filtering later (or the M016 frontend stream can wire it up).
|
||||||
|
|
||||||
|
#### 2c. Optimize `_find_dynamic_related` for technique pages
|
||||||
|
**File:** `backend/routers/techniques.py` lines 33-111
|
||||||
|
**Problem:** Loads ALL technique pages into memory to score relatedness in Python. O(n) in total page count.
|
||||||
|
**Fix:** Move scoring to SQL (keyword overlap via `ts_rank` or simple tag intersection) or cache related links per technique page with invalidation on new page creation.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. Auto-Avatar Integration (TheAudioDB)
|
||||||
|
|
||||||
|
Research concluded: **TheAudioDB is the best first source** — free, no OAuth, no caching restrictions, decent coverage for established artists.
|
||||||
|
|
||||||
|
#### 3a. Database migration
|
||||||
|
Add to `Creator` model:
|
||||||
|
- `avatar_url: String | None` — stored image URL or local path
|
||||||
|
- `avatar_source: String` — enum: `"generated"`, `"theaudiodb"`, `"manual"`
|
||||||
|
- `avatar_fetched_at: DateTime | None` — for cache invalidation
|
||||||
|
|
||||||
|
Alembic migration (will be 014 or later depending on M015 state).
|
||||||
|
|
||||||
|
#### 3b. TheAudioDB lookup service
|
||||||
|
New file: `backend/services/avatar.py`
|
||||||
|
- `async def fetch_avatar(creator_name: str, creator_genres: list[str]) -> AvatarResult | None`
|
||||||
|
- Calls `https://www.theaudiodb.com/api/v1/json/{key}/search.php?s={name}`
|
||||||
|
- Confidence scoring: name match via `thefuzz.fuzz.token_sort_ratio` (threshold ≥ 85%), genre overlap as tiebreaker
|
||||||
|
- Returns `strArtistThumb` URL if confident match, None otherwise
|
||||||
|
- Handle: no results, multiple results, missing image fields
|
||||||
|
|
||||||
|
#### 3c. Celery worker task
|
||||||
|
New task: `tasks.fetch_creator_avatar`
|
||||||
|
- Called on creator creation or manually via admin endpoint
|
||||||
|
- Runs TheAudioDB lookup → downloads image → stores locally (or stores URL)
|
||||||
|
- Updates `Creator.avatar_url`, `avatar_source`, `avatar_fetched_at`
|
||||||
|
- Falls back gracefully — if no match, leaves fields null (frontend already renders generated SVG as fallback)
|
||||||
|
|
||||||
|
#### 3d. Admin endpoint for manual trigger
|
||||||
|
`POST /admin/pipeline/creators/{id}/fetch-avatar` — triggers the worker task for a specific creator.
|
||||||
|
`POST /admin/pipeline/creators/fetch-all-avatars` — batch trigger for all creators missing avatars.
|
||||||
|
|
||||||
|
#### 3e. Wire avatar_url into creators API responses
|
||||||
|
Add `avatar_url` to `CreatorBrowseItem` and `CreatorDetail` schemas. The frontend `CreatorAvatar` component already accepts an `imageUrl` prop — it will just work once the API returns it.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4. Creator Landing Page API Groundwork
|
||||||
|
|
||||||
|
#### 4a. Social links model + migration
|
||||||
|
Add to `Creator` model:
|
||||||
|
- `social_links: JSON | None` — structured as `{"spotify": "url", "instagram": "url", "bandcamp": "url", "website": "url", ...}`
|
||||||
|
- `bio: Text | None` — short creator bio/description
|
||||||
|
- `featured: Boolean` — flag for homepage featuring
|
||||||
|
|
||||||
|
#### 4b. Creator detail endpoint enhancement
|
||||||
|
Expand `GET /api/v1/creators/{slug}` to return:
|
||||||
|
- `social_links`
|
||||||
|
- `bio`
|
||||||
|
- `avatar_url`
|
||||||
|
- `technique_count`
|
||||||
|
- Full technique list with titles, slugs, created_at
|
||||||
|
- Genre breakdown
|
||||||
|
|
||||||
|
#### 4c. Admin endpoint for creator profile editing
|
||||||
|
`PUT /admin/pipeline/creators/{id}` — update `bio`, `social_links`, `featured` flag, manually set `avatar_url`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 5. Embed Tab Investigation
|
||||||
|
The "Embed" tab under pipeline jobs is non-functional. Before building, need to:
|
||||||
|
- Read the existing frontend component to understand what it expects
|
||||||
|
- Determine what "Embed" should show (embedding vectors? embed codes? embedded content?)
|
||||||
|
- If it's about Qdrant vector embeddings: add an endpoint to query embedding status per technique page
|
||||||
|
- If it's about iframe embed codes: generate shareable snippet per technique
|
||||||
|
|
||||||
|
**This task starts as investigation — scope will be defined after reading the code.**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## General Load Time Optimization (apply throughout)
|
||||||
|
|
||||||
|
As each endpoint is touched, also consider:
|
||||||
|
- Add `Cache-Control` headers for public GET endpoints (technique pages, creators, search suggestions)
|
||||||
|
- Add Redis caching (30s-5min TTL) for expensive or frequently-hit endpoints
|
||||||
|
- Ensure database indexes exist on commonly filtered/sorted columns
|
||||||
|
- Consider adding `select_in_loading` for SQLAlchemy relationships to avoid implicit lazy loads
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Files
|
||||||
|
|
||||||
|
| File | What changes |
|
||||||
|
|------|-------------|
|
||||||
|
| `backend/routers/pipeline.py` | worker-status async fix, stale-pages rewrite, videos pagination, avatar admin endpoints |
|
||||||
|
| `backend/routers/creators.py` | Creator detail expansion, social links, admin editing |
|
||||||
|
| `backend/routers/techniques.py` | Related techniques optimization |
|
||||||
|
| `backend/models.py` | Creator model additions (avatar, social_links, bio) |
|
||||||
|
| `backend/schemas.py` | New response schemas |
|
||||||
|
| `backend/services/avatar.py` | New — TheAudioDB integration |
|
||||||
|
| `backend/tasks.py` | New avatar fetch task |
|
||||||
|
| `alembic/versions/014_*.py` | Migration for creator columns |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Merge Coordination with M016
|
||||||
|
|
||||||
|
These two streams have **zero file overlap:**
|
||||||
|
- **M016 touches:** `frontend/src/` only — `App.css`, `Home.tsx`, `TechniquePage.tsx`, `TableOfContents.tsx`, `AdminPipeline.tsx`, new frontend components, static assets
|
||||||
|
- **This session touches:** `backend/` only — routers, models, schemas, services, tasks, `alembic/`
|
||||||
|
|
||||||
|
No merge conflicts expected. The creators 422 fix (the former single overlap point) now lives in M016 S02.
|
||||||
|
|
||||||
|
For avatar/social-links frontend wiring: this session ships the API, M016 (or a follow-up) consumes it. No conflict — just sequencing.
|
||||||
48
CLAUDE.md
Normal file
48
CLAUDE.md
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Chrysopedia — Development Reference
|
||||||
|
|
||||||
|
## ⚠️ Canonical Development Directory
|
||||||
|
|
||||||
|
**This is NOT the canonical development directory.**
|
||||||
|
|
||||||
|
The production codebase and all future development happens on **ub01**:
|
||||||
|
|
||||||
|
```
|
||||||
|
ssh ub01
|
||||||
|
cd /vmPool/r/repos/xpltdco/chrysopedia
|
||||||
|
```
|
||||||
|
|
||||||
|
**Git:** https://git.xpltd.co/xpltdco/chrysopedia (Forgejo, xpltdco org)
|
||||||
|
|
||||||
|
## Why?
|
||||||
|
|
||||||
|
The Docker Compose stack runs on ub01 with bind mounts at `/vmPool/r/services/chrysopedia_*`. Development, deployment, and testing all happen from the ub01 clone. This directory (`/home/aux/projects/content-to-kb-automator`) was the initial workspace used during M001 development and should not be used for future work.
|
||||||
|
|
||||||
|
## Stack Info
|
||||||
|
|
||||||
|
- **Web UI:** http://ub01:8096
|
||||||
|
- **API Health:** http://ub01:8096/health
|
||||||
|
- **PostgreSQL:** ub01:5433 (user: chrysopedia)
|
||||||
|
- **Compose project:** xpltd_chrysopedia
|
||||||
|
- **Compose path:** /vmPool/r/compose/xpltd_chrysopedia/docker-compose.yml (symlink to repo)
|
||||||
|
- **Services:** chrysopedia-db, chrysopedia-redis, chrysopedia-qdrant, chrysopedia-ollama, chrysopedia-api, chrysopedia-worker, chrysopedia-web-8096
|
||||||
|
|
||||||
|
## Quick Commands (on ub01)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check status
|
||||||
|
docker ps --filter name=chrysopedia
|
||||||
|
|
||||||
|
# Rebuild and restart after code changes
|
||||||
|
cd /vmPool/r/repos/xpltdco/chrysopedia
|
||||||
|
git pull
|
||||||
|
docker compose build && docker compose up -d
|
||||||
|
|
||||||
|
# Run Alembic migrations
|
||||||
|
docker exec chrysopedia-api alembic upgrade head
|
||||||
|
|
||||||
|
# View worker logs
|
||||||
|
docker logs -f chrysopedia-worker
|
||||||
|
|
||||||
|
# View API logs
|
||||||
|
docker logs -f chrysopedia-api
|
||||||
|
```
|
||||||
320
README.md
Normal file
320
README.md
Normal file
|
|
@ -0,0 +1,320 @@
|
||||||
|
# Chrysopedia
|
||||||
|
|
||||||
|
> From *chrysopoeia* (alchemical transmutation of base material into gold) + *encyclopedia*.
|
||||||
|
> Chrysopedia transmutes raw video content into refined, searchable production knowledge.
|
||||||
|
|
||||||
|
A self-hosted knowledge extraction system for electronic music production content. Video libraries are transcribed with Whisper, analyzed through a multi-stage LLM pipeline, curated via an admin review workflow, and served through a search-first web UI designed for mid-session retrieval.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Information Flow
|
||||||
|
|
||||||
|
Content moves through six stages from raw video to searchable knowledge:
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ STAGE 1 · Transcription [Desktop / GPU] │
|
||||||
|
│ │
|
||||||
|
│ Video files → Whisper large-v3 (CUDA) → JSON transcripts │
|
||||||
|
│ Output: timestamped segments with speaker text │
|
||||||
|
└────────────────────────────────┬────────────────────────────────────────┘
|
||||||
|
│ JSON files (manual or folder watcher)
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ STAGE 2 · Ingestion [API + Watcher] │
|
||||||
|
│ │
|
||||||
|
│ POST /api/v1/ingest ← watcher auto-submits from /watch folder │
|
||||||
|
│ • Validate JSON structure │
|
||||||
|
│ • Compute content hash (SHA-256) for deduplication │
|
||||||
|
│ • Find-or-create Creator from folder name │
|
||||||
|
│ • Upsert SourceVideo (exact filename → content hash → fuzzy match) │
|
||||||
|
│ • Bulk-insert TranscriptSegment rows │
|
||||||
|
│ • Dispatch pipeline to Celery worker │
|
||||||
|
└────────────────────────────────┬────────────────────────────────────────┘
|
||||||
|
│ Celery task: run_pipeline(video_id)
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ STAGE 3 · LLM Extraction Pipeline [Celery Worker] │
|
||||||
|
│ │
|
||||||
|
│ Four sequential LLM stages, each with its own prompt template: │
|
||||||
|
│ │
|
||||||
|
│ 3a. Segmentation — Split transcript into semantic topic boundaries │
|
||||||
|
│ Model: chat (fast) Prompt: stage2_segmentation.txt │
|
||||||
|
│ │
|
||||||
|
│ 3b. Extraction — Identify key moments (title, summary, timestamps) │
|
||||||
|
│ Model: reasoning (think) Prompt: stage3_extraction.txt │
|
||||||
|
│ │
|
||||||
|
│ 3c. Classification — Assign content types + extract plugin names │
|
||||||
|
│ Model: chat (fast) Prompt: stage4_classification.txt │
|
||||||
|
│ │
|
||||||
|
│ 3d. Synthesis — Compose technique pages from approved moments │
|
||||||
|
│ Model: reasoning (think) Prompt: stage5_synthesis.txt │
|
||||||
|
│ │
|
||||||
|
│ Each stage emits PipelineEvent rows (tokens, duration, model, errors) │
|
||||||
|
└────────────────────────────────┬────────────────────────────────────────┘
|
||||||
|
│ KeyMoment rows (review_status: pending)
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ STAGE 4 · Review & Curation [Admin UI] │
|
||||||
|
│ │
|
||||||
|
│ Admin reviews extracted KeyMoments before they become technique pages: │
|
||||||
|
│ • Approve — moment proceeds to synthesis │
|
||||||
|
│ • Edit — correct title, summary, content type, plugins, then approve │
|
||||||
|
│ • Reject — moment is excluded from knowledge base │
|
||||||
|
│ (When REVIEW_MODE=false, moments auto-approve and skip this stage) │
|
||||||
|
└────────────────────────────────┬────────────────────────────────────────┘
|
||||||
|
│ Approved moments → Stage 3d synthesis
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ STAGE 5 · Knowledge Base [Web UI] │
|
||||||
|
│ │
|
||||||
|
│ TechniquePages — the primary output: │
|
||||||
|
│ • Structured body sections, signal chains, plugin lists │
|
||||||
|
│ • Linked to source KeyMoments with video timestamps │
|
||||||
|
│ • Cross-referenced via RelatedTechniqueLinks │
|
||||||
|
│ • Versioned (snapshots before each re-synthesis) │
|
||||||
|
│ • Organized by topic taxonomy (6 categories from canonical_tags.yaml) │
|
||||||
|
└────────────────────────────────┬────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ STAGE 6 · Search & Retrieval [Web UI] │
|
||||||
|
│ │
|
||||||
|
│ • Semantic search: query → embedding → Qdrant vector similarity │
|
||||||
|
│ • Keyword fallback: ILIKE search on title/summary (300ms timeout) │
|
||||||
|
│ • Browse by topic hierarchy, creator, or content type │
|
||||||
|
│ • Typeahead search from home page (debounced, top 5 results) │
|
||||||
|
└─────────────────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ Desktop (GPU workstation — hal0022) │
|
||||||
|
│ whisper/transcribe.py → JSON transcripts → copy to /watch folder │
|
||||||
|
└────────────────────────────┬─────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌──────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ Docker Compose: xpltd_chrysopedia (ub01) │
|
||||||
|
│ Network: chrysopedia (172.32.0.0/24) │
|
||||||
|
│ │
|
||||||
|
│ ┌────────────┐ ┌─────────────┐ ┌───────────────┐ ┌──────────────┐ │
|
||||||
|
│ │ PostgreSQL │ │ Redis │ │ Qdrant │ │ Ollama │ │
|
||||||
|
│ │ :5433 │ │ broker + │ │ vector DB │ │ embeddings │ │
|
||||||
|
│ │ 7 entities │ │ cache │ │ semantic │ │ nomic-embed │ │
|
||||||
|
│ └─────┬───────┘ └──────┬──────┘ └───────┬───────┘ └──────┬───────┘ │
|
||||||
|
│ │ │ │ │ │
|
||||||
|
│ ┌─────┴─────────────────┴─────────────────┴─────────────────┴────────┐ │
|
||||||
|
│ │ FastAPI (API) │ │
|
||||||
|
│ │ Ingest · Pipeline control · Review · Search · CRUD · Reports │ │
|
||||||
|
│ └──────────────────────────────┬────────────────────────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ┌──────────────┐ ┌────────────┴───┐ ┌──────────────────────────┐ │
|
||||||
|
│ │ Watcher │ │ Celery Worker │ │ Web UI (React) │ │
|
||||||
|
│ │ /watch → │ │ LLM pipeline │ │ nginx → :8096 │ │
|
||||||
|
│ │ auto-ingest │ │ stages 2-5 │ │ search-first interface │ │
|
||||||
|
│ └──────────────┘ └────────────────┘ └──────────────────────────┘ │
|
||||||
|
└──────────────────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Services
|
||||||
|
|
||||||
|
| Service | Image | Port | Purpose |
|
||||||
|
|---------|-------|------|---------|
|
||||||
|
| `chrysopedia-db` | `postgres:16-alpine` | `5433 → 5432` | Primary data store |
|
||||||
|
| `chrysopedia-redis` | `redis:7-alpine` | — | Celery broker + feature flag cache |
|
||||||
|
| `chrysopedia-qdrant` | `qdrant/qdrant:v1.13.2` | — | Vector DB for semantic search |
|
||||||
|
| `chrysopedia-ollama` | `ollama/ollama` | — | Embedding model server (nomic-embed-text) |
|
||||||
|
| `chrysopedia-api` | `Dockerfile.api` | `8000` | FastAPI REST API |
|
||||||
|
| `chrysopedia-worker` | `Dockerfile.api` | — | Celery worker (LLM pipeline) |
|
||||||
|
| `chrysopedia-watcher` | `Dockerfile.api` | — | Folder monitor → auto-ingest |
|
||||||
|
| `chrysopedia-web` | `Dockerfile.web` | `8096 → 80` | React frontend (nginx) |
|
||||||
|
|
||||||
|
### Data Model
|
||||||
|
|
||||||
|
| Entity | Purpose |
|
||||||
|
|--------|---------|
|
||||||
|
| **Creator** | Artists/producers whose content is indexed |
|
||||||
|
| **SourceVideo** | Video files processed by the pipeline (with content hash dedup) |
|
||||||
|
| **TranscriptSegment** | Timestamped text segments from Whisper |
|
||||||
|
| **KeyMoment** | Discrete insights extracted by LLM analysis |
|
||||||
|
| **TechniquePage** | Synthesized knowledge pages — the primary output |
|
||||||
|
| **TechniquePageVersion** | Snapshots before re-synthesis overwrites |
|
||||||
|
| **RelatedTechniqueLink** | Cross-references between technique pages |
|
||||||
|
| **Tag** | Hierarchical topic taxonomy |
|
||||||
|
| **ContentReport** | User-submitted content issues |
|
||||||
|
| **PipelineEvent** | Structured pipeline execution logs (tokens, timing, errors) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Docker ≥ 24.0 and Docker Compose ≥ 2.20
|
||||||
|
- Python 3.10+ with NVIDIA GPU + CUDA (for Whisper transcription)
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone and configure
|
||||||
|
git clone git@github.com:xpltdco/chrysopedia.git
|
||||||
|
cd chrysopedia
|
||||||
|
cp .env.example .env # edit with real values
|
||||||
|
|
||||||
|
# Start the stack
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
# Run database migrations
|
||||||
|
docker exec chrysopedia-api alembic upgrade head
|
||||||
|
|
||||||
|
# Pull the embedding model (first time only)
|
||||||
|
docker exec chrysopedia-ollama ollama pull nomic-embed-text
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
curl http://localhost:8096/health
|
||||||
|
```
|
||||||
|
|
||||||
|
### Transcribe videos
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd whisper && pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Single file
|
||||||
|
python transcribe.py --input "path/to/video.mp4" --output-dir ./transcripts
|
||||||
|
|
||||||
|
# Batch
|
||||||
|
python transcribe.py --input ./videos/ --output-dir ./transcripts
|
||||||
|
```
|
||||||
|
|
||||||
|
See [`whisper/README.md`](whisper/README.md) for full transcription docs.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
Copy `.env.example` to `.env`. Key groups:
|
||||||
|
|
||||||
|
| Group | Variables | Notes |
|
||||||
|
|-------|-----------|-------|
|
||||||
|
| **Database** | `POSTGRES_USER`, `POSTGRES_PASSWORD`, `POSTGRES_DB` | Default user: `chrysopedia` |
|
||||||
|
| **LLM** | `LLM_API_URL`, `LLM_API_KEY`, `LLM_MODEL` | OpenAI-compatible endpoint |
|
||||||
|
| **LLM Fallback** | `LLM_FALLBACK_URL`, `LLM_FALLBACK_MODEL` | Automatic failover |
|
||||||
|
| **Per-Stage Models** | `LLM_STAGE{2-5}_MODEL`, `LLM_STAGE{2-5}_MODALITY` | `chat` for fast stages, `thinking` for reasoning |
|
||||||
|
| **Embedding** | `EMBEDDING_API_URL`, `EMBEDDING_MODEL` | Ollama nomic-embed-text |
|
||||||
|
| **Vector DB** | `QDRANT_URL`, `QDRANT_COLLECTION` | Container-internal |
|
||||||
|
| **Features** | `REVIEW_MODE`, `DEBUG_MODE` | Review gate + LLM I/O capture |
|
||||||
|
| **Storage** | `TRANSCRIPT_STORAGE_PATH`, `VIDEO_METADATA_PATH` | Container bind mounts |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
### Public
|
||||||
|
|
||||||
|
| Method | Path | Description |
|
||||||
|
|--------|------|-------------|
|
||||||
|
| GET | `/health` | Health check (DB connectivity) |
|
||||||
|
| GET | `/api/v1/search?q=&scope=&limit=` | Semantic + keyword search |
|
||||||
|
| GET | `/api/v1/techniques` | List technique pages |
|
||||||
|
| GET | `/api/v1/techniques/{slug}` | Technique detail + key moments |
|
||||||
|
| GET | `/api/v1/techniques/{slug}/versions` | Version history |
|
||||||
|
| GET | `/api/v1/creators` | List creators (sort, genre filter) |
|
||||||
|
| GET | `/api/v1/creators/{slug}` | Creator detail |
|
||||||
|
| GET | `/api/v1/topics` | Topic hierarchy with counts |
|
||||||
|
| GET | `/api/v1/videos` | List source videos |
|
||||||
|
| POST | `/api/v1/reports` | Submit content report |
|
||||||
|
|
||||||
|
### Admin
|
||||||
|
|
||||||
|
| Method | Path | Description |
|
||||||
|
|--------|------|-------------|
|
||||||
|
| GET | `/api/v1/review/queue` | Review queue (status filter) |
|
||||||
|
| POST | `/api/v1/review/moments/{id}/approve` | Approve key moment |
|
||||||
|
| POST | `/api/v1/review/moments/{id}/reject` | Reject key moment |
|
||||||
|
| PUT | `/api/v1/review/moments/{id}` | Edit key moment |
|
||||||
|
| POST | `/api/v1/admin/pipeline/trigger/{video_id}` | Trigger/retrigger pipeline |
|
||||||
|
| GET | `/api/v1/admin/pipeline/events/{video_id}` | Pipeline event log |
|
||||||
|
| GET | `/api/v1/admin/pipeline/token-summary/{video_id}` | Token usage by stage |
|
||||||
|
| GET | `/api/v1/admin/pipeline/worker-status` | Celery worker status |
|
||||||
|
| PUT | `/api/v1/admin/pipeline/debug-mode` | Toggle debug mode |
|
||||||
|
|
||||||
|
### Ingest
|
||||||
|
|
||||||
|
| Method | Path | Description |
|
||||||
|
|--------|------|-------------|
|
||||||
|
| POST | `/api/v1/ingest` | Upload Whisper JSON transcript |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Local backend (with Docker services)
|
||||||
|
python -m venv .venv && source .venv/bin/activate
|
||||||
|
pip install -r backend/requirements.txt
|
||||||
|
docker compose up -d chrysopedia-db chrysopedia-redis
|
||||||
|
alembic upgrade head
|
||||||
|
cd backend && uvicorn main:app --reload --host 0.0.0.0 --port 8000
|
||||||
|
|
||||||
|
# Database migrations
|
||||||
|
alembic revision --autogenerate -m "describe_change"
|
||||||
|
alembic upgrade head
|
||||||
|
```
|
||||||
|
|
||||||
|
### Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
chrysopedia/
|
||||||
|
├── backend/ # FastAPI application
|
||||||
|
│ ├── main.py # Entry point, middleware, router mounting
|
||||||
|
│ ├── config.py # Pydantic Settings (all env vars)
|
||||||
|
│ ├── models.py # SQLAlchemy ORM models
|
||||||
|
│ ├── schemas.py # Pydantic request/response schemas
|
||||||
|
│ ├── worker.py # Celery app configuration
|
||||||
|
│ ├── watcher.py # Transcript folder watcher service
|
||||||
|
│ ├── search_service.py # Semantic search + keyword fallback
|
||||||
|
│ ├── routers/ # API endpoint handlers
|
||||||
|
│ ├── pipeline/ # LLM pipeline stages + clients
|
||||||
|
│ │ ├── stages.py # Stages 2-5 (Celery tasks)
|
||||||
|
│ │ ├── llm_client.py # OpenAI-compatible LLM client
|
||||||
|
│ │ ├── embedding_client.py
|
||||||
|
│ │ └── qdrant_client.py
|
||||||
|
│ └── tests/
|
||||||
|
├── frontend/ # React + TypeScript + Vite
|
||||||
|
│ └── src/
|
||||||
|
│ ├── pages/ # Home, Search, Technique, Creator, Topic, Admin
|
||||||
|
│ ├── components/ # Shared UI components
|
||||||
|
│ └── api/ # Typed API clients
|
||||||
|
├── whisper/ # Desktop transcription (Whisper large-v3)
|
||||||
|
├── docker/ # Dockerfiles + nginx config
|
||||||
|
├── alembic/ # Database migrations
|
||||||
|
├── config/ # canonical_tags.yaml (topic taxonomy)
|
||||||
|
├── prompts/ # LLM prompt templates (editable at runtime)
|
||||||
|
├── docker-compose.yml
|
||||||
|
└── .env.example
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Deployment (ub01)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh ub01
|
||||||
|
cd /vmPool/r/repos/xpltdco/chrysopedia
|
||||||
|
git pull && docker compose build && docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
| Resource | Location |
|
||||||
|
|----------|----------|
|
||||||
|
| Web UI | `http://ub01:8096` |
|
||||||
|
| API | `http://ub01:8096/health` |
|
||||||
|
| PostgreSQL | `ub01:5433` |
|
||||||
|
| Compose config | `/vmPool/r/compose/xpltd_chrysopedia/docker-compose.yml` |
|
||||||
|
| Persistent data | `/vmPool/r/services/chrysopedia_*` |
|
||||||
|
|
||||||
|
XPLTD conventions: `xpltd_chrysopedia` project name, dedicated bridge network (`172.32.0.0/24`), bind mounts under `/vmPool/r/services/`, PostgreSQL on port `5433`.
|
||||||
37
alembic.ini
Normal file
37
alembic.ini
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Chrysopedia — Alembic configuration
|
||||||
|
[alembic]
|
||||||
|
script_location = alembic
|
||||||
|
sqlalchemy.url = postgresql+asyncpg://chrysopedia:changeme@localhost:5433/chrysopedia
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
72
alembic/env.py
Normal file
72
alembic/env.py
Normal file
|
|
@ -0,0 +1,72 @@
|
||||||
|
"""Alembic env.py — async migration runner for Chrysopedia."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||||
|
|
||||||
|
# Ensure the backend package is importable
|
||||||
|
# When running locally: alembic/ sits beside backend/, so ../backend works
|
||||||
|
# When running in Docker: alembic/ is inside /app/ alongside the backend modules
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend"))
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
|
||||||
|
from database import Base # noqa: E402
|
||||||
|
import models # noqa: E402, F401 — registers all tables on Base.metadata
|
||||||
|
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
# Allow DATABASE_URL env var to override alembic.ini
|
||||||
|
url_override = os.getenv("DATABASE_URL")
|
||||||
|
if url_override:
|
||||||
|
config.set_main_option("sqlalchemy.url", url_override)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode — emit SQL to stdout."""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection):
|
||||||
|
context.configure(connection=connection, target_metadata=target_metadata)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_async_migrations() -> None:
|
||||||
|
"""Run migrations in 'online' mode with an async engine."""
|
||||||
|
connectable = async_engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
asyncio.run(run_async_migrations())
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
25
alembic/script.py.mako
Normal file
25
alembic/script.py.mako
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
171
alembic/versions/001_initial.py
Normal file
171
alembic/versions/001_initial.py
Normal file
|
|
@ -0,0 +1,171 @@
|
||||||
|
"""initial schema — 7 core entities
|
||||||
|
|
||||||
|
Revision ID: 001_initial
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-03-29
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "001_initial"
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── Enum types ───────────────────────────────────────────────────────
|
||||||
|
content_type = sa.Enum(
|
||||||
|
"tutorial", "livestream", "breakdown", "short_form",
|
||||||
|
name="content_type",
|
||||||
|
)
|
||||||
|
processing_status = sa.Enum(
|
||||||
|
"pending", "transcribed", "extracted", "reviewed", "published",
|
||||||
|
name="processing_status",
|
||||||
|
)
|
||||||
|
key_moment_content_type = sa.Enum(
|
||||||
|
"technique", "settings", "reasoning", "workflow",
|
||||||
|
name="key_moment_content_type",
|
||||||
|
)
|
||||||
|
review_status = sa.Enum(
|
||||||
|
"pending", "approved", "edited", "rejected",
|
||||||
|
name="review_status",
|
||||||
|
)
|
||||||
|
source_quality = sa.Enum(
|
||||||
|
"structured", "mixed", "unstructured",
|
||||||
|
name="source_quality",
|
||||||
|
)
|
||||||
|
page_review_status = sa.Enum(
|
||||||
|
"draft", "reviewed", "published",
|
||||||
|
name="page_review_status",
|
||||||
|
)
|
||||||
|
relationship_type = sa.Enum(
|
||||||
|
"same_technique_other_creator", "same_creator_adjacent", "general_cross_reference",
|
||||||
|
name="relationship_type",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── creators ─────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"creators",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("slug", sa.String(255), nullable=False, unique=True),
|
||||||
|
sa.Column("genres", ARRAY(sa.String), nullable=True),
|
||||||
|
sa.Column("folder_name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("view_count", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── source_videos ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"source_videos",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||||
|
sa.Column("creator_id", UUID(as_uuid=True), sa.ForeignKey("creators.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("filename", sa.String(500), nullable=False),
|
||||||
|
sa.Column("file_path", sa.String(1000), nullable=False),
|
||||||
|
sa.Column("duration_seconds", sa.Integer, nullable=True),
|
||||||
|
sa.Column("content_type", content_type, nullable=False),
|
||||||
|
sa.Column("transcript_path", sa.String(1000), nullable=True),
|
||||||
|
sa.Column("processing_status", processing_status, nullable=False, server_default="pending"),
|
||||||
|
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
op.create_index("ix_source_videos_creator_id", "source_videos", ["creator_id"])
|
||||||
|
|
||||||
|
# ── transcript_segments ──────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"transcript_segments",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||||
|
sa.Column("source_video_id", UUID(as_uuid=True), sa.ForeignKey("source_videos.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("start_time", sa.Float, nullable=False),
|
||||||
|
sa.Column("end_time", sa.Float, nullable=False),
|
||||||
|
sa.Column("text", sa.Text, nullable=False),
|
||||||
|
sa.Column("segment_index", sa.Integer, nullable=False),
|
||||||
|
sa.Column("topic_label", sa.String(255), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index("ix_transcript_segments_video_id", "transcript_segments", ["source_video_id"])
|
||||||
|
|
||||||
|
# ── technique_pages (must come before key_moments due to FK) ─────────
|
||||||
|
op.create_table(
|
||||||
|
"technique_pages",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||||
|
sa.Column("creator_id", UUID(as_uuid=True), sa.ForeignKey("creators.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("title", sa.String(500), nullable=False),
|
||||||
|
sa.Column("slug", sa.String(500), nullable=False, unique=True),
|
||||||
|
sa.Column("topic_category", sa.String(255), nullable=False),
|
||||||
|
sa.Column("topic_tags", ARRAY(sa.String), nullable=True),
|
||||||
|
sa.Column("summary", sa.Text, nullable=True),
|
||||||
|
sa.Column("body_sections", JSONB, nullable=True),
|
||||||
|
sa.Column("signal_chains", JSONB, nullable=True),
|
||||||
|
sa.Column("plugins", ARRAY(sa.String), nullable=True),
|
||||||
|
sa.Column("source_quality", source_quality, nullable=True),
|
||||||
|
sa.Column("view_count", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("review_status", page_review_status, nullable=False, server_default="draft"),
|
||||||
|
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
op.create_index("ix_technique_pages_creator_id", "technique_pages", ["creator_id"])
|
||||||
|
op.create_index("ix_technique_pages_topic_category", "technique_pages", ["topic_category"])
|
||||||
|
|
||||||
|
# ── key_moments ──────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"key_moments",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||||
|
sa.Column("source_video_id", UUID(as_uuid=True), sa.ForeignKey("source_videos.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("technique_page_id", UUID(as_uuid=True), sa.ForeignKey("technique_pages.id", ondelete="SET NULL"), nullable=True),
|
||||||
|
sa.Column("title", sa.String(500), nullable=False),
|
||||||
|
sa.Column("summary", sa.Text, nullable=False),
|
||||||
|
sa.Column("start_time", sa.Float, nullable=False),
|
||||||
|
sa.Column("end_time", sa.Float, nullable=False),
|
||||||
|
sa.Column("content_type", key_moment_content_type, nullable=False),
|
||||||
|
sa.Column("plugins", ARRAY(sa.String), nullable=True),
|
||||||
|
sa.Column("review_status", review_status, nullable=False, server_default="pending"),
|
||||||
|
sa.Column("raw_transcript", sa.Text, nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
op.create_index("ix_key_moments_source_video_id", "key_moments", ["source_video_id"])
|
||||||
|
op.create_index("ix_key_moments_technique_page_id", "key_moments", ["technique_page_id"])
|
||||||
|
|
||||||
|
# ── related_technique_links ──────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"related_technique_links",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||||
|
sa.Column("source_page_id", UUID(as_uuid=True), sa.ForeignKey("technique_pages.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("target_page_id", UUID(as_uuid=True), sa.ForeignKey("technique_pages.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("relationship", relationship_type, nullable=False),
|
||||||
|
sa.UniqueConstraint("source_page_id", "target_page_id", "relationship", name="uq_technique_link"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── tags ─────────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"tags",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False, unique=True),
|
||||||
|
sa.Column("category", sa.String(255), nullable=False),
|
||||||
|
sa.Column("aliases", ARRAY(sa.String), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index("ix_tags_category", "tags", ["category"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("tags")
|
||||||
|
op.drop_table("related_technique_links")
|
||||||
|
op.drop_table("key_moments")
|
||||||
|
op.drop_table("technique_pages")
|
||||||
|
op.drop_table("transcript_segments")
|
||||||
|
op.drop_table("source_videos")
|
||||||
|
op.drop_table("creators")
|
||||||
|
|
||||||
|
# Drop enum types
|
||||||
|
for name in [
|
||||||
|
"relationship_type", "page_review_status", "source_quality",
|
||||||
|
"review_status", "key_moment_content_type", "processing_status",
|
||||||
|
"content_type",
|
||||||
|
]:
|
||||||
|
sa.Enum(name=name).drop(op.get_bind(), checkfirst=True)
|
||||||
39
alembic/versions/002_technique_page_versions.py
Normal file
39
alembic/versions/002_technique_page_versions.py
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
"""technique_page_versions table for article versioning
|
||||||
|
|
||||||
|
Revision ID: 002_technique_page_versions
|
||||||
|
Revises: 001_initial
|
||||||
|
Create Date: 2026-03-30
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "002_technique_page_versions"
|
||||||
|
down_revision: Union[str, None] = "001_initial"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"technique_page_versions",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||||
|
sa.Column("technique_page_id", UUID(as_uuid=True), sa.ForeignKey("technique_pages.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("version_number", sa.Integer, nullable=False),
|
||||||
|
sa.Column("content_snapshot", JSONB, nullable=False),
|
||||||
|
sa.Column("pipeline_metadata", JSONB, nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_technique_page_versions_page_version",
|
||||||
|
"technique_page_versions",
|
||||||
|
["technique_page_id", "version_number"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("technique_page_versions")
|
||||||
47
alembic/versions/003_content_reports.py
Normal file
47
alembic/versions/003_content_reports.py
Normal file
|
|
@ -0,0 +1,47 @@
|
||||||
|
"""Create content_reports table.
|
||||||
|
|
||||||
|
Revision ID: 003_content_reports
|
||||||
|
Revises: 002_technique_page_versions
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
revision = "003_content_reports"
|
||||||
|
down_revision = "002_technique_page_versions"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"content_reports",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.func.gen_random_uuid()),
|
||||||
|
sa.Column("content_type", sa.String(50), nullable=False),
|
||||||
|
sa.Column("content_id", UUID(as_uuid=True), nullable=True),
|
||||||
|
sa.Column("content_title", sa.String(500), nullable=True),
|
||||||
|
sa.Column("report_type", sa.Enum(
|
||||||
|
"inaccurate", "missing_info", "wrong_attribution", "formatting", "other",
|
||||||
|
name="report_type", create_constraint=True,
|
||||||
|
), nullable=False),
|
||||||
|
sa.Column("description", sa.Text(), nullable=False),
|
||||||
|
sa.Column("status", sa.Enum(
|
||||||
|
"open", "acknowledged", "resolved", "dismissed",
|
||||||
|
name="report_status", create_constraint=True,
|
||||||
|
), nullable=False, server_default="open"),
|
||||||
|
sa.Column("admin_notes", sa.Text(), nullable=True),
|
||||||
|
sa.Column("page_url", sa.String(1000), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.Column("resolved_at", sa.DateTime(), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_index("ix_content_reports_status_created", "content_reports", ["status", "created_at"])
|
||||||
|
op.create_index("ix_content_reports_content", "content_reports", ["content_type", "content_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_content_reports_content")
|
||||||
|
op.drop_index("ix_content_reports_status_created")
|
||||||
|
op.drop_table("content_reports")
|
||||||
|
sa.Enum(name="report_status").drop(op.get_bind(), checkfirst=True)
|
||||||
|
sa.Enum(name="report_type").drop(op.get_bind(), checkfirst=True)
|
||||||
37
alembic/versions/004_pipeline_events.py
Normal file
37
alembic/versions/004_pipeline_events.py
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
"""Create pipeline_events table.
|
||||||
|
|
||||||
|
Revision ID: 004_pipeline_events
|
||||||
|
Revises: 003_content_reports
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||||
|
|
||||||
|
revision = "004_pipeline_events"
|
||||||
|
down_revision = "003_content_reports"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"pipeline_events",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.func.gen_random_uuid()),
|
||||||
|
sa.Column("video_id", UUID(as_uuid=True), nullable=False, index=True),
|
||||||
|
sa.Column("stage", sa.String(50), nullable=False),
|
||||||
|
sa.Column("event_type", sa.String(30), nullable=False),
|
||||||
|
sa.Column("prompt_tokens", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("completion_tokens", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("total_tokens", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("model", sa.String(100), nullable=True),
|
||||||
|
sa.Column("duration_ms", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("payload", JSONB(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False),
|
||||||
|
)
|
||||||
|
# Composite index for event log queries (video + newest first)
|
||||||
|
op.create_index("ix_pipeline_events_video_created", "pipeline_events", ["video_id", "created_at"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_pipeline_events_video_created")
|
||||||
|
op.drop_table("pipeline_events")
|
||||||
29
alembic/versions/005_content_hash.py
Normal file
29
alembic/versions/005_content_hash.py
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
"""Add content_hash to source_videos for duplicate detection.
|
||||||
|
|
||||||
|
Revision ID: 005_content_hash
|
||||||
|
Revises: 004_pipeline_events
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision = "005_content_hash"
|
||||||
|
down_revision = "004_pipeline_events"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"source_videos",
|
||||||
|
sa.Column("content_hash", sa.String(64), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_source_videos_content_hash",
|
||||||
|
"source_videos",
|
||||||
|
["content_hash"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_source_videos_content_hash")
|
||||||
|
op.drop_column("source_videos", "content_hash")
|
||||||
33
alembic/versions/006_debug_columns.py
Normal file
33
alembic/versions/006_debug_columns.py
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
"""Add debug LLM I/O capture columns to pipeline_events.
|
||||||
|
|
||||||
|
Revision ID: 006_debug_columns
|
||||||
|
Revises: 005_content_hash
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision = "006_debug_columns"
|
||||||
|
down_revision = "005_content_hash"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"pipeline_events",
|
||||||
|
sa.Column("system_prompt_text", sa.Text(), nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"pipeline_events",
|
||||||
|
sa.Column("user_prompt_text", sa.Text(), nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"pipeline_events",
|
||||||
|
sa.Column("response_text", sa.Text(), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("pipeline_events", "response_text")
|
||||||
|
op.drop_column("pipeline_events", "user_prompt_text")
|
||||||
|
op.drop_column("pipeline_events", "system_prompt_text")
|
||||||
30
alembic/versions/007_drop_review_columns.py
Normal file
30
alembic/versions/007_drop_review_columns.py
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
"""Drop review_status columns and enums.
|
||||||
|
|
||||||
|
Revision ID: 007_drop_review_columns
|
||||||
|
Revises: 006_debug_columns
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "007_drop_review_columns"
|
||||||
|
down_revision = "006_debug_columns"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.drop_column("key_moments", "review_status")
|
||||||
|
op.drop_column("technique_pages", "review_status")
|
||||||
|
op.execute("DROP TYPE IF EXISTS review_status")
|
||||||
|
op.execute("DROP TYPE IF EXISTS page_review_status")
|
||||||
|
# Collapse 'reviewed' into 'published' for any existing rows
|
||||||
|
op.execute(
|
||||||
|
"UPDATE source_videos SET processing_status = 'published' "
|
||||||
|
"WHERE processing_status = 'reviewed'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("CREATE TYPE review_status AS ENUM ('pending', 'approved', 'edited', 'rejected')")
|
||||||
|
op.execute("CREATE TYPE page_review_status AS ENUM ('draft', 'reviewed', 'published')")
|
||||||
|
op.add_column("key_moments", op.Column("review_status", op.Enum("pending", "approved", "edited", "rejected", name="review_status"), server_default="pending", nullable=False))
|
||||||
|
op.add_column("technique_pages", op.Column("review_status", op.Enum("draft", "reviewed", "published", name="page_review_status"), server_default="draft", nullable=False))
|
||||||
79
alembic/versions/008_rename_processing_status.py
Normal file
79
alembic/versions/008_rename_processing_status.py
Normal file
|
|
@ -0,0 +1,79 @@
|
||||||
|
"""Rename processing_status values to user-meaningful lifecycle states.
|
||||||
|
|
||||||
|
Old: pending, transcribed, extracted, published
|
||||||
|
New: not_started, queued, processing, error, complete
|
||||||
|
|
||||||
|
Uses text column conversion to avoid PG enum ADD VALUE transaction restriction.
|
||||||
|
|
||||||
|
Revision ID: 008_rename_processing_status
|
||||||
|
Revises: 007_drop_review_columns
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision = "008_rename_processing_status"
|
||||||
|
down_revision = "007_drop_review_columns"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# 1. Drop server default (it references the old enum type)
|
||||||
|
op.alter_column("source_videos", "processing_status", server_default=None)
|
||||||
|
|
||||||
|
# 2. Convert column to text to break free of the old enum
|
||||||
|
op.alter_column(
|
||||||
|
"source_videos", "processing_status",
|
||||||
|
type_=sa.Text(),
|
||||||
|
existing_type=sa.Enum(name="processing_status"),
|
||||||
|
postgresql_using="processing_status::text",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Drop old enum type
|
||||||
|
op.execute("DROP TYPE IF EXISTS processing_status")
|
||||||
|
|
||||||
|
# 4. Rename values in the text column
|
||||||
|
op.execute("UPDATE source_videos SET processing_status = 'not_started' WHERE processing_status = 'pending'")
|
||||||
|
op.execute("UPDATE source_videos SET processing_status = 'queued' WHERE processing_status = 'transcribed'")
|
||||||
|
op.execute("UPDATE source_videos SET processing_status = 'processing' WHERE processing_status = 'extracted'")
|
||||||
|
op.execute("UPDATE source_videos SET processing_status = 'complete' WHERE processing_status = 'published'")
|
||||||
|
|
||||||
|
# 5. Create new enum type
|
||||||
|
processing_status = sa.Enum(
|
||||||
|
"not_started", "queued", "processing", "error", "complete",
|
||||||
|
name="processing_status",
|
||||||
|
)
|
||||||
|
processing_status.create(op.get_bind(), checkfirst=True)
|
||||||
|
|
||||||
|
# 6. Convert column back to enum with new default
|
||||||
|
op.alter_column(
|
||||||
|
"source_videos", "processing_status",
|
||||||
|
type_=processing_status,
|
||||||
|
existing_type=sa.Text(),
|
||||||
|
postgresql_using="processing_status::processing_status",
|
||||||
|
server_default="not_started",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column("source_videos", "processing_status", server_default=None)
|
||||||
|
op.alter_column(
|
||||||
|
"source_videos", "processing_status",
|
||||||
|
type_=sa.Text(),
|
||||||
|
existing_type=sa.Enum(name="processing_status"),
|
||||||
|
postgresql_using="processing_status::text",
|
||||||
|
)
|
||||||
|
op.execute("DROP TYPE IF EXISTS processing_status")
|
||||||
|
op.execute("UPDATE source_videos SET processing_status = 'pending' WHERE processing_status = 'not_started'")
|
||||||
|
op.execute("UPDATE source_videos SET processing_status = 'transcribed' WHERE processing_status = 'queued'")
|
||||||
|
op.execute("UPDATE source_videos SET processing_status = 'extracted' WHERE processing_status = 'processing'")
|
||||||
|
op.execute("UPDATE source_videos SET processing_status = 'published' WHERE processing_status = 'complete'")
|
||||||
|
old_enum = sa.Enum("pending", "transcribed", "extracted", "published", name="processing_status")
|
||||||
|
old_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
op.alter_column(
|
||||||
|
"source_videos", "processing_status",
|
||||||
|
type_=old_enum,
|
||||||
|
existing_type=sa.Text(),
|
||||||
|
postgresql_using="processing_status::processing_status",
|
||||||
|
server_default="pending",
|
||||||
|
)
|
||||||
28
alembic/versions/009_add_creator_hidden_flag.py
Normal file
28
alembic/versions/009_add_creator_hidden_flag.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
"""Add hidden boolean flag to creators table.
|
||||||
|
|
||||||
|
Marks test/internal creators as hidden so they are filtered from
|
||||||
|
public API responses.
|
||||||
|
|
||||||
|
Revision ID: 009_add_creator_hidden_flag
|
||||||
|
Revises: 008_rename_processing_status
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision = "009_add_creator_hidden_flag"
|
||||||
|
down_revision = "008_rename_processing_status"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"creators",
|
||||||
|
sa.Column("hidden", sa.Boolean(), server_default="false", nullable=False),
|
||||||
|
)
|
||||||
|
# Mark known test creator as hidden
|
||||||
|
op.execute("UPDATE creators SET hidden = true WHERE slug = 'testcreator'")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("creators", "hidden")
|
||||||
54
alembic/versions/010_add_pipeline_runs.py
Normal file
54
alembic/versions/010_add_pipeline_runs.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
"""Add pipeline_runs table and run_id FK on pipeline_events.
|
||||||
|
|
||||||
|
Each pipeline trigger creates a run. Events are scoped to runs
|
||||||
|
for clean per-execution audit trails.
|
||||||
|
|
||||||
|
Revision ID: 010_add_pipeline_runs
|
||||||
|
Revises: 009_add_creator_hidden_flag
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
revision = "010_add_pipeline_runs"
|
||||||
|
down_revision = "009_add_creator_hidden_flag"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Create enums
|
||||||
|
pipeline_run_trigger = sa.Enum(
|
||||||
|
"manual", "clean_reprocess", "auto_ingest", "bulk",
|
||||||
|
name="pipeline_run_trigger",
|
||||||
|
)
|
||||||
|
pipeline_run_status = sa.Enum(
|
||||||
|
"running", "complete", "error", "cancelled",
|
||||||
|
name="pipeline_run_status",
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"pipeline_runs",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||||
|
sa.Column("video_id", UUID(as_uuid=True), sa.ForeignKey("source_videos.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||||
|
sa.Column("run_number", sa.Integer, nullable=False),
|
||||||
|
sa.Column("trigger", pipeline_run_trigger, nullable=False),
|
||||||
|
sa.Column("status", pipeline_run_status, nullable=False, server_default="running"),
|
||||||
|
sa.Column("started_at", sa.DateTime, nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("finished_at", sa.DateTime, nullable=True),
|
||||||
|
sa.Column("error_stage", sa.String(50), nullable=True),
|
||||||
|
sa.Column("total_tokens", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add run_id to pipeline_events (nullable for backward compat)
|
||||||
|
op.add_column(
|
||||||
|
"pipeline_events",
|
||||||
|
sa.Column("run_id", UUID(as_uuid=True), sa.ForeignKey("pipeline_runs.id", ondelete="SET NULL"), nullable=True, index=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("pipeline_events", "run_id")
|
||||||
|
op.drop_table("pipeline_runs")
|
||||||
|
op.execute("DROP TYPE IF EXISTS pipeline_run_trigger")
|
||||||
|
op.execute("DROP TYPE IF EXISTS pipeline_run_status")
|
||||||
35
alembic/versions/011_classification_cache_and_stage_rerun.py
Normal file
35
alembic/versions/011_classification_cache_and_stage_rerun.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
"""Add classification_data JSONB column to source_videos and stage_rerun trigger.
|
||||||
|
|
||||||
|
Persists stage 4 classification data in PostgreSQL alongside Redis cache,
|
||||||
|
eliminating the 24-hour TTL data loss risk. Also adds the 'stage_rerun'
|
||||||
|
trigger value for single-stage re-run support.
|
||||||
|
|
||||||
|
Revision ID: 011_cls_cache_rerun
|
||||||
|
Revises: 010_add_pipeline_runs
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
|
revision = "011_cls_cache_rerun"
|
||||||
|
down_revision = "010_add_pipeline_runs"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Add classification_data column to source_videos
|
||||||
|
op.add_column(
|
||||||
|
"source_videos",
|
||||||
|
sa.Column("classification_data", JSONB, nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add 'stage_rerun' to the pipeline_run_trigger enum
|
||||||
|
# PostgreSQL enums require ALTER TYPE to add values
|
||||||
|
op.execute("ALTER TYPE pipeline_run_trigger ADD VALUE IF NOT EXISTS 'stage_rerun'")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("source_videos", "classification_data")
|
||||||
|
# Note: PostgreSQL does not support removing values from enums.
|
||||||
|
# The 'stage_rerun' value will remain but be unused after downgrade.
|
||||||
55
alembic/versions/012_multi_source_format.py
Normal file
55
alembic/versions/012_multi_source_format.py
Normal file
|
|
@ -0,0 +1,55 @@
|
||||||
|
"""Add body_sections_format column and technique_page_videos association table.
|
||||||
|
|
||||||
|
Supports multi-source technique pages: tracks which source videos contributed
|
||||||
|
to a technique page, and marks the body_sections format version for future
|
||||||
|
structured section layouts.
|
||||||
|
|
||||||
|
Revision ID: 012_multi_source_fmt
|
||||||
|
Revises: 011_cls_cache_rerun
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
revision = "012_multi_source_fmt"
|
||||||
|
down_revision = "011_cls_cache_rerun"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Add body_sections_format to technique_pages with default for existing rows
|
||||||
|
op.add_column(
|
||||||
|
"technique_pages",
|
||||||
|
sa.Column(
|
||||||
|
"body_sections_format",
|
||||||
|
sa.String(20),
|
||||||
|
nullable=False,
|
||||||
|
server_default="v1",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create technique_page_videos association table
|
||||||
|
op.create_table(
|
||||||
|
"technique_page_videos",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.func.gen_random_uuid()),
|
||||||
|
sa.Column(
|
||||||
|
"technique_page_id",
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
sa.ForeignKey("technique_pages.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"source_video_id",
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
sa.ForeignKey("source_videos.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("added_at", sa.TIMESTAMP(), server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.UniqueConstraint("technique_page_id", "source_video_id", name="uq_page_video"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("technique_page_videos")
|
||||||
|
op.drop_column("technique_pages", "body_sections_format")
|
||||||
31
alembic/versions/013_add_search_log.py
Normal file
31
alembic/versions/013_add_search_log.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
"""Add search_log table for query analytics and popular searches.
|
||||||
|
|
||||||
|
Revision ID: 013_add_search_log
|
||||||
|
Revises: 012_multi_source_fmt
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision = "013_add_search_log"
|
||||||
|
down_revision = "012_multi_source_fmt"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"search_log",
|
||||||
|
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True),
|
||||||
|
sa.Column("query", sa.String(500), nullable=False),
|
||||||
|
sa.Column("scope", sa.String(50), nullable=False),
|
||||||
|
sa.Column("result_count", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("created_at", sa.TIMESTAMP(), server_default=sa.func.now(), nullable=False),
|
||||||
|
)
|
||||||
|
op.create_index("ix_search_log_query", "search_log", ["query"])
|
||||||
|
op.create_index("ix_search_log_created_at", "search_log", ["created_at"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_search_log_created_at", table_name="search_log")
|
||||||
|
op.drop_index("ix_search_log_query", table_name="search_log")
|
||||||
|
op.drop_table("search_log")
|
||||||
24
alembic/versions/014_add_creator_avatar.py
Normal file
24
alembic/versions/014_add_creator_avatar.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
"""Add avatar columns to creators table.
|
||||||
|
|
||||||
|
Revision ID: 014_add_creator_avatar
|
||||||
|
Revises: 013_add_search_log
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision = "014_add_creator_avatar"
|
||||||
|
down_revision = "013_add_search_log"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column("creators", sa.Column("avatar_url", sa.String(1000), nullable=True))
|
||||||
|
op.add_column("creators", sa.Column("avatar_source", sa.String(50), nullable=True))
|
||||||
|
op.add_column("creators", sa.Column("avatar_fetched_at", sa.TIMESTAMP(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("creators", "avatar_fetched_at")
|
||||||
|
op.drop_column("creators", "avatar_source")
|
||||||
|
op.drop_column("creators", "avatar_url")
|
||||||
25
alembic/versions/015_add_creator_profile.py
Normal file
25
alembic/versions/015_add_creator_profile.py
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
"""Add bio, social_links, and featured columns to creators table.
|
||||||
|
|
||||||
|
Revision ID: 015_add_creator_profile
|
||||||
|
Revises: 014_add_creator_avatar
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
|
revision = "015_add_creator_profile"
|
||||||
|
down_revision = "014_add_creator_avatar"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column("creators", sa.Column("bio", sa.Text(), nullable=True))
|
||||||
|
op.add_column("creators", sa.Column("social_links", JSONB(), nullable=True))
|
||||||
|
op.add_column("creators", sa.Column("featured", sa.Boolean(), server_default="false", nullable=False))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("creators", "featured")
|
||||||
|
op.drop_column("creators", "social_links")
|
||||||
|
op.drop_column("creators", "bio")
|
||||||
52
alembic/versions/016_add_users_and_invite_codes.py
Normal file
52
alembic/versions/016_add_users_and_invite_codes.py
Normal file
|
|
@ -0,0 +1,52 @@
|
||||||
|
"""Add users and invite_codes tables for creator authentication.
|
||||||
|
|
||||||
|
Revision ID: 016_add_users_and_invite_codes
|
||||||
|
Revises: 015_add_creator_profile
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "016_add_users_and_invite_codes"
|
||||||
|
down_revision = "015_add_creator_profile"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Use raw SQL to avoid SQLAlchemy's Enum double-creation bug with asyncpg
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE user_role AS ENUM ('creator', 'admin');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$
|
||||||
|
""")
|
||||||
|
|
||||||
|
op.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
email VARCHAR(255) NOT NULL UNIQUE,
|
||||||
|
hashed_password VARCHAR(255) NOT NULL,
|
||||||
|
display_name VARCHAR(255) NOT NULL,
|
||||||
|
role user_role NOT NULL DEFAULT 'creator',
|
||||||
|
creator_id UUID REFERENCES creators(id) ON DELETE SET NULL,
|
||||||
|
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMP NOT NULL DEFAULT now()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
op.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS invite_codes (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
code VARCHAR(100) NOT NULL UNIQUE,
|
||||||
|
uses_remaining INTEGER NOT NULL DEFAULT 1,
|
||||||
|
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||||
|
expires_at TIMESTAMP,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT now()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("DROP TABLE IF EXISTS invite_codes")
|
||||||
|
op.execute("DROP TABLE IF EXISTS users")
|
||||||
|
op.execute("DROP TYPE IF EXISTS user_role")
|
||||||
51
alembic/versions/017_add_consent_tables.py
Normal file
51
alembic/versions/017_add_consent_tables.py
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
"""Add video_consents and consent_audit_log tables for per-video consent management.
|
||||||
|
|
||||||
|
Revision ID: 017_add_consent_tables
|
||||||
|
Revises: 016_add_users_and_invite_codes
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
revision = "017_add_consent_tables"
|
||||||
|
down_revision = "016_add_users_and_invite_codes"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Create video_consents table
|
||||||
|
op.create_table(
|
||||||
|
"video_consents",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.func.gen_random_uuid()),
|
||||||
|
sa.Column("source_video_id", UUID(as_uuid=True), sa.ForeignKey("source_videos.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("creator_id", UUID(as_uuid=True), sa.ForeignKey("creators.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("kb_inclusion", sa.Boolean(), nullable=False, server_default="false"),
|
||||||
|
sa.Column("training_usage", sa.Boolean(), nullable=False, server_default="false"),
|
||||||
|
sa.Column("public_display", sa.Boolean(), nullable=False, server_default="true"),
|
||||||
|
sa.Column("updated_by", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="RESTRICT"), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
|
sa.UniqueConstraint("source_video_id", name="uq_video_consent_video"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create consent_audit_log table
|
||||||
|
op.create_table(
|
||||||
|
"consent_audit_log",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.func.gen_random_uuid()),
|
||||||
|
sa.Column("video_consent_id", UUID(as_uuid=True), sa.ForeignKey("video_consents.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("version", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("field_name", sa.String(50), nullable=False),
|
||||||
|
sa.Column("old_value", sa.Boolean(), nullable=True),
|
||||||
|
sa.Column("new_value", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("changed_by", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="RESTRICT"), nullable=False),
|
||||||
|
sa.Column("ip_address", sa.String(45), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
op.create_index("ix_consent_audit_log_video_consent_id", "consent_audit_log", ["video_consent_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_consent_audit_log_video_consent_id", table_name="consent_audit_log")
|
||||||
|
op.drop_table("consent_audit_log")
|
||||||
|
op.drop_table("video_consents")
|
||||||
37
alembic/versions/018_add_impersonation_log.py
Normal file
37
alembic/versions/018_add_impersonation_log.py
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
"""Add impersonation_log table for admin impersonation audit trail.
|
||||||
|
|
||||||
|
Revision ID: 018_add_impersonation_log
|
||||||
|
Revises: 017_add_consent_tables
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
|
||||||
|
revision = "018_add_impersonation_log"
|
||||||
|
down_revision = "017_add_consent_tables"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"impersonation_log",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")),
|
||||||
|
sa.Column("admin_user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("target_user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("action", sa.String(10), nullable=False), # 'start' or 'stop'
|
||||||
|
sa.Column("ip_address", sa.String(45), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime, server_default=sa.func.now(), nullable=False),
|
||||||
|
)
|
||||||
|
op.create_index("ix_impersonation_log_admin", "impersonation_log", ["admin_user_id"])
|
||||||
|
op.create_index("ix_impersonation_log_target", "impersonation_log", ["target_user_id"])
|
||||||
|
op.create_index("ix_impersonation_log_created", "impersonation_log", ["created_at"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_impersonation_log_created")
|
||||||
|
op.drop_index("ix_impersonation_log_target")
|
||||||
|
op.drop_index("ix_impersonation_log_admin")
|
||||||
|
op.drop_table("impersonation_log")
|
||||||
44
alembic/versions/019_add_highlight_candidates.py
Normal file
44
alembic/versions/019_add_highlight_candidates.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""Add highlight_candidates table for highlight detection scoring.
|
||||||
|
|
||||||
|
Revision ID: 019_add_highlight_candidates
|
||||||
|
Revises: 018_add_impersonation_log
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
|
||||||
|
revision = "019_add_highlight_candidates"
|
||||||
|
down_revision = "018_add_impersonation_log"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Pure SQL — idempotent with IF NOT EXISTS / exception guards
|
||||||
|
op.execute("DO $$ BEGIN CREATE TYPE highlight_status AS ENUM ('candidate', 'approved', 'rejected'); EXCEPTION WHEN duplicate_object THEN NULL; END $$")
|
||||||
|
op.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS highlight_candidates (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
key_moment_id UUID NOT NULL UNIQUE REFERENCES key_moments(id) ON DELETE CASCADE,
|
||||||
|
source_video_id UUID NOT NULL REFERENCES source_videos(id) ON DELETE CASCADE,
|
||||||
|
score FLOAT NOT NULL,
|
||||||
|
score_breakdown JSONB,
|
||||||
|
duration_secs FLOAT NOT NULL,
|
||||||
|
status highlight_status NOT NULL DEFAULT 'candidate',
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMP NOT NULL DEFAULT now()
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
op.execute("CREATE INDEX IF NOT EXISTS ix_highlight_candidates_source_video_id ON highlight_candidates (source_video_id)")
|
||||||
|
op.execute("CREATE INDEX IF NOT EXISTS ix_highlight_candidates_score_desc ON highlight_candidates (score DESC)")
|
||||||
|
op.execute("CREATE INDEX IF NOT EXISTS ix_highlight_candidates_status ON highlight_candidates (status)")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_highlight_candidates_status")
|
||||||
|
op.drop_index("ix_highlight_candidates_score_desc")
|
||||||
|
op.drop_index("ix_highlight_candidates_source_video_id")
|
||||||
|
op.drop_table("highlight_candidates")
|
||||||
|
sa.Enum(name="highlight_status").drop(op.get_bind(), checkfirst=True)
|
||||||
29
alembic/versions/020_add_chapter_status_and_sort_order.py
Normal file
29
alembic/versions/020_add_chapter_status_and_sort_order.py
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
"""Add chapter_status and sort_order columns to key_moments.
|
||||||
|
|
||||||
|
Revision ID: 020_add_chapter_status_and_sort_order
|
||||||
|
Revises: 019_add_highlight_candidates
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision = "020_add_chapter_status_and_sort_order"
|
||||||
|
down_revision = "019_add_highlight_candidates"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Pure SQL to avoid SQLAlchemy enum creation hooks
|
||||||
|
op.execute("DO $$ BEGIN CREATE TYPE chapter_status AS ENUM ('draft', 'approved', 'hidden'); EXCEPTION WHEN duplicate_object THEN NULL; END $$")
|
||||||
|
op.execute("ALTER TABLE key_moments ADD COLUMN IF NOT EXISTS chapter_status chapter_status NOT NULL DEFAULT 'draft'")
|
||||||
|
op.execute("ALTER TABLE key_moments ADD COLUMN IF NOT EXISTS sort_order INTEGER NOT NULL DEFAULT 0")
|
||||||
|
op.execute("CREATE INDEX IF NOT EXISTS ix_key_moments_chapter_status ON key_moments (chapter_status)")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_key_moments_chapter_status")
|
||||||
|
op.drop_column("key_moments", "sort_order")
|
||||||
|
op.drop_column("key_moments", "chapter_status")
|
||||||
|
sa.Enum(name="chapter_status").drop(op.get_bind(), checkfirst=True)
|
||||||
24
alembic/versions/021_add_highlight_trim_columns.py
Normal file
24
alembic/versions/021_add_highlight_trim_columns.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
"""Add trim_start and trim_end columns to highlight_candidates.
|
||||||
|
|
||||||
|
Revision ID: 021_add_highlight_trim_columns
|
||||||
|
Revises: 020_add_chapter_status_and_sort_order
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision = "021_add_highlight_trim_columns"
|
||||||
|
down_revision = "020_add_chapter_status_and_sort_order"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column("highlight_candidates", sa.Column("trim_start", sa.Float(), nullable=True))
|
||||||
|
op.add_column("highlight_candidates", sa.Column("trim_end", sa.Float(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("highlight_candidates", "trim_end")
|
||||||
|
op.drop_column("highlight_candidates", "trim_start")
|
||||||
31
alembic/versions/022_add_creator_follows.py
Normal file
31
alembic/versions/022_add_creator_follows.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
"""Add creator_follows table for user follow system.
|
||||||
|
|
||||||
|
Revision ID: 022_add_creator_follows
|
||||||
|
Revises: 021_add_highlight_trim_columns
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
revision = "022_add_creator_follows"
|
||||||
|
down_revision = "021_add_highlight_trim_columns"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS creator_follows (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
creator_id UUID NOT NULL REFERENCES creators(id) ON DELETE CASCADE,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT now(),
|
||||||
|
CONSTRAINT uq_creator_follow_user_creator UNIQUE (user_id, creator_id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
op.execute("CREATE INDEX IF NOT EXISTS ix_creator_follows_user_id ON creator_follows (user_id)")
|
||||||
|
op.execute("CREATE INDEX IF NOT EXISTS ix_creator_follows_creator_id ON creator_follows (creator_id)")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("DROP TABLE IF EXISTS creator_follows")
|
||||||
21
alembic/versions/023_add_personality_profile.py
Normal file
21
alembic/versions/023_add_personality_profile.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
"""Add personality_profile JSONB column to creators.
|
||||||
|
|
||||||
|
Revision ID: 023_add_personality_profile
|
||||||
|
Revises: 022_add_creator_follows
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
revision = "023_add_personality_profile"
|
||||||
|
down_revision = "022_add_creator_follows"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute("ALTER TABLE creators ADD COLUMN IF NOT EXISTS personality_profile JSONB")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("ALTER TABLE creators DROP COLUMN IF EXISTS personality_profile")
|
||||||
44
alembic/versions/024_add_posts_and_attachments.py
Normal file
44
alembic/versions/024_add_posts_and_attachments.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""Add posts and post_attachments tables.
|
||||||
|
|
||||||
|
Revision ID: 024_add_posts_and_attachments
|
||||||
|
Revises: 023_add_personality_profile
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "024_add_posts_and_attachments"
|
||||||
|
down_revision = "023_add_personality_profile"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"posts",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.func.gen_random_uuid()),
|
||||||
|
sa.Column("creator_id", UUID(as_uuid=True), sa.ForeignKey("creators.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||||
|
sa.Column("title", sa.String(500), nullable=False),
|
||||||
|
sa.Column("body_json", JSONB, nullable=False),
|
||||||
|
sa.Column("is_published", sa.Boolean, nullable=False, server_default="false"),
|
||||||
|
sa.Column("created_at", sa.DateTime, nullable=False, server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime, nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"post_attachments",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.func.gen_random_uuid()),
|
||||||
|
sa.Column("post_id", UUID(as_uuid=True), sa.ForeignKey("posts.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||||
|
sa.Column("filename", sa.String(500), nullable=False),
|
||||||
|
sa.Column("object_key", sa.String(1000), nullable=False),
|
||||||
|
sa.Column("content_type", sa.String(255), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.BigInteger, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime, nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("post_attachments")
|
||||||
|
op.drop_table("posts")
|
||||||
45
alembic/versions/025_add_generated_shorts.py
Normal file
45
alembic/versions/025_add_generated_shorts.py
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
"""Add generated_shorts table with format_preset and short_status enums.
|
||||||
|
|
||||||
|
Revision ID: 025_add_generated_shorts
|
||||||
|
Revises: 024_add_posts_and_attachments
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "025_add_generated_shorts"
|
||||||
|
down_revision = "024_add_posts_and_attachments"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
format_preset_enum = sa.Enum("vertical", "square", "horizontal", name="format_preset")
|
||||||
|
short_status_enum = sa.Enum("pending", "processing", "complete", "failed", name="short_status")
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
format_preset_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
short_status_enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"generated_shorts",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.func.gen_random_uuid()),
|
||||||
|
sa.Column("highlight_candidate_id", UUID(as_uuid=True), sa.ForeignKey("highlight_candidates.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||||
|
sa.Column("format_preset", format_preset_enum, nullable=False),
|
||||||
|
sa.Column("minio_object_key", sa.String(1000), nullable=True),
|
||||||
|
sa.Column("duration_secs", sa.Float, nullable=True),
|
||||||
|
sa.Column("width", sa.Integer, nullable=False),
|
||||||
|
sa.Column("height", sa.Integer, nullable=False),
|
||||||
|
sa.Column("file_size_bytes", sa.BigInteger, nullable=True),
|
||||||
|
sa.Column("status", short_status_enum, nullable=False, server_default="pending"),
|
||||||
|
sa.Column("error_message", sa.Text, nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime, nullable=False, server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime, nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("generated_shorts")
|
||||||
|
short_status_enum.drop(op.get_bind(), checkfirst=True)
|
||||||
|
format_preset_enum.drop(op.get_bind(), checkfirst=True)
|
||||||
45
alembic/versions/026_add_share_token.py
Normal file
45
alembic/versions/026_add_share_token.py
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
"""Add share_token column to generated_shorts for public sharing.
|
||||||
|
|
||||||
|
Revision ID: 026_add_share_token
|
||||||
|
Revises: 025_add_generated_shorts
|
||||||
|
"""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "026_add_share_token"
|
||||||
|
down_revision = "025_add_generated_shorts"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Add nullable column first
|
||||||
|
op.add_column(
|
||||||
|
"generated_shorts",
|
||||||
|
sa.Column("share_token", sa.String(16), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backfill existing complete shorts with unique tokens
|
||||||
|
conn = op.get_bind()
|
||||||
|
rows = conn.execute(
|
||||||
|
sa.text("SELECT id FROM generated_shorts WHERE status = 'complete' AND share_token IS NULL")
|
||||||
|
).fetchall()
|
||||||
|
for (row_id,) in rows:
|
||||||
|
token = secrets.token_urlsafe(8) # ~11 chars, fits in String(16)
|
||||||
|
conn.execute(
|
||||||
|
sa.text("UPDATE generated_shorts SET share_token = :token WHERE id = :id"),
|
||||||
|
{"token": token, "id": row_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create unique index
|
||||||
|
op.create_index("ix_generated_shorts_share_token", "generated_shorts", ["share_token"], unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_generated_shorts_share_token", table_name="generated_shorts")
|
||||||
|
op.drop_column("generated_shorts", "share_token")
|
||||||
30
alembic/versions/027_add_captions_enabled.py
Normal file
30
alembic/versions/027_add_captions_enabled.py
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
"""Add captions_enabled boolean to generated_shorts.
|
||||||
|
|
||||||
|
Revision ID: 027_add_captions_enabled
|
||||||
|
Revises: 026_add_share_token
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "027_add_captions_enabled"
|
||||||
|
down_revision = "026_add_share_token"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"generated_shorts",
|
||||||
|
sa.Column(
|
||||||
|
"captions_enabled",
|
||||||
|
sa.Boolean(),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("false"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("generated_shorts", "captions_enabled")
|
||||||
26
alembic/versions/028_add_shorts_template.py
Normal file
26
alembic/versions/028_add_shorts_template.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""Add shorts_template JSONB column to creators.
|
||||||
|
|
||||||
|
Revision ID: 028_add_shorts_template
|
||||||
|
Revises: 027_add_captions_enabled
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "028_add_shorts_template"
|
||||||
|
down_revision = "027_add_captions_enabled"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"creators",
|
||||||
|
sa.Column("shorts_template", JSONB, nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("creators", "shorts_template")
|
||||||
48
alembic/versions/029_add_email_digest.py
Normal file
48
alembic/versions/029_add_email_digest.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
"""Add notification_preferences to users and email_digest_log table.
|
||||||
|
|
||||||
|
Revision ID: 029_add_email_digest
|
||||||
|
Revises: 028_add_shorts_template
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = "029_add_email_digest"
|
||||||
|
down_revision = "028_add_shorts_template"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# notification_preferences JSONB on users
|
||||||
|
op.add_column(
|
||||||
|
"users",
|
||||||
|
sa.Column(
|
||||||
|
"notification_preferences",
|
||||||
|
JSONB,
|
||||||
|
nullable=False,
|
||||||
|
server_default='{"email_digests": true, "digest_frequency": "daily"}',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# email_digest_log table
|
||||||
|
op.create_table(
|
||||||
|
"email_digest_log",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.func.gen_random_uuid()),
|
||||||
|
sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("digest_sent_at", sa.DateTime, server_default=sa.func.now(), nullable=False),
|
||||||
|
sa.Column("content_summary", JSONB, nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_email_digest_log_user_sent",
|
||||||
|
"email_digest_log",
|
||||||
|
["user_id", "digest_sent_at"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_email_digest_log_user_sent", table_name="email_digest_log")
|
||||||
|
op.drop_table("email_digest_log")
|
||||||
|
op.drop_column("users", "notification_preferences")
|
||||||
31
alembic/versions/030_add_onboarding_completed.py
Normal file
31
alembic/versions/030_add_onboarding_completed.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
"""add_onboarding_completed
|
||||||
|
|
||||||
|
Revision ID: 030_onboarding
|
||||||
|
Revises: 029
|
||||||
|
Create Date: 2026-04-04
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers
|
||||||
|
revision = "030_onboarding"
|
||||||
|
down_revision = "029_add_email_digest"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"users",
|
||||||
|
sa.Column(
|
||||||
|
"onboarding_completed",
|
||||||
|
sa.Boolean(),
|
||||||
|
server_default="false",
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("users", "onboarding_completed")
|
||||||
40
alembic/versions/031_add_chat_usage_log.py
Normal file
40
alembic/versions/031_add_chat_usage_log.py
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
"""add_chat_usage_log
|
||||||
|
|
||||||
|
Revision ID: 031_chat_usage_log
|
||||||
|
Revises: 030_onboarding
|
||||||
|
Create Date: 2026-04-04
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# revision identifiers
|
||||||
|
revision = "031_chat_usage_log"
|
||||||
|
down_revision = "030_onboarding"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"chat_usage_log",
|
||||||
|
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.func.gen_random_uuid()),
|
||||||
|
sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True),
|
||||||
|
sa.Column("client_ip", sa.String(45), nullable=True),
|
||||||
|
sa.Column("creator_slug", sa.String(255), nullable=True),
|
||||||
|
sa.Column("query", sa.Text(), nullable=False),
|
||||||
|
sa.Column("prompt_tokens", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
sa.Column("completion_tokens", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
sa.Column("total_tokens", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
sa.Column("cascade_tier", sa.String(50), nullable=True),
|
||||||
|
sa.Column("model", sa.String(100), nullable=True),
|
||||||
|
sa.Column("latency_ms", sa.Float(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
op.create_index("ix_chat_usage_log_created_at", "chat_usage_log", ["created_at"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_chat_usage_log_created_at", table_name="chat_usage_log")
|
||||||
|
op.drop_table("chat_usage_log")
|
||||||
193
backend/auth.py
Normal file
193
backend/auth.py
Normal file
|
|
@ -0,0 +1,193 @@
|
||||||
|
"""Authentication utilities — password hashing, JWT, FastAPI dependencies."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
import jwt
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from config import get_settings
|
||||||
|
from database import get_session
|
||||||
|
from models import User, UserRole
|
||||||
|
|
||||||
|
# ── Password hashing ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(plain: str) -> str:
|
||||||
|
"""Hash a plaintext password with bcrypt."""
|
||||||
|
return bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain: str, hashed: str) -> bool:
|
||||||
|
"""Verify a plaintext password against a bcrypt hash."""
|
||||||
|
return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
# ── JWT ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_ALGORITHM = "HS256"
|
||||||
|
_ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 # 24 hours
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(
|
||||||
|
user_id: uuid.UUID | str,
|
||||||
|
role: str,
|
||||||
|
*,
|
||||||
|
expires_minutes: int = _ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||||
|
) -> str:
|
||||||
|
"""Create a signed JWT with user_id and role claims."""
|
||||||
|
settings = get_settings()
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
payload = {
|
||||||
|
"sub": str(user_id),
|
||||||
|
"role": role,
|
||||||
|
"iat": now,
|
||||||
|
"exp": now + timedelta(minutes=expires_minutes),
|
||||||
|
}
|
||||||
|
return jwt.encode(payload, settings.app_secret_key, algorithm=_ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
_IMPERSONATION_EXPIRE_MINUTES = 60 # 1 hour
|
||||||
|
|
||||||
|
|
||||||
|
def create_impersonation_token(
|
||||||
|
admin_user_id: uuid.UUID | str,
|
||||||
|
target_user_id: uuid.UUID | str,
|
||||||
|
target_role: str,
|
||||||
|
*,
|
||||||
|
write_mode: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""Create a scoped JWT for admin impersonation.
|
||||||
|
|
||||||
|
The token has sub=target_user_id so get_current_user loads the target,
|
||||||
|
plus original_user_id so the system knows it's impersonation.
|
||||||
|
When write_mode is True, the token allows write operations.
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
payload = {
|
||||||
|
"sub": str(target_user_id),
|
||||||
|
"role": target_role,
|
||||||
|
"original_user_id": str(admin_user_id),
|
||||||
|
"type": "impersonation",
|
||||||
|
"iat": now,
|
||||||
|
"exp": now + timedelta(minutes=_IMPERSONATION_EXPIRE_MINUTES),
|
||||||
|
}
|
||||||
|
if write_mode:
|
||||||
|
payload["write_mode"] = True
|
||||||
|
return jwt.encode(payload, settings.app_secret_key, algorithm=_ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_access_token(token: str) -> dict:
|
||||||
|
"""Decode and validate a JWT. Raises on expiry or malformed tokens."""
|
||||||
|
settings = get_settings()
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.app_secret_key,
|
||||||
|
algorithms=[_ALGORITHM],
|
||||||
|
options={"require": ["sub", "role", "exp"]},
|
||||||
|
)
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Token has expired",
|
||||||
|
)
|
||||||
|
except jwt.InvalidTokenError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=f"Invalid token: {exc}",
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
# ── FastAPI dependencies ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: Annotated[str, Depends(oauth2_scheme)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
) -> User:
|
||||||
|
"""Decode JWT, load User from DB, raise 401 if missing or inactive.
|
||||||
|
|
||||||
|
If the token contains an original_user_id claim (impersonation),
|
||||||
|
sets _impersonating_admin_id on the returned user object.
|
||||||
|
"""
|
||||||
|
payload = decode_access_token(token)
|
||||||
|
user_id = payload.get("sub")
|
||||||
|
result = await session.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User not found or inactive",
|
||||||
|
)
|
||||||
|
# Attach impersonation metadata (non-column runtime attribute)
|
||||||
|
user._impersonating_admin_id = payload.get("original_user_id") # type: ignore[attr-defined]
|
||||||
|
user._impersonation_write_mode = payload.get("write_mode", False) # type: ignore[attr-defined]
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
_optional_oauth2 = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_optional_user(
|
||||||
|
token: Annotated[str | None, Depends(_optional_oauth2)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
) -> User | None:
|
||||||
|
"""Like get_current_user but returns None instead of 401 when no token."""
|
||||||
|
if token is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
payload = decode_access_token(token)
|
||||||
|
except HTTPException:
|
||||||
|
return None
|
||||||
|
user_id = payload.get("sub")
|
||||||
|
result = await session.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not user.is_active:
|
||||||
|
return None
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def require_role(required_role: UserRole):
|
||||||
|
"""Return a dependency that checks the current user has the given role."""
|
||||||
|
|
||||||
|
async def _check(
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
) -> User:
|
||||||
|
if current_user.role != required_role:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Requires {required_role.value} role",
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
return _check
|
||||||
|
|
||||||
|
|
||||||
|
async def reject_impersonation(
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
) -> User:
|
||||||
|
"""Dependency that blocks write operations during impersonation.
|
||||||
|
|
||||||
|
If the impersonation token was issued with write_mode=True,
|
||||||
|
writes are permitted.
|
||||||
|
"""
|
||||||
|
admin_id = getattr(current_user, "_impersonating_admin_id", None)
|
||||||
|
if admin_id is not None:
|
||||||
|
write_mode = getattr(current_user, "_impersonation_write_mode", False)
|
||||||
|
if not write_mode:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Write operations are not allowed during impersonation",
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
519
backend/chat_service.py
Normal file
519
backend/chat_service.py
Normal file
|
|
@ -0,0 +1,519 @@
|
||||||
|
"""Chat service: retrieve context via search, stream LLM response as SSE events.
|
||||||
|
|
||||||
|
Assembles a numbered context block from search results, then streams
|
||||||
|
completion tokens from an OpenAI-compatible API. Yields SSE-formatted
|
||||||
|
events: sources, token, done, and error.
|
||||||
|
|
||||||
|
Multi-turn memory: When a conversation_id is provided, prior messages are
|
||||||
|
loaded from Redis, injected into the LLM messages array, and the new
|
||||||
|
user+assistant turn is appended after streaming completes. History is
|
||||||
|
capped at 10 turn pairs (20 messages) and expires after 1 hour of
|
||||||
|
inactivity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from typing import Any, AsyncIterator
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from config import Settings
|
||||||
|
from models import Creator
|
||||||
|
from search_service import SearchService
|
||||||
|
|
||||||
|
logger = logging.getLogger("chrysopedia.chat")
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
|
You are Chrysopedia, an expert assistant for music production techniques — \
|
||||||
|
synthesis, sound design, mixing, sampling, and audio processing.
|
||||||
|
|
||||||
|
## Rules
|
||||||
|
- Use ONLY the numbered sources below. Do not invent facts.
|
||||||
|
- Cite every factual claim inline with [N] immediately after the claim \
|
||||||
|
(e.g. "Parallel compression adds sustain [2] while preserving transients [1].").
|
||||||
|
- When sources disagree, present both perspectives with their citations.
|
||||||
|
- If the sources lack enough information, say so honestly.
|
||||||
|
|
||||||
|
## Response format
|
||||||
|
- Aim for 2–4 short paragraphs. Expand only when the question warrants detail.
|
||||||
|
- Use bullet lists for steps, signal chains, or parameter lists.
|
||||||
|
- **Bold** key terms on first mention.
|
||||||
|
- Use audio/synthesis/mixing terminology naturally — do not over-explain \
|
||||||
|
standard concepts (e.g. LFO, sidechain, wet/dry) unless the user asks.
|
||||||
|
|
||||||
|
Sources:
|
||||||
|
{context_block}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_MAX_CONTEXT_SOURCES = 10
|
||||||
|
_MAX_TURN_PAIRS = 10
|
||||||
|
_HISTORY_TTL_SECONDS = 3600 # 1 hour
|
||||||
|
|
||||||
|
|
||||||
|
def _redis_key(conversation_id: str) -> str:
|
||||||
|
return f"chrysopedia:chat:{conversation_id}"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatService:
|
||||||
|
"""Retrieve context from search, stream an LLM response with citations."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings, redis=None) -> None:
|
||||||
|
self.settings = settings
|
||||||
|
self._search = SearchService(settings)
|
||||||
|
self._openai = openai.AsyncOpenAI(
|
||||||
|
base_url=settings.llm_api_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
)
|
||||||
|
self._fallback_openai = openai.AsyncOpenAI(
|
||||||
|
base_url=settings.llm_fallback_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
)
|
||||||
|
self._redis = redis
|
||||||
|
|
||||||
|
async def _load_history(self, conversation_id: str) -> list[dict[str, str]]:
|
||||||
|
"""Load conversation history from Redis. Returns empty list on miss."""
|
||||||
|
if not self._redis:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
raw = await self._redis.get(_redis_key(conversation_id))
|
||||||
|
if raw:
|
||||||
|
return json.loads(raw)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("chat_history_load_error cid=%s", conversation_id, exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _save_history(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
history: list[dict[str, str]],
|
||||||
|
user_msg: str,
|
||||||
|
assistant_msg: str,
|
||||||
|
) -> None:
|
||||||
|
"""Append the new turn pair and persist to Redis with TTL refresh."""
|
||||||
|
if not self._redis:
|
||||||
|
return
|
||||||
|
history.append({"role": "user", "content": user_msg})
|
||||||
|
history.append({"role": "assistant", "content": assistant_msg})
|
||||||
|
# Cap at _MAX_TURN_PAIRS (keep most recent)
|
||||||
|
if len(history) > _MAX_TURN_PAIRS * 2:
|
||||||
|
history = history[-_MAX_TURN_PAIRS * 2:]
|
||||||
|
try:
|
||||||
|
await self._redis.set(
|
||||||
|
_redis_key(conversation_id),
|
||||||
|
json.dumps(history),
|
||||||
|
ex=_HISTORY_TTL_SECONDS,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("chat_history_save_error cid=%s", conversation_id, exc_info=True)
|
||||||
|
|
||||||
|
async def _inject_personality(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
creator_name: str,
|
||||||
|
weight: float,
|
||||||
|
) -> str:
|
||||||
|
"""Query creator personality_profile and append a voice block to the system prompt.
|
||||||
|
|
||||||
|
Falls back to the unmodified prompt on DB error, missing creator, or null profile.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Creator).where(Creator.name == creator_name)
|
||||||
|
)
|
||||||
|
creator_row = result.scalars().first()
|
||||||
|
except Exception:
|
||||||
|
logger.warning("chat_personality_db_error creator=%r", creator_name, exc_info=True)
|
||||||
|
return system_prompt
|
||||||
|
|
||||||
|
if creator_row is None or creator_row.personality_profile is None:
|
||||||
|
logger.debug("chat_personality_skip creator=%r reason=%s",
|
||||||
|
creator_name,
|
||||||
|
"not_found" if creator_row is None else "null_profile")
|
||||||
|
return system_prompt
|
||||||
|
|
||||||
|
profile = creator_row.personality_profile
|
||||||
|
voice_block = _build_personality_block(creator_name, profile, weight)
|
||||||
|
return system_prompt + "\n\n" + voice_block
|
||||||
|
|
||||||
|
async def _log_usage(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: Any | None,
|
||||||
|
client_ip: str | None,
|
||||||
|
creator_slug: str | None,
|
||||||
|
query: str,
|
||||||
|
usage: dict[str, int],
|
||||||
|
cascade_tier: str,
|
||||||
|
model: str,
|
||||||
|
latency_ms: float,
|
||||||
|
) -> None:
|
||||||
|
"""Insert a ChatUsageLog row. Non-blocking — errors logged, not raised."""
|
||||||
|
try:
|
||||||
|
from models import ChatUsageLog
|
||||||
|
|
||||||
|
log_entry = ChatUsageLog(
|
||||||
|
user_id=user_id,
|
||||||
|
client_ip=client_ip,
|
||||||
|
creator_slug=creator_slug,
|
||||||
|
query=query[:2000], # truncate very long queries
|
||||||
|
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage.get("completion_tokens", 0),
|
||||||
|
total_tokens=usage.get("total_tokens", 0),
|
||||||
|
cascade_tier=cascade_tier,
|
||||||
|
model=model,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
db.add(log_entry)
|
||||||
|
await db.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.error(
|
||||||
|
"chat_usage_log_insert_error user=%s ip=%s",
|
||||||
|
user_id, client_ip, exc_info=True,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await db.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def stream_response(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
creator: str | None = None,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
personality_weight: float = 0.0,
|
||||||
|
user_id: Any | None = None,
|
||||||
|
client_ip: str | None = None,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""Yield SSE-formatted events for a chat query.
|
||||||
|
|
||||||
|
Protocol:
|
||||||
|
1. ``event: sources\ndata: <json array of citation metadata>\n\n``
|
||||||
|
2. ``event: token\ndata: <text chunk>\n\n`` (repeated)
|
||||||
|
3. ``event: done\ndata: <json with cascade_tier, conversation_id>\n\n``
|
||||||
|
On error: ``event: error\ndata: <json with message>\n\n``
|
||||||
|
"""
|
||||||
|
start = time.monotonic()
|
||||||
|
|
||||||
|
# Assign conversation_id if not provided (single-turn becomes trackable)
|
||||||
|
if conversation_id is None:
|
||||||
|
conversation_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# ── 0. Load conversation history ────────────────────────────────
|
||||||
|
history = await self._load_history(conversation_id)
|
||||||
|
|
||||||
|
# ── 1. Retrieve context via search ──────────────────────────────
|
||||||
|
try:
|
||||||
|
search_result = await self._search.search(
|
||||||
|
query=query,
|
||||||
|
scope="all",
|
||||||
|
limit=_MAX_CONTEXT_SOURCES,
|
||||||
|
db=db,
|
||||||
|
creator=creator,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("chat_search_error query=%r creator=%r", query, creator)
|
||||||
|
yield _sse("error", {"message": "Search failed"})
|
||||||
|
return
|
||||||
|
|
||||||
|
items: list[dict[str, Any]] = search_result.get("items", [])
|
||||||
|
cascade_tier: str = search_result.get("cascade_tier", "")
|
||||||
|
|
||||||
|
# ── 2. Build citation metadata and context block ────────────────
|
||||||
|
sources = _build_sources(items)
|
||||||
|
context_block = _build_context_block(items)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"chat_search query=%r creator=%r cascade_tier=%s source_count=%d cid=%s",
|
||||||
|
query, creator, cascade_tier, len(sources), conversation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit sources event first
|
||||||
|
yield _sse("sources", sources)
|
||||||
|
|
||||||
|
# ── 3. Stream LLM completion ────────────────────────────────────
|
||||||
|
system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block)
|
||||||
|
|
||||||
|
# Inject creator personality voice when weight > 0
|
||||||
|
if personality_weight > 0 and creator:
|
||||||
|
system_prompt = await self._inject_personality(
|
||||||
|
system_prompt, db, creator, personality_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scale temperature with personality weight: 0.3 (encyclopedic) → 0.5 (full personality)
|
||||||
|
temperature = 0.3 + (personality_weight * 0.2)
|
||||||
|
|
||||||
|
messages: list[dict[str, str]] = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
]
|
||||||
|
# Inject conversation history between system prompt and current query
|
||||||
|
messages.extend(history)
|
||||||
|
messages.append({"role": "user", "content": query})
|
||||||
|
|
||||||
|
accumulated_response = ""
|
||||||
|
usage_data: dict[str, int] | None = None
|
||||||
|
fallback_used = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = await self._openai.chat.completions.create(
|
||||||
|
model=self.settings.llm_model,
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
# The final chunk with stream_options carries usage in chunk.usage
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||||
|
usage_data = {
|
||||||
|
"prompt_tokens": chunk.usage.prompt_tokens or 0,
|
||||||
|
"completion_tokens": chunk.usage.completion_tokens or 0,
|
||||||
|
"total_tokens": chunk.usage.total_tokens or 0,
|
||||||
|
}
|
||||||
|
choice = chunk.choices[0] if chunk.choices else None
|
||||||
|
if choice and choice.delta and choice.delta.content:
|
||||||
|
text = choice.delta.content
|
||||||
|
accumulated_response += text
|
||||||
|
yield _sse("token", text)
|
||||||
|
|
||||||
|
except (openai.APIConnectionError, openai.APITimeoutError, openai.InternalServerError) as exc:
|
||||||
|
logger.warning(
|
||||||
|
"chat_llm_fallback primary failed (%s: %s), retrying with fallback at %s",
|
||||||
|
type(exc).__name__, exc, self.settings.llm_fallback_url,
|
||||||
|
)
|
||||||
|
fallback_used = True
|
||||||
|
accumulated_response = ""
|
||||||
|
usage_data = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = await self._fallback_openai.chat.completions.create(
|
||||||
|
model=self.settings.llm_fallback_model,
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||||
|
usage_data = {
|
||||||
|
"prompt_tokens": chunk.usage.prompt_tokens or 0,
|
||||||
|
"completion_tokens": chunk.usage.completion_tokens or 0,
|
||||||
|
"total_tokens": chunk.usage.total_tokens or 0,
|
||||||
|
}
|
||||||
|
choice = chunk.choices[0] if chunk.choices else None
|
||||||
|
if choice and choice.delta and choice.delta.content:
|
||||||
|
text = choice.delta.content
|
||||||
|
accumulated_response += text
|
||||||
|
yield _sse("token", text)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
tb = traceback.format_exc()
|
||||||
|
logger.error("chat_llm_error fallback also failed query=%r cid=%s\n%s", query, conversation_id, tb)
|
||||||
|
yield _sse("error", {"message": "LLM generation failed"})
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
tb = traceback.format_exc()
|
||||||
|
logger.error("chat_llm_error query=%r cid=%s\n%s", query, conversation_id, tb)
|
||||||
|
yield _sse("error", {"message": "LLM generation failed"})
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 4. Save conversation history ────────────────────────────────
|
||||||
|
await self._save_history(conversation_id, history, query, accumulated_response)
|
||||||
|
|
||||||
|
# ── 5. Log token usage ──────────────────────────────────────────
|
||||||
|
latency_ms = (time.monotonic() - start) * 1000
|
||||||
|
|
||||||
|
# Fallback: estimate tokens from character counts if stream_options not available
|
||||||
|
if usage_data is None:
|
||||||
|
prompt_chars = sum(len(m.get("content", "")) for m in messages)
|
||||||
|
est_prompt = prompt_chars // 4
|
||||||
|
est_completion = len(accumulated_response) // 4
|
||||||
|
usage_data = {
|
||||||
|
"prompt_tokens": est_prompt,
|
||||||
|
"completion_tokens": est_completion,
|
||||||
|
"total_tokens": est_prompt + est_completion,
|
||||||
|
}
|
||||||
|
logger.warning("chat_usage_estimated cid=%s (stream_options usage not available)", conversation_id)
|
||||||
|
|
||||||
|
await self._log_usage(
|
||||||
|
db=db,
|
||||||
|
user_id=user_id,
|
||||||
|
client_ip=client_ip,
|
||||||
|
creator_slug=creator,
|
||||||
|
query=query,
|
||||||
|
usage=usage_data,
|
||||||
|
cascade_tier=cascade_tier,
|
||||||
|
model=self.settings.llm_fallback_model if fallback_used else self.settings.llm_model,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 6. Done event ───────────────────────────────────────────────
|
||||||
|
logger.info(
|
||||||
|
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s tokens=%d",
|
||||||
|
query, creator, cascade_tier, len(sources), latency_ms, conversation_id,
|
||||||
|
usage_data.get("total_tokens", 0),
|
||||||
|
)
|
||||||
|
yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id, "fallback_used": fallback_used})
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _sse(event: str, data: Any) -> str:
|
||||||
|
"""Format a single SSE event string."""
|
||||||
|
payload = json.dumps(data) if not isinstance(data, str) else data
|
||||||
|
return f"event: {event}\ndata: {payload}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_sources(items: list[dict[str, Any]]) -> list[dict[str, str]]:
|
||||||
|
"""Build a numbered citation metadata list from search result items."""
|
||||||
|
sources: list[dict[str, str]] = []
|
||||||
|
for idx, item in enumerate(items, start=1):
|
||||||
|
sources.append({
|
||||||
|
"number": idx,
|
||||||
|
"title": item.get("title", ""),
|
||||||
|
"slug": item.get("technique_page_slug", "") or item.get("slug", ""),
|
||||||
|
"creator_name": item.get("creator_name", ""),
|
||||||
|
"topic_category": item.get("topic_category", ""),
|
||||||
|
"summary": (item.get("summary", "") or "")[:200],
|
||||||
|
"section_anchor": item.get("section_anchor", ""),
|
||||||
|
"section_heading": item.get("section_heading", ""),
|
||||||
|
"source_video_id": item.get("source_video_id", ""),
|
||||||
|
"start_time": item.get("start_time"),
|
||||||
|
"end_time": item.get("end_time"),
|
||||||
|
"video_filename": item.get("video_filename", ""),
|
||||||
|
})
|
||||||
|
return sources
|
||||||
|
|
||||||
|
|
||||||
|
def _build_context_block(items: list[dict[str, Any]]) -> str:
|
||||||
|
"""Build a numbered context block string for the LLM system prompt."""
|
||||||
|
if not items:
|
||||||
|
return "(No sources available)"
|
||||||
|
|
||||||
|
lines: list[str] = []
|
||||||
|
for idx, item in enumerate(items, start=1):
|
||||||
|
title = item.get("title", "Untitled")
|
||||||
|
creator = item.get("creator_name", "")
|
||||||
|
summary = item.get("summary", "")
|
||||||
|
section = item.get("section_heading", "")
|
||||||
|
|
||||||
|
parts = [f"[{idx}] {title}"]
|
||||||
|
if creator:
|
||||||
|
parts.append(f"by {creator}")
|
||||||
|
if section:
|
||||||
|
parts.append(f"— {section}")
|
||||||
|
header = " ".join(parts)
|
||||||
|
|
||||||
|
lines.append(header)
|
||||||
|
if summary:
|
||||||
|
lines.append(f" {summary}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_personality_block(creator_name: str, profile: dict[str, Any], weight: float) -> str:
|
||||||
|
"""Build a personality voice injection block from a creator's personality_profile JSONB.
|
||||||
|
|
||||||
|
The ``weight`` (0.0–1.0) controls progressive inclusion of personality
|
||||||
|
fields via 5 tiers of continuous interpolation:
|
||||||
|
|
||||||
|
- < 0.2: no personality block (empty string)
|
||||||
|
- 0.2–0.39: basic tone — teaching_style, formality, energy + subtle hint
|
||||||
|
- 0.4–0.59: + descriptors, explanation_approach + adopt-voice instruction
|
||||||
|
- 0.6–0.79: + signature_phrases (count scaled by weight) + creator-voice
|
||||||
|
- 0.8–0.89: + distinctive_terms, sound_descriptions, sound_words,
|
||||||
|
self_references, pacing + fully-embody instruction
|
||||||
|
- >= 0.9: + full summary paragraph
|
||||||
|
"""
|
||||||
|
if weight < 0.2:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
vocab = profile.get("vocabulary", {})
|
||||||
|
tone = profile.get("tone", {})
|
||||||
|
style = profile.get("style_markers", {})
|
||||||
|
|
||||||
|
teaching_style = tone.get("teaching_style", "")
|
||||||
|
energy = tone.get("energy", "moderate")
|
||||||
|
formality = tone.get("formality", "conversational")
|
||||||
|
descriptors = tone.get("descriptors", [])
|
||||||
|
phrases = vocab.get("signature_phrases", [])
|
||||||
|
|
||||||
|
parts: list[str] = []
|
||||||
|
|
||||||
|
# --- Tier 1 (0.2–0.39): basic tone ---
|
||||||
|
if weight < 0.4:
|
||||||
|
parts.append(
|
||||||
|
f"When relevant, subtly reference {creator_name}'s communication style."
|
||||||
|
)
|
||||||
|
elif weight < 0.6:
|
||||||
|
parts.append(f"Adopt {creator_name}'s tone and communication style.")
|
||||||
|
elif weight < 0.8:
|
||||||
|
parts.append(
|
||||||
|
f"Respond as {creator_name} would, using their voice and manner."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parts.append(
|
||||||
|
f"Fully embody {creator_name} — use their exact phrases, energy, and teaching approach."
|
||||||
|
)
|
||||||
|
|
||||||
|
if teaching_style:
|
||||||
|
parts.append(f"Teaching style: {teaching_style}.")
|
||||||
|
parts.append(f"Match their {formality} {energy} tone.")
|
||||||
|
|
||||||
|
# --- Tier 2 (0.4+): descriptors, explanation_approach, uses_analogies, audience_engagement ---
|
||||||
|
if weight >= 0.4:
|
||||||
|
if descriptors:
|
||||||
|
parts.append(f"Tone: {', '.join(descriptors[:5])}.")
|
||||||
|
explanation = style.get("explanation_approach", "")
|
||||||
|
if explanation:
|
||||||
|
parts.append(f"Explanation approach: {explanation}.")
|
||||||
|
if style.get("uses_analogies"):
|
||||||
|
parts.append("Use analogies when helpful.")
|
||||||
|
if style.get("audience_engagement"):
|
||||||
|
parts.append(f"Audience engagement: {style['audience_engagement']}.")
|
||||||
|
|
||||||
|
# --- Tier 3 (0.6+): signature phrases (count scaled by weight) ---
|
||||||
|
if weight >= 0.6 and phrases:
|
||||||
|
count = max(2, round(weight * len(phrases)))
|
||||||
|
parts.append(f"Use their signature phrases: {', '.join(phrases[:count])}.")
|
||||||
|
|
||||||
|
# --- Tier 4 (0.8+): distinctive_terms, sound_descriptions, sound_words, self_references, pacing ---
|
||||||
|
if weight >= 0.8:
|
||||||
|
distinctive = vocab.get("distinctive_terms", [])
|
||||||
|
if distinctive:
|
||||||
|
parts.append(f"Distinctive terms: {', '.join(distinctive)}.")
|
||||||
|
sound_desc = vocab.get("sound_descriptions", [])
|
||||||
|
if sound_desc:
|
||||||
|
parts.append(f"Sound descriptions: {', '.join(sound_desc)}.")
|
||||||
|
sound_words = style.get("sound_words", [])
|
||||||
|
if sound_words:
|
||||||
|
parts.append(f"Sound words: {', '.join(sound_words)}.")
|
||||||
|
self_refs = style.get("self_references", "")
|
||||||
|
if self_refs:
|
||||||
|
parts.append(f"Self-references: {self_refs}.")
|
||||||
|
pacing = style.get("pacing", "")
|
||||||
|
if pacing:
|
||||||
|
parts.append(f"Pacing: {pacing}.")
|
||||||
|
|
||||||
|
# --- Tier 5 (0.9+): full summary paragraph ---
|
||||||
|
if weight >= 0.9:
|
||||||
|
summary = profile.get("summary", "")
|
||||||
|
if summary:
|
||||||
|
parts.append(summary)
|
||||||
|
|
||||||
|
return " ".join(parts)
|
||||||
115
backend/config.py
Normal file
115
backend/config.py
Normal file
|
|
@ -0,0 +1,115 @@
|
||||||
|
"""Application configuration loaded from environment variables."""
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Chrysopedia API settings.
|
||||||
|
|
||||||
|
Values are loaded from environment variables (or .env file via
|
||||||
|
pydantic-settings' dotenv support).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Database
|
||||||
|
database_url: str = "postgresql+asyncpg://chrysopedia:changeme@localhost:5433/chrysopedia"
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
redis_url: str = "redis://localhost:6379/0"
|
||||||
|
|
||||||
|
# Application
|
||||||
|
app_env: str = "development"
|
||||||
|
app_log_level: str = "info"
|
||||||
|
app_secret_key: str = "changeme-generate-a-real-secret"
|
||||||
|
|
||||||
|
# CORS
|
||||||
|
cors_origins: list[str] = ["*"]
|
||||||
|
|
||||||
|
# LLM endpoint (OpenAI-compatible)
|
||||||
|
llm_api_url: str = "http://localhost:11434/v1"
|
||||||
|
llm_api_key: str = "sk-placeholder"
|
||||||
|
llm_model: str = "fyn-llm-agent-chat"
|
||||||
|
llm_fallback_url: str = "http://localhost:11434/v1"
|
||||||
|
llm_fallback_model: str = "qwen2.5:7b"
|
||||||
|
|
||||||
|
# Per-stage model overrides (optional — falls back to llm_model / "chat")
|
||||||
|
llm_stage2_model: str | None = "fyn-llm-agent-chat" # segmentation — mechanical, fast chat
|
||||||
|
llm_stage2_modality: str = "chat"
|
||||||
|
llm_stage3_model: str | None = "fyn-llm-agent-think" # extraction — reasoning
|
||||||
|
llm_stage3_modality: str = "thinking"
|
||||||
|
llm_stage4_model: str | None = "fyn-llm-agent-chat" # classification — mechanical, fast chat
|
||||||
|
llm_stage4_modality: str = "chat"
|
||||||
|
llm_stage5_model: str | None = "fyn-llm-agent-think" # synthesis — reasoning
|
||||||
|
llm_stage5_modality: str = "thinking"
|
||||||
|
|
||||||
|
# Token limits — static across all stages
|
||||||
|
llm_max_tokens_hard_limit: int = 96000 # Hard ceiling for dynamic estimator
|
||||||
|
llm_max_tokens: int = 96000 # Fallback when no estimate is provided (must not exceed hard_limit)
|
||||||
|
llm_temperature: float = 0.0 # Deterministic output for structured JSON extraction
|
||||||
|
|
||||||
|
# Stage 5 synthesis chunking — max moments per LLM call before splitting
|
||||||
|
synthesis_chunk_size: int = 30
|
||||||
|
|
||||||
|
# Embedding endpoint
|
||||||
|
embedding_api_url: str = "http://localhost:11434/v1"
|
||||||
|
embedding_model: str = "nomic-embed-text"
|
||||||
|
embedding_dimensions: int = 768
|
||||||
|
|
||||||
|
# Qdrant
|
||||||
|
qdrant_url: str = "http://localhost:6333"
|
||||||
|
qdrant_collection: str = "chrysopedia"
|
||||||
|
|
||||||
|
# LightRAG
|
||||||
|
lightrag_url: str = "http://chrysopedia-lightrag:9621"
|
||||||
|
lightrag_search_timeout: float = 2.0
|
||||||
|
lightrag_min_query_length: int = 3
|
||||||
|
|
||||||
|
# Prompt templates
|
||||||
|
prompts_path: str = "./prompts"
|
||||||
|
|
||||||
|
# Debug mode — when True, pipeline captures full LLM prompts and responses
|
||||||
|
debug_mode: bool = False
|
||||||
|
|
||||||
|
# MinIO (file storage for post attachments)
|
||||||
|
minio_url: str = "chrysopedia-minio:9000"
|
||||||
|
minio_access_key: str = "chrysopedia"
|
||||||
|
minio_secret_key: str = "changeme-minio"
|
||||||
|
minio_bucket: str = "chrysopedia"
|
||||||
|
minio_secure: bool = False
|
||||||
|
|
||||||
|
# File storage
|
||||||
|
transcript_storage_path: str = "/data/transcripts"
|
||||||
|
video_metadata_path: str = "/data/video_meta"
|
||||||
|
video_source_path: str = "/videos"
|
||||||
|
|
||||||
|
# SMTP (email digests)
|
||||||
|
smtp_host: str = ""
|
||||||
|
smtp_port: int = 587
|
||||||
|
smtp_user: str = ""
|
||||||
|
smtp_password: str = ""
|
||||||
|
smtp_from_address: str = ""
|
||||||
|
smtp_tls: bool = True
|
||||||
|
|
||||||
|
# Public base URL for links in emails and external references
|
||||||
|
base_url: str = "http://localhost:8096"
|
||||||
|
|
||||||
|
# Rate limiting (per hour)
|
||||||
|
rate_limit_user_per_hour: int = 30
|
||||||
|
rate_limit_ip_per_hour: int = 10
|
||||||
|
rate_limit_creator_per_hour: int = 60
|
||||||
|
|
||||||
|
# Git commit SHA (set at Docker build time or via env var)
|
||||||
|
git_commit_sha: str = "unknown"
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_file": ".env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""Return cached application settings (singleton)."""
|
||||||
|
return Settings()
|
||||||
26
backend/database.py
Normal file
26
backend/database.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""Database engine, session factory, and declarative base for Chrysopedia."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
DATABASE_URL = os.getenv(
|
||||||
|
"DATABASE_URL",
|
||||||
|
"postgresql+asyncpg://chrysopedia:changeme@localhost:5433/chrysopedia",
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = create_async_engine(DATABASE_URL, echo=False, pool_pre_ping=True)
|
||||||
|
|
||||||
|
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
"""Declarative base for all ORM models."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session() -> AsyncSession: # type: ignore[misc]
|
||||||
|
"""FastAPI dependency that yields an async DB session."""
|
||||||
|
async with async_session() as session:
|
||||||
|
yield session
|
||||||
117
backend/main.py
Normal file
117
backend/main.py
Normal file
|
|
@ -0,0 +1,117 @@
|
||||||
|
"""Chrysopedia API — Knowledge extraction and retrieval system.
|
||||||
|
|
||||||
|
Entry point for the FastAPI application. Configures middleware,
|
||||||
|
structured logging, and mounts versioned API routers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from config import get_settings
|
||||||
|
from routers import admin, auth, chat, consent, creator_chapters, creator_dashboard, creator_highlights, creators, files, follows, health, highlights, ingest, notifications, pipeline, posts, reports, search, shorts, shorts_public, stats, techniques, topics, videos
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_logging() -> None:
|
||||||
|
"""Configure structured logging to stdout."""
|
||||||
|
settings = get_settings()
|
||||||
|
level = getattr(logging, settings.app_log_level.upper(), logging.INFO)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setFormatter(
|
||||||
|
logging.Formatter(
|
||||||
|
fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
||||||
|
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
root = logging.getLogger()
|
||||||
|
root.setLevel(level)
|
||||||
|
# Avoid duplicate handlers on reload
|
||||||
|
root.handlers.clear()
|
||||||
|
root.addHandler(handler)
|
||||||
|
|
||||||
|
# Quiet noisy libraries
|
||||||
|
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI): # noqa: ARG001
|
||||||
|
"""Application lifespan: setup on startup, teardown on shutdown."""
|
||||||
|
_setup_logging()
|
||||||
|
logger = logging.getLogger("chrysopedia")
|
||||||
|
settings = get_settings()
|
||||||
|
logger.info(
|
||||||
|
"Chrysopedia API starting (env=%s, log_level=%s)",
|
||||||
|
settings.app_env,
|
||||||
|
settings.app_log_level,
|
||||||
|
)
|
||||||
|
# Ensure MinIO bucket exists (best-effort — API still starts if MinIO is down)
|
||||||
|
try:
|
||||||
|
from minio_client import ensure_bucket
|
||||||
|
ensure_bucket()
|
||||||
|
logger.info("MinIO bucket ready")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("MinIO bucket init failed (will retry on first upload): %s", exc)
|
||||||
|
yield
|
||||||
|
logger.info("Chrysopedia API shutting down")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Chrysopedia API",
|
||||||
|
description="Knowledge extraction and retrieval for music production content",
|
||||||
|
version="0.1.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Middleware ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.cors_origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Routers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Root-level health (no prefix)
|
||||||
|
app.include_router(health.router)
|
||||||
|
|
||||||
|
# Versioned API
|
||||||
|
app.include_router(admin.router, prefix="/api/v1")
|
||||||
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
|
app.include_router(consent.router, prefix="/api/v1")
|
||||||
|
app.include_router(creator_dashboard.router, prefix="/api/v1")
|
||||||
|
app.include_router(creator_chapters.router, prefix="/api/v1")
|
||||||
|
app.include_router(creator_highlights.router, prefix="/api/v1")
|
||||||
|
app.include_router(creators.router, prefix="/api/v1")
|
||||||
|
app.include_router(creators.admin_router, prefix="/api/v1")
|
||||||
|
app.include_router(follows.router, prefix="/api/v1")
|
||||||
|
app.include_router(highlights.router, prefix="/api/v1")
|
||||||
|
app.include_router(ingest.router, prefix="/api/v1")
|
||||||
|
app.include_router(notifications.router, prefix="/api/v1")
|
||||||
|
app.include_router(pipeline.router, prefix="/api/v1")
|
||||||
|
app.include_router(posts.router, prefix="/api/v1")
|
||||||
|
app.include_router(files.router, prefix="/api/v1")
|
||||||
|
app.include_router(reports.router, prefix="/api/v1")
|
||||||
|
app.include_router(search.router, prefix="/api/v1")
|
||||||
|
app.include_router(shorts.router, prefix="/api/v1")
|
||||||
|
app.include_router(shorts_public.router, prefix="/api/v1")
|
||||||
|
app.include_router(stats.router, prefix="/api/v1")
|
||||||
|
app.include_router(techniques.router, prefix="/api/v1")
|
||||||
|
app.include_router(topics.router, prefix="/api/v1")
|
||||||
|
app.include_router(videos.router, prefix="/api/v1")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/v1/health")
|
||||||
|
async def api_health():
|
||||||
|
"""Lightweight version-prefixed health endpoint (no DB check)."""
|
||||||
|
return {"status": "ok", "version": "0.1.0"}
|
||||||
116
backend/minio_client.py
Normal file
116
backend/minio_client.py
Normal file
|
|
@ -0,0 +1,116 @@
|
||||||
|
"""MinIO client singleton with lazy initialization.
|
||||||
|
|
||||||
|
Provides file upload, presigned download URL generation, and automatic
|
||||||
|
bucket creation for the Chrysopedia post attachment storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from minio import Minio
|
||||||
|
from minio.error import S3Error
|
||||||
|
|
||||||
|
from config import get_settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_client: Minio | None = None
|
||||||
|
_bucket_ensured: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_minio_client() -> Minio:
|
||||||
|
"""Return the singleton MinIO client, creating it on first call."""
|
||||||
|
global _client
|
||||||
|
if _client is None:
|
||||||
|
settings = get_settings()
|
||||||
|
_client = Minio(
|
||||||
|
settings.minio_url,
|
||||||
|
access_key=settings.minio_access_key,
|
||||||
|
secret_key=settings.minio_secret_key,
|
||||||
|
secure=settings.minio_secure,
|
||||||
|
)
|
||||||
|
logger.info("MinIO client initialized (endpoint=%s)", settings.minio_url)
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_bucket() -> None:
|
||||||
|
"""Create the configured bucket if it doesn't already exist."""
|
||||||
|
global _bucket_ensured
|
||||||
|
if _bucket_ensured:
|
||||||
|
return
|
||||||
|
settings = get_settings()
|
||||||
|
client = get_minio_client()
|
||||||
|
bucket = settings.minio_bucket
|
||||||
|
try:
|
||||||
|
if not client.bucket_exists(bucket):
|
||||||
|
client.make_bucket(bucket)
|
||||||
|
logger.info("Created MinIO bucket: %s", bucket)
|
||||||
|
else:
|
||||||
|
logger.debug("MinIO bucket already exists: %s", bucket)
|
||||||
|
_bucket_ensured = True
|
||||||
|
except S3Error as exc:
|
||||||
|
logger.error("MinIO bucket check/create failed: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def upload_file(
|
||||||
|
object_key: str,
|
||||||
|
data: bytes | io.BytesIO,
|
||||||
|
length: int,
|
||||||
|
content_type: str = "application/octet-stream",
|
||||||
|
) -> None:
|
||||||
|
"""Upload a file to MinIO.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_key: The storage path within the bucket.
|
||||||
|
data: File content as bytes or BytesIO stream.
|
||||||
|
length: Size in bytes.
|
||||||
|
content_type: MIME type for the object.
|
||||||
|
"""
|
||||||
|
ensure_bucket()
|
||||||
|
settings = get_settings()
|
||||||
|
client = get_minio_client()
|
||||||
|
stream = io.BytesIO(data) if isinstance(data, bytes) else data
|
||||||
|
client.put_object(
|
||||||
|
settings.minio_bucket,
|
||||||
|
object_key,
|
||||||
|
stream,
|
||||||
|
length,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
logger.info("Uploaded %s (%d bytes, %s)", object_key, length, content_type)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_download_url(object_key: str, expires: int = 3600) -> str:
|
||||||
|
"""Generate a presigned GET URL for downloading a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_key: The storage path within the bucket.
|
||||||
|
expires: URL validity in seconds (default 1 hour).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Presigned URL string.
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
client = get_minio_client()
|
||||||
|
url: str = client.presigned_get_object(
|
||||||
|
settings.minio_bucket,
|
||||||
|
object_key,
|
||||||
|
expires=timedelta(seconds=expires),
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def delete_file(object_key: str) -> None:
|
||||||
|
"""Delete a file from MinIO.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_key: The storage path within the bucket.
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
client = get_minio_client()
|
||||||
|
client.remove_object(settings.minio_bucket, object_key)
|
||||||
|
logger.info("Deleted %s from MinIO", object_key)
|
||||||
932
backend/models.py
Normal file
932
backend/models.py
Normal file
|
|
@ -0,0 +1,932 @@
|
||||||
|
"""SQLAlchemy ORM models for the Chrysopedia knowledge base.
|
||||||
|
|
||||||
|
Seven entities matching chrysopedia-spec.md §6.1:
|
||||||
|
Creator, SourceVideo, TranscriptSegment, KeyMoment,
|
||||||
|
TechniquePage, RelatedTechniqueLink, Tag
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
|
Enum,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
func,
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
from sqlalchemy.orm import relationship as sa_relationship
|
||||||
|
|
||||||
|
from database import Base
|
||||||
|
|
||||||
|
|
||||||
|
# ── Enums ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ContentType(str, enum.Enum):
|
||||||
|
"""Source video content type."""
|
||||||
|
tutorial = "tutorial"
|
||||||
|
livestream = "livestream"
|
||||||
|
breakdown = "breakdown"
|
||||||
|
short_form = "short_form"
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessingStatus(str, enum.Enum):
|
||||||
|
"""Pipeline processing status for a source video.
|
||||||
|
|
||||||
|
User-facing lifecycle: not_started → queued → processing → complete
|
||||||
|
Error branch: processing → error (retrigger resets to queued)
|
||||||
|
"""
|
||||||
|
not_started = "not_started"
|
||||||
|
queued = "queued"
|
||||||
|
processing = "processing"
|
||||||
|
error = "error"
|
||||||
|
complete = "complete"
|
||||||
|
|
||||||
|
|
||||||
|
class KeyMomentContentType(str, enum.Enum):
|
||||||
|
"""Content classification for a key moment."""
|
||||||
|
technique = "technique"
|
||||||
|
settings = "settings"
|
||||||
|
reasoning = "reasoning"
|
||||||
|
workflow = "workflow"
|
||||||
|
|
||||||
|
|
||||||
|
class SourceQuality(str, enum.Enum):
|
||||||
|
"""Derived source quality for technique pages."""
|
||||||
|
structured = "structured"
|
||||||
|
mixed = "mixed"
|
||||||
|
unstructured = "unstructured"
|
||||||
|
|
||||||
|
|
||||||
|
class RelationshipType(str, enum.Enum):
|
||||||
|
"""Types of links between technique pages."""
|
||||||
|
same_technique_other_creator = "same_technique_other_creator"
|
||||||
|
same_creator_adjacent = "same_creator_adjacent"
|
||||||
|
general_cross_reference = "general_cross_reference"
|
||||||
|
|
||||||
|
|
||||||
|
class UserRole(str, enum.Enum):
|
||||||
|
"""Roles for authenticated users."""
|
||||||
|
creator = "creator"
|
||||||
|
admin = "admin"
|
||||||
|
|
||||||
|
|
||||||
|
class HighlightStatus(str, enum.Enum):
|
||||||
|
"""Triage status for highlight candidates."""
|
||||||
|
candidate = "candidate"
|
||||||
|
approved = "approved"
|
||||||
|
rejected = "rejected"
|
||||||
|
|
||||||
|
|
||||||
|
class ChapterStatus(str, enum.Enum):
|
||||||
|
"""Review status for auto-detected chapters."""
|
||||||
|
draft = "draft"
|
||||||
|
approved = "approved"
|
||||||
|
hidden = "hidden"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _uuid_pk() -> Mapped[uuid.UUID]:
|
||||||
|
return mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
primary_key=True,
|
||||||
|
default=uuid.uuid4,
|
||||||
|
server_default=func.gen_random_uuid(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
"""Return current UTC time as a naive datetime (no tzinfo).
|
||||||
|
|
||||||
|
PostgreSQL TIMESTAMP WITHOUT TIME ZONE columns require naive datetimes.
|
||||||
|
asyncpg rejects timezone-aware datetimes for such columns.
|
||||||
|
"""
|
||||||
|
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Models ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class Creator(Base):
|
||||||
|
__tablename__ = "creators"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
slug: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||||
|
genres: Mapped[list[str] | None] = mapped_column(ARRAY(String), nullable=True)
|
||||||
|
folder_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
avatar_url: Mapped[str | None] = mapped_column(String(1000), nullable=True)
|
||||||
|
avatar_source: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||||
|
avatar_fetched_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||||
|
bio: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
social_links: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
personality_profile: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
shorts_template: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
featured: Mapped[bool] = mapped_column(default=False, server_default="false")
|
||||||
|
view_count: Mapped[int] = mapped_column(Integer, default=0, server_default="0")
|
||||||
|
hidden: Mapped[bool] = mapped_column(default=False, server_default="false")
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now(), onupdate=_now
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
videos: Mapped[list[SourceVideo]] = sa_relationship(back_populates="creator")
|
||||||
|
technique_pages: Mapped[list[TechniquePage]] = sa_relationship(back_populates="creator")
|
||||||
|
posts: Mapped[list[Post]] = sa_relationship(back_populates="creator")
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
"""Authenticated user account for the creator dashboard."""
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||||
|
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
display_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
role: Mapped[UserRole] = mapped_column(
|
||||||
|
Enum(UserRole, name="user_role", create_constraint=True),
|
||||||
|
default=UserRole.creator,
|
||||||
|
server_default="creator",
|
||||||
|
)
|
||||||
|
creator_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
ForeignKey("creators.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
is_active: Mapped[bool] = mapped_column(
|
||||||
|
Boolean, default=True, server_default="true"
|
||||||
|
)
|
||||||
|
onboarding_completed: Mapped[bool] = mapped_column(
|
||||||
|
Boolean, default=False, server_default="false"
|
||||||
|
)
|
||||||
|
notification_preferences: Mapped[dict] = mapped_column(
|
||||||
|
JSONB, nullable=False,
|
||||||
|
server_default='{"email_digests": true, "digest_frequency": "daily"}',
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now(), onupdate=_now
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
creator: Mapped[Creator | None] = sa_relationship()
|
||||||
|
|
||||||
|
|
||||||
|
class EmailDigestLog(Base):
|
||||||
|
"""Record of a digest email sent to a user."""
|
||||||
|
__tablename__ = "email_digest_log"
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_email_digest_log_user_sent", "user_id", "digest_sent_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("users.id", ondelete="CASCADE"), nullable=False,
|
||||||
|
)
|
||||||
|
digest_sent_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
content_summary: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
user: Mapped[User] = sa_relationship()
|
||||||
|
|
||||||
|
|
||||||
|
class InviteCode(Base):
|
||||||
|
"""Single-use or limited-use invite codes for registration gating."""
|
||||||
|
__tablename__ = "invite_codes"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
code: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
|
||||||
|
uses_remaining: Mapped[int] = mapped_column(Integer, default=1, server_default="1")
|
||||||
|
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
expires_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceVideo(Base):
|
||||||
|
__tablename__ = "source_videos"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
creator_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("creators.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
filename: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
file_path: Mapped[str] = mapped_column(String(1000), nullable=False)
|
||||||
|
duration_seconds: Mapped[int] = mapped_column(Integer, nullable=True)
|
||||||
|
content_type: Mapped[ContentType] = mapped_column(
|
||||||
|
Enum(ContentType, name="content_type", create_constraint=True),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
transcript_path: Mapped[str | None] = mapped_column(String(1000), nullable=True)
|
||||||
|
content_hash: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||||
|
processing_status: Mapped[ProcessingStatus] = mapped_column(
|
||||||
|
Enum(ProcessingStatus, name="processing_status", create_constraint=True),
|
||||||
|
default=ProcessingStatus.not_started,
|
||||||
|
server_default="not_started",
|
||||||
|
)
|
||||||
|
classification_data: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now(), onupdate=_now
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
creator: Mapped[Creator] = sa_relationship(back_populates="videos")
|
||||||
|
segments: Mapped[list[TranscriptSegment]] = sa_relationship(back_populates="source_video")
|
||||||
|
key_moments: Mapped[list[KeyMoment]] = sa_relationship(back_populates="source_video")
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptSegment(Base):
|
||||||
|
__tablename__ = "transcript_segments"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
source_video_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("source_videos.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
start_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
end_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
text: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
segment_index: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
topic_label: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
source_video: Mapped[SourceVideo] = sa_relationship(back_populates="segments")
|
||||||
|
|
||||||
|
|
||||||
|
class KeyMoment(Base):
|
||||||
|
__tablename__ = "key_moments"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
source_video_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("source_videos.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
technique_page_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
ForeignKey("technique_pages.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
title: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
summary: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
start_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
end_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
content_type: Mapped[KeyMomentContentType] = mapped_column(
|
||||||
|
Enum(KeyMomentContentType, name="key_moment_content_type", create_constraint=True),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
plugins: Mapped[list[str] | None] = mapped_column(ARRAY(String), nullable=True)
|
||||||
|
raw_transcript: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
chapter_status: Mapped[ChapterStatus] = mapped_column(
|
||||||
|
Enum(ChapterStatus, name="chapter_status", create_constraint=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default="draft",
|
||||||
|
default=ChapterStatus.draft,
|
||||||
|
)
|
||||||
|
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, server_default="0", default=0)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now(), onupdate=_now
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
source_video: Mapped[SourceVideo] = sa_relationship(back_populates="key_moments")
|
||||||
|
technique_page: Mapped[TechniquePage | None] = sa_relationship(
|
||||||
|
back_populates="key_moments", foreign_keys=[technique_page_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TechniquePage(Base):
|
||||||
|
__tablename__ = "technique_pages"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
creator_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("creators.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
title: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
slug: Mapped[str] = mapped_column(String(500), unique=True, nullable=False)
|
||||||
|
topic_category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
topic_tags: Mapped[list[str] | None] = mapped_column(ARRAY(String), nullable=True)
|
||||||
|
summary: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
body_sections: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
body_sections_format: Mapped[str] = mapped_column(
|
||||||
|
String(20), nullable=False, default="v1", server_default="v1"
|
||||||
|
)
|
||||||
|
signal_chains: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
plugins: Mapped[list[str] | None] = mapped_column(ARRAY(String), nullable=True)
|
||||||
|
source_quality: Mapped[SourceQuality | None] = mapped_column(
|
||||||
|
Enum(SourceQuality, name="source_quality", create_constraint=True),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
view_count: Mapped[int] = mapped_column(Integer, default=0, server_default="0")
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now(), onupdate=_now
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
creator: Mapped[Creator] = sa_relationship(back_populates="technique_pages")
|
||||||
|
key_moments: Mapped[list[KeyMoment]] = sa_relationship(
|
||||||
|
back_populates="technique_page", foreign_keys=[KeyMoment.technique_page_id]
|
||||||
|
)
|
||||||
|
versions: Mapped[list[TechniquePageVersion]] = sa_relationship(
|
||||||
|
back_populates="technique_page", order_by="TechniquePageVersion.version_number"
|
||||||
|
)
|
||||||
|
outgoing_links: Mapped[list[RelatedTechniqueLink]] = sa_relationship(
|
||||||
|
foreign_keys="RelatedTechniqueLink.source_page_id", back_populates="source_page"
|
||||||
|
)
|
||||||
|
incoming_links: Mapped[list[RelatedTechniqueLink]] = sa_relationship(
|
||||||
|
foreign_keys="RelatedTechniqueLink.target_page_id", back_populates="target_page"
|
||||||
|
)
|
||||||
|
source_video_links: Mapped[list[TechniquePageVideo]] = sa_relationship(
|
||||||
|
back_populates="technique_page"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RelatedTechniqueLink(Base):
|
||||||
|
__tablename__ = "related_technique_links"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("source_page_id", "target_page_id", "relationship", name="uq_technique_link"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
source_page_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("technique_pages.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
target_page_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("technique_pages.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
relationship: Mapped[RelationshipType] = mapped_column(
|
||||||
|
Enum(RelationshipType, name="relationship_type", create_constraint=True),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
source_page: Mapped[TechniquePage] = sa_relationship(
|
||||||
|
foreign_keys=[source_page_id], back_populates="outgoing_links"
|
||||||
|
)
|
||||||
|
target_page: Mapped[TechniquePage] = sa_relationship(
|
||||||
|
foreign_keys=[target_page_id], back_populates="incoming_links"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TechniquePageVersion(Base):
|
||||||
|
"""Snapshot of a TechniquePage before a pipeline re-synthesis overwrites it."""
|
||||||
|
__tablename__ = "technique_page_versions"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
technique_page_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("technique_pages.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
version_number: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
content_snapshot: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
||||||
|
pipeline_metadata: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
technique_page: Mapped[TechniquePage] = sa_relationship(
|
||||||
|
back_populates="versions"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(Base):
|
||||||
|
__tablename__ = "tags"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
name: Mapped[str] = mapped_column(String(255), unique=True, nullable=False)
|
||||||
|
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
aliases: Mapped[list[str] | None] = mapped_column(ARRAY(String), nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TechniquePageVideo(Base):
|
||||||
|
"""Association linking a technique page to its contributing source videos."""
|
||||||
|
__tablename__ = "technique_page_videos"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("technique_page_id", "source_video_id", name="uq_page_video"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
technique_page_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("technique_pages.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
source_video_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("source_videos.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
added_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
technique_page: Mapped[TechniquePage] = sa_relationship(
|
||||||
|
back_populates="source_video_links"
|
||||||
|
)
|
||||||
|
source_video: Mapped[SourceVideo] = sa_relationship()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Content Report Enums ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ReportType(str, enum.Enum):
|
||||||
|
"""Classification of user-submitted content reports."""
|
||||||
|
inaccurate = "inaccurate"
|
||||||
|
missing_info = "missing_info"
|
||||||
|
wrong_attribution = "wrong_attribution"
|
||||||
|
formatting = "formatting"
|
||||||
|
other = "other"
|
||||||
|
|
||||||
|
|
||||||
|
class ReportStatus(str, enum.Enum):
|
||||||
|
"""Triage status for content reports."""
|
||||||
|
open = "open"
|
||||||
|
acknowledged = "acknowledged"
|
||||||
|
resolved = "resolved"
|
||||||
|
dismissed = "dismissed"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Content Report ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ContentReport(Base):
|
||||||
|
"""User-submitted report about a content issue.
|
||||||
|
|
||||||
|
Generic: content_type + content_id can reference any entity
|
||||||
|
(technique_page, key_moment, creator, or general).
|
||||||
|
"""
|
||||||
|
__tablename__ = "content_reports"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
content_type: Mapped[str] = mapped_column(
|
||||||
|
String(50), nullable=False, doc="Entity type: technique_page, key_moment, creator, general"
|
||||||
|
)
|
||||||
|
content_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
UUID(as_uuid=True), nullable=True, doc="FK to the reported entity (null for general reports)"
|
||||||
|
)
|
||||||
|
content_title: Mapped[str | None] = mapped_column(
|
||||||
|
String(500), nullable=True, doc="Snapshot of entity title at report time"
|
||||||
|
)
|
||||||
|
report_type: Mapped[ReportType] = mapped_column(
|
||||||
|
Enum(ReportType, name="report_type", create_constraint=True),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
status: Mapped[ReportStatus] = mapped_column(
|
||||||
|
Enum(ReportStatus, name="report_status", create_constraint=True),
|
||||||
|
default=ReportStatus.open,
|
||||||
|
server_default="open",
|
||||||
|
)
|
||||||
|
admin_notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
page_url: Mapped[str | None] = mapped_column(
|
||||||
|
String(1000), nullable=True, doc="URL the user was on when reporting"
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
resolved_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pipeline Event ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class SearchLog(Base):
|
||||||
|
"""Logged search query for analytics and popular searches."""
|
||||||
|
__tablename__ = "search_log"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
query: Mapped[str] = mapped_column(String(500), nullable=False, index=True)
|
||||||
|
scope: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
result_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now(), index=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineRunStatus(str, enum.Enum):
|
||||||
|
"""Status of a pipeline run."""
|
||||||
|
running = "running"
|
||||||
|
complete = "complete"
|
||||||
|
error = "error"
|
||||||
|
cancelled = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineRunTrigger(str, enum.Enum):
|
||||||
|
"""What initiated a pipeline run."""
|
||||||
|
manual = "manual"
|
||||||
|
clean_reprocess = "clean_reprocess"
|
||||||
|
auto_ingest = "auto_ingest"
|
||||||
|
bulk = "bulk"
|
||||||
|
stage_rerun = "stage_rerun"
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineRun(Base):
|
||||||
|
"""A single execution of the pipeline for a video.
|
||||||
|
|
||||||
|
Each trigger/retrigger creates a new run. Events are scoped to a run
|
||||||
|
via run_id, giving a clean audit trail per execution.
|
||||||
|
"""
|
||||||
|
__tablename__ = "pipeline_runs"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
video_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("source_videos.id", ondelete="CASCADE"), nullable=False, index=True,
|
||||||
|
)
|
||||||
|
run_number: Mapped[int] = mapped_column(
|
||||||
|
Integer, nullable=False, doc="Auto-increment per video, 1-indexed"
|
||||||
|
)
|
||||||
|
trigger: Mapped[PipelineRunTrigger] = mapped_column(
|
||||||
|
Enum(PipelineRunTrigger, name="pipeline_run_trigger", create_constraint=True),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
status: Mapped[PipelineRunStatus] = mapped_column(
|
||||||
|
Enum(PipelineRunStatus, name="pipeline_run_status", create_constraint=True),
|
||||||
|
default=PipelineRunStatus.running,
|
||||||
|
server_default="running",
|
||||||
|
)
|
||||||
|
started_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
finished_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||||
|
error_stage: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||||
|
total_tokens: Mapped[int] = mapped_column(Integer, default=0, server_default="0")
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
video: Mapped[SourceVideo] = sa_relationship()
|
||||||
|
events: Mapped[list[PipelineEvent]] = sa_relationship(
|
||||||
|
back_populates="run", foreign_keys="PipelineEvent.run_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pipeline Event ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PipelineEvent(Base):
|
||||||
|
"""Structured log entry for pipeline execution.
|
||||||
|
|
||||||
|
Captures per-stage start/complete/error/llm_call events with
|
||||||
|
token usage and optional response payloads for debugging.
|
||||||
|
"""
|
||||||
|
__tablename__ = "pipeline_events"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
video_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True), nullable=False, index=True,
|
||||||
|
)
|
||||||
|
run_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
ForeignKey("pipeline_runs.id", ondelete="SET NULL"), nullable=True, index=True,
|
||||||
|
)
|
||||||
|
stage: Mapped[str] = mapped_column(
|
||||||
|
String(50), nullable=False, doc="stage2_segmentation, stage3_extraction, etc."
|
||||||
|
)
|
||||||
|
event_type: Mapped[str] = mapped_column(
|
||||||
|
String(30), nullable=False, doc="start, complete, error, llm_call"
|
||||||
|
)
|
||||||
|
prompt_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
completion_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
total_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
model: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
payload: Mapped[dict | None] = mapped_column(
|
||||||
|
JSONB, nullable=True, doc="LLM response content, error details, stage metadata"
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Debug mode — full LLM I/O capture columns
|
||||||
|
system_prompt_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
user_prompt_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
response_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
run: Mapped[PipelineRun | None] = sa_relationship(
|
||||||
|
back_populates="events", foreign_keys=[run_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Consent Enums ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ConsentField(str, enum.Enum):
|
||||||
|
"""Fields that can be individually consented to per video."""
|
||||||
|
kb_inclusion = "kb_inclusion"
|
||||||
|
training_usage = "training_usage"
|
||||||
|
public_display = "public_display"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Consent Models ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class VideoConsent(Base):
|
||||||
|
"""Current consent state for a source video.
|
||||||
|
|
||||||
|
One row per video. Mutable — updated when a creator toggles consent.
|
||||||
|
The full change history lives in ConsentAuditLog.
|
||||||
|
"""
|
||||||
|
__tablename__ = "video_consents"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("source_video_id", name="uq_video_consent_video"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
source_video_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("source_videos.id", ondelete="CASCADE"), nullable=False,
|
||||||
|
)
|
||||||
|
creator_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("creators.id", ondelete="CASCADE"), nullable=False,
|
||||||
|
)
|
||||||
|
kb_inclusion: Mapped[bool] = mapped_column(
|
||||||
|
Boolean, default=False, server_default="false",
|
||||||
|
)
|
||||||
|
training_usage: Mapped[bool] = mapped_column(
|
||||||
|
Boolean, default=False, server_default="false",
|
||||||
|
)
|
||||||
|
public_display: Mapped[bool] = mapped_column(
|
||||||
|
Boolean, default=True, server_default="true",
|
||||||
|
)
|
||||||
|
updated_by: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("users.id", ondelete="RESTRICT"), nullable=False,
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now(), onupdate=_now
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
source_video: Mapped[SourceVideo] = sa_relationship()
|
||||||
|
creator: Mapped[Creator] = sa_relationship()
|
||||||
|
audit_entries: Mapped[list[ConsentAuditLog]] = sa_relationship(
|
||||||
|
back_populates="video_consent", order_by="ConsentAuditLog.version"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConsentAuditLog(Base):
|
||||||
|
"""Append-only versioned record of per-field consent changes.
|
||||||
|
|
||||||
|
Each row captures a single field change. Version is auto-assigned
|
||||||
|
in application code (max(version) + 1 per video_consent_id).
|
||||||
|
"""
|
||||||
|
__tablename__ = "consent_audit_log"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
video_consent_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("video_consents.id", ondelete="CASCADE"), nullable=False, index=True,
|
||||||
|
)
|
||||||
|
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
field_name: Mapped[str] = mapped_column(
|
||||||
|
String(50), nullable=False, doc="ConsentField value: kb_inclusion, training_usage, public_display"
|
||||||
|
)
|
||||||
|
old_value: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||||
|
new_value: Mapped[bool] = mapped_column(Boolean, nullable=False)
|
||||||
|
changed_by: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("users.id", ondelete="RESTRICT"), nullable=False,
|
||||||
|
)
|
||||||
|
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
video_consent: Mapped[VideoConsent] = sa_relationship(
|
||||||
|
back_populates="audit_entries"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImpersonationLog(Base):
|
||||||
|
"""Audit trail for admin impersonation sessions."""
|
||||||
|
__tablename__ = "impersonation_log"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
admin_user_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True,
|
||||||
|
)
|
||||||
|
target_user_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True,
|
||||||
|
)
|
||||||
|
action: Mapped[str] = mapped_column(
|
||||||
|
String(10), nullable=False, doc="'start' or 'stop'"
|
||||||
|
)
|
||||||
|
write_mode: Mapped[bool] = mapped_column(
|
||||||
|
default=False, server_default=text("false"),
|
||||||
|
)
|
||||||
|
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Highlight Detection ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class HighlightCandidate(Base):
|
||||||
|
"""Scored candidate for highlight detection, one per KeyMoment."""
|
||||||
|
__tablename__ = "highlight_candidates"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("key_moment_id", name="uq_highlight_candidate_moment"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
key_moment_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("key_moments.id", ondelete="CASCADE"), nullable=False, unique=True,
|
||||||
|
)
|
||||||
|
source_video_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("source_videos.id", ondelete="CASCADE"), nullable=False, index=True,
|
||||||
|
)
|
||||||
|
score: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
score_breakdown: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
duration_secs: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
status: Mapped[HighlightStatus] = mapped_column(
|
||||||
|
Enum(HighlightStatus, name="highlight_status", create_constraint=True),
|
||||||
|
default=HighlightStatus.candidate,
|
||||||
|
server_default="candidate",
|
||||||
|
)
|
||||||
|
trim_start: Mapped[float | None] = mapped_column(Float, nullable=True, default=None)
|
||||||
|
trim_end: Mapped[float | None] = mapped_column(Float, nullable=True, default=None)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now(), onupdate=_now
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
key_moment: Mapped[KeyMoment] = sa_relationship()
|
||||||
|
source_video: Mapped[SourceVideo] = sa_relationship()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Follow System ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class CreatorFollow(Base):
|
||||||
|
"""A user following a creator."""
|
||||||
|
__tablename__ = "creator_follows"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("user_id", "creator_id", name="uq_creator_follow_user_creator"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
user_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True,
|
||||||
|
)
|
||||||
|
creator_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("creators.id", ondelete="CASCADE"), nullable=False, index=True,
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
user: Mapped[User] = sa_relationship()
|
||||||
|
creator: Mapped[Creator] = sa_relationship()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Posts (Creator content feed) ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
class Post(Base):
|
||||||
|
"""A rich text post by a creator, optionally with file attachments."""
|
||||||
|
__tablename__ = "posts"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
creator_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("creators.id", ondelete="CASCADE"), nullable=False, index=True,
|
||||||
|
)
|
||||||
|
title: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
body_json: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
||||||
|
is_published: Mapped[bool] = mapped_column(
|
||||||
|
Boolean, default=False, server_default="false",
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now(), onupdate=_now
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
creator: Mapped[Creator] = sa_relationship(back_populates="posts")
|
||||||
|
attachments: Mapped[list[PostAttachment]] = sa_relationship(
|
||||||
|
back_populates="post", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PostAttachment(Base):
|
||||||
|
"""A file attachment on a post, stored in MinIO."""
|
||||||
|
__tablename__ = "post_attachments"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
post_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("posts.id", ondelete="CASCADE"), nullable=False, index=True,
|
||||||
|
)
|
||||||
|
filename: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
object_key: Mapped[str] = mapped_column(String(1000), nullable=False)
|
||||||
|
content_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
post: Mapped[Post] = sa_relationship(back_populates="attachments")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Shorts Generation ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class FormatPreset(str, enum.Enum):
|
||||||
|
"""Output format presets for generated shorts."""
|
||||||
|
vertical = "vertical" # 9:16 (1080x1920)
|
||||||
|
square = "square" # 1:1 (1080x1080)
|
||||||
|
horizontal = "horizontal" # 16:9 (1920x1080)
|
||||||
|
|
||||||
|
|
||||||
|
class ShortStatus(str, enum.Enum):
|
||||||
|
"""Processing status for a generated short."""
|
||||||
|
pending = "pending"
|
||||||
|
processing = "processing"
|
||||||
|
complete = "complete"
|
||||||
|
failed = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratedShort(Base):
|
||||||
|
"""A video short generated from a highlight candidate in a specific format."""
|
||||||
|
__tablename__ = "generated_shorts"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
highlight_candidate_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
ForeignKey("highlight_candidates.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
format_preset: Mapped[FormatPreset] = mapped_column(
|
||||||
|
Enum(FormatPreset, name="format_preset", create_constraint=True),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
minio_object_key: Mapped[str | None] = mapped_column(String(1000), nullable=True)
|
||||||
|
duration_secs: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
|
width: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
height: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
file_size_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||||
|
status: Mapped[ShortStatus] = mapped_column(
|
||||||
|
Enum(ShortStatus, name="short_status", create_constraint=True),
|
||||||
|
default=ShortStatus.pending,
|
||||||
|
server_default="pending",
|
||||||
|
)
|
||||||
|
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
share_token: Mapped[str | None] = mapped_column(
|
||||||
|
String(16), nullable=True, unique=True, index=True,
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now(), onupdate=_now
|
||||||
|
)
|
||||||
|
|
||||||
|
captions_enabled: Mapped[bool] = mapped_column(
|
||||||
|
Boolean, default=False, server_default=text("'false'"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
highlight_candidate: Mapped[HighlightCandidate] = sa_relationship()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Chat Usage Tracking ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ChatUsageLog(Base):
|
||||||
|
"""Per-request token usage log for chat completions.
|
||||||
|
|
||||||
|
Append-only table — one row per chat request. Used for cost tracking,
|
||||||
|
rate limit analytics, and the admin usage dashboard.
|
||||||
|
"""
|
||||||
|
__tablename__ = "chat_usage_log"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||||
|
user_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
ForeignKey("users.id", ondelete="SET NULL"), nullable=True,
|
||||||
|
)
|
||||||
|
client_ip: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||||
|
creator_slug: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
query: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
prompt_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
completion_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
total_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
cascade_tier: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||||
|
model: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
latency_ms: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
default=_now, server_default=func.now(), index=True,
|
||||||
|
)
|
||||||
0
backend/pipeline/__init__.py
Normal file
0
backend/pipeline/__init__.py
Normal file
155
backend/pipeline/caption_generator.py
Normal file
155
backend/pipeline/caption_generator.py
Normal file
|
|
@ -0,0 +1,155 @@
|
||||||
|
r"""ASS (Advanced SubStation Alpha) caption generator for shorts.
|
||||||
|
|
||||||
|
Converts word-level timings from Whisper transcripts into ASS subtitle
|
||||||
|
files with word-by-word karaoke highlighting. Each word gets its own
|
||||||
|
Dialogue line with {\k} tags that control highlight duration.
|
||||||
|
|
||||||
|
Pure functions — no DB access, no Celery dependency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Default style configuration ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
DEFAULT_STYLE: dict[str, Any] = {
|
||||||
|
"font_name": "Arial",
|
||||||
|
"font_size": 48,
|
||||||
|
"primary_colour": "&H00FFFFFF", # white (BGR + alpha)
|
||||||
|
"secondary_colour": "&H0000FFFF", # yellow highlight
|
||||||
|
"outline_colour": "&H00000000", # black outline
|
||||||
|
"back_colour": "&H80000000", # semi-transparent black shadow
|
||||||
|
"bold": -1, # bold
|
||||||
|
"outline": 3,
|
||||||
|
"shadow": 1,
|
||||||
|
"alignment": 2, # bottom-center
|
||||||
|
"margin_v": 60, # 60px from bottom (~15% on 1920h)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _format_ass_time(seconds: float) -> str:
|
||||||
|
"""Convert seconds to ASS timestamp format: H:MM:SS.cc (centiseconds).
|
||||||
|
|
||||||
|
>>> _format_ass_time(65.5)
|
||||||
|
'0:01:05.50'
|
||||||
|
>>> _format_ass_time(0.0)
|
||||||
|
'0:00:00.00'
|
||||||
|
"""
|
||||||
|
if seconds < 0:
|
||||||
|
seconds = 0.0
|
||||||
|
h = int(seconds // 3600)
|
||||||
|
m = int((seconds % 3600) // 60)
|
||||||
|
s = seconds % 60
|
||||||
|
return f"{h}:{m:02d}:{s:05.2f}"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_ass_header(style_config: dict[str, Any]) -> str:
|
||||||
|
"""Build ASS file header with script info and style definition."""
|
||||||
|
cfg = {**DEFAULT_STYLE, **(style_config or {})}
|
||||||
|
|
||||||
|
header = (
|
||||||
|
"[Script Info]\n"
|
||||||
|
"Title: Chrysopedia Auto-Captions\n"
|
||||||
|
"ScriptType: v4.00+\n"
|
||||||
|
"PlayResX: 1080\n"
|
||||||
|
"PlayResY: 1920\n"
|
||||||
|
"WrapStyle: 0\n"
|
||||||
|
"ScaledBorderAndShadow: yes\n"
|
||||||
|
"\n"
|
||||||
|
"[V4+ Styles]\n"
|
||||||
|
"Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, "
|
||||||
|
"OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, "
|
||||||
|
"ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, "
|
||||||
|
"Alignment, MarginL, MarginR, MarginV, Encoding\n"
|
||||||
|
f"Style: Default,{cfg['font_name']},{cfg['font_size']},"
|
||||||
|
f"{cfg['primary_colour']},{cfg['secondary_colour']},"
|
||||||
|
f"{cfg['outline_colour']},{cfg['back_colour']},"
|
||||||
|
f"{cfg['bold']},0,0,0,"
|
||||||
|
f"100,100,0,0,1,{cfg['outline']},{cfg['shadow']},"
|
||||||
|
f"{cfg['alignment']},20,20,{cfg['margin_v']},1\n"
|
||||||
|
"\n"
|
||||||
|
"[Events]\n"
|
||||||
|
"Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n"
|
||||||
|
)
|
||||||
|
return header
|
||||||
|
|
||||||
|
|
||||||
|
def generate_ass_captions(
|
||||||
|
word_timings: list[dict[str, Any]],
|
||||||
|
clip_start: float,
|
||||||
|
style_config: dict[str, Any] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate ASS subtitle content from word-level timings.
|
||||||
|
|
||||||
|
Each word is emitted as a separate Dialogue line with karaoke timing
|
||||||
|
(``{\\k<centiseconds>}``) so that words highlight one-by-one.
|
||||||
|
|
||||||
|
All word timestamps are offset by ``-clip_start`` to make them
|
||||||
|
clip-relative (i.e. the first frame of the clip is t=0).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
word_timings : list[dict]
|
||||||
|
Word-timing dicts with ``word``, ``start``, ``end`` keys.
|
||||||
|
``start`` and ``end`` are absolute times in seconds.
|
||||||
|
clip_start : float
|
||||||
|
Absolute start time of the clip in seconds. Subtracted from
|
||||||
|
all word timestamps.
|
||||||
|
style_config : dict | None
|
||||||
|
Override style parameters (merged onto DEFAULT_STYLE).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str — Full ASS file content. Empty dialogue section if no timings.
|
||||||
|
"""
|
||||||
|
header = _build_ass_header(style_config)
|
||||||
|
|
||||||
|
if not word_timings:
|
||||||
|
logger.debug("No word timings provided — returning header-only ASS")
|
||||||
|
return header
|
||||||
|
|
||||||
|
lines: list[str] = [header]
|
||||||
|
|
||||||
|
for w in word_timings:
|
||||||
|
word_text = w.get("word", "").strip()
|
||||||
|
if not word_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
abs_start = float(w.get("start", 0.0))
|
||||||
|
abs_end = float(w.get("end", abs_start))
|
||||||
|
|
||||||
|
# Make clip-relative
|
||||||
|
rel_start = max(0.0, abs_start - clip_start)
|
||||||
|
rel_end = max(rel_start, abs_end - clip_start)
|
||||||
|
|
||||||
|
# Karaoke duration in centiseconds
|
||||||
|
k_duration = max(1, round((rel_end - rel_start) * 100))
|
||||||
|
|
||||||
|
start_ts = _format_ass_time(rel_start)
|
||||||
|
end_ts = _format_ass_time(rel_end)
|
||||||
|
|
||||||
|
# Dialogue line with karaoke tag
|
||||||
|
line = (
|
||||||
|
f"Dialogue: 0,{start_ts},{end_ts},Default,,0,0,0,,"
|
||||||
|
f"{{\\k{k_duration}}}{word_text}"
|
||||||
|
)
|
||||||
|
lines.append(line)
|
||||||
|
|
||||||
|
return "\n".join(lines) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def write_ass_file(ass_content: str, output_path: Path) -> Path:
|
||||||
|
"""Write ASS content to disk.
|
||||||
|
|
||||||
|
Creates parent directories if needed. Returns the output path.
|
||||||
|
"""
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path.write_text(ass_content, encoding="utf-8")
|
||||||
|
logger.debug("Wrote ASS captions to %s (%d bytes)", output_path, len(ass_content))
|
||||||
|
return output_path
|
||||||
298
backend/pipeline/card_renderer.py
Normal file
298
backend/pipeline/card_renderer.py
Normal file
|
|
@ -0,0 +1,298 @@
|
||||||
|
"""FFmpeg-based intro/outro card video generation and segment concatenation.
|
||||||
|
|
||||||
|
Generates solid-color card clips with centered text using ffmpeg lavfi
|
||||||
|
(color + drawtext filters). Provides concat demuxer logic to assemble
|
||||||
|
intro + main clip + outro into a final short.
|
||||||
|
|
||||||
|
Pure functions — no DB access, no Celery dependency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FFMPEG_TIMEOUT_SECS = 120
|
||||||
|
|
||||||
|
# Default template values
|
||||||
|
DEFAULT_ACCENT_COLOR = "#22d3ee"
|
||||||
|
DEFAULT_FONT_FAMILY = "Inter"
|
||||||
|
DEFAULT_INTRO_DURATION = 2.0
|
||||||
|
DEFAULT_OUTRO_DURATION = 2.0
|
||||||
|
|
||||||
|
|
||||||
|
def render_card(
|
||||||
|
text: str,
|
||||||
|
duration_secs: float,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
accent_color: str = DEFAULT_ACCENT_COLOR,
|
||||||
|
font_family: str = DEFAULT_FONT_FAMILY,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Build ffmpeg command args that generate a card mp4 from lavfi input.
|
||||||
|
|
||||||
|
Produces a solid black background with centered white text and a thin
|
||||||
|
accent-color underline bar at the bottom third.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Display text (e.g., creator name or "Thanks for watching").
|
||||||
|
duration_secs: Card duration in seconds.
|
||||||
|
width: Output width in pixels.
|
||||||
|
height: Output height in pixels.
|
||||||
|
accent_color: Hex color for the underline glow bar.
|
||||||
|
font_family: Font family for drawtext (must be available on system).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ffmpeg command arguments (without the output path — caller appends).
|
||||||
|
"""
|
||||||
|
if duration_secs <= 0:
|
||||||
|
raise ValueError(f"duration_secs must be positive, got {duration_secs}")
|
||||||
|
if width <= 0 or height <= 0:
|
||||||
|
raise ValueError(f"dimensions must be positive, got {width}x{height}")
|
||||||
|
|
||||||
|
# Font size scales with height — ~5% of output height
|
||||||
|
font_size = max(24, int(height * 0.05))
|
||||||
|
# Accent bar: thin horizontal line at ~65% down
|
||||||
|
bar_y = int(height * 0.65)
|
||||||
|
bar_height = max(2, int(height * 0.004))
|
||||||
|
bar_margin = int(width * 0.2)
|
||||||
|
|
||||||
|
# Escape text for ffmpeg drawtext (colons, backslashes, single quotes)
|
||||||
|
escaped_text = (
|
||||||
|
text.replace("\\", "\\\\")
|
||||||
|
.replace(":", "\\:")
|
||||||
|
.replace("'", "'\\''")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build complex filtergraph:
|
||||||
|
# 1. color source for black background
|
||||||
|
# 2. drawtext for centered title
|
||||||
|
# 3. drawbox for accent underline bar
|
||||||
|
filtergraph = (
|
||||||
|
f"color=c=black:s={width}x{height}:d={duration_secs}:r=30,"
|
||||||
|
f"drawtext=text='{escaped_text}'"
|
||||||
|
f":fontcolor=white:fontsize={font_size}"
|
||||||
|
f":fontfile='':font='{font_family}'"
|
||||||
|
f":x=(w-text_w)/2:y=(h-text_h)/2-{font_size},"
|
||||||
|
f"drawbox=x={bar_margin}:y={bar_y}"
|
||||||
|
f":w={width - 2 * bar_margin}:h={bar_height}"
|
||||||
|
f":color='{accent_color}'@0.8:t=fill"
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-f", "lavfi",
|
||||||
|
"-i", filtergraph,
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", "fast",
|
||||||
|
"-crf", "23",
|
||||||
|
"-pix_fmt", "yuv420p",
|
||||||
|
"-t", str(duration_secs),
|
||||||
|
# Silent audio track so concat with audio segments works
|
||||||
|
"-f", "lavfi",
|
||||||
|
"-i", f"anullsrc=r=44100:cl=stereo:d={duration_secs}",
|
||||||
|
"-c:a", "aac",
|
||||||
|
"-b:a", "128k",
|
||||||
|
"-shortest",
|
||||||
|
"-movflags", "+faststart",
|
||||||
|
]
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def render_card_to_file(
|
||||||
|
text: str,
|
||||||
|
duration_secs: float,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
output_path: Path,
|
||||||
|
accent_color: str = DEFAULT_ACCENT_COLOR,
|
||||||
|
font_family: str = DEFAULT_FONT_FAMILY,
|
||||||
|
) -> Path:
|
||||||
|
"""Generate a card mp4 file via ffmpeg.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Display text for the card.
|
||||||
|
duration_secs: Card duration in seconds.
|
||||||
|
width: Output width in pixels.
|
||||||
|
height: Output height in pixels.
|
||||||
|
output_path: Destination mp4 file.
|
||||||
|
accent_color: Hex color for accent elements.
|
||||||
|
font_family: Font family for text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output_path on success.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
subprocess.CalledProcessError: If ffmpeg exits non-zero.
|
||||||
|
subprocess.TimeoutExpired: If ffmpeg exceeds timeout.
|
||||||
|
"""
|
||||||
|
cmd = render_card(
|
||||||
|
text=text,
|
||||||
|
duration_secs=duration_secs,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
accent_color=accent_color,
|
||||||
|
font_family=font_family,
|
||||||
|
)
|
||||||
|
cmd.append(str(output_path))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Rendering card: text=%r duration=%.1fs size=%dx%d → %s",
|
||||||
|
text, duration_secs, width, height, output_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
timeout=FFMPEG_TIMEOUT_SECS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
stderr_text = result.stderr.decode("utf-8", errors="replace")[-2000:]
|
||||||
|
logger.error("Card render failed (rc=%d): %s", result.returncode, stderr_text)
|
||||||
|
raise subprocess.CalledProcessError(
|
||||||
|
result.returncode, cmd, output=result.stdout, stderr=result.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Card rendered: %s (%d bytes)", output_path, output_path.stat().st_size)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def build_concat_list(segments: list[Path], list_path: Path) -> Path:
|
||||||
|
"""Write an ffmpeg concat demuxer list file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments: Ordered list of segment mp4 paths.
|
||||||
|
list_path: Where to write the concat list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list_path.
|
||||||
|
"""
|
||||||
|
lines = [f"file '{seg.resolve()}'" for seg in segments]
|
||||||
|
list_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||||
|
return list_path
|
||||||
|
|
||||||
|
|
||||||
|
def concat_segments(segments: list[Path], output_path: Path) -> Path:
|
||||||
|
"""Concatenate mp4 segments using ffmpeg concat demuxer.
|
||||||
|
|
||||||
|
All segments must share the same codec settings (libx264/aac, same
|
||||||
|
resolution). Uses ``-c copy`` for fast stream-copy concatenation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments: Ordered list of segment mp4 paths.
|
||||||
|
output_path: Destination mp4 file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output_path on success.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If segments list is empty.
|
||||||
|
subprocess.CalledProcessError: If ffmpeg exits non-zero.
|
||||||
|
subprocess.TimeoutExpired: If ffmpeg exceeds timeout.
|
||||||
|
"""
|
||||||
|
if not segments:
|
||||||
|
raise ValueError("segments list cannot be empty")
|
||||||
|
|
||||||
|
# Write concat list to a temp file
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="w", suffix=".txt", delete=False, prefix="concat_",
|
||||||
|
) as f:
|
||||||
|
for seg in segments:
|
||||||
|
f.write(f"file '{seg.resolve()}'\n")
|
||||||
|
list_path = Path(f.name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-f", "concat",
|
||||||
|
"-safe", "0",
|
||||||
|
"-i", str(list_path),
|
||||||
|
"-c", "copy",
|
||||||
|
"-movflags", "+faststart",
|
||||||
|
str(output_path),
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Concatenating %d segments → %s",
|
||||||
|
len(segments), output_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
timeout=FFMPEG_TIMEOUT_SECS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
stderr_text = result.stderr.decode("utf-8", errors="replace")[-2000:]
|
||||||
|
logger.error(
|
||||||
|
"Concat failed (rc=%d): %s", result.returncode, stderr_text,
|
||||||
|
)
|
||||||
|
raise subprocess.CalledProcessError(
|
||||||
|
result.returncode, cmd, output=result.stdout, stderr=result.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Concatenated %d segments: %s (%d bytes)",
|
||||||
|
len(segments), output_path, output_path.stat().st_size,
|
||||||
|
)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up temp list file
|
||||||
|
try:
|
||||||
|
list_path.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def parse_template_config(
|
||||||
|
shorts_template: dict | None,
|
||||||
|
) -> dict:
|
||||||
|
"""Parse a creator's shorts_template JSONB into normalized config.
|
||||||
|
|
||||||
|
Expected schema::
|
||||||
|
|
||||||
|
{
|
||||||
|
"show_intro": true,
|
||||||
|
"intro_text": "Creator Name Presents",
|
||||||
|
"intro_duration": 2.0,
|
||||||
|
"show_outro": true,
|
||||||
|
"outro_text": "Thanks for watching!",
|
||||||
|
"outro_duration": 2.0,
|
||||||
|
"accent_color": "#22d3ee",
|
||||||
|
"font_family": "Inter"
|
||||||
|
}
|
||||||
|
|
||||||
|
Missing fields get defaults. Returns a dict with all keys guaranteed.
|
||||||
|
"""
|
||||||
|
if not shorts_template:
|
||||||
|
return {
|
||||||
|
"show_intro": False,
|
||||||
|
"intro_text": "",
|
||||||
|
"intro_duration": DEFAULT_INTRO_DURATION,
|
||||||
|
"show_outro": False,
|
||||||
|
"outro_text": "",
|
||||||
|
"outro_duration": DEFAULT_OUTRO_DURATION,
|
||||||
|
"accent_color": DEFAULT_ACCENT_COLOR,
|
||||||
|
"font_family": DEFAULT_FONT_FAMILY,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"show_intro": bool(shorts_template.get("show_intro", False)),
|
||||||
|
"intro_text": str(shorts_template.get("intro_text", "")),
|
||||||
|
"intro_duration": float(shorts_template.get("intro_duration", DEFAULT_INTRO_DURATION)),
|
||||||
|
"show_outro": bool(shorts_template.get("show_outro", False)),
|
||||||
|
"outro_text": str(shorts_template.get("outro_text", "")),
|
||||||
|
"outro_duration": float(shorts_template.get("outro_duration", DEFAULT_OUTRO_DURATION)),
|
||||||
|
"accent_color": str(shorts_template.get("accent_color", DEFAULT_ACCENT_COLOR)),
|
||||||
|
"font_family": str(shorts_template.get("font_family", DEFAULT_FONT_FAMILY)),
|
||||||
|
}
|
||||||
64
backend/pipeline/citation_utils.py
Normal file
64
backend/pipeline/citation_utils.py
Normal file
|
|
@ -0,0 +1,64 @@
|
||||||
|
"""Citation extraction and validation utilities for synthesized technique pages.
|
||||||
|
|
||||||
|
Used by stage 5 synthesis and the test harness to verify that [N] citation
|
||||||
|
markers in body sections reference valid source moments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from pipeline.schemas import BodySection
|
||||||
|
|
||||||
|
# Matches [N] or [N,M] or [N,M,P] style citation markers where N,M,P are integers.
|
||||||
|
_CITATION_RE = re.compile(r"\[(\d+(?:,\s*\d+)*)\]")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_citations(text: str) -> list[int]:
|
||||||
|
"""Extract all citation indices from ``[N]`` and ``[N,M,...]`` markers in *text*.
|
||||||
|
|
||||||
|
Returns a sorted list of unique integer indices.
|
||||||
|
"""
|
||||||
|
indices: set[int] = set()
|
||||||
|
for match in _CITATION_RE.finditer(text):
|
||||||
|
for part in match.group(1).split(","):
|
||||||
|
indices.add(int(part.strip()))
|
||||||
|
return sorted(indices)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_citations(
|
||||||
|
sections: list[BodySection],
|
||||||
|
moment_count: int,
|
||||||
|
) -> dict:
|
||||||
|
"""Validate citation markers across all *sections* against *moment_count* source moments.
|
||||||
|
|
||||||
|
Moments are expected to be referenced as 0-based indices ``[0]`` through
|
||||||
|
``[moment_count - 1]``.
|
||||||
|
|
||||||
|
Returns a dict with:
|
||||||
|
valid (bool): True when every cited index is in range and every moment is cited.
|
||||||
|
total_citations (int): Count of unique cited indices.
|
||||||
|
invalid_indices (list[int]): Cited indices that are out of range.
|
||||||
|
uncited_moments (list[int]): In-range moment indices that are never cited.
|
||||||
|
coverage_pct (float): Percentage of moments that are cited (0.0–100.0).
|
||||||
|
"""
|
||||||
|
all_indices: set[int] = set()
|
||||||
|
|
||||||
|
for section in sections:
|
||||||
|
all_indices.update(extract_citations(section.content))
|
||||||
|
for sub in section.subsections:
|
||||||
|
all_indices.update(extract_citations(sub.content))
|
||||||
|
|
||||||
|
valid_range = set(range(moment_count))
|
||||||
|
invalid_indices = sorted(all_indices - valid_range)
|
||||||
|
cited_in_range = all_indices & valid_range
|
||||||
|
uncited_moments = sorted(valid_range - cited_in_range)
|
||||||
|
coverage_pct = (len(cited_in_range) / moment_count * 100.0) if moment_count > 0 else 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"valid": len(invalid_indices) == 0 and len(uncited_moments) == 0,
|
||||||
|
"total_citations": len(cited_in_range),
|
||||||
|
"invalid_indices": invalid_indices,
|
||||||
|
"uncited_moments": uncited_moments,
|
||||||
|
"coverage_pct": round(coverage_pct, 1),
|
||||||
|
}
|
||||||
88
backend/pipeline/embedding_client.py
Normal file
88
backend/pipeline/embedding_client.py
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
"""Synchronous embedding client using the OpenAI-compatible /v1/embeddings API.
|
||||||
|
|
||||||
|
Uses ``openai.OpenAI`` (sync) since Celery tasks run synchronously.
|
||||||
|
Handles connection failures gracefully — embedding is non-blocking for the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from config import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingClient:
|
||||||
|
"""Sync embedding client backed by an OpenAI-compatible /v1/embeddings endpoint."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
self.settings = settings
|
||||||
|
self._client = openai.OpenAI(
|
||||||
|
base_url=settings.embedding_api_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
"""Generate embedding vectors for a batch of texts.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
texts:
|
||||||
|
List of strings to embed.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[list[float]]
|
||||||
|
Embedding vectors. Returns empty list on connection/timeout errors
|
||||||
|
so the pipeline can continue without embeddings.
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self._client.embeddings.create(
|
||||||
|
model=self.settings.embedding_model,
|
||||||
|
input=texts,
|
||||||
|
)
|
||||||
|
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Embedding API unavailable (%s: %s). Skipping %d texts.",
|
||||||
|
type(exc).__name__,
|
||||||
|
exc,
|
||||||
|
len(texts),
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
except openai.APIError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Embedding API error (%s: %s). Skipping %d texts.",
|
||||||
|
type(exc).__name__,
|
||||||
|
exc,
|
||||||
|
len(texts),
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
vectors = [item.embedding for item in response.data]
|
||||||
|
|
||||||
|
# Validate dimensions
|
||||||
|
expected_dim = self.settings.embedding_dimensions
|
||||||
|
for i, vec in enumerate(vectors):
|
||||||
|
if len(vec) != expected_dim:
|
||||||
|
logger.warning(
|
||||||
|
"Embedding dimension mismatch at index %d: expected %d, got %d. "
|
||||||
|
"Returning empty list.",
|
||||||
|
i,
|
||||||
|
expected_dim,
|
||||||
|
len(vec),
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Generated %d embeddings (dim=%d) using model=%s",
|
||||||
|
len(vectors),
|
||||||
|
expected_dim,
|
||||||
|
self.settings.embedding_model,
|
||||||
|
)
|
||||||
|
return vectors
|
||||||
306
backend/pipeline/export_fixture.py
Normal file
306
backend/pipeline/export_fixture.py
Normal file
|
|
@ -0,0 +1,306 @@
|
||||||
|
"""Export pipeline stage inputs for a video as a reusable JSON fixture.
|
||||||
|
|
||||||
|
Connects to the live database, queries KeyMoments and classification data,
|
||||||
|
and writes a fixture file that the test harness can consume offline.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m pipeline.export_fixture --video-id <uuid> --output fixtures/video.json
|
||||||
|
python -m pipeline.export_fixture --video-id <uuid> # prints to stdout
|
||||||
|
python -m pipeline.export_fixture --list # list available videos
|
||||||
|
|
||||||
|
Requires: DATABASE_URL, REDIS_URL environment variables (or .env file).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from collections import Counter
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, select
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
|
||||||
|
def _log(tag: str, msg: str, level: str = "INFO") -> None:
|
||||||
|
"""Write structured log line to stderr."""
|
||||||
|
print(f"[EXPORT] [{level}] {tag}: {msg}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sync_session(database_url: str) -> Session:
|
||||||
|
"""Create a sync SQLAlchemy session from the database URL."""
|
||||||
|
url = database_url.replace("postgresql+asyncpg://", "postgresql+psycopg2://")
|
||||||
|
engine = create_engine(url, pool_pre_ping=True)
|
||||||
|
factory = sessionmaker(bind=engine)
|
||||||
|
return factory()
|
||||||
|
|
||||||
|
|
||||||
|
def _list_videos(database_url: str) -> int:
|
||||||
|
"""List all videos with their processing status and moment counts."""
|
||||||
|
from models import Creator, KeyMoment, SourceVideo
|
||||||
|
|
||||||
|
session = _get_sync_session(database_url)
|
||||||
|
try:
|
||||||
|
videos = (
|
||||||
|
session.execute(
|
||||||
|
select(SourceVideo).order_by(SourceVideo.created_at.desc())
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
if not videos:
|
||||||
|
_log("LIST", "No videos found in database")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
print(f"\n{'ID':<38s} {'Status':<14s} {'Moments':>7s} {'Creator':<20s} {'Filename'}", file=sys.stderr)
|
||||||
|
print(f"{'─'*38} {'─'*14} {'─'*7} {'─'*20} {'─'*40}", file=sys.stderr)
|
||||||
|
|
||||||
|
for video in videos:
|
||||||
|
moment_count = (
|
||||||
|
session.execute(
|
||||||
|
select(KeyMoment.id).where(KeyMoment.source_video_id == video.id)
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
creator = session.execute(
|
||||||
|
select(Creator).where(Creator.id == video.creator_id)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
creator_name = creator.name if creator else "?"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{str(video.id):<38s} {video.processing_status.value:<14s} "
|
||||||
|
f"{len(moment_count):>7d} {creator_name:<20s} {video.filename}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nTotal: {len(videos)} videos\n", file=sys.stderr)
|
||||||
|
return 0
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def export_fixture(
|
||||||
|
database_url: str,
|
||||||
|
redis_url: str,
|
||||||
|
video_id: str,
|
||||||
|
output_path: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Export stage 5 inputs for a video as a JSON fixture.
|
||||||
|
|
||||||
|
Returns exit code: 0 = success, 1 = error.
|
||||||
|
"""
|
||||||
|
from models import Creator, KeyMoment, SourceVideo
|
||||||
|
|
||||||
|
start = time.monotonic()
|
||||||
|
_log("CONNECT", "Connecting to database...")
|
||||||
|
|
||||||
|
session = _get_sync_session(database_url)
|
||||||
|
try:
|
||||||
|
# ── Load video ──────────────────────────────────────────────────
|
||||||
|
video = session.execute(
|
||||||
|
select(SourceVideo).where(SourceVideo.id == video_id)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if video is None:
|
||||||
|
_log("ERROR", f"Video not found: {video_id}", level="ERROR")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
creator = session.execute(
|
||||||
|
select(Creator).where(Creator.id == video.creator_id)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
creator_name = creator.name if creator else "Unknown"
|
||||||
|
|
||||||
|
_log(
|
||||||
|
"VIDEO",
|
||||||
|
f"Found: {video.filename} by {creator_name} "
|
||||||
|
f"({video.duration_seconds or '?'}s, {video.content_type.value}, "
|
||||||
|
f"status={video.processing_status.value})",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Load key moments ────────────────────────────────────────────
|
||||||
|
moments = (
|
||||||
|
session.execute(
|
||||||
|
select(KeyMoment)
|
||||||
|
.where(KeyMoment.source_video_id == video_id)
|
||||||
|
.order_by(KeyMoment.start_time)
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not moments:
|
||||||
|
_log("ERROR", f"No key moments found for video_id={video_id}", level="ERROR")
|
||||||
|
_log("HINT", "Pipeline stages 2-3 must complete before export is possible", level="ERROR")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
time_min = min(m.start_time for m in moments)
|
||||||
|
time_max = max(m.end_time for m in moments)
|
||||||
|
_log("MOMENTS", f"Loaded {len(moments)} key moments (time range: {time_min:.1f}s - {time_max:.1f}s)")
|
||||||
|
|
||||||
|
# ── Load classification data ────────────────────────────────────
|
||||||
|
classification_data: list[dict] = []
|
||||||
|
cls_source = "missing"
|
||||||
|
|
||||||
|
# Try Redis first
|
||||||
|
try:
|
||||||
|
import redis as redis_lib
|
||||||
|
|
||||||
|
r = redis_lib.Redis.from_url(redis_url)
|
||||||
|
key = f"chrysopedia:classification:{video_id}"
|
||||||
|
raw = r.get(key)
|
||||||
|
if raw is not None:
|
||||||
|
classification_data = json.loads(raw)
|
||||||
|
cls_source = "redis"
|
||||||
|
ttl = r.ttl(key)
|
||||||
|
_log("CLASSIFY", f"Source: redis ({len(classification_data)} entries, TTL={ttl}s)")
|
||||||
|
except Exception as exc:
|
||||||
|
_log("CLASSIFY", f"Redis unavailable: {exc}", level="WARN")
|
||||||
|
|
||||||
|
# Fallback: check SourceVideo.classification_data column (Phase 2 addition)
|
||||||
|
if not classification_data:
|
||||||
|
video_cls = getattr(video, "classification_data", None)
|
||||||
|
if video_cls:
|
||||||
|
classification_data = video_cls
|
||||||
|
cls_source = "postgresql"
|
||||||
|
_log("CLASSIFY", f"Source: postgresql ({len(classification_data)} entries)")
|
||||||
|
|
||||||
|
if not classification_data:
|
||||||
|
_log("CLASSIFY", "No classification data found in Redis or PostgreSQL", level="WARN")
|
||||||
|
_log("HINT", "Pipeline stage 4 must complete before classification data is available", level="WARN")
|
||||||
|
cls_source = "missing"
|
||||||
|
|
||||||
|
# Build classification lookup by moment_id
|
||||||
|
cls_by_moment_id = {c["moment_id"]: c for c in classification_data}
|
||||||
|
|
||||||
|
# Count moments without classification
|
||||||
|
unclassified = sum(1 for m in moments if str(m.id) not in cls_by_moment_id)
|
||||||
|
if unclassified > 0:
|
||||||
|
_log("CLASSIFY", f"WARNING: {unclassified}/{len(moments)} moments have no classification data", level="WARN")
|
||||||
|
|
||||||
|
# ── Build fixture ───────────────────────────────────────────────
|
||||||
|
fixture_moments = []
|
||||||
|
category_counts: Counter[str] = Counter()
|
||||||
|
|
||||||
|
for m in moments:
|
||||||
|
cls_info = cls_by_moment_id.get(str(m.id), {})
|
||||||
|
topic_category = cls_info.get("topic_category", "Uncategorized")
|
||||||
|
topic_tags = cls_info.get("topic_tags", [])
|
||||||
|
category_counts[topic_category] += 1
|
||||||
|
|
||||||
|
fixture_moments.append({
|
||||||
|
"moment_id": str(m.id),
|
||||||
|
"title": m.title,
|
||||||
|
"summary": m.summary,
|
||||||
|
"content_type": m.content_type.value,
|
||||||
|
"start_time": m.start_time,
|
||||||
|
"end_time": m.end_time,
|
||||||
|
"plugins": m.plugins or [],
|
||||||
|
"raw_transcript": m.raw_transcript or "",
|
||||||
|
# Classification data (stage 4 output)
|
||||||
|
"classification": {
|
||||||
|
"topic_category": topic_category,
|
||||||
|
"topic_tags": topic_tags,
|
||||||
|
},
|
||||||
|
# Compatibility fields for existing quality/scorer format
|
||||||
|
"transcript_excerpt": (m.raw_transcript or "")[:500],
|
||||||
|
"topic_tags": topic_tags,
|
||||||
|
"topic_category": topic_category,
|
||||||
|
})
|
||||||
|
|
||||||
|
fixture = {
|
||||||
|
"video_id": str(video.id),
|
||||||
|
"creator_name": creator_name,
|
||||||
|
"content_type": video.content_type.value,
|
||||||
|
"filename": video.filename,
|
||||||
|
"duration_seconds": video.duration_seconds,
|
||||||
|
"classification_source": cls_source,
|
||||||
|
"export_timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||||
|
"moments": fixture_moments,
|
||||||
|
}
|
||||||
|
|
||||||
|
fixture_json = json.dumps(fixture, indent=2, ensure_ascii=False)
|
||||||
|
fixture_size_kb = len(fixture_json.encode("utf-8")) / 1024
|
||||||
|
|
||||||
|
# ── Output ──────────────────────────────────────────────────────
|
||||||
|
if output_path:
|
||||||
|
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
Path(output_path).write_text(fixture_json, encoding="utf-8")
|
||||||
|
_log(
|
||||||
|
"OUTPUT",
|
||||||
|
f"Wrote fixture: {output_path} ({fixture_size_kb:.1f} KB, "
|
||||||
|
f"{len(fixture_moments)} moments, {len(category_counts)} categories)",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Print fixture JSON to stdout
|
||||||
|
print(fixture_json)
|
||||||
|
_log(
|
||||||
|
"OUTPUT",
|
||||||
|
f"Printed fixture to stdout ({fixture_size_kb:.1f} KB, "
|
||||||
|
f"{len(fixture_moments)} moments, {len(category_counts)} categories)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Category breakdown
|
||||||
|
for cat, count in category_counts.most_common():
|
||||||
|
_log("CATEGORY", f" {cat}: {count} moments")
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
_log("DONE", f"Export completed in {elapsed:.1f}s")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
_log("ERROR", f"Export failed: {exc}", level="ERROR")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc(file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="pipeline.export_fixture",
|
||||||
|
description="Export pipeline stage inputs for a video as a reusable JSON fixture",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--video-id",
|
||||||
|
type=str,
|
||||||
|
help="UUID of the video to export",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Output file path (default: print to stdout)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--list",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="List all videos with status and moment counts",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load settings
|
||||||
|
from config import get_settings
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
if args.list:
|
||||||
|
return _list_videos(settings.database_url)
|
||||||
|
|
||||||
|
if not args.video_id:
|
||||||
|
parser.error("--video-id is required (or use --list to see available videos)")
|
||||||
|
|
||||||
|
return export_fixture(
|
||||||
|
database_url=settings.database_url,
|
||||||
|
redis_url=settings.redis_url,
|
||||||
|
video_id=args.video_id,
|
||||||
|
output_path=args.output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
63
backend/pipeline/highlight_schemas.py
Normal file
63
backend/pipeline/highlight_schemas.py
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
"""Pydantic schemas for highlight detection pipeline.
|
||||||
|
|
||||||
|
Covers scoring breakdown, candidate responses, and batch result summaries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class HighlightScoreBreakdown(BaseModel):
|
||||||
|
"""Per-dimension score breakdown for a highlight candidate.
|
||||||
|
|
||||||
|
Each field is a float in [0, 1] representing the normalized score
|
||||||
|
for that scoring dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
duration_score: float = Field(description="Score based on moment duration (sweet-spot curve)")
|
||||||
|
content_density_score: float = Field(description="Score based on transcript richness / word density")
|
||||||
|
technique_relevance_score: float = Field(description="Score based on content_type and plugin mentions")
|
||||||
|
position_score: float = Field(description="Score based on temporal position within the video")
|
||||||
|
uniqueness_score: float = Field(description="Score based on title/topic distinctness among siblings")
|
||||||
|
engagement_proxy_score: float = Field(description="Proxy engagement signal from summary quality/length")
|
||||||
|
plugin_diversity_score: float = Field(description="Score based on breadth of plugins/tools mentioned")
|
||||||
|
speech_rate_variance_score: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
description="Score based on speech rate variation (emphasis shifts) from word timings",
|
||||||
|
)
|
||||||
|
pause_density_score: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
description="Score based on strategic pause frequency from word timings",
|
||||||
|
)
|
||||||
|
speaking_pace_score: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
description="Score based on words-per-second fitness for teaching pace",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HighlightCandidateResponse(BaseModel):
|
||||||
|
"""API response schema for a single highlight candidate."""
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
key_moment_id: uuid.UUID
|
||||||
|
source_video_id: uuid.UUID
|
||||||
|
score: float = Field(ge=0.0, le=1.0, description="Composite highlight score")
|
||||||
|
score_breakdown: HighlightScoreBreakdown
|
||||||
|
duration_secs: float = Field(ge=0.0, description="Duration of the key moment in seconds")
|
||||||
|
status: str = Field(description="One of: candidate, approved, rejected")
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class HighlightBatchResult(BaseModel):
|
||||||
|
"""Summary of a highlight scoring batch run for one video."""
|
||||||
|
|
||||||
|
video_id: uuid.UUID
|
||||||
|
candidates_created: int = Field(ge=0, description="Number of new candidates inserted")
|
||||||
|
candidates_updated: int = Field(ge=0, description="Number of existing candidates re-scored")
|
||||||
|
top_score: float = Field(ge=0.0, le=1.0, description="Highest score in this batch")
|
||||||
413
backend/pipeline/highlight_scorer.py
Normal file
413
backend/pipeline/highlight_scorer.py
Normal file
|
|
@ -0,0 +1,413 @@
|
||||||
|
"""Heuristic scoring engine for highlight candidate detection.
|
||||||
|
|
||||||
|
Takes KeyMoment data + context (source quality, video content type) and
|
||||||
|
returns a composite score in [0, 1] with a 10-dimension breakdown.
|
||||||
|
|
||||||
|
The breakdown fields align with HighlightScoreBreakdown in highlight_schemas.py:
|
||||||
|
duration_score, content_density_score, technique_relevance_score,
|
||||||
|
position_score, uniqueness_score, engagement_proxy_score, plugin_diversity_score,
|
||||||
|
speech_rate_variance_score, pause_density_score, speaking_pace_score
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
import statistics
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
# ── Weights per dimension (must sum to 1.0) ──────────────────────────────────
|
||||||
|
|
||||||
|
_WEIGHTS: dict[str, float] = {
|
||||||
|
"duration_score": 0.20,
|
||||||
|
"content_density_score": 0.15,
|
||||||
|
"technique_relevance_score": 0.15,
|
||||||
|
"plugin_diversity_score": 0.08,
|
||||||
|
"engagement_proxy_score": 0.08,
|
||||||
|
"position_score": 0.08, # mapped from source_quality
|
||||||
|
"uniqueness_score": 0.04, # mapped from video_type
|
||||||
|
"speech_rate_variance_score": 0.08,
|
||||||
|
"pause_density_score": 0.07,
|
||||||
|
"speaking_pace_score": 0.07,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert abs(sum(_WEIGHTS.values()) - 1.0) < 1e-9, "Weights must sum to 1.0"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Individual scoring functions ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _duration_fitness(duration_secs: float) -> float:
|
||||||
|
"""Bell-curve around 30-60s sweet spot.
|
||||||
|
|
||||||
|
Peak at 30-60s (score 1.0), penalty below 15s and above 120s,
|
||||||
|
zero above 300s.
|
||||||
|
"""
|
||||||
|
if duration_secs <= 0:
|
||||||
|
return 0.0
|
||||||
|
if duration_secs >= 300:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Sweet spot: 30-60s → 1.0
|
||||||
|
if 30 <= duration_secs <= 60:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
# Below sweet spot: linear ramp from 0 at 0s to 1.0 at 30s
|
||||||
|
# with steeper penalty below 15s
|
||||||
|
if duration_secs < 30:
|
||||||
|
if duration_secs < 15:
|
||||||
|
return duration_secs / 30.0 # 0→0.5 over 0-15s
|
||||||
|
return 0.5 + (duration_secs - 15) / 30.0 # 0.5→1.0 over 15-30s
|
||||||
|
|
||||||
|
# Above sweet spot: gradual decay from 1.0 at 60s to 0.0 at 300s
|
||||||
|
return max(0.0, 1.0 - (duration_secs - 60) / 240.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _content_type_weight(content_type: str | None) -> float:
|
||||||
|
"""Score based on KeyMoment content_type.
|
||||||
|
|
||||||
|
technique=1.0, settings=0.8, workflow=0.6, reasoning=0.4
|
||||||
|
"""
|
||||||
|
mapping = {
|
||||||
|
"technique": 1.0,
|
||||||
|
"settings": 0.8,
|
||||||
|
"workflow": 0.6,
|
||||||
|
"reasoning": 0.4,
|
||||||
|
}
|
||||||
|
return mapping.get(content_type or "", 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def _specificity_density(summary: str | None) -> float:
|
||||||
|
"""Score based on specificity signals in the summary.
|
||||||
|
|
||||||
|
Counts specific values (numbers, plugin names, dB, Hz, ms, %, ratios)
|
||||||
|
normalized by summary length.
|
||||||
|
"""
|
||||||
|
if not summary:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
words = summary.split()
|
||||||
|
word_count = len(words)
|
||||||
|
if word_count == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Patterns that indicate specificity
|
||||||
|
specificity_patterns = [
|
||||||
|
r"\b\d+\.?\d*\s*(?:dB|Hz|kHz|ms|sec|bpm|%)\b", # units
|
||||||
|
r"\b\d+\.?\d*\s*/\s*\d+\.?\d*\b", # ratios like 3/4
|
||||||
|
r"\b\d{2,}\b", # multi-digit numbers
|
||||||
|
r"\b\d+\.\d+\b", # decimal numbers
|
||||||
|
]
|
||||||
|
|
||||||
|
hits = 0
|
||||||
|
for pattern in specificity_patterns:
|
||||||
|
hits += len(re.findall(pattern, summary, re.IGNORECASE))
|
||||||
|
|
||||||
|
# Normalize: ~1 specific value per 10 words is high density
|
||||||
|
density = hits / (word_count / 10.0)
|
||||||
|
return min(density, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _plugin_richness(plugins: list[str] | None) -> float:
|
||||||
|
"""Score based on number of plugins mentioned.
|
||||||
|
|
||||||
|
min(len(plugins) / 3, 1.0)
|
||||||
|
"""
|
||||||
|
if not plugins:
|
||||||
|
return 0.0
|
||||||
|
return min(len(plugins) / 3.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _transcript_energy(raw_transcript: str | None) -> float:
|
||||||
|
"""Score based on teaching/engagement phrases in transcript.
|
||||||
|
|
||||||
|
Counts teaching phrases ('the trick is', 'notice how', 'because',
|
||||||
|
'I always', 'the key is', 'what I do') normalized by transcript
|
||||||
|
word count.
|
||||||
|
"""
|
||||||
|
if not raw_transcript:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
words = raw_transcript.split()
|
||||||
|
word_count = len(words)
|
||||||
|
if word_count == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
teaching_phrases = [
|
||||||
|
"the trick is",
|
||||||
|
"notice how",
|
||||||
|
"because",
|
||||||
|
"i always",
|
||||||
|
"the key is",
|
||||||
|
"what i do",
|
||||||
|
"important thing",
|
||||||
|
"you want to",
|
||||||
|
"make sure",
|
||||||
|
"here's why",
|
||||||
|
]
|
||||||
|
|
||||||
|
text_lower = raw_transcript.lower()
|
||||||
|
hits = sum(text_lower.count(phrase) for phrase in teaching_phrases)
|
||||||
|
|
||||||
|
# Normalize: ~1 phrase per 50 words is high energy
|
||||||
|
energy = hits / (word_count / 50.0)
|
||||||
|
return min(energy, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _source_quality_weight(source_quality: str | None) -> float:
|
||||||
|
"""Score based on TechniquePage source_quality.
|
||||||
|
|
||||||
|
structured=1.0, mixed=0.7, unstructured=0.4, None=0.5
|
||||||
|
"""
|
||||||
|
mapping = {
|
||||||
|
"structured": 1.0,
|
||||||
|
"mixed": 0.7,
|
||||||
|
"unstructured": 0.4,
|
||||||
|
}
|
||||||
|
return mapping.get(source_quality or "", 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def _video_type_weight(video_content_type: str | None) -> float:
|
||||||
|
"""Score based on SourceVideo content_type.
|
||||||
|
|
||||||
|
tutorial=1.0, breakdown=0.9, livestream=0.5, short_form=0.3
|
||||||
|
"""
|
||||||
|
mapping = {
|
||||||
|
"tutorial": 1.0,
|
||||||
|
"breakdown": 0.9,
|
||||||
|
"livestream": 0.5,
|
||||||
|
"short_form": 0.3,
|
||||||
|
}
|
||||||
|
return mapping.get(video_content_type or "", 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Audio proxy scoring functions ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
def extract_word_timings(
|
||||||
|
transcript_data: list[dict[str, Any]],
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Extract word-level timing dicts from transcript segments within a time window.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
transcript_data : list[dict]
|
||||||
|
Parsed transcript JSON — list of segments, each with a ``words`` array.
|
||||||
|
Each word dict must have ``start`` and ``end`` float fields (seconds).
|
||||||
|
start_time : float
|
||||||
|
Window start in seconds (inclusive).
|
||||||
|
end_time : float
|
||||||
|
Window end in seconds (inclusive).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[dict] — word-timing dicts whose ``start`` falls within [start_time, end_time].
|
||||||
|
"""
|
||||||
|
if not transcript_data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
words: list[dict[str, Any]] = []
|
||||||
|
for segment in transcript_data:
|
||||||
|
seg_words = segment.get("words")
|
||||||
|
if not seg_words:
|
||||||
|
continue
|
||||||
|
for w in seg_words:
|
||||||
|
w_start = w.get("start")
|
||||||
|
if w_start is None:
|
||||||
|
continue
|
||||||
|
if start_time <= w_start <= end_time:
|
||||||
|
words.append(w)
|
||||||
|
return words
|
||||||
|
|
||||||
|
|
||||||
|
def _speech_rate_variance(word_timings: list[dict[str, Any]] | None) -> float:
|
||||||
|
"""Compute normalized stdev of words-per-second in sliding windows.
|
||||||
|
|
||||||
|
High variance indicates emphasis shifts (speeding up / slowing down),
|
||||||
|
which correlates with engaging highlights.
|
||||||
|
|
||||||
|
Uses 5-second sliding windows with 2.5-second step.
|
||||||
|
Returns 0.5 (neutral) when word_timings is None or insufficient data.
|
||||||
|
"""
|
||||||
|
if not word_timings or len(word_timings) < 4:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
# Determine time span
|
||||||
|
first_start = word_timings[0].get("start", 0.0)
|
||||||
|
last_start = word_timings[-1].get("start", 0.0)
|
||||||
|
span = last_start - first_start
|
||||||
|
if span < 5.0:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
# Compute WPS in 5s sliding windows with 2.5s step
|
||||||
|
window_size = 5.0
|
||||||
|
step = 2.5
|
||||||
|
wps_values: list[float] = []
|
||||||
|
|
||||||
|
t = first_start
|
||||||
|
while t + window_size <= last_start + 0.01:
|
||||||
|
count = sum(
|
||||||
|
1 for w in word_timings
|
||||||
|
if t <= w.get("start", 0.0) < t + window_size
|
||||||
|
)
|
||||||
|
wps_values.append(count / window_size)
|
||||||
|
t += step
|
||||||
|
|
||||||
|
if len(wps_values) < 2:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
mean_wps = statistics.mean(wps_values)
|
||||||
|
if mean_wps < 0.01:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
stdev = statistics.stdev(wps_values)
|
||||||
|
# Normalize: coefficient of variation, capped at 1.0
|
||||||
|
# CV of ~0.3-0.5 is typical for varied speech; >0.5 is high variance
|
||||||
|
cv = stdev / mean_wps
|
||||||
|
return min(cv / 0.6, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _pause_density(word_timings: list[dict[str, Any]] | None) -> float:
|
||||||
|
"""Count strategic pauses normalized by duration.
|
||||||
|
|
||||||
|
Inter-word gaps >0.5s and inter-segment gaps >1.0s indicate deliberate
|
||||||
|
pauses for emphasis, which correlate with better highlights.
|
||||||
|
|
||||||
|
Returns 0.5 (neutral) when word_timings is None or insufficient data.
|
||||||
|
"""
|
||||||
|
if not word_timings or len(word_timings) < 2:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
first_start = word_timings[0].get("start", 0.0)
|
||||||
|
last_end = word_timings[-1].get("end", word_timings[-1].get("start", 0.0))
|
||||||
|
duration = last_end - first_start
|
||||||
|
if duration < 1.0:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
short_pauses = 0 # >0.5s gaps
|
||||||
|
long_pauses = 0 # >1.0s gaps
|
||||||
|
|
||||||
|
for i in range(1, len(word_timings)):
|
||||||
|
prev_end = word_timings[i - 1].get("end", word_timings[i - 1].get("start", 0.0))
|
||||||
|
curr_start = word_timings[i].get("start", 0.0)
|
||||||
|
gap = curr_start - prev_end
|
||||||
|
|
||||||
|
if gap > 1.0:
|
||||||
|
long_pauses += 1
|
||||||
|
elif gap > 0.5:
|
||||||
|
short_pauses += 1
|
||||||
|
|
||||||
|
# Weight long pauses more heavily
|
||||||
|
weighted_pauses = short_pauses + long_pauses * 2.0
|
||||||
|
# Normalize: ~2-4 weighted pauses per 30s is good density
|
||||||
|
density = weighted_pauses / (duration / 15.0)
|
||||||
|
return min(density, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _speaking_pace_fitness(word_timings: list[dict[str, Any]] | None) -> float:
|
||||||
|
"""Bell-curve score around 3-5 words-per-second optimal teaching pace.
|
||||||
|
|
||||||
|
3-5 WPS is the sweet spot for tutorial content — fast enough to be
|
||||||
|
engaging, slow enough for comprehension. Returns 0.5 (neutral) when
|
||||||
|
word_timings is None or insufficient data.
|
||||||
|
"""
|
||||||
|
if not word_timings or len(word_timings) < 2:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
first_start = word_timings[0].get("start", 0.0)
|
||||||
|
last_end = word_timings[-1].get("end", word_timings[-1].get("start", 0.0))
|
||||||
|
duration = last_end - first_start
|
||||||
|
if duration < 1.0:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
wps = len(word_timings) / duration
|
||||||
|
|
||||||
|
# Sweet spot: 3-5 WPS → 1.0
|
||||||
|
if 3.0 <= wps <= 5.0:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
# Below sweet spot: linear ramp from 0 at 0 WPS to 1.0 at 3 WPS
|
||||||
|
if wps < 3.0:
|
||||||
|
return max(0.0, wps / 3.0)
|
||||||
|
|
||||||
|
# Above sweet spot: decay from 1.0 at 5 WPS to 0.0 at 10 WPS
|
||||||
|
if wps > 5.0:
|
||||||
|
return max(0.0, 1.0 - (wps - 5.0) / 5.0)
|
||||||
|
|
||||||
|
return 0.5 # unreachable, but defensive
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main scoring function ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def score_moment(
|
||||||
|
*,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
content_type: str | None = None,
|
||||||
|
summary: str | None = None,
|
||||||
|
plugins: list[str] | None = None,
|
||||||
|
raw_transcript: str | None = None,
|
||||||
|
source_quality: str | None = None,
|
||||||
|
video_content_type: str | None = None,
|
||||||
|
word_timings: list[dict[str, Any]] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Score a KeyMoment for highlight potential.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
start_time : float
|
||||||
|
Moment start in seconds.
|
||||||
|
end_time : float
|
||||||
|
Moment end in seconds.
|
||||||
|
content_type : str | None
|
||||||
|
KeyMoment content type (technique, settings, workflow, reasoning).
|
||||||
|
summary : str | None
|
||||||
|
KeyMoment summary text.
|
||||||
|
plugins : list[str] | None
|
||||||
|
Plugins mentioned in the moment.
|
||||||
|
raw_transcript : str | None
|
||||||
|
Raw transcript text of the moment.
|
||||||
|
source_quality : str | None
|
||||||
|
TechniquePage source quality (structured, mixed, unstructured).
|
||||||
|
video_content_type : str | None
|
||||||
|
SourceVideo content type (tutorial, breakdown, livestream, short_form).
|
||||||
|
word_timings : list[dict] | None
|
||||||
|
Word-level timing dicts with ``start`` and ``end`` keys (seconds).
|
||||||
|
When None, audio proxy dimensions score 0.5 (neutral).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict with keys:
|
||||||
|
score : float in [0.0, 1.0]
|
||||||
|
score_breakdown : dict mapping dimension names to float scores
|
||||||
|
duration_secs : float
|
||||||
|
"""
|
||||||
|
duration_secs = max(0.0, end_time - start_time)
|
||||||
|
|
||||||
|
breakdown = {
|
||||||
|
"duration_score": _duration_fitness(duration_secs),
|
||||||
|
"content_density_score": _specificity_density(summary),
|
||||||
|
"technique_relevance_score": _content_type_weight(content_type),
|
||||||
|
"plugin_diversity_score": _plugin_richness(plugins),
|
||||||
|
"engagement_proxy_score": _transcript_energy(raw_transcript),
|
||||||
|
"position_score": _source_quality_weight(source_quality),
|
||||||
|
"uniqueness_score": _video_type_weight(video_content_type),
|
||||||
|
"speech_rate_variance_score": _speech_rate_variance(word_timings),
|
||||||
|
"pause_density_score": _pause_density(word_timings),
|
||||||
|
"speaking_pace_score": _speaking_pace_fitness(word_timings),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Weighted composite
|
||||||
|
composite = sum(
|
||||||
|
breakdown[dim] * weight for dim, weight in _WEIGHTS.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clamp to [0, 1] for safety
|
||||||
|
composite = max(0.0, min(1.0, composite))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"score": composite,
|
||||||
|
"score_breakdown": breakdown,
|
||||||
|
"duration_secs": duration_secs,
|
||||||
|
}
|
||||||
357
backend/pipeline/llm_client.py
Normal file
357
backend/pipeline/llm_client.py
Normal file
|
|
@ -0,0 +1,357 @@
|
||||||
|
"""Synchronous LLM client with primary/fallback endpoint logic.
|
||||||
|
|
||||||
|
Uses the OpenAI-compatible API (works with Ollama, vLLM, OpenWebUI, etc.).
|
||||||
|
Celery tasks run synchronously, so this uses ``openai.OpenAI`` (not Async).
|
||||||
|
|
||||||
|
Supports two modalities:
|
||||||
|
- **chat**: Standard JSON mode with ``response_format: {"type": "json_object"}``
|
||||||
|
- **thinking**: For reasoning models that emit ``<think>...</think>`` blocks
|
||||||
|
before their answer. Skips ``response_format``, appends JSON instructions to
|
||||||
|
the system prompt, and strips think tags from the response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING, TypeVar
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from config import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM Response wrapper ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class LLMResponse(str):
|
||||||
|
"""String subclass that carries LLM response metadata.
|
||||||
|
|
||||||
|
Backward-compatible with all code that treats the response as a plain
|
||||||
|
string, but callers that know about it can inspect finish_reason and
|
||||||
|
the truncated property.
|
||||||
|
"""
|
||||||
|
finish_reason: str | None
|
||||||
|
prompt_tokens: int | None
|
||||||
|
completion_tokens: int | None
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
text: str,
|
||||||
|
finish_reason: str | None = None,
|
||||||
|
prompt_tokens: int | None = None,
|
||||||
|
completion_tokens: int | None = None,
|
||||||
|
):
|
||||||
|
obj = super().__new__(cls, text)
|
||||||
|
obj.finish_reason = finish_reason
|
||||||
|
obj.prompt_tokens = prompt_tokens
|
||||||
|
obj.completion_tokens = completion_tokens
|
||||||
|
return obj
|
||||||
|
|
||||||
|
@property
|
||||||
|
def truncated(self) -> bool:
|
||||||
|
"""True if the model hit its token limit before finishing."""
|
||||||
|
return self.finish_reason == "length"
|
||||||
|
|
||||||
|
# ── Think-tag stripping ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_THINK_PATTERN = re.compile(r"<think>.*?</think>", re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
def strip_think_tags(text: str) -> str:
|
||||||
|
"""Remove ``<think>...</think>`` blocks from LLM output.
|
||||||
|
|
||||||
|
Thinking/reasoning models often prefix their JSON with a reasoning trace
|
||||||
|
wrapped in ``<think>`` tags. This strips all such blocks (including
|
||||||
|
multiline and multiple occurrences) and returns the cleaned text.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Single ``<think>...</think>`` block
|
||||||
|
- Multiple blocks in one response
|
||||||
|
- Multiline content inside think tags
|
||||||
|
- Responses with no think tags (passthrough)
|
||||||
|
- Empty input (passthrough)
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
cleaned = _THINK_PATTERN.sub("", text)
|
||||||
|
return cleaned.strip()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ── Token estimation ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Stage-specific output multipliers: estimated output tokens as a ratio of input tokens.
|
||||||
|
# Tuned from actual pipeline data (KCL Ep 31 audit, April 2026):
|
||||||
|
# stage2: actual compl/prompt = 680/39312 = 0.017 → use 0.05 with buffer
|
||||||
|
# stage3: actual compl/prompt ≈ 1000/7000 = 0.14 → use 0.3 with buffer
|
||||||
|
# stage4: actual compl/prompt = 740/3736 = 0.20 → use 0.3 with buffer
|
||||||
|
# stage5: actual compl/prompt ≈ 2500/7000 = 0.36 → use 0.8 with buffer
|
||||||
|
_STAGE_OUTPUT_RATIOS: dict[str, float] = {
|
||||||
|
"stage2_segmentation": 0.05, # Compact topic groups — much smaller than input
|
||||||
|
"stage3_extraction": 0.3, # Key moments with summaries — moderate
|
||||||
|
"stage4_classification": 0.3, # Tags + categories per moment
|
||||||
|
"stage5_synthesis": 0.8, # Full prose technique pages — heaviest output
|
||||||
|
}
|
||||||
|
|
||||||
|
# Minimum floor so we never send a trivially small max_tokens
|
||||||
|
_MIN_MAX_TOKENS = 4096
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_tokens(text: str) -> int:
|
||||||
|
"""Estimate token count from text using a chars-per-token heuristic.
|
||||||
|
|
||||||
|
Uses 3.5 chars/token which is conservative for English + JSON markup.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
return max(1, int(len(text) / 3.5))
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_max_tokens(
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
stage: str | None = None,
|
||||||
|
hard_limit: int = 32768,
|
||||||
|
) -> int:
|
||||||
|
"""Return the hard_limit as max_tokens for all stages.
|
||||||
|
|
||||||
|
Previously used dynamic estimation based on input size and stage-specific
|
||||||
|
multipliers, but thinking models consume unpredictable token budgets for
|
||||||
|
internal reasoning. A static ceiling avoids truncation errors.
|
||||||
|
|
||||||
|
The hard_limit value comes from Settings.llm_max_tokens_hard_limit (96000).
|
||||||
|
"""
|
||||||
|
input_tokens = estimate_tokens(system_prompt) + estimate_tokens(user_prompt)
|
||||||
|
logger.info(
|
||||||
|
"Token estimate: input≈%d, stage=%s, max_tokens=%d (static hard_limit)",
|
||||||
|
input_tokens, stage or "default", hard_limit,
|
||||||
|
)
|
||||||
|
return hard_limit
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient:
|
||||||
|
"""Sync LLM client that tries a primary endpoint and falls back on failure."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
self.settings = settings
|
||||||
|
self._primary = openai.OpenAI(
|
||||||
|
base_url=settings.llm_api_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
)
|
||||||
|
self._fallback = openai.OpenAI(
|
||||||
|
base_url=settings.llm_fallback_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Core completion ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def complete(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
modality: str = "chat",
|
||||||
|
model_override: str | None = None,
|
||||||
|
on_complete: "Callable | None" = None,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
) -> "LLMResponse":
|
||||||
|
"""Send a chat completion request, falling back on connection/timeout errors.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
system_prompt:
|
||||||
|
System message content.
|
||||||
|
user_prompt:
|
||||||
|
User message content.
|
||||||
|
response_model:
|
||||||
|
If provided and modality is "chat", ``response_format`` is set to
|
||||||
|
``{"type": "json_object"}``. For "thinking" modality, JSON
|
||||||
|
instructions are appended to the system prompt instead.
|
||||||
|
modality:
|
||||||
|
Either "chat" (default) or "thinking". Thinking modality skips
|
||||||
|
response_format and strips ``<think>`` tags from output.
|
||||||
|
model_override:
|
||||||
|
Model name to use instead of the default. If None, uses the
|
||||||
|
configured default for the endpoint.
|
||||||
|
max_tokens:
|
||||||
|
Override for max_tokens on this call. If None, falls back to
|
||||||
|
the configured ``llm_max_tokens`` from settings.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LLMResponse
|
||||||
|
Raw completion text (str subclass) with finish_reason metadata.
|
||||||
|
"""
|
||||||
|
kwargs: dict = {}
|
||||||
|
effective_system = system_prompt
|
||||||
|
|
||||||
|
if modality == "thinking":
|
||||||
|
# Thinking models often don't support response_format: json_object.
|
||||||
|
# Instead, append explicit JSON instructions to the system prompt.
|
||||||
|
if response_model is not None:
|
||||||
|
json_schema_hint = (
|
||||||
|
"\n\nYou MUST respond with ONLY valid JSON. "
|
||||||
|
"No markdown code fences, no explanation, no preamble — "
|
||||||
|
"just the raw JSON object."
|
||||||
|
)
|
||||||
|
effective_system = system_prompt + json_schema_hint
|
||||||
|
else:
|
||||||
|
# Chat modality — use standard JSON mode
|
||||||
|
if response_model is not None:
|
||||||
|
kwargs["response_format"] = {"type": "json_object"}
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": effective_system},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
primary_model = model_override or self.settings.llm_model
|
||||||
|
fallback_model = self.settings.llm_fallback_model
|
||||||
|
effective_max_tokens = max_tokens if max_tokens is not None else self.settings.llm_max_tokens
|
||||||
|
effective_temperature = self.settings.llm_temperature
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"LLM request: model=%s, modality=%s, response_model=%s, max_tokens=%d, temperature=%.1f",
|
||||||
|
primary_model,
|
||||||
|
modality,
|
||||||
|
response_model.__name__ if response_model else None,
|
||||||
|
effective_max_tokens,
|
||||||
|
effective_temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Try primary endpoint ---
|
||||||
|
try:
|
||||||
|
response = self._primary.chat.completions.create(
|
||||||
|
model=primary_model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=effective_max_tokens,
|
||||||
|
temperature=effective_temperature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
raw = response.choices[0].message.content or ""
|
||||||
|
usage = getattr(response, "usage", None)
|
||||||
|
if usage:
|
||||||
|
logger.info(
|
||||||
|
"LLM response: prompt_tokens=%s, completion_tokens=%s, total=%s, content_len=%d, finish=%s",
|
||||||
|
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens,
|
||||||
|
len(raw), response.choices[0].finish_reason,
|
||||||
|
)
|
||||||
|
if modality == "thinking":
|
||||||
|
raw = strip_think_tags(raw)
|
||||||
|
finish = response.choices[0].finish_reason if response.choices else None
|
||||||
|
if on_complete is not None:
|
||||||
|
try:
|
||||||
|
on_complete(
|
||||||
|
model=primary_model,
|
||||||
|
prompt_tokens=usage.prompt_tokens if usage else None,
|
||||||
|
completion_tokens=usage.completion_tokens if usage else None,
|
||||||
|
total_tokens=usage.total_tokens if usage else None,
|
||||||
|
content=raw,
|
||||||
|
finish_reason=finish,
|
||||||
|
)
|
||||||
|
except Exception as cb_exc:
|
||||||
|
logger.warning("on_complete callback failed: %s", cb_exc)
|
||||||
|
return LLMResponse(
|
||||||
|
raw,
|
||||||
|
finish_reason=finish,
|
||||||
|
prompt_tokens=usage.prompt_tokens if usage else None,
|
||||||
|
completion_tokens=usage.completion_tokens if usage else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Primary LLM endpoint failed (%s: %s), trying fallback at %s",
|
||||||
|
type(exc).__name__,
|
||||||
|
exc,
|
||||||
|
self.settings.llm_fallback_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Try fallback endpoint ---
|
||||||
|
try:
|
||||||
|
response = self._fallback.chat.completions.create(
|
||||||
|
model=fallback_model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=effective_max_tokens,
|
||||||
|
temperature=effective_temperature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
raw = response.choices[0].message.content or ""
|
||||||
|
usage = getattr(response, "usage", None)
|
||||||
|
if usage:
|
||||||
|
logger.info(
|
||||||
|
"LLM response (fallback): prompt_tokens=%s, completion_tokens=%s, total=%s, content_len=%d, finish=%s",
|
||||||
|
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens,
|
||||||
|
len(raw), response.choices[0].finish_reason,
|
||||||
|
)
|
||||||
|
if modality == "thinking":
|
||||||
|
raw = strip_think_tags(raw)
|
||||||
|
finish = response.choices[0].finish_reason if response.choices else None
|
||||||
|
if on_complete is not None:
|
||||||
|
try:
|
||||||
|
on_complete(
|
||||||
|
model=fallback_model,
|
||||||
|
prompt_tokens=usage.prompt_tokens if usage else None,
|
||||||
|
completion_tokens=usage.completion_tokens if usage else None,
|
||||||
|
total_tokens=usage.total_tokens if usage else None,
|
||||||
|
content=raw,
|
||||||
|
finish_reason=finish,
|
||||||
|
is_fallback=True,
|
||||||
|
)
|
||||||
|
except Exception as cb_exc:
|
||||||
|
logger.warning("on_complete callback failed: %s", cb_exc)
|
||||||
|
return LLMResponse(
|
||||||
|
raw,
|
||||||
|
finish_reason=finish,
|
||||||
|
prompt_tokens=usage.prompt_tokens if usage else None,
|
||||||
|
completion_tokens=usage.completion_tokens if usage else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
|
||||||
|
logger.error(
|
||||||
|
"Fallback LLM endpoint also failed (%s: %s). Giving up.",
|
||||||
|
type(exc).__name__,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ── Response parsing ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def parse_response(self, text: str, model: type[T]) -> T:
|
||||||
|
"""Parse raw LLM output as JSON and validate against a Pydantic model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
text:
|
||||||
|
Raw JSON string from the LLM.
|
||||||
|
model:
|
||||||
|
Pydantic model class to validate against.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
T
|
||||||
|
Validated Pydantic model instance.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the JSON doesn't match the schema.
|
||||||
|
ValueError
|
||||||
|
If the text is not valid JSON.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return model.model_validate_json(text)
|
||||||
|
except Exception:
|
||||||
|
logger.error(
|
||||||
|
"Failed to parse LLM response as %s. Response text: %.500s",
|
||||||
|
model.__name__,
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
raise
|
||||||
320
backend/pipeline/qdrant_client.py
Normal file
320
backend/pipeline/qdrant_client.py
Normal file
|
|
@ -0,0 +1,320 @@
|
||||||
|
"""Qdrant vector database manager for collection lifecycle and point upserts.
|
||||||
|
|
||||||
|
Handles collection creation (idempotent) and batch upserts for technique pages
|
||||||
|
and key moments. Connection failures are non-blocking — the pipeline continues
|
||||||
|
without search indexing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.http import exceptions as qdrant_exceptions
|
||||||
|
from qdrant_client.models import Distance, PointStruct, VectorParams
|
||||||
|
|
||||||
|
from config import Settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Namespace UUID for deterministic point IDs
|
||||||
|
_QDRANT_NAMESPACE = uuid.UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890")
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantManager:
|
||||||
|
"""Manages a Qdrant collection for Chrysopedia technique-page and key-moment vectors."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
self.settings = settings
|
||||||
|
self._client = QdrantClient(url=settings.qdrant_url)
|
||||||
|
self._collection = settings.qdrant_collection
|
||||||
|
|
||||||
|
# ── Collection management ────────────────────────────────────────────
|
||||||
|
|
||||||
|
def ensure_collection(self) -> None:
|
||||||
|
"""Create the collection if it does not already exist.
|
||||||
|
|
||||||
|
Uses cosine distance and the configured embedding dimensions.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self._client.collection_exists(self._collection):
|
||||||
|
logger.info("Qdrant collection '%s' already exists.", self._collection)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._client.create_collection(
|
||||||
|
collection_name=self._collection,
|
||||||
|
vectors_config=VectorParams(
|
||||||
|
size=self.settings.embedding_dimensions,
|
||||||
|
distance=Distance.COSINE,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Created Qdrant collection '%s' (dim=%d, cosine).",
|
||||||
|
self._collection,
|
||||||
|
self.settings.embedding_dimensions,
|
||||||
|
)
|
||||||
|
except qdrant_exceptions.UnexpectedResponse as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Qdrant error during ensure_collection (%s). Skipping.",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Qdrant connection failed during ensure_collection (%s: %s). Skipping.",
|
||||||
|
type(exc).__name__,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Deletion ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def delete_by_video_id(self, video_id: str) -> int:
|
||||||
|
"""Delete all points (key moments + technique pages) associated with a video.
|
||||||
|
|
||||||
|
Key moments have source_video_id in payload.
|
||||||
|
Technique pages don't have direct video linkage, so only moments are deleted.
|
||||||
|
|
||||||
|
Returns the count of deleted points (best-effort — Qdrant may not report exact counts).
|
||||||
|
"""
|
||||||
|
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self._client.delete(
|
||||||
|
collection_name=self._collection,
|
||||||
|
points_selector=Filter(
|
||||||
|
must=[
|
||||||
|
FieldCondition(
|
||||||
|
key="source_video_id",
|
||||||
|
match=MatchValue(value=video_id),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Deleted Qdrant points for video_id=%s from collection '%s'.",
|
||||||
|
video_id,
|
||||||
|
self._collection,
|
||||||
|
)
|
||||||
|
return 0 # Qdrant delete doesn't return count
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Qdrant delete for video_id=%s failed (%s: %s). Skipping.",
|
||||||
|
video_id,
|
||||||
|
type(exc).__name__,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# ── Low-level upsert ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def upsert_points(self, points: list[PointStruct]) -> None:
|
||||||
|
"""Upsert a batch of pre-built PointStruct objects."""
|
||||||
|
if not points:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self._client.upsert(
|
||||||
|
collection_name=self._collection,
|
||||||
|
points=points,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Upserted %d points to Qdrant collection '%s'.",
|
||||||
|
len(points),
|
||||||
|
self._collection,
|
||||||
|
)
|
||||||
|
except qdrant_exceptions.UnexpectedResponse as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Qdrant upsert failed (%s). %d points skipped.",
|
||||||
|
exc,
|
||||||
|
len(points),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Qdrant upsert connection error (%s: %s). %d points skipped.",
|
||||||
|
type(exc).__name__,
|
||||||
|
exc,
|
||||||
|
len(points),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── High-level upserts ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def upsert_technique_pages(
|
||||||
|
self,
|
||||||
|
pages: list[dict],
|
||||||
|
vectors: list[list[float]],
|
||||||
|
) -> None:
|
||||||
|
"""Build and upsert PointStructs for technique pages.
|
||||||
|
|
||||||
|
Each page dict must contain:
|
||||||
|
page_id, creator_id, title, topic_category, topic_tags, summary
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pages:
|
||||||
|
Metadata dicts, one per technique page.
|
||||||
|
vectors:
|
||||||
|
Corresponding embedding vectors (same order as pages).
|
||||||
|
"""
|
||||||
|
if len(pages) != len(vectors):
|
||||||
|
logger.warning(
|
||||||
|
"Technique-page count (%d) != vector count (%d). Skipping upsert.",
|
||||||
|
len(pages),
|
||||||
|
len(vectors),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
points = []
|
||||||
|
for page, vector in zip(pages, vectors):
|
||||||
|
# Deterministic UUID: same page always gets the same point ID
|
||||||
|
point_id = str(uuid.uuid5(_QDRANT_NAMESPACE, f"tp:{page['page_id']}"))
|
||||||
|
point = PointStruct(
|
||||||
|
id=point_id,
|
||||||
|
vector=vector,
|
||||||
|
payload={
|
||||||
|
"type": "technique_page",
|
||||||
|
"page_id": page["page_id"],
|
||||||
|
"creator_id": page["creator_id"],
|
||||||
|
"creator_name": page.get("creator_name", ""),
|
||||||
|
"title": page["title"],
|
||||||
|
"slug": page.get("slug", ""),
|
||||||
|
"topic_category": page["topic_category"],
|
||||||
|
"topic_tags": page.get("topic_tags") or [],
|
||||||
|
"summary": page.get("summary") or "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
points.append(point)
|
||||||
|
|
||||||
|
self.upsert_points(points)
|
||||||
|
|
||||||
|
def upsert_key_moments(
|
||||||
|
self,
|
||||||
|
moments: list[dict],
|
||||||
|
vectors: list[list[float]],
|
||||||
|
) -> None:
|
||||||
|
"""Build and upsert PointStructs for key moments.
|
||||||
|
|
||||||
|
Each moment dict must contain:
|
||||||
|
moment_id, source_video_id, title, start_time, end_time, content_type
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
moments:
|
||||||
|
Metadata dicts, one per key moment.
|
||||||
|
vectors:
|
||||||
|
Corresponding embedding vectors (same order as moments).
|
||||||
|
"""
|
||||||
|
if len(moments) != len(vectors):
|
||||||
|
logger.warning(
|
||||||
|
"Key-moment count (%d) != vector count (%d). Skipping upsert.",
|
||||||
|
len(moments),
|
||||||
|
len(vectors),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
points = []
|
||||||
|
for moment, vector in zip(moments, vectors):
|
||||||
|
# Deterministic UUID: same moment always gets the same point ID
|
||||||
|
point_id = str(uuid.uuid5(_QDRANT_NAMESPACE, f"km:{moment['moment_id']}"))
|
||||||
|
point = PointStruct(
|
||||||
|
id=point_id,
|
||||||
|
vector=vector,
|
||||||
|
payload={
|
||||||
|
"type": "key_moment",
|
||||||
|
"moment_id": moment["moment_id"],
|
||||||
|
"source_video_id": moment["source_video_id"],
|
||||||
|
"creator_id": moment.get("creator_id", ""),
|
||||||
|
"technique_page_id": moment.get("technique_page_id", ""),
|
||||||
|
"technique_page_slug": moment.get("technique_page_slug", ""),
|
||||||
|
"title": moment["title"],
|
||||||
|
"creator_name": moment.get("creator_name", ""),
|
||||||
|
"start_time": moment["start_time"],
|
||||||
|
"end_time": moment["end_time"],
|
||||||
|
"content_type": moment["content_type"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
points.append(point)
|
||||||
|
|
||||||
|
self.upsert_points(points)
|
||||||
|
|
||||||
|
# ── Technique section operations ─────────────────────────────────────
|
||||||
|
|
||||||
|
def delete_sections_by_page_id(self, page_id: str) -> None:
|
||||||
|
"""Delete all technique_section points for a given page_id.
|
||||||
|
|
||||||
|
Called before re-upserting sections to prevent orphan points when
|
||||||
|
headings are renamed or sections removed. Non-blocking — logs warning
|
||||||
|
on failure.
|
||||||
|
"""
|
||||||
|
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._client.delete(
|
||||||
|
collection_name=self._collection,
|
||||||
|
points_selector=Filter(
|
||||||
|
must=[
|
||||||
|
FieldCondition(
|
||||||
|
key="page_id",
|
||||||
|
match=MatchValue(value=page_id),
|
||||||
|
),
|
||||||
|
FieldCondition(
|
||||||
|
key="type",
|
||||||
|
match=MatchValue(value="technique_section"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Deleted technique_section points for page_id=%s from '%s'.",
|
||||||
|
page_id, self._collection,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Qdrant delete sections for page_id=%s failed (%s: %s). Skipping.",
|
||||||
|
page_id, type(exc).__name__, exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
def upsert_technique_sections(
|
||||||
|
self,
|
||||||
|
sections: list[dict],
|
||||||
|
vectors: list[list[float]],
|
||||||
|
) -> None:
|
||||||
|
"""Build and upsert PointStructs for technique page sections.
|
||||||
|
|
||||||
|
Each section dict must contain:
|
||||||
|
page_id, section_anchor, section_heading, creator_id, creator_name,
|
||||||
|
title (page title), slug (page slug), topic_category, topic_tags, summary
|
||||||
|
|
||||||
|
Uses deterministic UUIDs: ``uuid5(namespace, 'ts:{page_id}:{section_anchor}')``.
|
||||||
|
"""
|
||||||
|
if len(sections) != len(vectors):
|
||||||
|
logger.warning(
|
||||||
|
"Technique-section count (%d) != vector count (%d). Skipping upsert.",
|
||||||
|
len(sections), len(vectors),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
points = []
|
||||||
|
for sec, vector in zip(sections, vectors):
|
||||||
|
point_id = str(uuid.uuid5(
|
||||||
|
_QDRANT_NAMESPACE,
|
||||||
|
f"ts:{sec['page_id']}:{sec['section_anchor']}",
|
||||||
|
))
|
||||||
|
point = PointStruct(
|
||||||
|
id=point_id,
|
||||||
|
vector=vector,
|
||||||
|
payload={
|
||||||
|
"type": "technique_section",
|
||||||
|
"page_id": sec["page_id"],
|
||||||
|
"creator_id": sec.get("creator_id", ""),
|
||||||
|
"creator_name": sec.get("creator_name", ""),
|
||||||
|
"title": sec.get("title", ""),
|
||||||
|
"slug": sec.get("slug", ""),
|
||||||
|
"section_heading": sec["section_heading"],
|
||||||
|
"section_anchor": sec["section_anchor"],
|
||||||
|
"topic_category": sec.get("topic_category", ""),
|
||||||
|
"topic_tags": sec.get("topic_tags") or [],
|
||||||
|
"summary": (sec.get("summary") or "")[:200],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
points.append(point)
|
||||||
|
|
||||||
|
self.upsert_points(points)
|
||||||
11
backend/pipeline/quality/__init__.py
Normal file
11
backend/pipeline/quality/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
"""FYN-LLM quality assurance toolkit."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Ensure backend/ is on sys.path so sibling modules (config, pipeline.llm_client)
|
||||||
|
# resolve when running from the project root via symlink.
|
||||||
|
_backend_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..")
|
||||||
|
_backend_abs = os.path.normpath(os.path.abspath(_backend_dir))
|
||||||
|
if _backend_abs not in sys.path:
|
||||||
|
sys.path.insert(0, _backend_abs)
|
||||||
646
backend/pipeline/quality/__main__.py
Normal file
646
backend/pipeline/quality/__main__.py
Normal file
|
|
@ -0,0 +1,646 @@
|
||||||
|
"""FYN-LLM quality assurance toolkit.
|
||||||
|
|
||||||
|
Subcommands:
|
||||||
|
fitness — Run LLM fitness tests across four categories
|
||||||
|
score — Score a Stage 5 technique page across 5 quality dimensions
|
||||||
|
optimize — Automated prompt optimization loop with leaderboard output
|
||||||
|
|
||||||
|
Run with: python -m pipeline.quality <command>
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from config import get_settings
|
||||||
|
from pipeline.llm_client import LLMClient
|
||||||
|
|
||||||
|
from .chat_eval import ChatEvalRunner
|
||||||
|
from .chat_scorer import ChatScoreRunner
|
||||||
|
from .fitness import FitnessRunner
|
||||||
|
from .optimizer import OptimizationLoop, OptimizationResult
|
||||||
|
from .scorer import DIMENSIONS, STAGE_CONFIGS, ScoreRunner
|
||||||
|
|
||||||
|
|
||||||
|
# ── Reporting helpers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def print_leaderboard(result: OptimizationResult, stage: int = 5) -> None:
|
||||||
|
"""Print a formatted leaderboard of top 5 variants by composite score."""
|
||||||
|
dims = STAGE_CONFIGS[stage].dimensions if stage in STAGE_CONFIGS else DIMENSIONS
|
||||||
|
|
||||||
|
# Filter to entries that actually scored (no errors)
|
||||||
|
scored = [h for h in result.history if not h.get("error")]
|
||||||
|
if not scored:
|
||||||
|
print("\n No successfully scored variants to rank.\n")
|
||||||
|
return
|
||||||
|
|
||||||
|
ranked = sorted(scored, key=lambda h: h["composite"], reverse=True)[:5]
|
||||||
|
|
||||||
|
print(f"\n{'='*72}")
|
||||||
|
print(f" LEADERBOARD — Top 5 Variants by Composite Score (Stage {stage})")
|
||||||
|
print(f"{'='*72}")
|
||||||
|
|
||||||
|
# Header
|
||||||
|
dim_headers = " ".join(f"{d[:5]:>5s}" for d in dims)
|
||||||
|
sep_segments = " ".join("─" * 5 for _ in dims)
|
||||||
|
print(f" {'#':>2s} {'Label':<16s} {'Comp':>5s} {dim_headers}")
|
||||||
|
print(f" {'─'*2} {'─'*16} {'─'*5} {sep_segments}")
|
||||||
|
|
||||||
|
for i, entry in enumerate(ranked, 1):
|
||||||
|
label = entry.get("label", "?")[:16]
|
||||||
|
comp = entry["composite"]
|
||||||
|
dim_vals = " ".join(
|
||||||
|
f"{entry['scores'].get(d, 0.0):5.2f}" for d in dims
|
||||||
|
)
|
||||||
|
bar = "█" * int(comp * 20) + "░" * (20 - int(comp * 20))
|
||||||
|
print(f" {i:>2d} {label:<16s} {comp:5.3f} {dim_vals} {bar}")
|
||||||
|
|
||||||
|
print(f"{'='*72}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def print_trajectory(result: OptimizationResult) -> None:
|
||||||
|
"""Print an ASCII chart of composite score across iterations."""
|
||||||
|
scored = [h for h in result.history if not h.get("error")]
|
||||||
|
if len(scored) < 2:
|
||||||
|
print(" (Not enough data points for trajectory chart)\n")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get the best composite per iteration
|
||||||
|
iter_best: dict[int, float] = {}
|
||||||
|
for h in scored:
|
||||||
|
it = h["iteration"]
|
||||||
|
if it not in iter_best or h["composite"] > iter_best[it]:
|
||||||
|
iter_best[it] = h["composite"]
|
||||||
|
|
||||||
|
iterations = sorted(iter_best.keys())
|
||||||
|
values = [iter_best[it] for it in iterations]
|
||||||
|
|
||||||
|
# Chart dimensions
|
||||||
|
chart_height = 15
|
||||||
|
min_val = max(0.0, min(values) - 0.05)
|
||||||
|
max_val = min(1.0, max(values) + 0.05)
|
||||||
|
val_range = max_val - min_val
|
||||||
|
if val_range < 0.01:
|
||||||
|
val_range = 0.1
|
||||||
|
min_val = max(0.0, values[0] - 0.05)
|
||||||
|
max_val = min_val + val_range
|
||||||
|
|
||||||
|
print(f" {'─'*50}")
|
||||||
|
print(" SCORE TRAJECTORY — Best Composite per Iteration")
|
||||||
|
print(f" {'─'*50}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Render rows top to bottom
|
||||||
|
for row in range(chart_height, -1, -1):
|
||||||
|
threshold = min_val + (row / chart_height) * val_range
|
||||||
|
# Y-axis label every 5 rows
|
||||||
|
if row % 5 == 0:
|
||||||
|
label = f"{threshold:.2f}"
|
||||||
|
else:
|
||||||
|
label = " "
|
||||||
|
line = f" {label} │"
|
||||||
|
|
||||||
|
for vi, val in enumerate(values):
|
||||||
|
normalized = (val - min_val) / val_range
|
||||||
|
filled_rows = int(normalized * chart_height)
|
||||||
|
if filled_rows >= row:
|
||||||
|
line += " ● "
|
||||||
|
else:
|
||||||
|
line += " · "
|
||||||
|
|
||||||
|
print(line)
|
||||||
|
|
||||||
|
# X-axis
|
||||||
|
print(f" ───── ┼{'───' * len(values)}")
|
||||||
|
x_labels = " " + " "
|
||||||
|
for it in iterations:
|
||||||
|
x_labels += f"{it:>2d} "
|
||||||
|
print(x_labels)
|
||||||
|
print(" " + " iteration →")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def write_results_json(
|
||||||
|
result: OptimizationResult,
|
||||||
|
output_dir: str,
|
||||||
|
stage: int,
|
||||||
|
iterations: int,
|
||||||
|
variants_per_iter: int,
|
||||||
|
fixture_path: str,
|
||||||
|
) -> str:
|
||||||
|
"""Write optimization results to a timestamped JSON file. Returns the path."""
|
||||||
|
out_path = Path(output_dir)
|
||||||
|
out_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"optimize_stage{stage}_{timestamp}.json"
|
||||||
|
filepath = out_path / filename
|
||||||
|
|
||||||
|
dims = STAGE_CONFIGS[stage].dimensions if stage in STAGE_CONFIGS else DIMENSIONS
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"config": {
|
||||||
|
"stage": stage,
|
||||||
|
"iterations": iterations,
|
||||||
|
"variants_per_iter": variants_per_iter,
|
||||||
|
"fixture_path": fixture_path,
|
||||||
|
},
|
||||||
|
"best_prompt": result.best_prompt,
|
||||||
|
"best_scores": {
|
||||||
|
"composite": result.best_score.composite,
|
||||||
|
**{d: result.best_score.scores.get(d, 0.0) for d in dims},
|
||||||
|
},
|
||||||
|
"elapsed_seconds": result.elapsed_seconds,
|
||||||
|
"history": result.history,
|
||||||
|
}
|
||||||
|
|
||||||
|
filepath.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||||
|
return str(filepath)
|
||||||
|
|
||||||
|
|
||||||
|
# ── CLI ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="pipeline.quality",
|
||||||
|
description="FYN-LLM quality assurance toolkit",
|
||||||
|
)
|
||||||
|
sub = parser.add_subparsers(dest="command")
|
||||||
|
|
||||||
|
# -- fitness subcommand --
|
||||||
|
sub.add_parser("fitness", help="Run LLM fitness tests across four categories")
|
||||||
|
|
||||||
|
# -- score subcommand --
|
||||||
|
score_parser = sub.add_parser(
|
||||||
|
"score",
|
||||||
|
help="Score a Stage 5 technique page across 5 quality dimensions",
|
||||||
|
)
|
||||||
|
source_group = score_parser.add_mutually_exclusive_group(required=True)
|
||||||
|
source_group.add_argument(
|
||||||
|
"--file",
|
||||||
|
type=str,
|
||||||
|
help="Path to a moments JSON file (creator_name, moments array)",
|
||||||
|
)
|
||||||
|
source_group.add_argument(
|
||||||
|
"--slug",
|
||||||
|
type=str,
|
||||||
|
help="Technique slug to load from the database",
|
||||||
|
)
|
||||||
|
score_parser.add_argument(
|
||||||
|
"--voice-level",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Voice preservation dial (0.0=clinical, 1.0=maximum voice). Triggers re-synthesis before scoring.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- optimize subcommand --
|
||||||
|
opt_parser = sub.add_parser(
|
||||||
|
"optimize",
|
||||||
|
help="Automated prompt optimization loop with leaderboard output",
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- apply subcommand --
|
||||||
|
apply_parser = sub.add_parser(
|
||||||
|
"apply",
|
||||||
|
help="Apply a winning prompt from optimization results to the stage's prompt file",
|
||||||
|
)
|
||||||
|
apply_parser.add_argument(
|
||||||
|
"results_file",
|
||||||
|
type=str,
|
||||||
|
help="Path to an optimization results JSON file",
|
||||||
|
)
|
||||||
|
apply_parser.add_argument(
|
||||||
|
"--dry-run",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Show what would change without writing",
|
||||||
|
)
|
||||||
|
opt_parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Pipeline stage to optimize (default: 5)",
|
||||||
|
)
|
||||||
|
opt_parser.add_argument(
|
||||||
|
"--iterations",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Number of optimization iterations (default: 10)",
|
||||||
|
)
|
||||||
|
opt_parser.add_argument(
|
||||||
|
"--variants-per-iter",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Variants generated per iteration (default: 2)",
|
||||||
|
)
|
||||||
|
opt_source = opt_parser.add_mutually_exclusive_group(required=True)
|
||||||
|
opt_source.add_argument(
|
||||||
|
"--file",
|
||||||
|
type=str,
|
||||||
|
help="Path to moments JSON fixture file",
|
||||||
|
)
|
||||||
|
opt_source.add_argument(
|
||||||
|
"--video-id",
|
||||||
|
type=str,
|
||||||
|
help="Video UUID — exports fixture from DB automatically (requires DATABASE_URL, REDIS_URL)",
|
||||||
|
)
|
||||||
|
opt_parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
default="backend/pipeline/quality/results/",
|
||||||
|
help="Directory to write result JSON (default: backend/pipeline/quality/results/)",
|
||||||
|
)
|
||||||
|
opt_parser.add_argument(
|
||||||
|
"--apply",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Write the winning prompt back to the stage's prompt file (backs up the original first)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- chat_eval subcommand --
|
||||||
|
chat_parser = sub.add_parser(
|
||||||
|
"chat_eval",
|
||||||
|
help="Evaluate chat quality across a test suite of queries",
|
||||||
|
)
|
||||||
|
chat_parser.add_argument(
|
||||||
|
"--suite",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to a chat test suite YAML/JSON file",
|
||||||
|
)
|
||||||
|
chat_parser.add_argument(
|
||||||
|
"--base-url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8096",
|
||||||
|
help="Chat API base URL (default: http://localhost:8096)",
|
||||||
|
)
|
||||||
|
chat_parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
type=str,
|
||||||
|
default="backend/pipeline/quality/results/",
|
||||||
|
help="Output path for results JSON (default: backend/pipeline/quality/results/)",
|
||||||
|
)
|
||||||
|
chat_parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=float,
|
||||||
|
default=120.0,
|
||||||
|
help="Request timeout in seconds (default: 120)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.command is None:
|
||||||
|
parser.print_help()
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if args.command == "fitness":
|
||||||
|
settings = get_settings()
|
||||||
|
client = LLMClient(settings)
|
||||||
|
runner = FitnessRunner(client)
|
||||||
|
return runner.run_all()
|
||||||
|
|
||||||
|
if args.command == "score":
|
||||||
|
return _run_score(args)
|
||||||
|
|
||||||
|
if args.command == "optimize":
|
||||||
|
return _run_optimize(args)
|
||||||
|
|
||||||
|
if args.command == "apply":
|
||||||
|
return _run_apply(args)
|
||||||
|
|
||||||
|
if args.command == "chat_eval":
|
||||||
|
return _run_chat_eval(args)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _run_score(args: argparse.Namespace) -> int:
|
||||||
|
"""Execute the score subcommand."""
|
||||||
|
# -- Load source data --
|
||||||
|
if args.slug:
|
||||||
|
print("DB loading not yet implemented", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(args.file) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"File not found: {args.file}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
print(f"Invalid JSON in {args.file}: {exc}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
moments = data.get("moments", [])
|
||||||
|
creator_name = data.get("creator_name", "Unknown")
|
||||||
|
|
||||||
|
if not moments:
|
||||||
|
print("No moments found in input file", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
client = LLMClient(settings)
|
||||||
|
runner = ScoreRunner(client)
|
||||||
|
|
||||||
|
# -- Voice-level mode: re-synthesize then score --
|
||||||
|
if args.voice_level is not None:
|
||||||
|
voice_level = args.voice_level
|
||||||
|
if not (0.0 <= voice_level <= 1.0):
|
||||||
|
print("--voice-level must be between 0.0 and 1.0", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
print(f"\nRe-synthesizing + scoring for '{creator_name}' ({len(moments)} moments, voice_level={voice_level})...")
|
||||||
|
result = runner.synthesize_and_score(moments, creator_name, voice_level)
|
||||||
|
|
||||||
|
if result.error:
|
||||||
|
runner.print_report(result)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
runner.print_report(result)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# -- Standard mode: build page stub from moments, score directly --
|
||||||
|
page_json = {
|
||||||
|
"title": f"{creator_name} — Technique Page",
|
||||||
|
"creator_name": creator_name,
|
||||||
|
"summary": f"Technique page synthesized from {len(moments)} key moments.",
|
||||||
|
"body_sections": [
|
||||||
|
{
|
||||||
|
"heading": m.get("topic_tags", ["Technique"])[0] if m.get("topic_tags") else "Technique",
|
||||||
|
"content": m.get("summary", "") + "\n\n" + m.get("transcript_excerpt", ""),
|
||||||
|
}
|
||||||
|
for m in moments
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"\nScoring page for '{creator_name}' ({len(moments)} moments)...")
|
||||||
|
|
||||||
|
result = runner.score_page(page_json, moments)
|
||||||
|
|
||||||
|
if result.error:
|
||||||
|
runner.print_report(result)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
runner.print_report(result)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _run_optimize(args: argparse.Namespace) -> int:
|
||||||
|
"""Execute the optimize subcommand."""
|
||||||
|
# Stage validation — stages 2-5 are supported
|
||||||
|
if args.stage not in STAGE_CONFIGS:
|
||||||
|
print(
|
||||||
|
f"Error: unsupported stage {args.stage}. Valid stages: {sorted(STAGE_CONFIGS)}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Resolve fixture: either from --file or auto-export from --video-id
|
||||||
|
fixture_path: str
|
||||||
|
if args.file:
|
||||||
|
fixture_path = args.file
|
||||||
|
else:
|
||||||
|
# Auto-export from database
|
||||||
|
print(f"\n[OPTIMIZE] Exporting fixture from video_id={args.video_id}...", file=sys.stderr)
|
||||||
|
import tempfile
|
||||||
|
from pipeline.export_fixture import export_fixture
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
tmp = tempfile.NamedTemporaryFile(suffix=".json", prefix="optimize_fixture_", delete=False)
|
||||||
|
tmp.close()
|
||||||
|
exit_code = export_fixture(
|
||||||
|
database_url=settings.database_url,
|
||||||
|
redis_url=settings.redis_url,
|
||||||
|
video_id=args.video_id,
|
||||||
|
output_path=tmp.name,
|
||||||
|
)
|
||||||
|
if exit_code != 0:
|
||||||
|
print(f"Error: fixture export failed (exit code {exit_code})", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
fixture_path = tmp.name
|
||||||
|
print(f"[OPTIMIZE] Fixture exported to: {fixture_path}", file=sys.stderr)
|
||||||
|
|
||||||
|
fixture = Path(fixture_path)
|
||||||
|
if not fixture.exists():
|
||||||
|
print(f"Error: fixture file not found: {fixture_path}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Ensure output dir
|
||||||
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
client = LLMClient(settings)
|
||||||
|
|
||||||
|
loop = OptimizationLoop(
|
||||||
|
client=client,
|
||||||
|
stage=args.stage,
|
||||||
|
fixture_path=fixture_path,
|
||||||
|
iterations=args.iterations,
|
||||||
|
variants_per_iter=args.variants_per_iter,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = loop.run()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n Optimization interrupted by user.", file=sys.stderr)
|
||||||
|
return 130
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"\nError: optimization failed: {exc}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# If the loop returned an error on baseline, report and exit
|
||||||
|
if result.best_score.error and not result.history:
|
||||||
|
print(f"\nError: {result.best_score.error}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Reporting
|
||||||
|
print_leaderboard(result, stage=args.stage)
|
||||||
|
print_trajectory(result)
|
||||||
|
|
||||||
|
# Write results JSON
|
||||||
|
try:
|
||||||
|
json_path = write_results_json(
|
||||||
|
result=result,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
stage=args.stage,
|
||||||
|
iterations=args.iterations,
|
||||||
|
variants_per_iter=args.variants_per_iter,
|
||||||
|
fixture_path=fixture_path,
|
||||||
|
)
|
||||||
|
print(f" Results written to: {json_path}")
|
||||||
|
except OSError as exc:
|
||||||
|
print(f" Warning: failed to write results JSON: {exc}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Apply winning prompt if requested
|
||||||
|
if args.apply:
|
||||||
|
baseline_composite = 0.0
|
||||||
|
for h in result.history:
|
||||||
|
if h.get("label") == "baseline" and not h.get("error"):
|
||||||
|
baseline_composite = h["composite"]
|
||||||
|
break
|
||||||
|
|
||||||
|
if result.best_score.composite <= baseline_composite:
|
||||||
|
print("\n --apply: Best prompt did not beat baseline — skipping apply.")
|
||||||
|
elif result.best_score.error:
|
||||||
|
print("\n --apply: Best result has an error — skipping apply.")
|
||||||
|
else:
|
||||||
|
print("\n --apply: Winning prompt beat baseline — applying...")
|
||||||
|
success, msg = apply_prompt(args.stage, result.best_prompt)
|
||||||
|
print(f" {msg}")
|
||||||
|
if not success:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def apply_prompt(stage: int, new_prompt: str, dry_run: bool = False) -> tuple[bool, str]:
|
||||||
|
"""Apply a new prompt to a stage's prompt file. Returns (success, message).
|
||||||
|
|
||||||
|
Creates a timestamped backup of the original before overwriting.
|
||||||
|
"""
|
||||||
|
if stage not in STAGE_CONFIGS:
|
||||||
|
return False, f"Unsupported stage {stage}. Valid: {sorted(STAGE_CONFIGS)}"
|
||||||
|
|
||||||
|
config = STAGE_CONFIGS[stage]
|
||||||
|
settings = get_settings()
|
||||||
|
prompt_path = Path(settings.prompts_path) / config.prompt_file
|
||||||
|
|
||||||
|
if not prompt_path.exists():
|
||||||
|
return False, f"Prompt file not found: {prompt_path}"
|
||||||
|
|
||||||
|
original = prompt_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
if original.strip() == new_prompt.strip():
|
||||||
|
return True, "No change — winning prompt is identical to current prompt."
|
||||||
|
|
||||||
|
# Show diff summary
|
||||||
|
orig_lines = original.strip().splitlines()
|
||||||
|
new_lines = new_prompt.strip().splitlines()
|
||||||
|
print(f"\n Prompt file: {prompt_path}")
|
||||||
|
print(f" Original: {len(orig_lines)} lines, {len(original)} chars")
|
||||||
|
print(f" New: {len(new_lines)} lines, {len(new_prompt)} chars")
|
||||||
|
|
||||||
|
# Simple line-level diff summary
|
||||||
|
import difflib
|
||||||
|
diff = list(difflib.unified_diff(orig_lines, new_lines, lineterm="", n=2))
|
||||||
|
added = sum(1 for l in diff if l.startswith("+") and not l.startswith("+++"))
|
||||||
|
removed = sum(1 for l in diff if l.startswith("-") and not l.startswith("---"))
|
||||||
|
print(f" Changes: +{added} lines, -{removed} lines")
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
print("\n [DRY RUN] Would write to:", prompt_path)
|
||||||
|
if len(diff) <= 40:
|
||||||
|
print()
|
||||||
|
for line in diff:
|
||||||
|
print(f" {line}")
|
||||||
|
else:
|
||||||
|
print(f"\n (diff is {len(diff)} lines — showing first 30)")
|
||||||
|
for line in diff[:30]:
|
||||||
|
print(f" {line}")
|
||||||
|
print(" ...")
|
||||||
|
return True, "Dry run — no files modified."
|
||||||
|
|
||||||
|
# Backup original
|
||||||
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||||
|
backup_path = prompt_path.with_suffix(f".{timestamp}.bak")
|
||||||
|
backup_path.write_text(original, encoding="utf-8")
|
||||||
|
print(f" Backup: {backup_path}")
|
||||||
|
|
||||||
|
# Write new prompt
|
||||||
|
prompt_path.write_text(new_prompt, encoding="utf-8")
|
||||||
|
print(f" ✓ Written: {prompt_path}")
|
||||||
|
|
||||||
|
return True, f"Prompt applied. Backup at {backup_path}"
|
||||||
|
|
||||||
|
|
||||||
|
def _run_apply(args: argparse.Namespace) -> int:
|
||||||
|
"""Execute the apply subcommand — read a results JSON and apply the winning prompt."""
|
||||||
|
results_path = Path(args.results_file)
|
||||||
|
if not results_path.exists():
|
||||||
|
print(f"Error: results file not found: {args.results_file}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(results_path.read_text(encoding="utf-8"))
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
print(f"Error: invalid JSON in {args.results_file}: {exc}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
stage = data.get("config", {}).get("stage")
|
||||||
|
best_prompt = data.get("best_prompt", "")
|
||||||
|
best_scores = data.get("best_scores", {})
|
||||||
|
|
||||||
|
if not stage:
|
||||||
|
print("Error: results JSON missing config.stage", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
if not best_prompt:
|
||||||
|
print("Error: results JSON missing best_prompt or it's empty", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
composite = best_scores.get("composite", 0)
|
||||||
|
print(f"\n Applying results from: {results_path}")
|
||||||
|
print(f" Stage: {stage}")
|
||||||
|
print(f" Best composite score: {composite:.3f}")
|
||||||
|
|
||||||
|
success, msg = apply_prompt(stage, best_prompt, dry_run=args.dry_run)
|
||||||
|
print(f"\n {msg}")
|
||||||
|
return 0 if success else 1
|
||||||
|
|
||||||
|
|
||||||
|
def _run_chat_eval(args: argparse.Namespace) -> int:
|
||||||
|
"""Execute the chat_eval subcommand — evaluate chat quality across a test suite."""
|
||||||
|
suite_path = Path(args.suite)
|
||||||
|
if not suite_path.exists():
|
||||||
|
print(f"Error: suite file not found: {args.suite}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Load test cases
|
||||||
|
try:
|
||||||
|
cases = ChatEvalRunner.load_suite(suite_path)
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"Error loading test suite: {exc}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if not cases:
|
||||||
|
print("Error: test suite contains no queries", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
print(f"\n Chat Evaluation: {len(cases)} queries from {suite_path}")
|
||||||
|
print(f" Endpoint: {args.base_url}")
|
||||||
|
|
||||||
|
# Build scorer and runner
|
||||||
|
settings = get_settings()
|
||||||
|
client = LLMClient(settings)
|
||||||
|
scorer = ChatScoreRunner(client)
|
||||||
|
runner = ChatEvalRunner(
|
||||||
|
scorer=scorer,
|
||||||
|
base_url=args.base_url,
|
||||||
|
timeout=args.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
results = runner.run_suite(cases)
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
runner.print_summary(results)
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
try:
|
||||||
|
json_path = runner.write_results(results, args.output)
|
||||||
|
print(f" Results written to: {json_path}")
|
||||||
|
except OSError as exc:
|
||||||
|
print(f" Warning: failed to write results: {exc}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Exit code: 0 if at least one scored, 1 if all errored
|
||||||
|
scored = [r for r in results if r.score and not r.score.error and not r.request_error]
|
||||||
|
return 0 if scored else 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
352
backend/pipeline/quality/chat_eval.py
Normal file
352
backend/pipeline/quality/chat_eval.py
Normal file
|
|
@ -0,0 +1,352 @@
|
||||||
|
"""Chat evaluation harness — sends queries to the live chat endpoint, scores responses.
|
||||||
|
|
||||||
|
Loads a test suite (YAML or JSON), calls the chat HTTP endpoint for each query,
|
||||||
|
parses SSE events to collect response text and sources, then scores each using
|
||||||
|
ChatScoreRunner. Writes results to a JSON file.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m pipeline.quality chat_eval --suite fixtures/chat_test_suite.yaml
|
||||||
|
python -m pipeline.quality chat_eval --suite fixtures/chat_test_suite.yaml --base-url http://ub01:8096
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from pipeline.llm_client import LLMClient
|
||||||
|
from pipeline.quality.chat_scorer import CHAT_DIMENSIONS, ChatScoreResult, ChatScoreRunner
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_BASE_URL = "http://localhost:8096"
|
||||||
|
_CHAT_ENDPOINT = "/api/chat"
|
||||||
|
_REQUEST_TIMEOUT = 120.0 # seconds — LLM streaming can be slow
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatTestCase:
|
||||||
|
"""A single test case from the test suite."""
|
||||||
|
|
||||||
|
query: str
|
||||||
|
creator: str | None = None
|
||||||
|
personality_weight: float = 0.0
|
||||||
|
category: str = "general"
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatEvalResult:
|
||||||
|
"""Result of evaluating a single test case."""
|
||||||
|
|
||||||
|
test_case: ChatTestCase
|
||||||
|
response: str = ""
|
||||||
|
sources: list[dict] = field(default_factory=list)
|
||||||
|
cascade_tier: str = ""
|
||||||
|
score: ChatScoreResult | None = None
|
||||||
|
request_error: str | None = None
|
||||||
|
latency_seconds: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class ChatEvalRunner:
|
||||||
|
"""Runs a chat evaluation suite against a live endpoint."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scorer: ChatScoreRunner,
|
||||||
|
base_url: str = _DEFAULT_BASE_URL,
|
||||||
|
timeout: float = _REQUEST_TIMEOUT,
|
||||||
|
) -> None:
|
||||||
|
self.scorer = scorer
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_suite(path: str | Path) -> list[ChatTestCase]:
|
||||||
|
"""Load test cases from a YAML or JSON file.
|
||||||
|
|
||||||
|
Expected format (YAML):
|
||||||
|
queries:
|
||||||
|
- query: "How do I sidechain a bass?"
|
||||||
|
creator: null
|
||||||
|
personality_weight: 0.0
|
||||||
|
category: technical
|
||||||
|
description: "Basic sidechain compression question"
|
||||||
|
"""
|
||||||
|
filepath = Path(path)
|
||||||
|
text = filepath.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
if filepath.suffix in (".yaml", ".yml"):
|
||||||
|
try:
|
||||||
|
import yaml
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"PyYAML is required to load YAML test suites. "
|
||||||
|
"Install with: pip install pyyaml"
|
||||||
|
)
|
||||||
|
data = yaml.safe_load(text)
|
||||||
|
else:
|
||||||
|
data = json.loads(text)
|
||||||
|
|
||||||
|
queries = data.get("queries", [])
|
||||||
|
cases: list[ChatTestCase] = []
|
||||||
|
for q in queries:
|
||||||
|
cases.append(ChatTestCase(
|
||||||
|
query=q["query"],
|
||||||
|
creator=q.get("creator"),
|
||||||
|
personality_weight=float(q.get("personality_weight", 0.0)),
|
||||||
|
category=q.get("category", "general"),
|
||||||
|
description=q.get("description", ""),
|
||||||
|
))
|
||||||
|
return cases
|
||||||
|
|
||||||
|
def run_suite(self, cases: list[ChatTestCase]) -> list[ChatEvalResult]:
|
||||||
|
"""Execute all test cases sequentially, scoring each response."""
|
||||||
|
results: list[ChatEvalResult] = []
|
||||||
|
|
||||||
|
for i, case in enumerate(cases, 1):
|
||||||
|
print(f"\n [{i}/{len(cases)}] {case.category}: {case.query[:60]}...")
|
||||||
|
result = self._run_single(case)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
if result.request_error:
|
||||||
|
print(f" ✗ Request error: {result.request_error}")
|
||||||
|
elif result.score and result.score.error:
|
||||||
|
print(f" ✗ Scoring error: {result.score.error}")
|
||||||
|
elif result.score:
|
||||||
|
print(f" ✓ Composite: {result.score.composite:.3f} "
|
||||||
|
f"(latency: {result.latency_seconds:.1f}s)")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _run_single(self, case: ChatTestCase) -> ChatEvalResult:
|
||||||
|
"""Execute a single test case: call endpoint, parse SSE, score."""
|
||||||
|
eval_result = ChatEvalResult(test_case=case)
|
||||||
|
|
||||||
|
# Call the chat endpoint
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
response_text, sources, cascade_tier = self._call_chat_endpoint(case)
|
||||||
|
eval_result.latency_seconds = round(time.monotonic() - t0, 2)
|
||||||
|
except Exception as exc:
|
||||||
|
eval_result.latency_seconds = round(time.monotonic() - t0, 2)
|
||||||
|
eval_result.request_error = str(exc)
|
||||||
|
logger.error("chat_eval_request_error query=%r error=%s", case.query, exc)
|
||||||
|
return eval_result
|
||||||
|
|
||||||
|
eval_result.response = response_text
|
||||||
|
eval_result.sources = sources
|
||||||
|
eval_result.cascade_tier = cascade_tier
|
||||||
|
|
||||||
|
if not response_text:
|
||||||
|
eval_result.request_error = "Empty response from chat endpoint"
|
||||||
|
return eval_result
|
||||||
|
|
||||||
|
# Score the response
|
||||||
|
eval_result.score = self.scorer.score_response(
|
||||||
|
query=case.query,
|
||||||
|
response=response_text,
|
||||||
|
sources=sources,
|
||||||
|
personality_weight=case.personality_weight,
|
||||||
|
creator_name=case.creator,
|
||||||
|
)
|
||||||
|
|
||||||
|
return eval_result
|
||||||
|
|
||||||
|
def _call_chat_endpoint(
|
||||||
|
self, case: ChatTestCase
|
||||||
|
) -> tuple[str, list[dict], str]:
|
||||||
|
"""Call the chat SSE endpoint and parse the event stream.
|
||||||
|
|
||||||
|
Returns (accumulated_text, sources_list, cascade_tier).
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}{_CHAT_ENDPOINT}"
|
||||||
|
payload: dict[str, Any] = {"query": case.query}
|
||||||
|
if case.creator:
|
||||||
|
payload["creator"] = case.creator
|
||||||
|
if case.personality_weight > 0:
|
||||||
|
payload["personality_weight"] = case.personality_weight
|
||||||
|
|
||||||
|
sources: list[dict] = []
|
||||||
|
accumulated = ""
|
||||||
|
cascade_tier = ""
|
||||||
|
|
||||||
|
with httpx.Client(timeout=self.timeout) as client:
|
||||||
|
with client.stream("POST", url, json=payload) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
for chunk in resp.iter_text():
|
||||||
|
buffer += chunk
|
||||||
|
# Parse SSE events from buffer
|
||||||
|
while "\n\n" in buffer:
|
||||||
|
event_block, buffer = buffer.split("\n\n", 1)
|
||||||
|
event_type, event_data = self._parse_sse_event(event_block)
|
||||||
|
|
||||||
|
if event_type == "sources":
|
||||||
|
sources = event_data if isinstance(event_data, list) else []
|
||||||
|
elif event_type == "token":
|
||||||
|
accumulated += event_data if isinstance(event_data, str) else str(event_data)
|
||||||
|
elif event_type == "done":
|
||||||
|
if isinstance(event_data, dict):
|
||||||
|
cascade_tier = event_data.get("cascade_tier", "")
|
||||||
|
elif event_type == "error":
|
||||||
|
msg = event_data.get("message", str(event_data)) if isinstance(event_data, dict) else str(event_data)
|
||||||
|
raise RuntimeError(f"Chat endpoint returned error: {msg}")
|
||||||
|
|
||||||
|
return accumulated, sources, cascade_tier
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_sse_event(block: str) -> tuple[str, Any]:
|
||||||
|
"""Parse a single SSE event block into (event_type, data)."""
|
||||||
|
event_type = ""
|
||||||
|
data_lines: list[str] = []
|
||||||
|
|
||||||
|
for line in block.strip().splitlines():
|
||||||
|
if line.startswith("event: "):
|
||||||
|
event_type = line[7:].strip()
|
||||||
|
elif line.startswith("data: "):
|
||||||
|
data_lines.append(line[6:])
|
||||||
|
elif line.startswith("data:"):
|
||||||
|
data_lines.append(line[5:])
|
||||||
|
|
||||||
|
raw_data = "\n".join(data_lines)
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw_data)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
parsed = raw_data # plain text token
|
||||||
|
|
||||||
|
return event_type, parsed
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_results(
|
||||||
|
results: list[ChatEvalResult],
|
||||||
|
output_path: str | Path,
|
||||||
|
) -> str:
|
||||||
|
"""Write evaluation results to a JSON file. Returns the path."""
|
||||||
|
out = Path(output_path)
|
||||||
|
out.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||||
|
if out.is_dir():
|
||||||
|
filepath = out / f"chat_eval_{timestamp}.json"
|
||||||
|
else:
|
||||||
|
filepath = out
|
||||||
|
|
||||||
|
# Build serializable payload
|
||||||
|
entries: list[dict] = []
|
||||||
|
for r in results:
|
||||||
|
entry: dict[str, Any] = {
|
||||||
|
"query": r.test_case.query,
|
||||||
|
"creator": r.test_case.creator,
|
||||||
|
"personality_weight": r.test_case.personality_weight,
|
||||||
|
"category": r.test_case.category,
|
||||||
|
"description": r.test_case.description,
|
||||||
|
"response_length": len(r.response),
|
||||||
|
"source_count": len(r.sources),
|
||||||
|
"cascade_tier": r.cascade_tier,
|
||||||
|
"latency_seconds": r.latency_seconds,
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.request_error:
|
||||||
|
entry["error"] = r.request_error
|
||||||
|
elif r.score:
|
||||||
|
entry["scores"] = r.score.scores
|
||||||
|
entry["composite"] = r.score.composite
|
||||||
|
entry["justifications"] = r.score.justifications
|
||||||
|
entry["scoring_time"] = r.score.elapsed_seconds
|
||||||
|
if r.score.error:
|
||||||
|
entry["scoring_error"] = r.score.error
|
||||||
|
|
||||||
|
entries.append(entry)
|
||||||
|
|
||||||
|
# Summary stats
|
||||||
|
scored = [e for e in entries if "composite" in e]
|
||||||
|
avg_composite = (
|
||||||
|
sum(e["composite"] for e in scored) / len(scored) if scored else 0.0
|
||||||
|
)
|
||||||
|
dim_avgs: dict[str, float] = {}
|
||||||
|
for dim in CHAT_DIMENSIONS:
|
||||||
|
vals = [e["scores"][dim] for e in scored if dim in e.get("scores", {})]
|
||||||
|
dim_avgs[dim] = round(sum(vals) / len(vals), 3) if vals else 0.0
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"total_queries": len(results),
|
||||||
|
"scored_queries": len(scored),
|
||||||
|
"errors": len(results) - len(scored),
|
||||||
|
"average_composite": round(avg_composite, 3),
|
||||||
|
"dimension_averages": dim_avgs,
|
||||||
|
"results": entries,
|
||||||
|
}
|
||||||
|
|
||||||
|
filepath.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||||
|
return str(filepath)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def print_summary(results: list[ChatEvalResult]) -> None:
|
||||||
|
"""Print a summary table of evaluation results."""
|
||||||
|
print("\n" + "=" * 72)
|
||||||
|
print(" CHAT EVALUATION SUMMARY")
|
||||||
|
print("=" * 72)
|
||||||
|
|
||||||
|
scored = [r for r in results if r.score and not r.score.error and not r.request_error]
|
||||||
|
errored = [r for r in results if r.request_error or (r.score and r.score.error)]
|
||||||
|
|
||||||
|
if not scored:
|
||||||
|
print("\n No successfully scored responses.\n")
|
||||||
|
if errored:
|
||||||
|
print(f" Errors: {len(errored)}")
|
||||||
|
for r in errored:
|
||||||
|
err = r.request_error or (r.score.error if r.score else "unknown")
|
||||||
|
print(f" - {r.test_case.query[:50]}: {err}")
|
||||||
|
print("=" * 72 + "\n")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Header
|
||||||
|
print(f"\n {'Category':<12s} {'Query':<30s} {'Comp':>5s} {'Cite':>5s} {'Struct':>6s} {'Domain':>6s} {'Ground':>6s} {'Person':>6s}")
|
||||||
|
print(f" {'─'*12} {'─'*30} {'─'*5} {'─'*5} {'─'*6} {'─'*6} {'─'*6} {'─'*6}")
|
||||||
|
|
||||||
|
for r in scored:
|
||||||
|
s = r.score
|
||||||
|
assert s is not None
|
||||||
|
q = r.test_case.query[:30]
|
||||||
|
cat = r.test_case.category[:12]
|
||||||
|
print(
|
||||||
|
f" {cat:<12s} {q:<30s} "
|
||||||
|
f"{s.composite:5.2f} "
|
||||||
|
f"{s.citation_accuracy:5.2f} "
|
||||||
|
f"{s.response_structure:6.2f} "
|
||||||
|
f"{s.domain_expertise:6.2f} "
|
||||||
|
f"{s.source_grounding:6.2f} "
|
||||||
|
f"{s.personality_fidelity:6.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Averages
|
||||||
|
avg_comp = sum(r.score.composite for r in scored) / len(scored)
|
||||||
|
avg_dims = {}
|
||||||
|
for dim in CHAT_DIMENSIONS:
|
||||||
|
vals = [r.score.scores.get(dim, 0.0) for r in scored]
|
||||||
|
avg_dims[dim] = sum(vals) / len(vals)
|
||||||
|
|
||||||
|
print(f"\n {'AVERAGE':<12s} {'':30s} "
|
||||||
|
f"{avg_comp:5.2f} "
|
||||||
|
f"{avg_dims['citation_accuracy']:5.2f} "
|
||||||
|
f"{avg_dims['response_structure']:6.2f} "
|
||||||
|
f"{avg_dims['domain_expertise']:6.2f} "
|
||||||
|
f"{avg_dims['source_grounding']:6.2f} "
|
||||||
|
f"{avg_dims['personality_fidelity']:6.2f}")
|
||||||
|
|
||||||
|
if errored:
|
||||||
|
print(f"\n Errors: {len(errored)}")
|
||||||
|
for r in errored:
|
||||||
|
err = r.request_error or (r.score.error if r.score else "unknown")
|
||||||
|
print(f" - {r.test_case.query[:50]}: {err}")
|
||||||
|
|
||||||
|
print("=" * 72 + "\n")
|
||||||
271
backend/pipeline/quality/chat_scorer.py
Normal file
271
backend/pipeline/quality/chat_scorer.py
Normal file
|
|
@ -0,0 +1,271 @@
|
||||||
|
"""Chat-specific quality scorer — LLM-as-judge evaluation for chat responses.
|
||||||
|
|
||||||
|
Scores chat responses across 5 dimensions:
|
||||||
|
- citation_accuracy: Are citations real and correctly numbered?
|
||||||
|
- response_structure: Concise, well-organized, uses appropriate formatting?
|
||||||
|
- domain_expertise: Music production terminology used naturally?
|
||||||
|
- source_grounding: Claims backed by provided sources, no fabrication?
|
||||||
|
- personality_fidelity: At weight>0, response reflects creator voice?
|
||||||
|
|
||||||
|
Run via: python -m pipeline.quality chat_eval --suite <path>
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from pipeline.llm_client import LLMClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CHAT_DIMENSIONS = [
|
||||||
|
"citation_accuracy",
|
||||||
|
"response_structure",
|
||||||
|
"domain_expertise",
|
||||||
|
"source_grounding",
|
||||||
|
"personality_fidelity",
|
||||||
|
]
|
||||||
|
|
||||||
|
CHAT_RUBRIC = """\
|
||||||
|
You are an expert evaluator of AI chat response quality for a music production knowledge base.
|
||||||
|
|
||||||
|
You will be given:
|
||||||
|
1. The user's query
|
||||||
|
2. The assistant's response
|
||||||
|
3. The numbered source citations that were provided to the assistant
|
||||||
|
4. The personality_weight (0.0 = encyclopedic, >0 = creator voice expected)
|
||||||
|
5. The creator_name (if any)
|
||||||
|
|
||||||
|
Evaluate the response across these 5 dimensions, scoring each 0.0 to 1.0:
|
||||||
|
|
||||||
|
**citation_accuracy** — Citations are real, correctly numbered, and point to relevant sources
|
||||||
|
- 0.9-1.0: Every [N] citation references a real source number, citations are placed next to the claim they support, no phantom citations
|
||||||
|
- 0.5-0.7: Most citations are valid but some are misplaced or reference non-existent source numbers
|
||||||
|
- 0.0-0.3: Many phantom citations, wrong numbers, or citations placed randomly without connection to claims
|
||||||
|
|
||||||
|
**response_structure** — Response is concise, well-organized, uses appropriate formatting
|
||||||
|
- 0.9-1.0: Clear paragraphs, uses bullet lists for steps/lists, bold for key terms, appropriate length (not padded)
|
||||||
|
- 0.5-0.7: Readable but could be better organized — wall of text, missing formatting where it would help
|
||||||
|
- 0.0-0.3: Disorganized, excessively long or too terse, no formatting, hard to scan
|
||||||
|
|
||||||
|
**domain_expertise** — Music production terminology used naturally and correctly
|
||||||
|
- 0.9-1.0: Uses correct audio/synth/mixing terminology, explains technical terms when appropriate, sounds like a knowledgeable producer
|
||||||
|
- 0.5-0.7: Generally correct but some terminology is vague ("adjust the sound" vs "shape the transient") or misused
|
||||||
|
- 0.0-0.3: Generic language, avoids domain terminology, or uses terms incorrectly
|
||||||
|
|
||||||
|
**source_grounding** — Claims are backed by provided sources, no fabrication
|
||||||
|
- 0.9-1.0: Every factual claim traces to a provided source, no invented details (plugin names, settings, frequencies not in sources)
|
||||||
|
- 0.5-0.7: Mostly grounded but 1-2 claims seem embellished or not directly from sources
|
||||||
|
- 0.0-0.3: Contains hallucinated specifics — settings, plugin names, or techniques not present in any source
|
||||||
|
|
||||||
|
**personality_fidelity** — When personality_weight > 0, response reflects the creator's voice proportional to the weight
|
||||||
|
- If personality_weight == 0: Score based on neutral encyclopedic tone (should NOT show personality). Neutral, informative = 1.0. Forced personality = 0.5.
|
||||||
|
- If personality_weight > 0 and personality_weight < 0.5: Subtle personality hints expected. Score higher if tone is lightly flavored but still mainly encyclopedic.
|
||||||
|
- If personality_weight >= 0.5: Clear creator voice expected. Score higher for signature phrases, teaching style, energy matching the named creator.
|
||||||
|
- If no creator_name is provided: Score 1.0 if response is neutral/encyclopedic, lower if it adopts an unexplained persona.
|
||||||
|
|
||||||
|
Return ONLY a JSON object with this exact structure:
|
||||||
|
{
|
||||||
|
"citation_accuracy": <float 0.0-1.0>,
|
||||||
|
"response_structure": <float 0.0-1.0>,
|
||||||
|
"domain_expertise": <float 0.0-1.0>,
|
||||||
|
"source_grounding": <float 0.0-1.0>,
|
||||||
|
"personality_fidelity": <float 0.0-1.0>,
|
||||||
|
"justifications": {
|
||||||
|
"citation_accuracy": "<1-2 sentence justification>",
|
||||||
|
"response_structure": "<1-2 sentence justification>",
|
||||||
|
"domain_expertise": "<1-2 sentence justification>",
|
||||||
|
"source_grounding": "<1-2 sentence justification>",
|
||||||
|
"personality_fidelity": "<1-2 sentence justification>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatScoreResult:
|
||||||
|
"""Outcome of scoring a chat response across quality dimensions."""
|
||||||
|
|
||||||
|
scores: dict[str, float] = field(default_factory=dict)
|
||||||
|
composite: float = 0.0
|
||||||
|
justifications: dict[str, str] = field(default_factory=dict)
|
||||||
|
elapsed_seconds: float = 0.0
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
# Convenience properties
|
||||||
|
@property
|
||||||
|
def citation_accuracy(self) -> float:
|
||||||
|
return self.scores.get("citation_accuracy", 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def response_structure(self) -> float:
|
||||||
|
return self.scores.get("response_structure", 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def domain_expertise(self) -> float:
|
||||||
|
return self.scores.get("domain_expertise", 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def source_grounding(self) -> float:
|
||||||
|
return self.scores.get("source_grounding", 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def personality_fidelity(self) -> float:
|
||||||
|
return self.scores.get("personality_fidelity", 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatScoreRunner:
|
||||||
|
"""Scores chat responses using LLM-as-judge evaluation."""
|
||||||
|
|
||||||
|
def __init__(self, client: LLMClient) -> None:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
def score_response(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
response: str,
|
||||||
|
sources: list[dict],
|
||||||
|
personality_weight: float = 0.0,
|
||||||
|
creator_name: str | None = None,
|
||||||
|
) -> ChatScoreResult:
|
||||||
|
"""Score a single chat response against the 5 chat quality dimensions.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query:
|
||||||
|
The user's original query.
|
||||||
|
response:
|
||||||
|
The assistant's accumulated response text.
|
||||||
|
sources:
|
||||||
|
List of source citation dicts (as emitted by the SSE sources event).
|
||||||
|
personality_weight:
|
||||||
|
0.0 = encyclopedic mode, >0 = personality mode.
|
||||||
|
creator_name:
|
||||||
|
Creator name, if this was a creator-scoped query.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ChatScoreResult with per-dimension scores.
|
||||||
|
"""
|
||||||
|
sources_block = json.dumps(sources, indent=2) if sources else "(no sources)"
|
||||||
|
|
||||||
|
user_prompt = (
|
||||||
|
f"## User Query\n\n{query}\n\n"
|
||||||
|
f"## Assistant Response\n\n{response}\n\n"
|
||||||
|
f"## Sources Provided\n\n```json\n{sources_block}\n```\n\n"
|
||||||
|
f"## Metadata\n\n"
|
||||||
|
f"- personality_weight: {personality_weight}\n"
|
||||||
|
f"- creator_name: {creator_name or '(none)'}\n\n"
|
||||||
|
f"Score this chat response across all 5 dimensions."
|
||||||
|
)
|
||||||
|
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
from pydantic import BaseModel as _BM
|
||||||
|
resp = self.client.complete(
|
||||||
|
system_prompt=CHAT_RUBRIC,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
response_model=_BM,
|
||||||
|
modality="chat",
|
||||||
|
)
|
||||||
|
elapsed = round(time.monotonic() - t0, 2)
|
||||||
|
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
||||||
|
elapsed = round(time.monotonic() - t0, 2)
|
||||||
|
return ChatScoreResult(
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
error=f"Cannot reach LLM judge. Error: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_text = str(resp).strip()
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw_text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error("Malformed chat judge response (not JSON): %.300s", raw_text)
|
||||||
|
return ChatScoreResult(
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
error=f"Malformed judge response. Raw excerpt: {raw_text[:200]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._parse_scores(parsed, elapsed)
|
||||||
|
|
||||||
|
def _parse_scores(self, parsed: dict, elapsed: float) -> ChatScoreResult:
|
||||||
|
"""Extract and validate scores from parsed JSON judge response."""
|
||||||
|
scores: dict[str, float] = {}
|
||||||
|
justifications: dict[str, str] = {}
|
||||||
|
|
||||||
|
raw_justifications = parsed.get("justifications", {})
|
||||||
|
if not isinstance(raw_justifications, dict):
|
||||||
|
raw_justifications = {}
|
||||||
|
|
||||||
|
for dim in CHAT_DIMENSIONS:
|
||||||
|
raw = parsed.get(dim)
|
||||||
|
if raw is None:
|
||||||
|
logger.warning("Missing dimension '%s' in chat judge response", dim)
|
||||||
|
scores[dim] = 0.0
|
||||||
|
justifications[dim] = "(missing from judge response)"
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
val = float(raw)
|
||||||
|
scores[dim] = max(0.0, min(1.0, val))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
logger.warning("Invalid value for '%s': %r", dim, raw)
|
||||||
|
scores[dim] = 0.0
|
||||||
|
justifications[dim] = f"(invalid value: {raw!r})"
|
||||||
|
continue
|
||||||
|
|
||||||
|
justifications[dim] = str(raw_justifications.get(dim, ""))
|
||||||
|
|
||||||
|
composite = sum(scores.values()) / len(CHAT_DIMENSIONS) if CHAT_DIMENSIONS else 0.0
|
||||||
|
|
||||||
|
return ChatScoreResult(
|
||||||
|
scores=scores,
|
||||||
|
composite=round(composite, 3),
|
||||||
|
justifications=justifications,
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def print_report(self, result: ChatScoreResult, query: str = "") -> None:
|
||||||
|
"""Print a formatted chat scoring report to stdout."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(" CHAT QUALITY SCORE REPORT")
|
||||||
|
if query:
|
||||||
|
print(f" Query: {query[:60]}{'...' if len(query) > 60 else ''}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if result.error:
|
||||||
|
print(f"\n ✗ Error: {result.error}\n")
|
||||||
|
print("=" * 60 + "\n")
|
||||||
|
return
|
||||||
|
|
||||||
|
for dim in CHAT_DIMENSIONS:
|
||||||
|
score = result.scores.get(dim, 0.0)
|
||||||
|
filled = int(score * 20)
|
||||||
|
bar = "█" * filled + "░" * (20 - filled)
|
||||||
|
justification = result.justifications.get(dim, "")
|
||||||
|
print(f"\n {dim.replace('_', ' ').title()}")
|
||||||
|
print(f" Score: {score:.2f} {bar}")
|
||||||
|
if justification:
|
||||||
|
# Simple word wrap at ~56 chars
|
||||||
|
words = justification.split()
|
||||||
|
lines: list[str] = []
|
||||||
|
current = ""
|
||||||
|
for word in words:
|
||||||
|
if current and len(current) + len(word) + 1 > 56:
|
||||||
|
lines.append(current)
|
||||||
|
current = word
|
||||||
|
else:
|
||||||
|
current = f"{current} {word}" if current else word
|
||||||
|
if current:
|
||||||
|
lines.append(current)
|
||||||
|
for line in lines:
|
||||||
|
print(f" {line}")
|
||||||
|
|
||||||
|
print("\n" + "-" * 60)
|
||||||
|
print(f" Composite: {result.composite:.3f}")
|
||||||
|
print(f" Time: {result.elapsed_seconds}s")
|
||||||
|
print("=" * 60 + "\n")
|
||||||
489
backend/pipeline/quality/fitness.py
Normal file
489
backend/pipeline/quality/fitness.py
Normal file
|
|
@ -0,0 +1,489 @@
|
||||||
|
"""FYN-LLM fitness test runner.
|
||||||
|
|
||||||
|
Tests four categories:
|
||||||
|
1. Mandelbrot reasoning — factual knowledge / reasoning depth
|
||||||
|
2. JSON compliance — simple and nested structured output
|
||||||
|
3. Instruction following — bullet count, keyword inclusion, casing
|
||||||
|
4. Diverse prompt battery — summarization, classification, extraction
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from pipeline.llm_client import LLMClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result types ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestResult:
|
||||||
|
"""Outcome of a single fitness test."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
passed: bool
|
||||||
|
elapsed_seconds: float
|
||||||
|
token_count: int | None = None
|
||||||
|
detail: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CategoryReport:
|
||||||
|
"""Results for one test category."""
|
||||||
|
|
||||||
|
category: str
|
||||||
|
results: list[TestResult] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_passed(self) -> bool:
|
||||||
|
return all(r.passed for r in self.results)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pydantic models for JSON compliance tests ────────────────────────────────
|
||||||
|
|
||||||
|
class SimpleItem(BaseModel):
|
||||||
|
name: str
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
class Address(BaseModel):
|
||||||
|
street: str
|
||||||
|
city: str
|
||||||
|
zip_code: str
|
||||||
|
|
||||||
|
|
||||||
|
class PersonWithAddress(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
address: Address
|
||||||
|
|
||||||
|
|
||||||
|
# ── Runner ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class FitnessRunner:
|
||||||
|
"""Runs all fitness tests against the configured LLM endpoint."""
|
||||||
|
|
||||||
|
def __init__(self, client: LLMClient) -> None:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
# ── Public entry point ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def run_all(self) -> int:
|
||||||
|
"""Run all fitness tests, print report, return exit code (0=pass, 1=fail)."""
|
||||||
|
categories: list[CategoryReport] = []
|
||||||
|
|
||||||
|
# Connectivity pre-check — fail fast with a clear message
|
||||||
|
try:
|
||||||
|
self._probe_connectivity()
|
||||||
|
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
||||||
|
url = self.client.settings.llm_api_url
|
||||||
|
fallback = self.client.settings.llm_fallback_url
|
||||||
|
print(
|
||||||
|
f"\n✗ Cannot reach LLM endpoint at {url} (fallback {fallback})\n"
|
||||||
|
f" Error: {exc}\n"
|
||||||
|
)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
categories.append(self._run_mandelbrot())
|
||||||
|
categories.append(self._run_json_compliance())
|
||||||
|
categories.append(self._run_instruction_following())
|
||||||
|
categories.append(self._run_diverse_battery())
|
||||||
|
|
||||||
|
self._print_report(categories)
|
||||||
|
|
||||||
|
return 0 if all(c.all_passed for c in categories) else 1
|
||||||
|
|
||||||
|
# ── Connectivity probe ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _probe_connectivity(self) -> None:
|
||||||
|
"""Quick completion to verify the endpoint is reachable."""
|
||||||
|
self.client.complete(
|
||||||
|
system_prompt="You are a test probe.",
|
||||||
|
user_prompt="Respond with the single word: ok",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Category 1: Mandelbrot reasoning ─────────────────────────────────
|
||||||
|
|
||||||
|
def _run_mandelbrot(self) -> CategoryReport:
|
||||||
|
cat = CategoryReport(category="Mandelbrot Reasoning")
|
||||||
|
cat.results.append(self._test_mandelbrot())
|
||||||
|
return cat
|
||||||
|
|
||||||
|
def _test_mandelbrot(self) -> TestResult:
|
||||||
|
name = "mandelbrot_area_knowledge"
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
resp = self.client.complete(
|
||||||
|
system_prompt="You are a mathematics expert. Answer precisely and concisely.",
|
||||||
|
user_prompt=(
|
||||||
|
"What is the approximate area of the Mandelbrot set? "
|
||||||
|
"Include the numerical value and mention whether the exact area is known."
|
||||||
|
),
|
||||||
|
modality="thinking",
|
||||||
|
)
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
text = resp.lower()
|
||||||
|
# Check for key concepts
|
||||||
|
has_area = any(kw in text for kw in ["1.506", "1.507", "1.50659"])
|
||||||
|
has_uncertainty = any(
|
||||||
|
kw in text
|
||||||
|
for kw in ["not exactly known", "not known exactly", "approximate", "estimated", "conjecture"]
|
||||||
|
)
|
||||||
|
passed = has_area and has_uncertainty
|
||||||
|
detail = "" if passed else f"Missing: area={has_area}, uncertainty={has_uncertainty}. Response: {resp[:200]}"
|
||||||
|
return TestResult(
|
||||||
|
name=name,
|
||||||
|
passed=passed,
|
||||||
|
elapsed_seconds=round(elapsed, 2),
|
||||||
|
token_count=resp.completion_tokens,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
return TestResult(
|
||||||
|
name=name,
|
||||||
|
passed=False,
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
detail=f"Exception: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Category 2: JSON compliance ──────────────────────────────────────
|
||||||
|
|
||||||
|
def _run_json_compliance(self) -> CategoryReport:
|
||||||
|
cat = CategoryReport(category="JSON Compliance")
|
||||||
|
cat.results.append(self._test_json_simple())
|
||||||
|
cat.results.append(self._test_json_nested())
|
||||||
|
return cat
|
||||||
|
|
||||||
|
def _test_json_simple(self) -> TestResult:
|
||||||
|
name = "json_simple_object"
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
resp = self.client.complete(
|
||||||
|
system_prompt="You are a JSON generator. Output ONLY valid JSON, nothing else.",
|
||||||
|
user_prompt=(
|
||||||
|
'Generate a JSON object with exactly two keys: "name" (a string) '
|
||||||
|
'and "count" (an integer). Example structure: {"name": "...", "count": N}'
|
||||||
|
),
|
||||||
|
response_model=SimpleItem,
|
||||||
|
modality="chat",
|
||||||
|
)
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
return self._validate_json(name, resp, SimpleItem, elapsed)
|
||||||
|
except Exception as exc:
|
||||||
|
return TestResult(
|
||||||
|
name=name,
|
||||||
|
passed=False,
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
detail=f"Exception: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_json_nested(self) -> TestResult:
|
||||||
|
name = "json_nested_object"
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
resp = self.client.complete(
|
||||||
|
system_prompt="You are a JSON generator. Output ONLY valid JSON, nothing else.",
|
||||||
|
user_prompt=(
|
||||||
|
'Generate a JSON object with keys "name" (string), "age" (integer), '
|
||||||
|
'and "address" (object with "street", "city", "zip_code" string fields).'
|
||||||
|
),
|
||||||
|
response_model=PersonWithAddress,
|
||||||
|
modality="chat",
|
||||||
|
)
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
return self._validate_json(name, resp, PersonWithAddress, elapsed)
|
||||||
|
except Exception as exc:
|
||||||
|
return TestResult(
|
||||||
|
name=name,
|
||||||
|
passed=False,
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
detail=f"Exception: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_json(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
resp: str,
|
||||||
|
model: type[BaseModel],
|
||||||
|
elapsed: float,
|
||||||
|
) -> TestResult:
|
||||||
|
"""Parse response as JSON, validate against Pydantic model."""
|
||||||
|
text = str(resp).strip()
|
||||||
|
if not text:
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=False, elapsed_seconds=round(elapsed, 2),
|
||||||
|
token_count=getattr(resp, "completion_tokens", None),
|
||||||
|
detail="Empty response from LLM",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
parsed = json.loads(text)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=False, elapsed_seconds=round(elapsed, 2),
|
||||||
|
token_count=getattr(resp, "completion_tokens", None),
|
||||||
|
detail=f"Invalid JSON: {exc}. Raw: {text[:200]}",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
model.model_validate(parsed)
|
||||||
|
except Exception as exc:
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=False, elapsed_seconds=round(elapsed, 2),
|
||||||
|
token_count=getattr(resp, "completion_tokens", None),
|
||||||
|
detail=f"Schema validation failed: {exc}",
|
||||||
|
)
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=True, elapsed_seconds=round(elapsed, 2),
|
||||||
|
token_count=getattr(resp, "completion_tokens", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Category 3: Instruction following ────────────────────────────────
|
||||||
|
|
||||||
|
def _run_instruction_following(self) -> CategoryReport:
|
||||||
|
cat = CategoryReport(category="Instruction Following")
|
||||||
|
cat.results.append(self._test_bullet_count())
|
||||||
|
cat.results.append(self._test_keyword_inclusion())
|
||||||
|
cat.results.append(self._test_lowercase_only())
|
||||||
|
return cat
|
||||||
|
|
||||||
|
def _test_bullet_count(self) -> TestResult:
|
||||||
|
name = "instruction_bullet_count"
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
resp = self.client.complete(
|
||||||
|
system_prompt="Follow instructions exactly.",
|
||||||
|
user_prompt="List exactly 3 benefits of exercise. Use bullet points starting with '- '.",
|
||||||
|
)
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
lines = [l.strip() for l in str(resp).strip().splitlines() if l.strip().startswith("- ")]
|
||||||
|
passed = len(lines) == 3
|
||||||
|
detail = "" if passed else f"Expected 3 bullets, got {len(lines)}: {str(resp)[:200]}"
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=passed, elapsed_seconds=round(elapsed, 2),
|
||||||
|
token_count=resp.completion_tokens,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=False,
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
detail=f"Exception: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_keyword_inclusion(self) -> TestResult:
|
||||||
|
name = "instruction_keyword_inclusion"
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
resp = self.client.complete(
|
||||||
|
system_prompt="Follow instructions exactly.",
|
||||||
|
user_prompt=(
|
||||||
|
"Write one sentence about the weather. "
|
||||||
|
'You MUST include the word "elephant" somewhere in your sentence.'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
passed = "elephant" in str(resp).lower()
|
||||||
|
detail = "" if passed else f"Missing keyword 'elephant'. Response: {str(resp)[:200]}"
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=passed, elapsed_seconds=round(elapsed, 2),
|
||||||
|
token_count=resp.completion_tokens,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=False,
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
detail=f"Exception: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_lowercase_only(self) -> TestResult:
|
||||||
|
name = "instruction_lowercase_only"
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
resp = self.client.complete(
|
||||||
|
system_prompt="Follow instructions exactly.",
|
||||||
|
user_prompt=(
|
||||||
|
"Write a short sentence about the ocean. "
|
||||||
|
"Use ONLY lowercase letters — no uppercase at all, not even at the start."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
text = str(resp).strip()
|
||||||
|
# Allow non-alpha chars (punctuation, spaces, numbers) but no uppercase letters
|
||||||
|
has_upper = any(c.isupper() for c in text)
|
||||||
|
passed = not has_upper and len(text) > 5
|
||||||
|
detail = "" if passed else f"Contains uppercase or too short. Response: {text[:200]}"
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=passed, elapsed_seconds=round(elapsed, 2),
|
||||||
|
token_count=resp.completion_tokens,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=False,
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
detail=f"Exception: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Category 4: Diverse prompt battery ───────────────────────────────
|
||||||
|
|
||||||
|
def _run_diverse_battery(self) -> CategoryReport:
|
||||||
|
cat = CategoryReport(category="Diverse Prompt Battery")
|
||||||
|
cat.results.append(self._test_summarization())
|
||||||
|
cat.results.append(self._test_classification())
|
||||||
|
cat.results.append(self._test_extraction())
|
||||||
|
return cat
|
||||||
|
|
||||||
|
def _test_summarization(self) -> TestResult:
|
||||||
|
name = "battery_summarization"
|
||||||
|
paragraph = (
|
||||||
|
"The James Webb Space Telescope (JWST) is the largest optical telescope in space. "
|
||||||
|
"Launched in December 2021, it is designed to conduct infrared astronomy. Its high "
|
||||||
|
"resolution and sensitivity allow it to view objects too old and distant for the Hubble "
|
||||||
|
"Space Telescope. Among its goals are observing the first stars and the formation of "
|
||||||
|
"the first galaxies, and detailed atmospheric characterization of exoplanets."
|
||||||
|
)
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
resp = self.client.complete(
|
||||||
|
system_prompt="You are a concise summarizer.",
|
||||||
|
user_prompt=f"Summarize the following in exactly 2 sentences:\n\n{paragraph}",
|
||||||
|
)
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
text = str(resp).strip()
|
||||||
|
# Rough sentence count: split on period followed by space or end
|
||||||
|
sentences = [s.strip() for s in text.replace("! ", ". ").split(". ") if s.strip()]
|
||||||
|
# Be generous: 1-3 sentences is acceptable
|
||||||
|
passed = 1 <= len(sentences) <= 3 and len(text) > 20
|
||||||
|
detail = "" if passed else f"Expected ~2 sentences, got {len(sentences)}. Response: {text[:200]}"
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=passed, elapsed_seconds=round(elapsed, 2),
|
||||||
|
token_count=resp.completion_tokens,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=False,
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
detail=f"Exception: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_classification(self) -> TestResult:
|
||||||
|
name = "battery_classification"
|
||||||
|
categories = ["technology", "sports", "politics", "science", "entertainment"]
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
resp = self.client.complete(
|
||||||
|
system_prompt=(
|
||||||
|
"You are a text classifier. Respond with ONLY one word from the given categories."
|
||||||
|
),
|
||||||
|
user_prompt=(
|
||||||
|
f"Classify the following text into one of these categories: {', '.join(categories)}\n\n"
|
||||||
|
"Text: \"NASA's Perseverance rover has discovered organic molecules on Mars, "
|
||||||
|
"suggesting the planet may have once harbored microbial life.\"\n\n"
|
||||||
|
"Category:"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
answer = str(resp).strip().lower().rstrip(".")
|
||||||
|
passed = answer in categories
|
||||||
|
detail = "" if passed else f"Response '{answer}' not in {categories}"
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=passed, elapsed_seconds=round(elapsed, 2),
|
||||||
|
token_count=resp.completion_tokens,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=False,
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
detail=f"Exception: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_extraction(self) -> TestResult:
|
||||||
|
name = "battery_extraction"
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
resp = self.client.complete(
|
||||||
|
system_prompt="You are a data extractor. Output ONLY valid JSON, nothing else.",
|
||||||
|
user_prompt=(
|
||||||
|
"Extract the following fields as a JSON object: "
|
||||||
|
'"event_name", "date", "location"\n\n'
|
||||||
|
"Text: \"The annual Tech Summit 2026 will be held on March 15, 2026 "
|
||||||
|
'in San Francisco, California."\n\n'
|
||||||
|
"JSON:"
|
||||||
|
),
|
||||||
|
response_model=BaseModel, # triggers json mode
|
||||||
|
modality="chat",
|
||||||
|
)
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
text = str(resp).strip()
|
||||||
|
if not text:
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=False, elapsed_seconds=round(elapsed, 2),
|
||||||
|
token_count=getattr(resp, "completion_tokens", None),
|
||||||
|
detail="Empty response from LLM",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
parsed = json.loads(text)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=False, elapsed_seconds=round(elapsed, 2),
|
||||||
|
token_count=getattr(resp, "completion_tokens", None),
|
||||||
|
detail=f"Invalid JSON: {exc}. Raw: {text[:200]}",
|
||||||
|
)
|
||||||
|
required_keys = {"event_name", "date", "location"}
|
||||||
|
present = set(parsed.keys()) & required_keys
|
||||||
|
passed = present == required_keys
|
||||||
|
detail = "" if passed else f"Missing keys: {required_keys - present}"
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=passed, elapsed_seconds=round(elapsed, 2),
|
||||||
|
token_count=getattr(resp, "completion_tokens", None),
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
return TestResult(
|
||||||
|
name=name, passed=False,
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
detail=f"Exception: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Report formatting ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _print_report(self, categories: list[CategoryReport]) -> None:
|
||||||
|
"""Print a formatted pass/fail report to stdout."""
|
||||||
|
total = 0
|
||||||
|
passed_count = 0
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(" FYN-LLM FITNESS REPORT")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
for cat in categories:
|
||||||
|
status = "✓ PASS" if cat.all_passed else "✗ FAIL"
|
||||||
|
print(f"\n [{status}] {cat.category}")
|
||||||
|
for r in cat.results:
|
||||||
|
total += 1
|
||||||
|
icon = "✓" if r.passed else "✗"
|
||||||
|
tokens = f" ({r.token_count} tok)" if r.token_count else ""
|
||||||
|
print(f" {icon} {r.name} [{r.elapsed_seconds}s{tokens}]")
|
||||||
|
if r.detail:
|
||||||
|
# Indent detail lines
|
||||||
|
for line in r.detail.splitlines():
|
||||||
|
print(f" {line}")
|
||||||
|
if r.passed:
|
||||||
|
passed_count += 1
|
||||||
|
|
||||||
|
print("\n" + "-" * 60)
|
||||||
|
print(f" Total: {passed_count}/{total} passed")
|
||||||
|
if passed_count == total:
|
||||||
|
print(" Result: ✓ ALL PASS")
|
||||||
|
else:
|
||||||
|
print(f" Result: ✗ {total - passed_count} FAILED")
|
||||||
|
print("=" * 60 + "\n")
|
||||||
0
backend/pipeline/quality/fixtures/__init__.py
Normal file
0
backend/pipeline/quality/fixtures/__init__.py
Normal file
72
backend/pipeline/quality/fixtures/chat_test_suite.yaml
Normal file
72
backend/pipeline/quality/fixtures/chat_test_suite.yaml
Normal file
|
|
@ -0,0 +1,72 @@
|
||||||
|
# Chat quality evaluation test suite
|
||||||
|
# 10 representative queries across 4 categories:
|
||||||
|
# - technical: How-to questions about specific production techniques
|
||||||
|
# - conceptual: Broader understanding questions about audio concepts
|
||||||
|
# - creator: Creator-scoped queries at different personality weights
|
||||||
|
# - cross_creator: Queries spanning multiple creators' approaches
|
||||||
|
|
||||||
|
queries:
|
||||||
|
# ── Technical how-to (2) ────────────────────────────────────────────
|
||||||
|
- query: "How do I set up sidechain compression on a bass synth using a kick drum as the trigger?"
|
||||||
|
creator: null
|
||||||
|
personality_weight: 0.0
|
||||||
|
category: technical
|
||||||
|
description: "Common sidechain compression setup — expects specific settings (ratio, attack, release)"
|
||||||
|
|
||||||
|
- query: "What are the best EQ settings for cleaning up a muddy vocal recording?"
|
||||||
|
creator: null
|
||||||
|
personality_weight: 0.0
|
||||||
|
category: technical
|
||||||
|
description: "Vocal EQ technique — expects frequency ranges, Q values, cut/boost guidance"
|
||||||
|
|
||||||
|
# ── Conceptual (2) ─────────────────────────────────────────────────
|
||||||
|
- query: "What is the difference between parallel compression and serial compression, and when should I use each?"
|
||||||
|
creator: null
|
||||||
|
personality_weight: 0.0
|
||||||
|
category: conceptual
|
||||||
|
description: "Conceptual comparison — expects clear definitions, use cases, pros/cons"
|
||||||
|
|
||||||
|
- query: "How does sample rate affect sound quality in music production?"
|
||||||
|
creator: null
|
||||||
|
personality_weight: 0.0
|
||||||
|
category: conceptual
|
||||||
|
description: "Audio fundamentals — expects Nyquist, aliasing, practical guidance"
|
||||||
|
|
||||||
|
# ── Creator-specific: encyclopedic (2) ──────────────────────────────
|
||||||
|
- query: "How does this creator approach sound design for bass sounds?"
|
||||||
|
creator: "KEOTA"
|
||||||
|
personality_weight: 0.0
|
||||||
|
category: creator_encyclopedic
|
||||||
|
description: "Creator-scoped query at weight=0 — should be neutral/encyclopedic about KEOTA's techniques"
|
||||||
|
|
||||||
|
- query: "What mixing techniques does this creator recommend for achieving width in a mix?"
|
||||||
|
creator: "Mr. Bill"
|
||||||
|
personality_weight: 0.0
|
||||||
|
category: creator_encyclopedic
|
||||||
|
description: "Creator-scoped query at weight=0 — neutral tone about Mr. Bill's approach"
|
||||||
|
|
||||||
|
# ── Creator-specific: personality (2) ───────────────────────────────
|
||||||
|
- query: "How does this creator approach sound design for bass sounds?"
|
||||||
|
creator: "KEOTA"
|
||||||
|
personality_weight: 0.7
|
||||||
|
category: creator_personality
|
||||||
|
description: "Same query as above but at weight=0.7 — should reflect KEOTA's voice and teaching style"
|
||||||
|
|
||||||
|
- query: "What mixing techniques does this creator recommend for achieving width in a mix?"
|
||||||
|
creator: "Mr. Bill"
|
||||||
|
personality_weight: 0.7
|
||||||
|
category: creator_personality
|
||||||
|
description: "Same query as above but at weight=0.7 — should reflect Mr. Bill's voice"
|
||||||
|
|
||||||
|
# ── Cross-creator (2) ──────────────────────────────────────────────
|
||||||
|
- query: "What are the different approaches to layering synth sounds across creators?"
|
||||||
|
creator: null
|
||||||
|
personality_weight: 0.0
|
||||||
|
category: cross_creator
|
||||||
|
description: "Cross-creator comparison — should cite multiple creators' techniques"
|
||||||
|
|
||||||
|
- query: "How do different producers approach drum processing and what plugins do they prefer?"
|
||||||
|
creator: null
|
||||||
|
personality_weight: 0.0
|
||||||
|
category: cross_creator
|
||||||
|
description: "Cross-creator comparison on drums — expects multiple perspectives with citations"
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
{
|
||||||
|
"extracted_moments": [
|
||||||
|
{
|
||||||
|
"title": "Frequency-specific sidechain with Trackspacer",
|
||||||
|
"summary": "Using Trackspacer plugin for frequency-band sidechain compression targeting 100-300Hz, allowing bass high-end to remain present while clearing low-mid mud under the kick.",
|
||||||
|
"content_type": "technique",
|
||||||
|
"plugins": ["Trackspacer"],
|
||||||
|
"start_time": 15.2,
|
||||||
|
"end_time": 52.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Parallel drum compression chain",
|
||||||
|
"summary": "Setting up Ableton's Drum Buss at 40% drive into a return track with Valhalla Room at 1.2s decay, mixed at -12dB for room sound without wash.",
|
||||||
|
"content_type": "settings",
|
||||||
|
"plugins": ["Drum Buss", "Valhalla Room"],
|
||||||
|
"start_time": 52.1,
|
||||||
|
"end_time": 89.3
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Mono compatibility checking workflow",
|
||||||
|
"summary": "Using Ableton's Utility plugin on the sub bus to constantly check mono compatibility of layered bass patches, catching phase cancellation before mixdown.",
|
||||||
|
"content_type": "workflow",
|
||||||
|
"plugins": ["Utility"],
|
||||||
|
"start_time": 89.3,
|
||||||
|
"end_time": 110.0
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"taxonomy": "Sound Design > Mixing & Processing"
|
||||||
|
}
|
||||||
54
backend/pipeline/quality/fixtures/sample_moments.json
Normal file
54
backend/pipeline/quality/fixtures/sample_moments.json
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
{
|
||||||
|
"creator_name": "KOAN Sound",
|
||||||
|
"topic_category": "Sound design",
|
||||||
|
"moments": [
|
||||||
|
{
|
||||||
|
"summary": "Layering snare transients by combining a high-frequency click from a Popcorn Snare with a mid-body from a pitched-down 808 rim shot, blending at -6dB relative offset.",
|
||||||
|
"transcript_excerpt": "So what I'll do is take the Popcorn Snare — that's got this really sharp click at like 4k — and then I layer underneath it a rim shot pitched down maybe 3 semitones. You blend those together and suddenly you've got this snare that cuts through everything but still has weight.",
|
||||||
|
"topic_tags": ["snare layering", "transient design", "sample stacking"],
|
||||||
|
"topic_category": "Sound design",
|
||||||
|
"start_time": 124.5,
|
||||||
|
"end_time": 158.2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"summary": "Using Serum's noise oscillator with the 'Analog_Crackle' wavetable at 12% mix to add organic texture to bass patches, followed by OTT at 30% depth for glue.",
|
||||||
|
"transcript_excerpt": "One trick I always come back to is Serum's noise osc with Analog_Crackle. You don't want it loud — like 12 percent mix — just enough that the bass feels alive. Then slap OTT on there at maybe 30 percent depth and it glues the whole thing together without squashing it.",
|
||||||
|
"topic_tags": ["bass design", "Serum", "OTT", "texture"],
|
||||||
|
"topic_category": "Sound design",
|
||||||
|
"start_time": 203.1,
|
||||||
|
"end_time": 241.7
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"summary": "Resampling technique: bounce a bass patch to audio, chop the best 2 bars, then re-pitch in Simpler with warp off for tighter timing and consistent tone.",
|
||||||
|
"transcript_excerpt": "I'll resample everything. Bounce it down, find the two bars that sound best, throw it in Simpler with warp completely off. Now you've got this tight, consistent thing where every hit is exactly the same energy. The pitch tracking is way more predictable too.",
|
||||||
|
"topic_tags": ["resampling", "Ableton", "Simpler", "bass production"],
|
||||||
|
"topic_category": "Sound design",
|
||||||
|
"start_time": 312.0,
|
||||||
|
"end_time": 349.8
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"summary": "Parallel compression chain for drums using Ableton's Drum Buss at 40% drive into a return track with Valhalla Room at 1.2s decay, mixed at -12dB.",
|
||||||
|
"transcript_excerpt": "The parallel chain is dead simple — Drum Buss, crank the drive to about 40 percent, send that to a return with Valhalla Room. Keep the decay short, like 1.2 seconds. Mix it in at minus 12 and your drums just... breathe. They've got this room sound without getting washy.",
|
||||||
|
"topic_tags": ["parallel compression", "drum processing", "Valhalla Room", "Drum Buss"],
|
||||||
|
"topic_category": "Sound design",
|
||||||
|
"start_time": 421.3,
|
||||||
|
"end_time": 462.1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"summary": "Frequency-specific sidechain using Trackspacer plugin instead of volume ducking, targeting only 100-300Hz so the bass ducks under the kick without losing high-end presence.",
|
||||||
|
"transcript_excerpt": "Everyone does volume sidechain but honestly Trackspacer changed everything for me. You set it to only affect 100 to 300 Hz so when the kick hits, the bass ducks just in that low-mid range. The top end of the bass stays right there — you keep all the character and harmonics, you just clear the mud.",
|
||||||
|
"topic_tags": ["sidechaining", "Trackspacer", "frequency ducking", "mixing"],
|
||||||
|
"topic_category": "Sound design",
|
||||||
|
"start_time": 498.7,
|
||||||
|
"end_time": 534.2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"summary": "Using Ableton's Utility plugin to check mono compatibility at every stage, specifically toggling mono on the sub bus to catch phase cancellation from layered bass patches.",
|
||||||
|
"transcript_excerpt": "I'm almost paranoid about mono. I've got Utility on the sub bus and I'm flipping to mono constantly. If your layered bass sounds thin in mono you've got phase issues — doesn't matter how fat it sounds in stereo, it'll collapse on a club system.",
|
||||||
|
"topic_tags": ["mono compatibility", "phase checking", "club mixing", "Utility"],
|
||||||
|
"topic_category": "Sound design",
|
||||||
|
"start_time": 567.0,
|
||||||
|
"end_time": 598.4
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
40
backend/pipeline/quality/fixtures/sample_segments.json
Normal file
40
backend/pipeline/quality/fixtures/sample_segments.json
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
{
|
||||||
|
"transcript_segments": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"start_time": 0.0,
|
||||||
|
"end_time": 15.2,
|
||||||
|
"text": "Hey everyone, today we're going to talk about sidechain compression and how I use it in my productions."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 1,
|
||||||
|
"start_time": 15.2,
|
||||||
|
"end_time": 34.8,
|
||||||
|
"text": "So the basic idea is you take the kick drum signal and use it to duck the bass. Most people use a compressor for this but I actually prefer Trackspacer because it gives you frequency-specific ducking."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 2,
|
||||||
|
"start_time": 34.8,
|
||||||
|
"end_time": 52.1,
|
||||||
|
"text": "With Trackspacer you can set it to only affect 100 to 300 Hz so when the kick hits, the bass ducks just in that low-mid range. The top end stays right there."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 3,
|
||||||
|
"start_time": 52.1,
|
||||||
|
"end_time": 71.5,
|
||||||
|
"text": "Now let me show you another technique — parallel compression on drums. I use Drum Buss with the drive at about 40 percent, then send that to a return track."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 4,
|
||||||
|
"start_time": 71.5,
|
||||||
|
"end_time": 89.3,
|
||||||
|
"text": "On the return I put Valhalla Room with a short decay, like 1.2 seconds. Mix it in at minus 12 dB. Your drums just breathe — they get this room sound without getting washy."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 5,
|
||||||
|
"start_time": 89.3,
|
||||||
|
"end_time": 110.0,
|
||||||
|
"text": "One more thing about mono compatibility. I always have Utility on the sub bus and I flip to mono constantly. If your layered bass sounds thin in mono you've got phase issues."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
18
backend/pipeline/quality/fixtures/sample_topic_group.json
Normal file
18
backend/pipeline/quality/fixtures/sample_topic_group.json
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
{
|
||||||
|
"topic_segments": [
|
||||||
|
{
|
||||||
|
"start_index": 0,
|
||||||
|
"end_index": 2,
|
||||||
|
"topic_label": "Frequency-specific sidechain compression with Trackspacer",
|
||||||
|
"summary": "Using Trackspacer for frequency-band sidechain ducking instead of traditional volume compression",
|
||||||
|
"transcript_text": "Hey everyone, today we're going to talk about sidechain compression and how I use it in my productions. So the basic idea is you take the kick drum signal and use it to duck the bass. Most people use a compressor for this but I actually prefer Trackspacer because it gives you frequency-specific ducking. With Trackspacer you can set it to only affect 100 to 300 Hz so when the kick hits, the bass ducks just in that low-mid range. The top end stays right there."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"start_index": 3,
|
||||||
|
"end_index": 4,
|
||||||
|
"topic_label": "Parallel drum compression with Drum Buss and Valhalla Room",
|
||||||
|
"summary": "Setting up a parallel compression chain using Ableton's Drum Buss and Valhalla Room reverb for drum processing",
|
||||||
|
"transcript_text": "Now let me show you another technique — parallel compression on drums. I use Drum Buss with the drive at about 40 percent, then send that to a return track. On the return I put Valhalla Room with a short decay, like 1.2 seconds. Mix it in at minus 12 dB. Your drums just breathe — they get this room sound without getting washy."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
522
backend/pipeline/quality/optimizer.py
Normal file
522
backend/pipeline/quality/optimizer.py
Normal file
|
|
@ -0,0 +1,522 @@
|
||||||
|
"""Automated prompt optimization loop for pipeline stages 2-5.
|
||||||
|
|
||||||
|
Orchestrates a generate→score→select cycle:
|
||||||
|
1. Score the current best prompt against reference fixtures
|
||||||
|
2. Generate N variants targeting weak dimensions
|
||||||
|
3. Score each variant
|
||||||
|
4. Keep the best scorer as the new baseline
|
||||||
|
5. Repeat for K iterations
|
||||||
|
|
||||||
|
Usage (via CLI):
|
||||||
|
python -m pipeline.quality optimize --stage 5 --iterations 10
|
||||||
|
python -m pipeline.quality optimize --stage 3 --iterations 5 --file fixtures/sample_topic_group.json
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pipeline.llm_client import LLMClient
|
||||||
|
from pipeline.quality.scorer import STAGE_CONFIGS, ScoreResult, ScoreRunner
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OptimizationResult:
|
||||||
|
"""Full result of an optimization run."""
|
||||||
|
|
||||||
|
best_prompt: str = ""
|
||||||
|
best_score: ScoreResult = field(default_factory=ScoreResult)
|
||||||
|
history: list[dict] = field(default_factory=list)
|
||||||
|
elapsed_seconds: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizationLoop:
|
||||||
|
"""Runs iterative prompt optimization for a pipeline stage.
|
||||||
|
|
||||||
|
Each iteration generates *variants_per_iter* prompt mutations,
|
||||||
|
scores each against reference fixture data, and keeps the
|
||||||
|
highest-composite-scoring variant as the new baseline.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
client:
|
||||||
|
LLMClient instance for LLM calls (synthesis + scoring + variant gen).
|
||||||
|
stage:
|
||||||
|
Pipeline stage number (2-5).
|
||||||
|
fixture_path:
|
||||||
|
Path to a JSON fixture file matching the stage's expected keys.
|
||||||
|
iterations:
|
||||||
|
Number of generate→score→select cycles.
|
||||||
|
variants_per_iter:
|
||||||
|
Number of variant prompts to generate per iteration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: LLMClient,
|
||||||
|
stage: int,
|
||||||
|
fixture_path: str,
|
||||||
|
iterations: int = 5,
|
||||||
|
variants_per_iter: int = 2,
|
||||||
|
output_dir: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
if stage not in STAGE_CONFIGS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported stage {stage}. Valid stages: {sorted(STAGE_CONFIGS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client = client
|
||||||
|
self.stage = stage
|
||||||
|
self.fixture_path = fixture_path
|
||||||
|
self.iterations = iterations
|
||||||
|
self.variants_per_iter = variants_per_iter
|
||||||
|
self.config = STAGE_CONFIGS[stage]
|
||||||
|
self.output_dir = output_dir
|
||||||
|
|
||||||
|
self.scorer = ScoreRunner(client)
|
||||||
|
self.generator = PromptVariantGenerator(client)
|
||||||
|
|
||||||
|
def run(self) -> OptimizationResult:
|
||||||
|
"""Execute the full optimization loop.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
OptimizationResult
|
||||||
|
Contains the best prompt, its scores, full iteration history,
|
||||||
|
and wall-clock elapsed time.
|
||||||
|
"""
|
||||||
|
from pipeline.stages import _load_prompt
|
||||||
|
|
||||||
|
t0 = time.monotonic()
|
||||||
|
dimensions = self.config.dimensions
|
||||||
|
|
||||||
|
# Load base prompt using the stage's configured prompt file
|
||||||
|
prompt_file = self.config.prompt_file
|
||||||
|
try:
|
||||||
|
base_prompt = _load_prompt(prompt_file)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error("Prompt file not found: %s", prompt_file)
|
||||||
|
return OptimizationResult(
|
||||||
|
best_prompt="",
|
||||||
|
best_score=ScoreResult(error=f"Prompt file not found: {prompt_file}"),
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load fixture data
|
||||||
|
try:
|
||||||
|
fixture = self._load_fixture()
|
||||||
|
except (FileNotFoundError, json.JSONDecodeError, KeyError) as exc:
|
||||||
|
logger.error("Failed to load fixture: %s", exc)
|
||||||
|
return OptimizationResult(
|
||||||
|
best_prompt=base_prompt,
|
||||||
|
best_score=ScoreResult(error=f"Fixture load error: {exc}"),
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
history: list[dict] = []
|
||||||
|
|
||||||
|
# Score the baseline
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f" PROMPT OPTIMIZATION — Stage {self.stage}")
|
||||||
|
print(f" Iterations: {self.iterations}, Variants/iter: {self.variants_per_iter}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
print(" Scoring baseline prompt...")
|
||||||
|
best_score = self._score_variant(base_prompt, fixture)
|
||||||
|
best_prompt = base_prompt
|
||||||
|
|
||||||
|
history.append({
|
||||||
|
"iteration": 0,
|
||||||
|
"variant_index": 0,
|
||||||
|
"prompt_text": base_prompt[:200] + "..." if len(base_prompt) > 200 else base_prompt,
|
||||||
|
"prompt_length": len(base_prompt),
|
||||||
|
"composite": best_score.composite,
|
||||||
|
"scores": {d: best_score.scores.get(d, 0.0) for d in dimensions},
|
||||||
|
"error": best_score.error,
|
||||||
|
"label": "baseline",
|
||||||
|
})
|
||||||
|
|
||||||
|
if best_score.error:
|
||||||
|
print(f" ✗ Baseline scoring failed: {best_score.error}")
|
||||||
|
print(" Aborting optimization — fix the baseline first.\n")
|
||||||
|
return OptimizationResult(
|
||||||
|
best_prompt=best_prompt,
|
||||||
|
best_score=best_score,
|
||||||
|
history=history,
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
baseline_composite = best_score.composite
|
||||||
|
total_variants_scored = 0
|
||||||
|
|
||||||
|
self._write_progress(
|
||||||
|
phase="baseline_scored",
|
||||||
|
iteration=0, variant=0,
|
||||||
|
total_variants_scored=0,
|
||||||
|
best_composite=best_score.composite,
|
||||||
|
baseline_composite=baseline_composite,
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
best_label="baseline",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._print_iteration_summary(0, best_score, is_baseline=True)
|
||||||
|
|
||||||
|
# Iterate
|
||||||
|
best_label = "baseline"
|
||||||
|
for iteration in range(1, self.iterations + 1):
|
||||||
|
print(f"\n ── Iteration {iteration}/{self.iterations} ──")
|
||||||
|
|
||||||
|
# Generate variants with stage-appropriate markers
|
||||||
|
variants = self.generator.generate(
|
||||||
|
base_prompt=best_prompt,
|
||||||
|
scores=best_score,
|
||||||
|
n=self.variants_per_iter,
|
||||||
|
stage=self.stage,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not variants:
|
||||||
|
print(" ⚠ No valid variants generated — skipping iteration")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Score each variant
|
||||||
|
iteration_best_score = best_score
|
||||||
|
iteration_best_prompt = best_prompt
|
||||||
|
|
||||||
|
for vi, variant_prompt in enumerate(variants):
|
||||||
|
print(f" Scoring variant {vi + 1}/{len(variants)}...")
|
||||||
|
|
||||||
|
score = self._score_variant(variant_prompt, fixture)
|
||||||
|
|
||||||
|
history.append({
|
||||||
|
"iteration": iteration,
|
||||||
|
"variant_index": vi + 1,
|
||||||
|
"prompt_text": variant_prompt[:200] + "..." if len(variant_prompt) > 200 else variant_prompt,
|
||||||
|
"prompt_length": len(variant_prompt),
|
||||||
|
"composite": score.composite,
|
||||||
|
"scores": {d: score.scores.get(d, 0.0) for d in dimensions},
|
||||||
|
"error": score.error,
|
||||||
|
"label": f"iter{iteration}_v{vi+1}",
|
||||||
|
})
|
||||||
|
|
||||||
|
if score.error:
|
||||||
|
print(f" ✗ Variant {vi + 1} errored: {score.error}")
|
||||||
|
total_variants_scored += 1
|
||||||
|
self._write_progress(
|
||||||
|
phase="variant_scored",
|
||||||
|
iteration=iteration, variant=vi + 1,
|
||||||
|
total_variants_scored=total_variants_scored,
|
||||||
|
best_composite=best_score.composite,
|
||||||
|
baseline_composite=baseline_composite,
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
best_label=best_label,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
total_variants_scored += 1
|
||||||
|
|
||||||
|
if score.composite > iteration_best_score.composite:
|
||||||
|
iteration_best_score = score
|
||||||
|
iteration_best_prompt = variant_prompt
|
||||||
|
print(f" ✓ New best: {score.composite:.3f} (was {best_score.composite:.3f})")
|
||||||
|
else:
|
||||||
|
print(f" · Score {score.composite:.3f} ≤ current best {iteration_best_score.composite:.3f}")
|
||||||
|
|
||||||
|
self._write_progress(
|
||||||
|
phase="variant_scored",
|
||||||
|
iteration=iteration, variant=vi + 1,
|
||||||
|
total_variants_scored=total_variants_scored,
|
||||||
|
best_composite=max(best_score.composite, iteration_best_score.composite),
|
||||||
|
baseline_composite=baseline_composite,
|
||||||
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||||||
|
best_label=best_label if iteration_best_score.composite <= best_score.composite
|
||||||
|
else f"iter{iteration}_v{vi+1}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update global best if this iteration improved
|
||||||
|
if iteration_best_score.composite > best_score.composite:
|
||||||
|
best_score = iteration_best_score
|
||||||
|
best_prompt = iteration_best_prompt
|
||||||
|
best_label = f"iter{iteration}"
|
||||||
|
print(f" ★ Iteration {iteration} improved: {best_score.composite:.3f}")
|
||||||
|
else:
|
||||||
|
print(f" · No improvement in iteration {iteration}")
|
||||||
|
|
||||||
|
self._print_iteration_summary(iteration, best_score)
|
||||||
|
|
||||||
|
# Final report
|
||||||
|
elapsed = round(time.monotonic() - t0, 2)
|
||||||
|
self._print_final_report(best_score, history, elapsed)
|
||||||
|
|
||||||
|
self._write_progress(
|
||||||
|
phase="complete",
|
||||||
|
iteration=self.iterations,
|
||||||
|
variant=self.variants_per_iter,
|
||||||
|
total_variants_scored=total_variants_scored,
|
||||||
|
best_composite=best_score.composite,
|
||||||
|
baseline_composite=baseline_composite,
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
best_label=best_label,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OptimizationResult(
|
||||||
|
best_prompt=best_prompt,
|
||||||
|
best_score=best_score,
|
||||||
|
history=history,
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Internal helpers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _write_progress(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
phase: str,
|
||||||
|
iteration: int,
|
||||||
|
variant: int,
|
||||||
|
total_variants_scored: int,
|
||||||
|
best_composite: float,
|
||||||
|
baseline_composite: float,
|
||||||
|
elapsed_seconds: float,
|
||||||
|
best_label: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""Write a progress.json file to the output directory for external monitoring.
|
||||||
|
|
||||||
|
File is atomic-replaced so readers never see partial writes.
|
||||||
|
"""
|
||||||
|
if not self.output_dir:
|
||||||
|
return
|
||||||
|
|
||||||
|
out_dir = Path(self.output_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
progress_path = out_dir / f"progress_stage{self.stage}.json"
|
||||||
|
|
||||||
|
total_expected = self.iterations * self.variants_per_iter
|
||||||
|
pct = (total_variants_scored / total_expected * 100) if total_expected else 0
|
||||||
|
|
||||||
|
# ETA: average time per variant × remaining
|
||||||
|
remaining = total_expected - total_variants_scored
|
||||||
|
avg_per_variant = (elapsed_seconds / total_variants_scored) if total_variants_scored > 0 else 0
|
||||||
|
eta_seconds = round(avg_per_variant * remaining, 1)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"stage": self.stage,
|
||||||
|
"phase": phase,
|
||||||
|
"iteration": iteration,
|
||||||
|
"total_iterations": self.iterations,
|
||||||
|
"variant": variant,
|
||||||
|
"variants_per_iter": self.variants_per_iter,
|
||||||
|
"total_variants_scored": total_variants_scored,
|
||||||
|
"total_expected": total_expected,
|
||||||
|
"percent_complete": round(pct, 1),
|
||||||
|
"baseline_composite": round(baseline_composite, 4),
|
||||||
|
"best_composite": round(best_composite, 4),
|
||||||
|
"improvement": round(best_composite - baseline_composite, 4),
|
||||||
|
"best_label": best_label,
|
||||||
|
"elapsed_seconds": round(elapsed_seconds, 1),
|
||||||
|
"eta_seconds": eta_seconds,
|
||||||
|
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Atomic write via temp file + rename
|
||||||
|
tmp_path = progress_path.with_suffix(".tmp")
|
||||||
|
tmp_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||||
|
tmp_path.rename(progress_path)
|
||||||
|
|
||||||
|
def _load_fixture(self) -> dict:
|
||||||
|
"""Load and validate the fixture JSON file against stage-specific keys."""
|
||||||
|
path = Path(self.fixture_path)
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"Fixture not found: {path}")
|
||||||
|
data = json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
for key in self.config.fixture_keys:
|
||||||
|
if key not in data:
|
||||||
|
raise KeyError(
|
||||||
|
f"Stage {self.stage} fixture must contain '{key}' key "
|
||||||
|
f"(required: {self.config.fixture_keys})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _score_variant(
|
||||||
|
self,
|
||||||
|
variant_prompt: str,
|
||||||
|
fixture: dict,
|
||||||
|
) -> ScoreResult:
|
||||||
|
"""Score a variant prompt by running LLM completion + scoring.
|
||||||
|
|
||||||
|
Dispatches to stage-specific synthesis logic:
|
||||||
|
- Stages 2-4: call LLM with the variant prompt and fixture input,
|
||||||
|
parse with the stage's schema, then score via score_stage_output()
|
||||||
|
- Stage 5: original flow (synthesis + page scoring)
|
||||||
|
"""
|
||||||
|
from pipeline.stages import _get_stage_config
|
||||||
|
|
||||||
|
import json as _json
|
||||||
|
import openai as _openai
|
||||||
|
|
||||||
|
model_override, modality = _get_stage_config(self.stage)
|
||||||
|
schema_class = self.config.get_schema()
|
||||||
|
|
||||||
|
# Build user prompt from fixture data — stage-specific formatting
|
||||||
|
user_prompt = self._build_user_prompt(fixture)
|
||||||
|
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
raw = self.client.complete(
|
||||||
|
system_prompt=variant_prompt,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
response_model=schema_class,
|
||||||
|
modality=modality,
|
||||||
|
model_override=model_override,
|
||||||
|
)
|
||||||
|
elapsed_synth = round(time.monotonic() - t0, 2)
|
||||||
|
except (_openai.APIConnectionError, _openai.APITimeoutError) as exc:
|
||||||
|
elapsed_synth = round(time.monotonic() - t0, 2)
|
||||||
|
return ScoreResult(
|
||||||
|
elapsed_seconds=elapsed_synth,
|
||||||
|
error=f"LLM error (stage {self.stage}): {exc}",
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
elapsed_synth = round(time.monotonic() - t0, 2)
|
||||||
|
logger.exception("Unexpected error during variant synthesis (stage %d)", self.stage)
|
||||||
|
return ScoreResult(
|
||||||
|
elapsed_seconds=elapsed_synth,
|
||||||
|
error=f"Unexpected synthesis error: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the LLM response into the stage schema
|
||||||
|
raw_text = str(raw).strip()
|
||||||
|
try:
|
||||||
|
parsed = self.client.parse_response(raw_text, schema_class)
|
||||||
|
except Exception as exc:
|
||||||
|
return ScoreResult(
|
||||||
|
elapsed_seconds=elapsed_synth,
|
||||||
|
error=f"Variant parse error (stage {self.stage}): {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert parsed output to JSON for the scorer
|
||||||
|
output_json = self._schema_to_output_json(parsed)
|
||||||
|
if output_json is None:
|
||||||
|
return ScoreResult(
|
||||||
|
elapsed_seconds=elapsed_synth,
|
||||||
|
error=f"Stage {self.stage} produced empty output",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Score using the generic stage scorer
|
||||||
|
result = self.scorer.score_stage_output(
|
||||||
|
stage=self.stage,
|
||||||
|
output_json=output_json,
|
||||||
|
input_json=self._fixture_to_input_json(fixture),
|
||||||
|
)
|
||||||
|
result.elapsed_seconds = round(result.elapsed_seconds + elapsed_synth, 2)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _build_user_prompt(self, fixture: dict) -> str:
|
||||||
|
"""Build a stage-appropriate user prompt from fixture data."""
|
||||||
|
if self.stage == 2:
|
||||||
|
segments_json = json.dumps(fixture["transcript_segments"], indent=2)
|
||||||
|
return f"<transcript_segments>\n{segments_json}\n</transcript_segments>"
|
||||||
|
|
||||||
|
elif self.stage == 3:
|
||||||
|
segments_json = json.dumps(fixture["topic_segments"], indent=2)
|
||||||
|
return f"<topic_segments>\n{segments_json}\n</topic_segments>"
|
||||||
|
|
||||||
|
elif self.stage == 4:
|
||||||
|
moments_json = json.dumps(fixture["extracted_moments"], indent=2)
|
||||||
|
taxonomy = fixture.get("taxonomy", "")
|
||||||
|
prompt = f"<moments>\n{moments_json}\n</moments>"
|
||||||
|
if taxonomy:
|
||||||
|
prompt += f"\n<taxonomy>{taxonomy}</taxonomy>"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
elif self.stage == 5:
|
||||||
|
moments_json = json.dumps(fixture["moments"], indent=2)
|
||||||
|
creator = fixture.get("creator_name", "Unknown")
|
||||||
|
return f"<creator>{creator}</creator>\n<moments>\n{moments_json}\n</moments>"
|
||||||
|
|
||||||
|
else:
|
||||||
|
return json.dumps(fixture, indent=2)
|
||||||
|
|
||||||
|
def _schema_to_output_json(self, parsed: object) -> dict | list | None:
|
||||||
|
"""Convert a parsed Pydantic schema instance to JSON-serializable dict."""
|
||||||
|
if hasattr(parsed, "model_dump"):
|
||||||
|
return parsed.model_dump()
|
||||||
|
elif hasattr(parsed, "dict"):
|
||||||
|
return parsed.dict()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fixture_to_input_json(self, fixture: dict) -> dict | list:
|
||||||
|
"""Extract the primary input data from the fixture for scorer context."""
|
||||||
|
if self.stage == 2:
|
||||||
|
return fixture["transcript_segments"]
|
||||||
|
elif self.stage == 3:
|
||||||
|
return fixture["topic_segments"]
|
||||||
|
elif self.stage == 4:
|
||||||
|
return fixture["extracted_moments"]
|
||||||
|
elif self.stage == 5:
|
||||||
|
return fixture["moments"]
|
||||||
|
return fixture
|
||||||
|
|
||||||
|
def _print_iteration_summary(
|
||||||
|
self,
|
||||||
|
iteration: int,
|
||||||
|
score: ScoreResult,
|
||||||
|
is_baseline: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Print a compact one-line summary of the current best scores."""
|
||||||
|
label = "BASELINE" if is_baseline else f"ITER {iteration}"
|
||||||
|
dimensions = self.config.dimensions
|
||||||
|
dims = " ".join(
|
||||||
|
f"{d[:4]}={score.scores.get(d, 0.0):.2f}" for d in dimensions
|
||||||
|
)
|
||||||
|
print(f" [{label}] composite={score.composite:.3f} {dims}")
|
||||||
|
|
||||||
|
def _print_final_report(
|
||||||
|
self,
|
||||||
|
best_score: ScoreResult,
|
||||||
|
history: list[dict],
|
||||||
|
elapsed: float,
|
||||||
|
) -> None:
|
||||||
|
"""Print the final optimization summary."""
|
||||||
|
dimensions = self.config.dimensions
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(" OPTIMIZATION COMPLETE")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f" Total time: {elapsed}s")
|
||||||
|
print(f" Iterations: {self.iterations}")
|
||||||
|
print(f" Variants scored: {len(history) - 1}") # minus baseline
|
||||||
|
|
||||||
|
baseline_composite = history[0]["composite"] if history else 0.0
|
||||||
|
improvement = best_score.composite - baseline_composite
|
||||||
|
|
||||||
|
print(f"\n Baseline composite: {baseline_composite:.3f}")
|
||||||
|
print(f" Best composite: {best_score.composite:.3f}")
|
||||||
|
if improvement > 0:
|
||||||
|
print(f" Improvement: +{improvement:.3f}")
|
||||||
|
else:
|
||||||
|
print(f" Improvement: {improvement:.3f} (no gain)")
|
||||||
|
|
||||||
|
print(f"\n Per-dimension best scores:")
|
||||||
|
for d in dimensions:
|
||||||
|
val = best_score.scores.get(d, 0.0)
|
||||||
|
bar = "█" * int(val * 20) + "░" * (20 - int(val * 20))
|
||||||
|
print(f" {d.replace('_', ' ').title():25s} {val:.2f} {bar}")
|
||||||
|
|
||||||
|
errored = sum(1 for h in history if h.get("error"))
|
||||||
|
if errored:
|
||||||
|
print(f"\n ⚠ {errored} variant(s) errored during scoring")
|
||||||
|
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
|
||||||
|
# Late import to avoid circular dependency (scorer imports at module level,
|
||||||
|
# variant_generator imports scorer)
|
||||||
|
from pipeline.quality.variant_generator import PromptVariantGenerator # noqa: E402
|
||||||
0
backend/pipeline/quality/results/.gitkeep
Normal file
0
backend/pipeline/quality/results/.gitkeep
Normal file
91
backend/pipeline/quality/results/chat_eval_baseline.json
Normal file
91
backend/pipeline/quality/results/chat_eval_baseline.json
Normal file
|
|
@ -0,0 +1,91 @@
|
||||||
|
{
|
||||||
|
"timestamp": "20260404_043200",
|
||||||
|
"evaluation_method": "manual_curl",
|
||||||
|
"llm_status": "unavailable (upstream 502 Bad Gateway at chat.forgetyour.name)",
|
||||||
|
"api_health": "ok",
|
||||||
|
"total_queries": 6,
|
||||||
|
"scored_queries": 0,
|
||||||
|
"errors_llm": 6,
|
||||||
|
"note": "LLM completions unavailable — only source retrieval quality assessed. Re-run with automated eval when LLM proxy is restored.",
|
||||||
|
"source_retrieval_results": [
|
||||||
|
{
|
||||||
|
"query": "How do I set up sidechain compression on a bass synth using a kick drum as the trigger?",
|
||||||
|
"creator": null,
|
||||||
|
"personality_weight": 0.0,
|
||||||
|
"category": "technical",
|
||||||
|
"source_count": 10,
|
||||||
|
"unique_creators": ["Break", "Caracal Project, The", "Chee", "KOAN Sound"],
|
||||||
|
"creator_distribution": {"Break": 3, "Caracal Project, The": 2, "Chee": 2, "KOAN Sound": 1},
|
||||||
|
"relevance_assessment": "highly_relevant",
|
||||||
|
"notes": "All 10 sources directly about sidechain compression. Good creator diversity."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "What are the different approaches to layering synth sounds across creators?",
|
||||||
|
"creator": null,
|
||||||
|
"personality_weight": 0.0,
|
||||||
|
"category": "cross_creator",
|
||||||
|
"source_count": 10,
|
||||||
|
"unique_creators": ["Chee", "COPYCATT", "Caracal Project, The", "Current Value", "Emperor"],
|
||||||
|
"creator_distribution": {"Chee": 5, "COPYCATT": 2, "Caracal Project, The": 1, "Current Value": 1, "Emperor": 1},
|
||||||
|
"relevance_assessment": "relevant_but_skewed",
|
||||||
|
"notes": "50% of sources from Chee — cross-creator diversity could be improved."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "How does this creator approach sound design for bass sounds?",
|
||||||
|
"creator": "Keota",
|
||||||
|
"personality_weight": 0.0,
|
||||||
|
"category": "creator_encyclopedic",
|
||||||
|
"source_count": 10,
|
||||||
|
"unique_creators": ["COPYCATT", "Break", "Chee", "Caracal Project, The"],
|
||||||
|
"creator_distribution": {"COPYCATT": 2, "Break": 2, "Chee": 3, "Caracal Project, The": 3},
|
||||||
|
"relevance_assessment": "creator_scope_failure",
|
||||||
|
"notes": "Zero sources from Keota despite creator-scoped query. Cascade fell through to global tier."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "What mixing techniques does this creator recommend for achieving width in a mix?",
|
||||||
|
"creator": "Mr. Bill",
|
||||||
|
"personality_weight": 0.0,
|
||||||
|
"category": "creator_encyclopedic",
|
||||||
|
"source_count": 10,
|
||||||
|
"unique_creators": ["Break", "Frequent", "Caracal Project, The", "COPYCATT", "Chee"],
|
||||||
|
"creator_distribution": {"Break": 2, "Frequent": 1, "Caracal Project, The": 2, "COPYCATT": 2, "Chee": 3},
|
||||||
|
"relevance_assessment": "creator_scope_failure",
|
||||||
|
"notes": "Zero sources from Mr. Bill despite creator-scoped query."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "How does this creator approach sound design for bass sounds? (personality)",
|
||||||
|
"creator": "Keota",
|
||||||
|
"personality_weight": 0.7,
|
||||||
|
"category": "creator_personality",
|
||||||
|
"source_count": 10,
|
||||||
|
"personality_profile_exists": false,
|
||||||
|
"notes": "Personality weight=0.7 accepted but no profile data exists — falls back to encyclopedic mode silently."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "What mixing techniques does this creator recommend for width? (personality)",
|
||||||
|
"creator": "Mr. Bill",
|
||||||
|
"personality_weight": 0.7,
|
||||||
|
"category": "creator_personality",
|
||||||
|
"source_count": 10,
|
||||||
|
"personality_profile_exists": false,
|
||||||
|
"notes": "Personality weight=0.7 accepted but no profile data exists — falls back to encyclopedic mode silently."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"personality_profiles_status": {
|
||||||
|
"total_creators": 25,
|
||||||
|
"creators_with_profile": 0,
|
||||||
|
"assessment": "No personality profiles populated. The 5-tier progressive injection system is architecturally complete (26 unit tests pass) but functionally inert on the live system."
|
||||||
|
},
|
||||||
|
"prompt_changes": {
|
||||||
|
"before_lines": 4,
|
||||||
|
"after_lines": 18,
|
||||||
|
"changes": [
|
||||||
|
"Added structured citation guidance with inline example",
|
||||||
|
"Added response format section (2-4 paragraphs, bullet lists, bold terms)",
|
||||||
|
"Added domain awareness (music production subdomain list)",
|
||||||
|
"Added conflicting source handling instruction",
|
||||||
|
"Added response length guidance"
|
||||||
|
],
|
||||||
|
"test_impact": "Zero test modifications needed — all 26 tests pass unchanged"
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because one or more lines are too long
18
backend/pipeline/quality/results/progress_stage5.json
Normal file
18
backend/pipeline/quality/results/progress_stage5.json
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
{
|
||||||
|
"stage": 5,
|
||||||
|
"phase": "variant_scored",
|
||||||
|
"iteration": 3,
|
||||||
|
"total_iterations": 5,
|
||||||
|
"variant": 2,
|
||||||
|
"variants_per_iter": 3,
|
||||||
|
"total_variants_scored": 4,
|
||||||
|
"total_expected": 15,
|
||||||
|
"percent_complete": 26.7,
|
||||||
|
"baseline_composite": 1.0,
|
||||||
|
"best_composite": 1.0,
|
||||||
|
"improvement": 0.0,
|
||||||
|
"best_label": "baseline",
|
||||||
|
"elapsed_seconds": 1303.4,
|
||||||
|
"eta_seconds": 3584.3,
|
||||||
|
"updated_at": "2026-04-01T10:37:26.971865+00:00"
|
||||||
|
}
|
||||||
614
backend/pipeline/quality/scorer.py
Normal file
614
backend/pipeline/quality/scorer.py
Normal file
|
|
@ -0,0 +1,614 @@
|
||||||
|
"""Multi-stage quality scorer — LLM-as-judge evaluation with per-stage rubrics.
|
||||||
|
|
||||||
|
Supports stages 2-5, each with its own scoring dimensions, rubric, format
|
||||||
|
markers, fixture key requirements, prompt file name, and output schema.
|
||||||
|
|
||||||
|
Run via: python -m pipeline.quality score --file <path>
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from pipeline.llm_client import LLMClient
|
||||||
|
from pipeline.quality.voice_dial import VoiceDial
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Per-stage configuration registry ─────────────────────────────────────────
|
||||||
|
|
||||||
|
class StageConfig:
|
||||||
|
"""Configuration for scoring a specific pipeline stage."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stage: int,
|
||||||
|
dimensions: list[str],
|
||||||
|
rubric: str,
|
||||||
|
format_markers: list[str],
|
||||||
|
fixture_keys: list[str],
|
||||||
|
prompt_file: str,
|
||||||
|
schema_class: str,
|
||||||
|
) -> None:
|
||||||
|
self.stage = stage
|
||||||
|
self.dimensions = dimensions
|
||||||
|
self.rubric = rubric
|
||||||
|
self.format_markers = format_markers
|
||||||
|
self.fixture_keys = fixture_keys
|
||||||
|
self.prompt_file = prompt_file
|
||||||
|
self.schema_class = schema_class
|
||||||
|
|
||||||
|
def get_schema(self) -> type[BaseModel]:
|
||||||
|
"""Import and return the Pydantic schema class for this stage."""
|
||||||
|
from pipeline import schemas
|
||||||
|
return getattr(schemas, self.schema_class)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stage rubrics ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_STAGE_2_RUBRIC = """\
|
||||||
|
You are an expert evaluator of transcript segmentation quality for educational content.
|
||||||
|
|
||||||
|
You will be given:
|
||||||
|
1. A segmentation result (JSON with segments, each having start_index, end_index, topic_label, summary)
|
||||||
|
2. The source transcript segments used as input
|
||||||
|
|
||||||
|
Evaluate the segmentation across these 4 dimensions, scoring each 0.0 to 1.0:
|
||||||
|
|
||||||
|
**coverage_completeness** — All transcript content accounted for
|
||||||
|
- 0.9-1.0: Every transcript segment is covered by exactly one topic segment, no gaps or overlaps
|
||||||
|
- 0.5-0.7: Minor gaps or overlaps, but most content is covered
|
||||||
|
- 0.0-0.3: Large gaps — significant transcript segments are not assigned to any topic
|
||||||
|
|
||||||
|
**topic_specificity** — Topic labels are descriptive and useful
|
||||||
|
- 0.9-1.0: Labels are specific and descriptive (e.g., "Sidechain compression on kick-bass" not "Audio processing")
|
||||||
|
- 0.5-0.7: Labels are somewhat specific but could be more descriptive
|
||||||
|
- 0.0-0.3: Labels are generic or meaningless ("Topic 1", "Discussion", "Audio")
|
||||||
|
|
||||||
|
**boundary_accuracy** — Segment boundaries align with actual topic transitions
|
||||||
|
- 0.9-1.0: Boundaries fall at natural topic transitions, segments are coherent units
|
||||||
|
- 0.5-0.7: Most boundaries are reasonable but some segments mix distinct topics
|
||||||
|
- 0.0-0.3: Boundaries seem arbitrary, segments contain unrelated content
|
||||||
|
|
||||||
|
**summary_quality** — Summaries accurately describe segment content
|
||||||
|
- 0.9-1.0: Summaries capture the key points of each segment concisely and accurately
|
||||||
|
- 0.5-0.7: Summaries are acceptable but miss some key points or are too vague
|
||||||
|
- 0.0-0.3: Summaries are inaccurate, too generic, or missing
|
||||||
|
|
||||||
|
Return ONLY a JSON object with this exact structure:
|
||||||
|
{
|
||||||
|
"coverage_completeness": <float 0.0-1.0>,
|
||||||
|
"topic_specificity": <float 0.0-1.0>,
|
||||||
|
"boundary_accuracy": <float 0.0-1.0>,
|
||||||
|
"summary_quality": <float 0.0-1.0>,
|
||||||
|
"justifications": {
|
||||||
|
"coverage_completeness": "<1-2 sentence justification>",
|
||||||
|
"topic_specificity": "<1-2 sentence justification>",
|
||||||
|
"boundary_accuracy": "<1-2 sentence justification>",
|
||||||
|
"summary_quality": "<1-2 sentence justification>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_STAGE_3_RUBRIC = """\
|
||||||
|
You are an expert evaluator of key moment extraction quality for educational content.
|
||||||
|
|
||||||
|
You will be given:
|
||||||
|
1. An extraction result (JSON with moments, each having title, summary, start_time, end_time, content_type, plugins, raw_transcript)
|
||||||
|
2. The source topic segments used as input
|
||||||
|
|
||||||
|
Evaluate the extraction across these 5 dimensions, scoring each 0.0 to 1.0:
|
||||||
|
|
||||||
|
**moment_richness** — Extracted moments capture substantial, distinct insights
|
||||||
|
- 0.9-1.0: Each moment represents a meaningful, distinct technique or concept with detailed summary
|
||||||
|
- 0.5-0.7: Moments are valid but some are thin or overlap significantly with others
|
||||||
|
- 0.0-0.3: Moments are trivial, redundant, or miss the main techniques discussed
|
||||||
|
|
||||||
|
**timestamp_accuracy** — Time ranges are plausible and well-bounded
|
||||||
|
- 0.9-1.0: Start/end times form reasonable ranges, no zero-length or absurdly long spans
|
||||||
|
- 0.5-0.7: Most timestamps are reasonable but some spans seem too wide or narrow
|
||||||
|
- 0.0-0.3: Timestamps appear arbitrary or many are zero/identical
|
||||||
|
|
||||||
|
**content_type_correctness** — Content types match the actual moment content
|
||||||
|
- 0.9-1.0: Each moment's content_type (technique/settings/reasoning/workflow) accurately describes it
|
||||||
|
- 0.5-0.7: Most are correct but 1-2 are miscategorized
|
||||||
|
- 0.0-0.3: Content types seem randomly assigned or all the same
|
||||||
|
|
||||||
|
**summary_actionability** — Summaries provide actionable, specific information
|
||||||
|
- 0.9-1.0: Summaries contain concrete details (values, settings, steps) that a practitioner could follow
|
||||||
|
- 0.5-0.7: Summaries describe the topic but lack specific actionable details
|
||||||
|
- 0.0-0.3: Summaries are vague ("discusses compression") with no actionable information
|
||||||
|
|
||||||
|
**plugin_normalization** — Plugin/tool names are correctly identified and normalized
|
||||||
|
- 0.9-1.0: Plugin names match standard names, no duplicates, captures all mentioned tools
|
||||||
|
- 0.5-0.7: Most plugins captured but some are misspelled, duplicated, or missed
|
||||||
|
- 0.0-0.3: Plugin list is mostly empty, contains non-plugins, or has many errors
|
||||||
|
|
||||||
|
Return ONLY a JSON object with this exact structure:
|
||||||
|
{
|
||||||
|
"moment_richness": <float 0.0-1.0>,
|
||||||
|
"timestamp_accuracy": <float 0.0-1.0>,
|
||||||
|
"content_type_correctness": <float 0.0-1.0>,
|
||||||
|
"summary_actionability": <float 0.0-1.0>,
|
||||||
|
"plugin_normalization": <float 0.0-1.0>,
|
||||||
|
"justifications": {
|
||||||
|
"moment_richness": "<1-2 sentence justification>",
|
||||||
|
"timestamp_accuracy": "<1-2 sentence justification>",
|
||||||
|
"content_type_correctness": "<1-2 sentence justification>",
|
||||||
|
"summary_actionability": "<1-2 sentence justification>",
|
||||||
|
"plugin_normalization": "<1-2 sentence justification>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_STAGE_4_RUBRIC = """\
|
||||||
|
You are an expert evaluator of content classification quality for educational content.
|
||||||
|
|
||||||
|
You will be given:
|
||||||
|
1. A classification result (JSON with classifications, each having moment_index, topic_category, topic_tags)
|
||||||
|
2. The source extracted moments used as input
|
||||||
|
|
||||||
|
Evaluate the classification across these 4 dimensions, scoring each 0.0 to 1.0:
|
||||||
|
|
||||||
|
**category_accuracy** — Topic categories are appropriate and meaningful
|
||||||
|
- 0.9-1.0: Categories accurately reflect the primary topic of each moment, using domain-appropriate labels
|
||||||
|
- 0.5-0.7: Most categories are reasonable but some are too broad or slightly off
|
||||||
|
- 0.0-0.3: Categories are generic ("Music"), incorrect, or all the same
|
||||||
|
|
||||||
|
**tag_completeness** — All relevant tags are captured
|
||||||
|
- 0.9-1.0: Tags capture the key concepts, tools, and techniques in each moment comprehensively
|
||||||
|
- 0.5-0.7: Main tags are present but secondary concepts or tools are missed
|
||||||
|
- 0.0-0.3: Tags are sparse, missing major concepts mentioned in the moments
|
||||||
|
|
||||||
|
**tag_specificity** — Tags are specific enough to be useful for search/filtering
|
||||||
|
- 0.9-1.0: Tags are specific ("sidechain compression", "Pro-Q 3") not generic ("audio", "mixing")
|
||||||
|
- 0.5-0.7: Mix of specific and generic tags
|
||||||
|
- 0.0-0.3: Tags are too generic to meaningfully distinguish moments
|
||||||
|
|
||||||
|
**coverage** — All moments are classified
|
||||||
|
- 0.9-1.0: Every moment_index from the input has a corresponding classification entry
|
||||||
|
- 0.5-0.7: Most moments classified but 1-2 are missing
|
||||||
|
- 0.0-0.3: Many moments are not classified
|
||||||
|
|
||||||
|
Return ONLY a JSON object with this exact structure:
|
||||||
|
{
|
||||||
|
"category_accuracy": <float 0.0-1.0>,
|
||||||
|
"tag_completeness": <float 0.0-1.0>,
|
||||||
|
"tag_specificity": <float 0.0-1.0>,
|
||||||
|
"coverage": <float 0.0-1.0>,
|
||||||
|
"justifications": {
|
||||||
|
"category_accuracy": "<1-2 sentence justification>",
|
||||||
|
"tag_completeness": "<1-2 sentence justification>",
|
||||||
|
"tag_specificity": "<1-2 sentence justification>",
|
||||||
|
"coverage": "<1-2 sentence justification>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_STAGE_5_RUBRIC = """\
|
||||||
|
You are an expert evaluator of synthesized technique articles for music production education.
|
||||||
|
|
||||||
|
You will be given:
|
||||||
|
1. A synthesized technique page (JSON with title, summary, body_sections)
|
||||||
|
2. The source key moments (transcript excerpts, summaries, tags) used to create it
|
||||||
|
|
||||||
|
Evaluate the page across these 5 dimensions, scoring each 0.0 to 1.0:
|
||||||
|
|
||||||
|
**structural** — Section naming and organization
|
||||||
|
- 0.9-1.0: Well-named specific sections (not generic "Overview"/"Tips"), appropriate count (3-6), 2-5 paragraphs per section
|
||||||
|
- 0.5-0.7: Acceptable structure but some generic section names or uneven depth
|
||||||
|
- 0.0-0.3: Poor structure — too few/many sections, generic names, single-paragraph sections
|
||||||
|
|
||||||
|
**content_specificity** — Concrete technical details
|
||||||
|
- 0.9-1.0: Rich in frequencies (Hz), time values (ms), ratios, plugin names, specific settings, dB values
|
||||||
|
- 0.5-0.7: Some specific details but padded with vague statements ("adjust to taste", "experiment with settings")
|
||||||
|
- 0.0-0.3: Mostly vague generalities with few concrete values from the source material
|
||||||
|
|
||||||
|
**voice_preservation** — Creator's authentic voice
|
||||||
|
- 0.9-1.0: Direct quotes preserved, opinions attributed to creator by name, personality and strong views retained
|
||||||
|
- 0.5-0.7: Some paraphrased references to creator's views but few direct quotes
|
||||||
|
- 0.0-0.3: Encyclopedia style — creator's voice completely smoothed out, no attribution
|
||||||
|
|
||||||
|
**readability** — Synthesis quality and flow
|
||||||
|
- 0.9-1.0: Reads as a cohesive article, related info merged, logical flow, no redundancy or contradiction
|
||||||
|
- 0.5-0.7: Generally readable but some awkward transitions or minor repetition
|
||||||
|
- 0.0-0.3: Feels like concatenated bullet points, disjointed, redundant passages
|
||||||
|
|
||||||
|
**factual_fidelity** — Grounded in source material
|
||||||
|
- 0.9-1.0: Every claim traceable to source moments, no invented plugin names/settings/techniques
|
||||||
|
- 0.5-0.7: Mostly grounded but 1-2 details seem embellished or not directly from sources
|
||||||
|
- 0.0-0.3: Contains hallucinated specifics — plugin names, settings, or techniques not in sources
|
||||||
|
|
||||||
|
Return ONLY a JSON object with this exact structure:
|
||||||
|
{
|
||||||
|
"structural": <float 0.0-1.0>,
|
||||||
|
"content_specificity": <float 0.0-1.0>,
|
||||||
|
"voice_preservation": <float 0.0-1.0>,
|
||||||
|
"readability": <float 0.0-1.0>,
|
||||||
|
"factual_fidelity": <float 0.0-1.0>,
|
||||||
|
"justifications": {
|
||||||
|
"structural": "<1-2 sentence justification>",
|
||||||
|
"content_specificity": "<1-2 sentence justification>",
|
||||||
|
"voice_preservation": "<1-2 sentence justification>",
|
||||||
|
"readability": "<1-2 sentence justification>",
|
||||||
|
"factual_fidelity": "<1-2 sentence justification>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Backward-compat alias used by synthesize_and_score and external references
|
||||||
|
SCORING_RUBRIC = _STAGE_5_RUBRIC
|
||||||
|
|
||||||
|
# Build the stage configs registry
|
||||||
|
STAGE_CONFIGS: dict[int, StageConfig] = {
|
||||||
|
2: StageConfig(
|
||||||
|
stage=2,
|
||||||
|
dimensions=["coverage_completeness", "topic_specificity", "boundary_accuracy", "summary_quality"],
|
||||||
|
rubric=_STAGE_2_RUBRIC,
|
||||||
|
format_markers=["segments", "start_index", "end_index", "topic_label"],
|
||||||
|
fixture_keys=["transcript_segments"],
|
||||||
|
prompt_file="stage2_segmentation.txt",
|
||||||
|
schema_class="SegmentationResult",
|
||||||
|
),
|
||||||
|
3: StageConfig(
|
||||||
|
stage=3,
|
||||||
|
dimensions=["moment_richness", "timestamp_accuracy", "content_type_correctness", "summary_actionability", "plugin_normalization"],
|
||||||
|
rubric=_STAGE_3_RUBRIC,
|
||||||
|
format_markers=["moments", "content_type", "raw_transcript", "plugins"],
|
||||||
|
fixture_keys=["topic_segments"],
|
||||||
|
prompt_file="stage3_extraction.txt",
|
||||||
|
schema_class="ExtractionResult",
|
||||||
|
),
|
||||||
|
4: StageConfig(
|
||||||
|
stage=4,
|
||||||
|
dimensions=["category_accuracy", "tag_completeness", "tag_specificity", "coverage"],
|
||||||
|
rubric=_STAGE_4_RUBRIC,
|
||||||
|
format_markers=["classifications", "moment_index", "topic_category", "topic_tags"],
|
||||||
|
fixture_keys=["extracted_moments"],
|
||||||
|
prompt_file="stage4_classification.txt",
|
||||||
|
schema_class="ClassificationResult",
|
||||||
|
),
|
||||||
|
5: StageConfig(
|
||||||
|
stage=5,
|
||||||
|
dimensions=["structural", "content_specificity", "voice_preservation", "readability", "factual_fidelity"],
|
||||||
|
rubric=SCORING_RUBRIC,
|
||||||
|
format_markers=["SynthesisResult", '"pages"', "body_sections", "title", "summary"],
|
||||||
|
fixture_keys=["moments", "creator_name"],
|
||||||
|
prompt_file="stage5_synthesis.txt",
|
||||||
|
schema_class="SynthesisResult",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Backward-compatible alias: stage 5 dimensions list
|
||||||
|
DIMENSIONS = STAGE_CONFIGS[5].dimensions
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result type ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScoreResult:
|
||||||
|
"""Outcome of scoring a stage output across quality dimensions.
|
||||||
|
|
||||||
|
Uses a generic ``scores`` dict keyed by dimension name. Stage 5's
|
||||||
|
original named fields (structural, content_specificity, …) are
|
||||||
|
preserved as properties for backward compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
scores: dict[str, float] = field(default_factory=dict)
|
||||||
|
composite: float = 0.0
|
||||||
|
justifications: dict[str, str] = field(default_factory=dict)
|
||||||
|
elapsed_seconds: float = 0.0
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
# ── Backward-compat properties for stage 5 named dimensions ──────
|
||||||
|
@property
|
||||||
|
def structural(self) -> float:
|
||||||
|
return self.scores.get("structural", 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content_specificity(self) -> float:
|
||||||
|
return self.scores.get("content_specificity", 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def voice_preservation(self) -> float:
|
||||||
|
return self.scores.get("voice_preservation", 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def readability(self) -> float:
|
||||||
|
return self.scores.get("readability", 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def factual_fidelity(self) -> float:
|
||||||
|
return self.scores.get("factual_fidelity", 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Runner ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ScoreRunner:
|
||||||
|
"""Scores pipeline stage outputs using LLM-as-judge evaluation."""
|
||||||
|
|
||||||
|
def __init__(self, client: LLMClient) -> None:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
# ── Generic stage scorer ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
def score_stage_output(
|
||||||
|
self,
|
||||||
|
stage: int,
|
||||||
|
output_json: dict | list,
|
||||||
|
input_json: dict | list,
|
||||||
|
) -> ScoreResult:
|
||||||
|
"""Score an arbitrary stage's output against its input.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
stage:
|
||||||
|
Pipeline stage number (2-5).
|
||||||
|
output_json:
|
||||||
|
The stage output to evaluate (parsed JSON).
|
||||||
|
input_json:
|
||||||
|
The stage input / source material.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ScoreResult with per-dimension scores for the requested stage.
|
||||||
|
"""
|
||||||
|
if stage not in STAGE_CONFIGS:
|
||||||
|
return ScoreResult(error=f"No config for stage {stage}. Valid: {sorted(STAGE_CONFIGS)}")
|
||||||
|
|
||||||
|
cfg = STAGE_CONFIGS[stage]
|
||||||
|
|
||||||
|
user_prompt = (
|
||||||
|
"## Stage Output\n\n"
|
||||||
|
f"```json\n{json.dumps(output_json, indent=2)}\n```\n\n"
|
||||||
|
"## Stage Input\n\n"
|
||||||
|
f"```json\n{json.dumps(input_json, indent=2)}\n```\n\n"
|
||||||
|
f"Score this stage {stage} output across all {len(cfg.dimensions)} dimensions."
|
||||||
|
)
|
||||||
|
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
resp = self.client.complete(
|
||||||
|
system_prompt=cfg.rubric,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
response_model=BaseModel,
|
||||||
|
modality="chat",
|
||||||
|
)
|
||||||
|
elapsed = round(time.monotonic() - t0, 2)
|
||||||
|
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
||||||
|
elapsed = round(time.monotonic() - t0, 2)
|
||||||
|
url = self.client.settings.llm_api_url
|
||||||
|
fallback = self.client.settings.llm_fallback_url
|
||||||
|
return ScoreResult(
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
error=f"Cannot reach LLM endpoint at {url} (fallback {fallback}). Error: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_text = str(resp).strip()
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw_text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error("Malformed judge response (not JSON): %.300s", raw_text)
|
||||||
|
return ScoreResult(
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
error=f"Malformed judge response (not valid JSON). Raw excerpt: {raw_text[:200]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._parse_scores(parsed, elapsed, cfg.dimensions)
|
||||||
|
|
||||||
|
# ── Stage 5 convenience (backward compat) ────────────────────────────
|
||||||
|
|
||||||
|
def score_page(
|
||||||
|
self,
|
||||||
|
page_json: dict,
|
||||||
|
moments: list[dict],
|
||||||
|
) -> ScoreResult:
|
||||||
|
"""Evaluate a stage 5 technique page against source moments."""
|
||||||
|
return self.score_stage_output(
|
||||||
|
stage=5,
|
||||||
|
output_json=page_json,
|
||||||
|
input_json=moments,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._parse_scores(parsed, elapsed)
|
||||||
|
|
||||||
|
def _parse_scores(self, parsed: dict, elapsed: float, dimensions: list[str] | None = None) -> ScoreResult:
|
||||||
|
"""Extract and validate scores from parsed JSON response."""
|
||||||
|
dims = dimensions or DIMENSIONS
|
||||||
|
scores: dict[str, float] = {}
|
||||||
|
justifications: dict[str, str] = {}
|
||||||
|
|
||||||
|
raw_justifications = parsed.get("justifications", {})
|
||||||
|
if not isinstance(raw_justifications, dict):
|
||||||
|
raw_justifications = {}
|
||||||
|
|
||||||
|
for dim in dims:
|
||||||
|
raw = parsed.get(dim)
|
||||||
|
if raw is None:
|
||||||
|
logger.warning("Missing dimension '%s' in judge response", dim)
|
||||||
|
scores[dim] = 0.0
|
||||||
|
justifications[dim] = "(missing from judge response)"
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
val = float(raw)
|
||||||
|
scores[dim] = max(0.0, min(1.0, val)) # clamp
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
logger.warning("Invalid value for '%s': %r", dim, raw)
|
||||||
|
scores[dim] = 0.0
|
||||||
|
justifications[dim] = f"(invalid value: {raw!r})"
|
||||||
|
continue
|
||||||
|
|
||||||
|
justifications[dim] = str(raw_justifications.get(dim, ""))
|
||||||
|
|
||||||
|
composite = sum(scores.values()) / len(dims) if dims else 0.0
|
||||||
|
|
||||||
|
return ScoreResult(
|
||||||
|
scores=scores,
|
||||||
|
composite=round(composite, 3),
|
||||||
|
justifications=justifications,
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def synthesize_and_score(
|
||||||
|
self,
|
||||||
|
moments: list[dict],
|
||||||
|
creator_name: str,
|
||||||
|
voice_level: float,
|
||||||
|
) -> ScoreResult:
|
||||||
|
"""Re-synthesize from source moments with a voice-dialed prompt, then score.
|
||||||
|
|
||||||
|
Loads the stage 5 synthesis prompt from disk, applies the VoiceDial
|
||||||
|
modifier at the given voice_level, calls the LLM to produce a
|
||||||
|
SynthesisResult, then scores the first page.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
moments:
|
||||||
|
Source key moments (dicts with summary, transcript_excerpt, etc.)
|
||||||
|
creator_name:
|
||||||
|
Creator name to inject into the synthesis prompt.
|
||||||
|
voice_level:
|
||||||
|
Float 0.0–1.0 controlling voice preservation intensity.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ScoreResult with per-dimension scores after voice-dialed re-synthesis.
|
||||||
|
"""
|
||||||
|
from pipeline.schemas import SynthesisResult
|
||||||
|
from pipeline.stages import _get_stage_config, _load_prompt
|
||||||
|
|
||||||
|
# Load and modify the stage 5 system prompt
|
||||||
|
try:
|
||||||
|
base_prompt = _load_prompt("stage5_synthesis.txt")
|
||||||
|
except FileNotFoundError as exc:
|
||||||
|
return ScoreResult(error=f"Prompt file not found: {exc}")
|
||||||
|
|
||||||
|
dial = VoiceDial(base_prompt)
|
||||||
|
modified_prompt = dial.modify(voice_level)
|
||||||
|
band = dial.band_name(voice_level)
|
||||||
|
|
||||||
|
# Build user prompt in the same format as _synthesize_chunk
|
||||||
|
moments_json = json.dumps(moments, indent=2)
|
||||||
|
user_prompt = f"<creator>{creator_name}</creator>\n<moments>\n{moments_json}\n</moments>"
|
||||||
|
|
||||||
|
model_override, modality = _get_stage_config(5)
|
||||||
|
|
||||||
|
print(f" Re-synthesizing at voice_level={voice_level} (band={band})...")
|
||||||
|
|
||||||
|
t0 = time.monotonic()
|
||||||
|
try:
|
||||||
|
raw = self.client.complete(
|
||||||
|
system_prompt=modified_prompt,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
response_model=SynthesisResult,
|
||||||
|
modality=modality,
|
||||||
|
model_override=model_override,
|
||||||
|
)
|
||||||
|
elapsed_synth = round(time.monotonic() - t0, 2)
|
||||||
|
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
||||||
|
elapsed_synth = round(time.monotonic() - t0, 2)
|
||||||
|
url = self.client.settings.llm_api_url
|
||||||
|
fallback = self.client.settings.llm_fallback_url
|
||||||
|
return ScoreResult(
|
||||||
|
elapsed_seconds=elapsed_synth,
|
||||||
|
error=(
|
||||||
|
f"Cannot reach LLM endpoint at {url} (fallback {fallback}). "
|
||||||
|
f"Error: {exc}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse synthesis response
|
||||||
|
raw_text = str(raw).strip()
|
||||||
|
try:
|
||||||
|
synthesis = self.client.parse_response(raw_text, SynthesisResult)
|
||||||
|
except (json.JSONDecodeError, ValueError, Exception) as exc:
|
||||||
|
logger.error("Malformed synthesis response: %.300s", raw_text)
|
||||||
|
return ScoreResult(
|
||||||
|
elapsed_seconds=elapsed_synth,
|
||||||
|
error=f"Malformed synthesis response: {exc}. Raw excerpt: {raw_text[:200]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not synthesis.pages:
|
||||||
|
return ScoreResult(
|
||||||
|
elapsed_seconds=elapsed_synth,
|
||||||
|
error="Synthesis returned no pages.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Score the first page
|
||||||
|
page = synthesis.pages[0]
|
||||||
|
page_json = {
|
||||||
|
"title": page.title,
|
||||||
|
"creator_name": creator_name,
|
||||||
|
"summary": page.summary,
|
||||||
|
"body_sections": [
|
||||||
|
{"heading": heading, "content": content}
|
||||||
|
for heading, content in page.body_sections.items()
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f" Synthesis complete ({elapsed_synth}s). Scoring...")
|
||||||
|
result = self.score_page(page_json, moments)
|
||||||
|
# Include synthesis time in total
|
||||||
|
result.elapsed_seconds = round(result.elapsed_seconds + elapsed_synth, 2)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def print_report(self, result: ScoreResult, stage: int = 5) -> None:
|
||||||
|
"""Print a formatted scoring report to stdout."""
|
||||||
|
dims = STAGE_CONFIGS[stage].dimensions if stage in STAGE_CONFIGS else list(result.scores.keys())
|
||||||
|
stage_label = f"STAGE {stage}" if stage in STAGE_CONFIGS else "QUALITY"
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(f" {stage_label} QUALITY SCORE REPORT")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if result.error:
|
||||||
|
print(f"\n ✗ Error: {result.error}\n")
|
||||||
|
print("=" * 60 + "\n")
|
||||||
|
return
|
||||||
|
|
||||||
|
for dim in dims:
|
||||||
|
score = result.scores.get(dim, 0.0)
|
||||||
|
bar = self._score_bar(score)
|
||||||
|
justification = result.justifications.get(dim, "")
|
||||||
|
print(f"\n {dim.replace('_', ' ').title()}")
|
||||||
|
print(f" Score: {score:.2f} {bar}")
|
||||||
|
if justification:
|
||||||
|
# Wrap justification at ~60 chars
|
||||||
|
for line in self._wrap(justification, 56):
|
||||||
|
print(f" {line}")
|
||||||
|
|
||||||
|
print("\n" + "-" * 60)
|
||||||
|
print(f" Composite: {result.composite:.3f}")
|
||||||
|
print(f" Time: {result.elapsed_seconds}s")
|
||||||
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _score_bar(score: float, width: int = 20) -> str:
|
||||||
|
"""Render a visual bar for a 0-1 score."""
|
||||||
|
filled = int(score * width)
|
||||||
|
return "█" * filled + "░" * (width - filled)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _wrap(text: str, width: int) -> list[str]:
|
||||||
|
"""Simple word wrap."""
|
||||||
|
words = text.split()
|
||||||
|
lines: list[str] = []
|
||||||
|
current = ""
|
||||||
|
for word in words:
|
||||||
|
if current and len(current) + len(word) + 1 > width:
|
||||||
|
lines.append(current)
|
||||||
|
current = word
|
||||||
|
else:
|
||||||
|
current = f"{current} {word}" if current else word
|
||||||
|
if current:
|
||||||
|
lines.append(current)
|
||||||
|
return lines
|
||||||
247
backend/pipeline/quality/variant_generator.py
Normal file
247
backend/pipeline/quality/variant_generator.py
Normal file
|
|
@ -0,0 +1,247 @@
|
||||||
|
"""LLM-powered prompt variant generator for automated optimization.
|
||||||
|
|
||||||
|
Uses a meta-prompt to instruct the LLM to act as a prompt engineer,
|
||||||
|
analyzing per-dimension scores and producing targeted prompt mutations
|
||||||
|
that improve the weakest scoring dimensions while preserving the JSON
|
||||||
|
output format required by downstream parsing.
|
||||||
|
|
||||||
|
Supports any pipeline stage (2-5) — callers pass the stage's dimensions
|
||||||
|
and format markers so the meta-prompt and validation adapt automatically.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from pipeline.llm_client import LLMClient
|
||||||
|
from pipeline.quality.scorer import DIMENSIONS, STAGE_CONFIGS, ScoreResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Meta-prompt for variant generation ────────────────────────────────────────
|
||||||
|
|
||||||
|
VARIANT_META_PROMPT = """\
|
||||||
|
You are an expert prompt engineer specializing in LLM-powered content processing pipelines.
|
||||||
|
|
||||||
|
Your task: given a pipeline stage prompt and its quality evaluation scores, produce an
|
||||||
|
improved variant of the prompt that targets the weakest-scoring dimensions while
|
||||||
|
maintaining or improving the others.
|
||||||
|
|
||||||
|
## Scoring Dimensions (each 0.0–1.0)
|
||||||
|
|
||||||
|
{dimension_descriptions}
|
||||||
|
|
||||||
|
## Rules
|
||||||
|
|
||||||
|
1. Focus your changes on the weakest 1-2 dimensions. Don't dilute the prompt by trying to fix everything.
|
||||||
|
2. Add specific, actionable instructions — not vague encouragements.
|
||||||
|
3. **CRITICAL: You MUST preserve the JSON output format section of the prompt EXACTLY as-is.**
|
||||||
|
The prompt contains instructions about outputting a JSON object with a specific schema.
|
||||||
|
Do NOT modify, remove, or rephrase any part of the JSON format instructions.
|
||||||
|
Your changes should target the processing/analysis guidelines only.
|
||||||
|
4. Keep the overall prompt length within 2x of the original. Don't bloat it.
|
||||||
|
5. Make substantive changes — rewording a sentence or adding one adjective is not enough.
|
||||||
|
|
||||||
|
## Output
|
||||||
|
|
||||||
|
Return ONLY the full modified prompt text. No explanation, no markdown fences, no preamble.
|
||||||
|
Just the complete prompt that could be used directly as a system prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Dimension descriptions per stage, used to fill the meta-prompt template.
|
||||||
|
_DIMENSION_DESCRIPTIONS: dict[int, str] = {
|
||||||
|
2: (
|
||||||
|
"- **coverage_completeness** — All transcript content accounted for, no gaps or overlaps\n"
|
||||||
|
"- **topic_specificity** — Topic labels are descriptive and useful, not generic\n"
|
||||||
|
"- **boundary_accuracy** — Segment boundaries align with actual topic transitions\n"
|
||||||
|
"- **summary_quality** — Summaries accurately describe segment content"
|
||||||
|
),
|
||||||
|
3: (
|
||||||
|
"- **moment_richness** — Extracted moments capture substantial, distinct insights\n"
|
||||||
|
"- **timestamp_accuracy** — Time ranges are plausible and well-bounded\n"
|
||||||
|
"- **content_type_correctness** — Content types match the actual moment content\n"
|
||||||
|
"- **summary_actionability** — Summaries provide actionable, specific information\n"
|
||||||
|
"- **plugin_normalization** — Plugin/tool names are correctly identified and normalized"
|
||||||
|
),
|
||||||
|
4: (
|
||||||
|
"- **category_accuracy** — Topic categories are appropriate and meaningful\n"
|
||||||
|
"- **tag_completeness** — All relevant tags are captured\n"
|
||||||
|
"- **tag_specificity** — Tags are specific enough to be useful for search/filtering\n"
|
||||||
|
"- **coverage** — All moments are classified"
|
||||||
|
),
|
||||||
|
5: (
|
||||||
|
"- **structural** — Section naming, count (3-6), paragraph depth (2-5 per section)\n"
|
||||||
|
"- **content_specificity** — Concrete details: frequencies, time values, ratios, plugin names, dB values\n"
|
||||||
|
"- **voice_preservation** — Direct quotes preserved, opinions attributed to creator by name, personality retained\n"
|
||||||
|
"- **readability** — Cohesive article flow, related info merged, no redundancy or contradiction\n"
|
||||||
|
"- **factual_fidelity** — Every claim traceable to source material, no hallucinated specifics"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Legacy default format markers for stage 5
|
||||||
|
_FORMAT_MARKERS = ["SynthesisResult", '"pages"', "body_sections", "title", "summary"]
|
||||||
|
|
||||||
|
|
||||||
|
class PromptVariantGenerator:
|
||||||
|
"""Generates prompt variants by asking an LLM to act as a prompt engineer.
|
||||||
|
|
||||||
|
Given a base prompt and its evaluation scores, produces N mutated
|
||||||
|
variants targeting the weakest dimensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, client: LLMClient) -> None:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
base_prompt: str,
|
||||||
|
scores: ScoreResult,
|
||||||
|
n: int = 2,
|
||||||
|
*,
|
||||||
|
format_markers: Sequence[str] | None = None,
|
||||||
|
stage: int = 5,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate up to *n* valid prompt variants.
|
||||||
|
|
||||||
|
Each variant is produced by a separate LLM call with the meta-prompt.
|
||||||
|
Variants are validated: they must differ from the base by ≥50 characters
|
||||||
|
and must contain the JSON format instruction markers found in the base.
|
||||||
|
|
||||||
|
Invalid variants are logged and skipped.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
base_prompt:
|
||||||
|
The current best prompt text for the target stage.
|
||||||
|
scores:
|
||||||
|
ScoreResult from the most recent evaluation of *base_prompt*.
|
||||||
|
n:
|
||||||
|
Number of variants to attempt generating.
|
||||||
|
format_markers:
|
||||||
|
Override format markers for validation. When *None*, uses the
|
||||||
|
markers from ``STAGE_CONFIGS[stage]`` (falling back to stage 5
|
||||||
|
defaults for backward compat).
|
||||||
|
stage:
|
||||||
|
Pipeline stage number (2-5), used to select dimension
|
||||||
|
descriptions for the meta-prompt and default format markers.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[str]
|
||||||
|
Valid variant prompt strings (may be fewer than *n*).
|
||||||
|
"""
|
||||||
|
# Resolve format markers and dimensions for the target stage
|
||||||
|
if format_markers is not None:
|
||||||
|
markers = list(format_markers)
|
||||||
|
elif stage in STAGE_CONFIGS:
|
||||||
|
markers = STAGE_CONFIGS[stage].format_markers
|
||||||
|
else:
|
||||||
|
markers = _FORMAT_MARKERS
|
||||||
|
|
||||||
|
dimensions = STAGE_CONFIGS[stage].dimensions if stage in STAGE_CONFIGS else DIMENSIONS
|
||||||
|
|
||||||
|
# Build the system prompt with stage-appropriate dimension descriptions
|
||||||
|
dim_desc = _DIMENSION_DESCRIPTIONS.get(stage, _DIMENSION_DESCRIPTIONS[5])
|
||||||
|
system_prompt = VARIANT_META_PROMPT.format(dimension_descriptions=dim_desc)
|
||||||
|
|
||||||
|
user_prompt = self._build_user_prompt(base_prompt, scores, dimensions)
|
||||||
|
# Identify which format markers are actually present in the base
|
||||||
|
required_markers = [m for m in markers if m in base_prompt]
|
||||||
|
|
||||||
|
variants: list[str] = []
|
||||||
|
for i in range(n):
|
||||||
|
logger.info("Generating variant %d/%d (stage %d)...", i + 1, n, stage)
|
||||||
|
try:
|
||||||
|
raw = self.client.complete(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_prompt=user_prompt,
|
||||||
|
response_model=None, # free-form text, not JSON
|
||||||
|
modality="chat",
|
||||||
|
)
|
||||||
|
variant = str(raw).strip()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("LLM error generating variant %d/%d", i + 1, n)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate the variant
|
||||||
|
if not self._validate(variant, base_prompt, required_markers, i + 1):
|
||||||
|
continue
|
||||||
|
|
||||||
|
variants.append(variant)
|
||||||
|
logger.info("Variant %d/%d accepted (%d chars)", i + 1, n, len(variant))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Generated %d valid variant(s) out of %d attempts", len(variants), n
|
||||||
|
)
|
||||||
|
return variants
|
||||||
|
|
||||||
|
# ── Internal helpers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _build_user_prompt(self, base_prompt: str, scores: ScoreResult, dimensions: list[str] | None = None) -> str:
|
||||||
|
"""Build the user message describing the current prompt and its scores."""
|
||||||
|
dims = dimensions or DIMENSIONS
|
||||||
|
# Build per-dimension score lines, sorted worst-first
|
||||||
|
dim_lines: list[str] = []
|
||||||
|
dim_scores = [(d, scores.scores.get(d, 0.0)) for d in dims]
|
||||||
|
dim_scores.sort(key=lambda x: x[1])
|
||||||
|
|
||||||
|
for dim, val in dim_scores:
|
||||||
|
justification = scores.justifications.get(dim, "")
|
||||||
|
label = dim.replace("_", " ").title()
|
||||||
|
line = f" {label}: {val:.2f}"
|
||||||
|
if justification:
|
||||||
|
line += f" — {justification}"
|
||||||
|
dim_lines.append(line)
|
||||||
|
|
||||||
|
weakest = dim_scores[0][0].replace("_", " ").title()
|
||||||
|
second_weakest = dim_scores[1][0].replace("_", " ").title() if len(dim_scores) > 1 else weakest
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"## Current Prompt\n\n{base_prompt}\n\n"
|
||||||
|
f"## Evaluation Scores (sorted weakest → strongest)\n\n"
|
||||||
|
+ "\n".join(dim_lines)
|
||||||
|
+ f"\n\n Composite: {scores.composite:.3f}\n\n"
|
||||||
|
f"## Priority\n\n"
|
||||||
|
f"The weakest dimensions are **{weakest}** and **{second_weakest}**. "
|
||||||
|
f"Focus your prompt modifications on improving these.\n\n"
|
||||||
|
f"Return the full modified prompt now."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate(
|
||||||
|
self,
|
||||||
|
variant: str,
|
||||||
|
base_prompt: str,
|
||||||
|
required_markers: list[str],
|
||||||
|
index: int,
|
||||||
|
) -> bool:
|
||||||
|
"""Check a variant meets minimum quality gates."""
|
||||||
|
if not variant:
|
||||||
|
logger.warning("Variant %d is empty — skipping", index)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Must differ meaningfully from base
|
||||||
|
diff = abs(len(variant) - len(base_prompt))
|
||||||
|
# Also check actual content difference via set-symmetric-difference of lines
|
||||||
|
base_lines = set(base_prompt.splitlines())
|
||||||
|
variant_lines = set(variant.splitlines())
|
||||||
|
changed_lines = len(base_lines.symmetric_difference(variant_lines))
|
||||||
|
|
||||||
|
if diff < 50 and changed_lines < 3:
|
||||||
|
logger.warning(
|
||||||
|
"Variant %d too similar to base (len_diff=%d, changed_lines=%d) — skipping",
|
||||||
|
index, diff, changed_lines,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Must preserve format markers
|
||||||
|
missing = [m for m in required_markers if m not in variant]
|
||||||
|
if missing:
|
||||||
|
logger.warning(
|
||||||
|
"Variant %d missing format markers %s — skipping",
|
||||||
|
index, missing,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
91
backend/pipeline/quality/voice_dial.py
Normal file
91
backend/pipeline/quality/voice_dial.py
Normal file
|
|
@ -0,0 +1,91 @@
|
||||||
|
"""Voice preservation dial — modifies Stage 5 synthesis prompt by intensity band.
|
||||||
|
|
||||||
|
Three bands control how much of the creator's original voice is preserved:
|
||||||
|
- Low (0.0–0.33): Clinical, encyclopedic tone — suppress direct quotes
|
||||||
|
- Mid (0.34–0.66): Base prompt unchanged (already ~0.6 voice preservation)
|
||||||
|
- High (0.67–1.0): Maximum voice — prioritize exact words, strong opinions
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
# ── Band modifier text ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_LOW_BAND_MODIFIER = """
|
||||||
|
|
||||||
|
## Voice Suppression Override
|
||||||
|
|
||||||
|
IMPORTANT — override the voice/tone guidelines above. For this synthesis:
|
||||||
|
|
||||||
|
- Do NOT include any direct quotes from the creator. Rephrase all insights in neutral third-person encyclopedic style.
|
||||||
|
- Do NOT attribute opinions or preferences to the creator by name (avoid "he recommends", "she prefers").
|
||||||
|
- Remove all personality markers, humor, strong opinions, and conversational tone.
|
||||||
|
- Write as a reference manual: factual, impersonal, technically precise.
|
||||||
|
- Replace phrases like "he warns against" with neutral statements like "this approach is generally avoided because."
|
||||||
|
- Suppress colloquialisms and informal language entirely.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_HIGH_BAND_MODIFIER = """
|
||||||
|
|
||||||
|
## Maximum Voice Preservation Override
|
||||||
|
|
||||||
|
IMPORTANT — amplify the voice/tone guidelines above. For this synthesis:
|
||||||
|
|
||||||
|
- Maximize the use of direct quotes from the transcript. Every memorable phrase, vivid metaphor, or strong opinion should be quoted verbatim with quotation marks.
|
||||||
|
- Attribute all insights, preferences, and techniques to the creator by name — use their name frequently.
|
||||||
|
- Preserve personality, humor, strong opinions, and conversational tone. If the creator is emphatic, the prose should feel emphatic.
|
||||||
|
- Prioritize the creator's exact words over paraphrase. When a transcript excerpt contains a usable phrase, quote it rather than summarizing it.
|
||||||
|
- Include warnings, caveats, and opinionated asides in the creator's own voice.
|
||||||
|
- The resulting page should feel like the creator is speaking directly to the reader through the text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ── VoiceDial class ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class VoiceDial:
|
||||||
|
"""Modifies a Stage 5 synthesis prompt based on a voice_level parameter.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
base_prompt:
|
||||||
|
The original stage5_synthesis.txt system prompt content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Band boundaries
|
||||||
|
LOW_UPPER = 0.33
|
||||||
|
HIGH_LOWER = 0.67
|
||||||
|
|
||||||
|
def __init__(self, base_prompt: str) -> None:
|
||||||
|
self.base_prompt = base_prompt
|
||||||
|
|
||||||
|
def modify(self, voice_level: float) -> str:
|
||||||
|
"""Return the system prompt modified for the given voice_level.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
voice_level:
|
||||||
|
Float 0.0–1.0. Values outside this range are clamped.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
Modified system prompt with band-appropriate instructions appended.
|
||||||
|
"""
|
||||||
|
voice_level = max(0.0, min(1.0, voice_level))
|
||||||
|
|
||||||
|
if voice_level <= self.LOW_UPPER:
|
||||||
|
return self.base_prompt + _LOW_BAND_MODIFIER
|
||||||
|
elif voice_level >= self.HIGH_LOWER:
|
||||||
|
return self.base_prompt + _HIGH_BAND_MODIFIER
|
||||||
|
else:
|
||||||
|
# Mid band — base prompt is already moderate voice preservation
|
||||||
|
return self.base_prompt
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def band_name(voice_level: float) -> str:
|
||||||
|
"""Return the human-readable band name for a voice_level value."""
|
||||||
|
voice_level = max(0.0, min(1.0, voice_level))
|
||||||
|
if voice_level <= VoiceDial.LOW_UPPER:
|
||||||
|
return "low"
|
||||||
|
elif voice_level >= VoiceDial.HIGH_LOWER:
|
||||||
|
return "high"
|
||||||
|
return "mid"
|
||||||
125
backend/pipeline/schemas.py
Normal file
125
backend/pipeline/schemas.py
Normal file
|
|
@ -0,0 +1,125 @@
|
||||||
|
"""Pydantic schemas for pipeline stage inputs and outputs.
|
||||||
|
|
||||||
|
Stage 2 — Segmentation: groups transcript segments by topic.
|
||||||
|
Stage 3 — Extraction: extracts key moments from segments.
|
||||||
|
Stage 4 — Classification: classifies moments by category/tags.
|
||||||
|
Stage 5 — Synthesis: generates technique pages from classified moments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stage 2: Segmentation ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TopicSegment(BaseModel):
|
||||||
|
"""A contiguous group of transcript segments sharing a topic."""
|
||||||
|
|
||||||
|
start_index: int = Field(description="First transcript segment index in this group")
|
||||||
|
end_index: int = Field(description="Last transcript segment index in this group (inclusive)")
|
||||||
|
topic_label: str = Field(description="Short label describing the topic")
|
||||||
|
summary: str = Field(description="Brief summary of what is discussed")
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentationResult(BaseModel):
|
||||||
|
"""Full output of stage 2 (segmentation)."""
|
||||||
|
|
||||||
|
segments: list[TopicSegment]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stage 3: Extraction ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ExtractedMoment(BaseModel):
|
||||||
|
"""A single key moment extracted from a topic segment group."""
|
||||||
|
|
||||||
|
title: str = Field(description="Concise title for the moment")
|
||||||
|
summary: str = Field(description="Detailed summary of the technique/concept")
|
||||||
|
start_time: float = Field(description="Start time in seconds")
|
||||||
|
end_time: float = Field(description="End time in seconds")
|
||||||
|
content_type: str = Field(description="One of: technique, settings, reasoning, workflow")
|
||||||
|
plugins: list[str] = Field(default_factory=list, description="Plugins/tools mentioned")
|
||||||
|
raw_transcript: str = Field(default="", description="Raw transcript text for this moment")
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionResult(BaseModel):
|
||||||
|
"""Full output of stage 3 (extraction)."""
|
||||||
|
|
||||||
|
moments: list[ExtractedMoment]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stage 4: Classification ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ClassifiedMoment(BaseModel):
|
||||||
|
"""Classification metadata for a single extracted moment."""
|
||||||
|
|
||||||
|
moment_index: int = Field(description="Index into ExtractionResult.moments")
|
||||||
|
topic_category: str = Field(description="High-level topic category")
|
||||||
|
topic_tags: list[str] = Field(default_factory=list, description="Specific topic tags")
|
||||||
|
content_type_override: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Override for content_type if classification disagrees with extraction",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationResult(BaseModel):
|
||||||
|
"""Full output of stage 4 (classification)."""
|
||||||
|
|
||||||
|
classifications: list[ClassifiedMoment]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stage 5: Synthesis ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class BodySubSection(BaseModel):
|
||||||
|
"""An H3-level subsection within a body section."""
|
||||||
|
|
||||||
|
heading: str = Field(description="H3 subsection heading")
|
||||||
|
content: str = Field(description="Subsection body text (may contain [N] citation markers)")
|
||||||
|
|
||||||
|
|
||||||
|
class BodySection(BaseModel):
|
||||||
|
"""An H2-level section of a technique page body."""
|
||||||
|
|
||||||
|
heading: str = Field(description="H2 section heading")
|
||||||
|
content: str = Field(description="Section body text (may contain [N] citation markers)")
|
||||||
|
subsections: list[BodySubSection] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Optional H3-level subsections",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SynthesizedPage(BaseModel):
|
||||||
|
"""A technique page synthesized from classified moments."""
|
||||||
|
|
||||||
|
title: str = Field(description="Page title")
|
||||||
|
slug: str = Field(description="URL-safe slug")
|
||||||
|
topic_category: str = Field(description="Primary topic category")
|
||||||
|
topic_tags: list[str] = Field(default_factory=list, description="Associated tags")
|
||||||
|
summary: str = Field(description="Page summary / overview paragraph")
|
||||||
|
body_sections: list[BodySection] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Structured body content as H2 sections with optional H3 subsections",
|
||||||
|
)
|
||||||
|
body_sections_format: str = Field(
|
||||||
|
default="v2",
|
||||||
|
description="Schema version for body_sections ('v2' = list[BodySection])",
|
||||||
|
)
|
||||||
|
signal_chains: list[dict] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Signal chain descriptions (for audio/music production contexts)",
|
||||||
|
)
|
||||||
|
plugins: list[str] = Field(default_factory=list, description="Plugins/tools referenced")
|
||||||
|
source_quality: str = Field(
|
||||||
|
default="mixed",
|
||||||
|
description="One of: structured, mixed, unstructured",
|
||||||
|
)
|
||||||
|
moment_indices: list[int] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Indices of source moments (from the input list) that this page covers",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SynthesisResult(BaseModel):
|
||||||
|
"""Full output of stage 5 (synthesis)."""
|
||||||
|
|
||||||
|
pages: list[SynthesizedPage]
|
||||||
222
backend/pipeline/shorts_generator.py
Normal file
222
backend/pipeline/shorts_generator.py
Normal file
|
|
@ -0,0 +1,222 @@
|
||||||
|
"""FFmpeg clip extraction with format presets for shorts generation.
|
||||||
|
|
||||||
|
Pure functions — no DB access, no Celery dependency. Tested independently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from models import FormatPreset
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FFMPEG_TIMEOUT_SECS = 300
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PresetSpec:
|
||||||
|
"""Resolution and ffmpeg video filter for a format preset."""
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
vf_filter: str
|
||||||
|
|
||||||
|
|
||||||
|
PRESETS: dict[FormatPreset, PresetSpec] = {
|
||||||
|
FormatPreset.vertical: PresetSpec(
|
||||||
|
width=1080,
|
||||||
|
height=1920,
|
||||||
|
vf_filter="scale=1080:-2,pad=1080:1920:(ow-iw)/2:(oh-ih)/2:black",
|
||||||
|
),
|
||||||
|
FormatPreset.square: PresetSpec(
|
||||||
|
width=1080,
|
||||||
|
height=1080,
|
||||||
|
vf_filter="crop=min(iw\\,ih):min(iw\\,ih),scale=1080:1080",
|
||||||
|
),
|
||||||
|
FormatPreset.horizontal: PresetSpec(
|
||||||
|
width=1920,
|
||||||
|
height=1080,
|
||||||
|
vf_filter="scale=1920:1080:force_original_aspect_ratio=decrease,pad=1920:1080:(ow-iw)/2:(oh-ih)/2:black",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_video_path(video_source_root: str, file_path: str) -> Path:
|
||||||
|
"""Join root + relative path and validate the file exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_source_root: Base directory for video files (e.g. /videos).
|
||||||
|
file_path: Relative path stored in SourceVideo.file_path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved absolute Path.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the resolved path doesn't exist or isn't a file.
|
||||||
|
"""
|
||||||
|
resolved = Path(video_source_root) / file_path
|
||||||
|
if not resolved.is_file():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Video file not found: {resolved} "
|
||||||
|
f"(root={video_source_root!r}, relative={file_path!r})"
|
||||||
|
)
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
|
def extract_clip(
|
||||||
|
input_path: Path | str,
|
||||||
|
output_path: Path | str,
|
||||||
|
start_secs: float,
|
||||||
|
end_secs: float,
|
||||||
|
vf_filter: str,
|
||||||
|
ass_path: Path | str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Extract a clip from a video file using ffmpeg.
|
||||||
|
|
||||||
|
Seeks to *start_secs*, encodes until *end_secs*, and applies *vf_filter*.
|
||||||
|
Uses ``-c:v libx264 -preset fast -crf 23`` for reasonable quality/speed.
|
||||||
|
|
||||||
|
When *ass_path* is provided, the ASS subtitle filter is appended to the
|
||||||
|
video filter chain so that captions are burned into the output video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: Source video file.
|
||||||
|
output_path: Destination mp4 file (parent dir must exist).
|
||||||
|
start_secs: Start time in seconds.
|
||||||
|
end_secs: End time in seconds.
|
||||||
|
vf_filter: ffmpeg ``-vf`` filter string.
|
||||||
|
ass_path: Optional path to an ASS subtitle file. When provided,
|
||||||
|
``ass=<path>`` is appended to the filter chain.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
subprocess.CalledProcessError: If ffmpeg exits non-zero.
|
||||||
|
subprocess.TimeoutExpired: If ffmpeg exceeds the timeout.
|
||||||
|
ValueError: If start >= end.
|
||||||
|
"""
|
||||||
|
duration = end_secs - start_secs
|
||||||
|
if duration <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid clip range: start={start_secs}s end={end_secs}s "
|
||||||
|
f"(duration={duration}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build the video filter chain — ASS burn-in comes after scale/pad
|
||||||
|
effective_vf = vf_filter
|
||||||
|
if ass_path is not None:
|
||||||
|
# Escape colons and backslashes in the path for ffmpeg filter syntax
|
||||||
|
escaped = str(ass_path).replace("\\", "\\\\").replace(":", "\\:")
|
||||||
|
effective_vf = f"{vf_filter},ass={escaped}"
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-y", # overwrite output
|
||||||
|
"-ss", str(start_secs), # seek before input (fast)
|
||||||
|
"-i", str(input_path),
|
||||||
|
"-t", str(duration),
|
||||||
|
"-vf", effective_vf,
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", "fast",
|
||||||
|
"-crf", "23",
|
||||||
|
"-c:a", "aac",
|
||||||
|
"-b:a", "128k",
|
||||||
|
"-movflags", "+faststart", # web-friendly mp4
|
||||||
|
str(output_path),
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"ffmpeg: extracting %.1fs clip from %s → %s",
|
||||||
|
duration, input_path, output_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
timeout=FFMPEG_TIMEOUT_SECS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
stderr_text = result.stderr.decode("utf-8", errors="replace")[-2000:]
|
||||||
|
logger.error("ffmpeg failed (rc=%d): %s", result.returncode, stderr_text)
|
||||||
|
raise subprocess.CalledProcessError(
|
||||||
|
result.returncode, cmd, output=result.stdout, stderr=result.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_clip_with_template(
|
||||||
|
input_path: Path | str,
|
||||||
|
output_path: Path | str,
|
||||||
|
start_secs: float,
|
||||||
|
end_secs: float,
|
||||||
|
vf_filter: str,
|
||||||
|
ass_path: Path | str | None = None,
|
||||||
|
intro_path: Path | str | None = None,
|
||||||
|
outro_path: Path | str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Extract a clip and optionally prepend/append intro/outro cards.
|
||||||
|
|
||||||
|
If neither intro nor outro is provided, delegates directly to
|
||||||
|
:func:`extract_clip`. When cards are provided, the main clip is
|
||||||
|
extracted to a temp file, then all segments are concatenated via
|
||||||
|
:func:`~pipeline.card_renderer.concat_segments`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: Source video file.
|
||||||
|
output_path: Final destination mp4 file.
|
||||||
|
start_secs: Start time in seconds.
|
||||||
|
end_secs: End time in seconds.
|
||||||
|
vf_filter: ffmpeg ``-vf`` filter string.
|
||||||
|
ass_path: Optional ASS subtitle file path.
|
||||||
|
intro_path: Optional intro card mp4 path.
|
||||||
|
outro_path: Optional outro card mp4 path.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
subprocess.CalledProcessError: If any ffmpeg command fails.
|
||||||
|
ValueError: If clip range is invalid.
|
||||||
|
"""
|
||||||
|
has_cards = intro_path is not None or outro_path is not None
|
||||||
|
|
||||||
|
if not has_cards:
|
||||||
|
# No template cards — simple extraction
|
||||||
|
extract_clip(
|
||||||
|
input_path=input_path,
|
||||||
|
output_path=output_path,
|
||||||
|
start_secs=start_secs,
|
||||||
|
end_secs=end_secs,
|
||||||
|
vf_filter=vf_filter,
|
||||||
|
ass_path=ass_path,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract main clip to a temp file for concatenation
|
||||||
|
main_clip_path = Path(str(output_path) + ".main.mp4")
|
||||||
|
try:
|
||||||
|
extract_clip(
|
||||||
|
input_path=input_path,
|
||||||
|
output_path=main_clip_path,
|
||||||
|
start_secs=start_secs,
|
||||||
|
end_secs=end_secs,
|
||||||
|
vf_filter=vf_filter,
|
||||||
|
ass_path=ass_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build segment list in order: intro → main → outro
|
||||||
|
segments: list[Path] = []
|
||||||
|
if intro_path is not None:
|
||||||
|
segments.append(Path(intro_path))
|
||||||
|
segments.append(main_clip_path)
|
||||||
|
if outro_path is not None:
|
||||||
|
segments.append(Path(outro_path))
|
||||||
|
|
||||||
|
from pipeline.card_renderer import concat_segments
|
||||||
|
concat_segments(segments=segments, output_path=Path(output_path))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up temp main clip
|
||||||
|
if main_clip_path.exists():
|
||||||
|
try:
|
||||||
|
main_clip_path.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
3194
backend/pipeline/stages.py
Normal file
3194
backend/pipeline/stages.py
Normal file
File diff suppressed because it is too large
Load diff
159
backend/pipeline/test_caption_generator.py
Normal file
159
backend/pipeline/test_caption_generator.py
Normal file
|
|
@ -0,0 +1,159 @@
|
||||||
|
"""Unit tests for caption_generator module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pipeline.caption_generator import (
|
||||||
|
DEFAULT_STYLE,
|
||||||
|
_format_ass_time,
|
||||||
|
generate_ass_captions,
|
||||||
|
write_ass_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_word_timings() -> list[dict]:
|
||||||
|
"""Realistic word timings as produced by extract_word_timings."""
|
||||||
|
return [
|
||||||
|
{"word": "This", "start": 10.0, "end": 10.3},
|
||||||
|
{"word": "is", "start": 10.3, "end": 10.5},
|
||||||
|
{"word": "a", "start": 10.5, "end": 10.6},
|
||||||
|
{"word": "test", "start": 10.6, "end": 11.0},
|
||||||
|
{"word": "sentence", "start": 11.1, "end": 11.6},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Time formatting ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestFormatAssTime:
|
||||||
|
def test_zero(self):
|
||||||
|
assert _format_ass_time(0.0) == "0:00:00.00"
|
||||||
|
|
||||||
|
def test_sub_second(self):
|
||||||
|
assert _format_ass_time(0.5) == "0:00:00.50"
|
||||||
|
|
||||||
|
def test_minutes(self):
|
||||||
|
assert _format_ass_time(65.5) == "0:01:05.50"
|
||||||
|
|
||||||
|
def test_hours(self):
|
||||||
|
assert _format_ass_time(3661.25) == "1:01:01.25"
|
||||||
|
|
||||||
|
def test_negative_clamps_to_zero(self):
|
||||||
|
assert _format_ass_time(-5.0) == "0:00:00.00"
|
||||||
|
|
||||||
|
|
||||||
|
# ── ASS generation ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestGenerateAssCaptions:
|
||||||
|
def test_empty_timings_returns_header_only(self):
|
||||||
|
result = generate_ass_captions([], clip_start=0.0)
|
||||||
|
assert "[Script Info]" in result
|
||||||
|
assert "[Events]" in result
|
||||||
|
# No Dialogue lines
|
||||||
|
assert "Dialogue:" not in result
|
||||||
|
|
||||||
|
def test_structure_has_required_sections(self, sample_word_timings):
|
||||||
|
result = generate_ass_captions(sample_word_timings, clip_start=10.0)
|
||||||
|
assert "[Script Info]" in result
|
||||||
|
assert "[V4+ Styles]" in result
|
||||||
|
assert "[Events]" in result
|
||||||
|
assert "Dialogue:" in result
|
||||||
|
|
||||||
|
def test_clip_offset_applied(self, sample_word_timings):
|
||||||
|
"""Word at t=10.5 with clip_start=10.0 should become t=0.5 in ASS."""
|
||||||
|
result = generate_ass_captions(sample_word_timings, clip_start=10.0)
|
||||||
|
lines = result.strip().split("\n")
|
||||||
|
dialogue_lines = [l for l in lines if l.startswith("Dialogue:")]
|
||||||
|
|
||||||
|
# First word "This" starts at 10.0, clip_start=10.0 → relative 0.0
|
||||||
|
assert dialogue_lines[0].startswith("Dialogue: 0,0:00:00.00,")
|
||||||
|
|
||||||
|
# Third word "a" starts at 10.5, clip_start=10.0 → relative 0.5
|
||||||
|
assert "0:00:00.50" in dialogue_lines[2]
|
||||||
|
|
||||||
|
def test_karaoke_tags_present(self, sample_word_timings):
|
||||||
|
result = generate_ass_captions(sample_word_timings, clip_start=10.0)
|
||||||
|
lines = result.strip().split("\n")
|
||||||
|
dialogue_lines = [l for l in lines if l.startswith("Dialogue:")]
|
||||||
|
|
||||||
|
for line in dialogue_lines:
|
||||||
|
# Each line should have a \kN tag
|
||||||
|
assert re.search(r"\{\\k\d+\}", line), f"Missing karaoke tag in: {line}"
|
||||||
|
|
||||||
|
def test_karaoke_duration_math(self, sample_word_timings):
|
||||||
|
"""Word "This" at [10.0, 10.3] → 0.3s → k30 (30 centiseconds)."""
|
||||||
|
result = generate_ass_captions(sample_word_timings, clip_start=10.0)
|
||||||
|
lines = result.strip().split("\n")
|
||||||
|
dialogue_lines = [l for l in lines if l.startswith("Dialogue:")]
|
||||||
|
|
||||||
|
# "This" duration: 10.3 - 10.0 = 0.3s = 30cs
|
||||||
|
assert "{\\k30}This" in dialogue_lines[0]
|
||||||
|
|
||||||
|
# "test" duration: 11.0 - 10.6 = 0.4s = 40cs
|
||||||
|
assert "{\\k40}test" in dialogue_lines[3]
|
||||||
|
|
||||||
|
def test_word_count_matches(self, sample_word_timings):
|
||||||
|
result = generate_ass_captions(sample_word_timings, clip_start=10.0)
|
||||||
|
lines = result.strip().split("\n")
|
||||||
|
dialogue_lines = [l for l in lines if l.startswith("Dialogue:")]
|
||||||
|
assert len(dialogue_lines) == 5
|
||||||
|
|
||||||
|
def test_empty_word_text_skipped(self):
|
||||||
|
timings = [
|
||||||
|
{"word": "hello", "start": 0.0, "end": 0.5},
|
||||||
|
{"word": " ", "start": 0.5, "end": 0.7}, # whitespace-only
|
||||||
|
{"word": "", "start": 0.7, "end": 0.8}, # empty
|
||||||
|
{"word": "world", "start": 0.8, "end": 1.2},
|
||||||
|
]
|
||||||
|
result = generate_ass_captions(timings, clip_start=0.0)
|
||||||
|
lines = result.strip().split("\n")
|
||||||
|
dialogue_lines = [l for l in lines if l.startswith("Dialogue:")]
|
||||||
|
assert len(dialogue_lines) == 2 # only "hello" and "world"
|
||||||
|
|
||||||
|
def test_custom_style_overrides(self, sample_word_timings):
|
||||||
|
result = generate_ass_captions(
|
||||||
|
sample_word_timings,
|
||||||
|
clip_start=10.0,
|
||||||
|
style_config={"font_size": 72, "font_name": "Roboto"},
|
||||||
|
)
|
||||||
|
assert "Roboto" in result
|
||||||
|
assert ",72," in result
|
||||||
|
|
||||||
|
def test_negative_relative_time_clamped(self):
|
||||||
|
"""Words before clip_start should clamp to 0."""
|
||||||
|
timings = [{"word": "early", "start": 5.0, "end": 5.5}]
|
||||||
|
result = generate_ass_captions(timings, clip_start=10.0)
|
||||||
|
lines = [l for l in result.strip().split("\n") if l.startswith("Dialogue:")]
|
||||||
|
# Both start and end clamped to 0
|
||||||
|
assert lines[0].startswith("Dialogue: 0,0:00:00.00,0:00:00.00,")
|
||||||
|
|
||||||
|
|
||||||
|
# ── File writing ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestWriteAssFile:
|
||||||
|
def test_writes_content(self):
|
||||||
|
content = "[Script Info]\ntest content\n"
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
out = write_ass_file(content, Path(td) / "sub.ass")
|
||||||
|
assert out.exists()
|
||||||
|
assert out.read_text() == content
|
||||||
|
|
||||||
|
def test_creates_parent_dirs(self):
|
||||||
|
content = "test"
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
out = write_ass_file(content, Path(td) / "nested" / "deep" / "sub.ass")
|
||||||
|
assert out.exists()
|
||||||
|
|
||||||
|
def test_returns_path(self):
|
||||||
|
content = "test"
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
target = Path(td) / "sub.ass"
|
||||||
|
result = write_ass_file(content, target)
|
||||||
|
assert result == target
|
||||||
365
backend/pipeline/test_card_renderer.py
Normal file
365
backend/pipeline/test_card_renderer.py
Normal file
|
|
@ -0,0 +1,365 @@
|
||||||
|
"""Tests for card_renderer: ffmpeg card generation and concat pipeline.
|
||||||
|
|
||||||
|
Tests verify command construction, concat list file format, and template
|
||||||
|
config parsing — no actual ffmpeg execution required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import textwrap
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pipeline.card_renderer import (
|
||||||
|
DEFAULT_ACCENT_COLOR,
|
||||||
|
DEFAULT_FONT_FAMILY,
|
||||||
|
DEFAULT_INTRO_DURATION,
|
||||||
|
DEFAULT_OUTRO_DURATION,
|
||||||
|
build_concat_list,
|
||||||
|
concat_segments,
|
||||||
|
parse_template_config,
|
||||||
|
render_card,
|
||||||
|
render_card_to_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── render_card tests ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRenderCard:
|
||||||
|
"""Tests for render_card() ffmpeg command generation."""
|
||||||
|
|
||||||
|
def test_returns_list_of_strings(self):
|
||||||
|
cmd = render_card("Hello", 2.0, 1080, 1920)
|
||||||
|
assert isinstance(cmd, list)
|
||||||
|
assert all(isinstance(s, str) for s in cmd)
|
||||||
|
|
||||||
|
def test_contains_ffmpeg_and_lavfi(self):
|
||||||
|
cmd = render_card("Test", 3.0, 1080, 1920)
|
||||||
|
assert cmd[0] == "ffmpeg"
|
||||||
|
assert "-f" in cmd
|
||||||
|
lavfi_idx = cmd.index("-f")
|
||||||
|
assert cmd[lavfi_idx + 1] == "lavfi"
|
||||||
|
|
||||||
|
def test_contains_correct_dimensions_in_filtergraph(self):
|
||||||
|
cmd = render_card("Test", 2.0, 1920, 1080)
|
||||||
|
# The filtergraph is the arg after -i for lavfi
|
||||||
|
filtergraph = None
|
||||||
|
for i, arg in enumerate(cmd):
|
||||||
|
if arg == "-i" and i > 0 and cmd[i - 1] == "lavfi":
|
||||||
|
filtergraph = cmd[i + 1]
|
||||||
|
break
|
||||||
|
assert filtergraph is not None
|
||||||
|
assert "s=1920x1080" in filtergraph
|
||||||
|
|
||||||
|
def test_contains_duration_in_filtergraph(self):
|
||||||
|
cmd = render_card("Test", 5.5, 1080, 1920)
|
||||||
|
filtergraph = None
|
||||||
|
for i, arg in enumerate(cmd):
|
||||||
|
if arg == "-i" and i > 0 and cmd[i - 1] == "lavfi":
|
||||||
|
filtergraph = cmd[i + 1]
|
||||||
|
break
|
||||||
|
assert "d=5.5" in filtergraph
|
||||||
|
|
||||||
|
def test_contains_drawtext_with_text(self):
|
||||||
|
cmd = render_card("My Creator", 2.0, 1080, 1920)
|
||||||
|
filtergraph = None
|
||||||
|
for i, arg in enumerate(cmd):
|
||||||
|
if arg == "-i" and i > 0 and cmd[i - 1] == "lavfi":
|
||||||
|
filtergraph = cmd[i + 1]
|
||||||
|
break
|
||||||
|
assert "drawtext=" in filtergraph
|
||||||
|
assert "My Creator" in filtergraph
|
||||||
|
|
||||||
|
def test_codec_settings(self):
|
||||||
|
cmd = render_card("Test", 2.0, 1080, 1920)
|
||||||
|
assert "-c:v" in cmd
|
||||||
|
assert "libx264" in cmd
|
||||||
|
assert "-c:a" in cmd
|
||||||
|
assert "aac" in cmd
|
||||||
|
|
||||||
|
def test_silent_audio_track(self):
|
||||||
|
"""Card includes anullsrc so concat with audio segments works."""
|
||||||
|
cmd = render_card("Test", 2.0, 1080, 1920)
|
||||||
|
# Should have a second -f lavfi -i anullsrc input
|
||||||
|
cmd_str = " ".join(cmd)
|
||||||
|
assert "anullsrc" in cmd_str
|
||||||
|
|
||||||
|
def test_rejects_zero_duration(self):
|
||||||
|
with pytest.raises(ValueError, match="positive"):
|
||||||
|
render_card("Test", 0, 1080, 1920)
|
||||||
|
|
||||||
|
def test_rejects_negative_duration(self):
|
||||||
|
with pytest.raises(ValueError, match="positive"):
|
||||||
|
render_card("Test", -1.0, 1080, 1920)
|
||||||
|
|
||||||
|
def test_rejects_zero_dimensions(self):
|
||||||
|
with pytest.raises(ValueError, match="positive"):
|
||||||
|
render_card("Test", 2.0, 0, 1920)
|
||||||
|
|
||||||
|
def test_custom_accent_color(self):
|
||||||
|
cmd = render_card("Test", 2.0, 1080, 1920, accent_color="#ff0000")
|
||||||
|
filtergraph = None
|
||||||
|
for i, arg in enumerate(cmd):
|
||||||
|
if arg == "-i" and i > 0 and cmd[i - 1] == "lavfi":
|
||||||
|
filtergraph = cmd[i + 1]
|
||||||
|
break
|
||||||
|
assert "#ff0000" in filtergraph
|
||||||
|
|
||||||
|
def test_escapes_colons_in_text(self):
|
||||||
|
cmd = render_card("Hello: World", 2.0, 1080, 1920)
|
||||||
|
filtergraph = None
|
||||||
|
for i, arg in enumerate(cmd):
|
||||||
|
if arg == "-i" and i > 0 and cmd[i - 1] == "lavfi":
|
||||||
|
filtergraph = cmd[i + 1]
|
||||||
|
break
|
||||||
|
# Colons should be escaped for ffmpeg
|
||||||
|
assert "Hello\\: World" in filtergraph
|
||||||
|
|
||||||
|
|
||||||
|
# ── render_card_to_file tests ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRenderCardToFile:
|
||||||
|
"""Tests for render_card_to_file() — mocked ffmpeg execution."""
|
||||||
|
|
||||||
|
@patch("pipeline.card_renderer.subprocess.run")
|
||||||
|
def test_calls_ffmpeg_and_returns_path(self, mock_run, tmp_path):
|
||||||
|
mock_run.return_value = MagicMock(returncode=0)
|
||||||
|
out = tmp_path / "card.mp4"
|
||||||
|
out.write_bytes(b"fake") # stat().st_size needs the file
|
||||||
|
|
||||||
|
result = render_card_to_file("Hi", 2.0, 1080, 1920, out)
|
||||||
|
assert result == out
|
||||||
|
mock_run.assert_called_once()
|
||||||
|
# Output path should be the last arg
|
||||||
|
call_args = mock_run.call_args[0][0]
|
||||||
|
assert call_args[-1] == str(out)
|
||||||
|
|
||||||
|
@patch("pipeline.card_renderer.subprocess.run")
|
||||||
|
def test_raises_on_ffmpeg_failure(self, mock_run, tmp_path):
|
||||||
|
mock_run.return_value = MagicMock(
|
||||||
|
returncode=1,
|
||||||
|
stderr=b"error: something failed",
|
||||||
|
)
|
||||||
|
out = tmp_path / "card.mp4"
|
||||||
|
with pytest.raises(subprocess.CalledProcessError):
|
||||||
|
render_card_to_file("Hi", 2.0, 1080, 1920, out)
|
||||||
|
|
||||||
|
|
||||||
|
# ── build_concat_list tests ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestBuildConcatList:
|
||||||
|
"""Tests for build_concat_list() file content."""
|
||||||
|
|
||||||
|
def test_writes_correct_format(self, tmp_path):
|
||||||
|
seg1 = tmp_path / "intro.mp4"
|
||||||
|
seg2 = tmp_path / "main.mp4"
|
||||||
|
seg1.touch()
|
||||||
|
seg2.touch()
|
||||||
|
|
||||||
|
list_file = tmp_path / "concat.txt"
|
||||||
|
result = build_concat_list([seg1, seg2], list_file)
|
||||||
|
|
||||||
|
assert result == list_file
|
||||||
|
content = list_file.read_text()
|
||||||
|
lines = content.strip().split("\n")
|
||||||
|
assert len(lines) == 2
|
||||||
|
assert lines[0] == f"file '{seg1.resolve()}'"
|
||||||
|
assert lines[1] == f"file '{seg2.resolve()}'"
|
||||||
|
|
||||||
|
def test_three_segments(self, tmp_path):
|
||||||
|
segs = [tmp_path / f"seg{i}.mp4" for i in range(3)]
|
||||||
|
for s in segs:
|
||||||
|
s.touch()
|
||||||
|
|
||||||
|
list_file = tmp_path / "list.txt"
|
||||||
|
build_concat_list(segs, list_file)
|
||||||
|
|
||||||
|
content = list_file.read_text()
|
||||||
|
lines = content.strip().split("\n")
|
||||||
|
assert len(lines) == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ── concat_segments tests ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestConcatSegments:
|
||||||
|
"""Tests for concat_segments() — mocked ffmpeg execution."""
|
||||||
|
|
||||||
|
@patch("pipeline.card_renderer.subprocess.run")
|
||||||
|
def test_calls_ffmpeg_concat_demuxer(self, mock_run, tmp_path):
|
||||||
|
mock_run.return_value = MagicMock(returncode=0)
|
||||||
|
seg1 = tmp_path / "seg1.mp4"
|
||||||
|
seg2 = tmp_path / "seg2.mp4"
|
||||||
|
seg1.touch()
|
||||||
|
seg2.touch()
|
||||||
|
out = tmp_path / "output.mp4"
|
||||||
|
out.write_bytes(b"fakemp4")
|
||||||
|
|
||||||
|
result = concat_segments([seg1, seg2], out)
|
||||||
|
assert result == out
|
||||||
|
mock_run.assert_called_once()
|
||||||
|
|
||||||
|
call_args = mock_run.call_args[0][0]
|
||||||
|
assert "concat" in call_args
|
||||||
|
assert "-safe" in call_args
|
||||||
|
assert "0" in call_args
|
||||||
|
assert "-c" in call_args
|
||||||
|
assert "copy" in call_args
|
||||||
|
|
||||||
|
def test_rejects_empty_segments(self):
|
||||||
|
with pytest.raises(ValueError, match="empty"):
|
||||||
|
concat_segments([], Path("/tmp/out.mp4"))
|
||||||
|
|
||||||
|
@patch("pipeline.card_renderer.subprocess.run")
|
||||||
|
def test_raises_on_ffmpeg_failure(self, mock_run, tmp_path):
|
||||||
|
mock_run.return_value = MagicMock(
|
||||||
|
returncode=1, stderr=b"concat error",
|
||||||
|
)
|
||||||
|
seg1 = tmp_path / "s.mp4"
|
||||||
|
seg1.touch()
|
||||||
|
out = tmp_path / "out.mp4"
|
||||||
|
|
||||||
|
with pytest.raises(subprocess.CalledProcessError):
|
||||||
|
concat_segments([seg1], out)
|
||||||
|
|
||||||
|
|
||||||
|
# ── parse_template_config tests ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestParseTemplateConfig:
|
||||||
|
"""Tests for parse_template_config() defaults and overrides."""
|
||||||
|
|
||||||
|
def test_none_returns_all_defaults_disabled(self):
|
||||||
|
cfg = parse_template_config(None)
|
||||||
|
assert cfg["show_intro"] is False
|
||||||
|
assert cfg["show_outro"] is False
|
||||||
|
assert cfg["accent_color"] == DEFAULT_ACCENT_COLOR
|
||||||
|
assert cfg["font_family"] == DEFAULT_FONT_FAMILY
|
||||||
|
assert cfg["intro_duration"] == DEFAULT_INTRO_DURATION
|
||||||
|
assert cfg["outro_duration"] == DEFAULT_OUTRO_DURATION
|
||||||
|
|
||||||
|
def test_empty_dict_returns_defaults_disabled(self):
|
||||||
|
cfg = parse_template_config({})
|
||||||
|
assert cfg["show_intro"] is False
|
||||||
|
assert cfg["show_outro"] is False
|
||||||
|
|
||||||
|
def test_full_config_preserves_values(self):
|
||||||
|
raw = {
|
||||||
|
"show_intro": True,
|
||||||
|
"intro_text": "Welcome!",
|
||||||
|
"intro_duration": 3.0,
|
||||||
|
"show_outro": True,
|
||||||
|
"outro_text": "Bye!",
|
||||||
|
"outro_duration": 1.5,
|
||||||
|
"accent_color": "#ff0000",
|
||||||
|
"font_family": "Roboto",
|
||||||
|
}
|
||||||
|
cfg = parse_template_config(raw)
|
||||||
|
assert cfg["show_intro"] is True
|
||||||
|
assert cfg["intro_text"] == "Welcome!"
|
||||||
|
assert cfg["intro_duration"] == 3.0
|
||||||
|
assert cfg["show_outro"] is True
|
||||||
|
assert cfg["outro_text"] == "Bye!"
|
||||||
|
assert cfg["outro_duration"] == 1.5
|
||||||
|
assert cfg["accent_color"] == "#ff0000"
|
||||||
|
assert cfg["font_family"] == "Roboto"
|
||||||
|
|
||||||
|
def test_partial_config_fills_defaults(self):
|
||||||
|
raw = {"show_intro": True, "intro_text": "Hi"}
|
||||||
|
cfg = parse_template_config(raw)
|
||||||
|
assert cfg["show_intro"] is True
|
||||||
|
assert cfg["intro_text"] == "Hi"
|
||||||
|
assert cfg["intro_duration"] == DEFAULT_INTRO_DURATION
|
||||||
|
assert cfg["show_outro"] is False
|
||||||
|
assert cfg["outro_text"] == ""
|
||||||
|
assert cfg["accent_color"] == DEFAULT_ACCENT_COLOR
|
||||||
|
|
||||||
|
def test_truthy_coercion(self):
|
||||||
|
"""Non-bool truthy values should coerce to bool."""
|
||||||
|
cfg = parse_template_config({"show_intro": 1, "show_outro": 0})
|
||||||
|
assert cfg["show_intro"] is True
|
||||||
|
assert cfg["show_outro"] is False
|
||||||
|
|
||||||
|
def test_duration_coercion_from_int(self):
|
||||||
|
cfg = parse_template_config({"intro_duration": 5})
|
||||||
|
assert cfg["intro_duration"] == 5.0
|
||||||
|
assert isinstance(cfg["intro_duration"], float)
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_clip_with_template tests ─────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestExtractClipWithTemplate:
|
||||||
|
"""Tests for the shorts_generator.extract_clip_with_template function."""
|
||||||
|
|
||||||
|
@patch("pipeline.shorts_generator.extract_clip")
|
||||||
|
def test_no_cards_delegates_to_extract_clip(self, mock_extract):
|
||||||
|
from pipeline.shorts_generator import extract_clip_with_template
|
||||||
|
extract_clip_with_template(
|
||||||
|
input_path=Path("/fake/input.mp4"),
|
||||||
|
output_path=Path("/fake/output.mp4"),
|
||||||
|
start_secs=10.0,
|
||||||
|
end_secs=20.0,
|
||||||
|
vf_filter="scale=1080:-2",
|
||||||
|
)
|
||||||
|
mock_extract.assert_called_once()
|
||||||
|
|
||||||
|
@patch("pipeline.card_renderer.concat_segments")
|
||||||
|
@patch("pipeline.shorts_generator.extract_clip")
|
||||||
|
def test_with_intro_concats_two_segments(self, mock_extract, mock_concat, tmp_path):
|
||||||
|
from pipeline.shorts_generator import extract_clip_with_template
|
||||||
|
|
||||||
|
intro = tmp_path / "intro.mp4"
|
||||||
|
intro.touch()
|
||||||
|
out = tmp_path / "final.mp4"
|
||||||
|
main_tmp = Path(str(out) + ".main.mp4")
|
||||||
|
# Create the main clip temp file so cleanup doesn't error
|
||||||
|
main_tmp.touch()
|
||||||
|
|
||||||
|
mock_concat.return_value = out
|
||||||
|
|
||||||
|
extract_clip_with_template(
|
||||||
|
input_path=Path("/fake/input.mp4"),
|
||||||
|
output_path=out,
|
||||||
|
start_secs=10.0,
|
||||||
|
end_secs=20.0,
|
||||||
|
vf_filter="scale=1080:-2",
|
||||||
|
intro_path=intro,
|
||||||
|
)
|
||||||
|
mock_extract.assert_called_once()
|
||||||
|
mock_concat.assert_called_once()
|
||||||
|
# Segments should be [intro, main_clip]
|
||||||
|
segments = mock_concat.call_args[1]["segments"]
|
||||||
|
assert len(segments) == 2
|
||||||
|
assert segments[0] == intro
|
||||||
|
|
||||||
|
@patch("pipeline.card_renderer.concat_segments")
|
||||||
|
@patch("pipeline.shorts_generator.extract_clip")
|
||||||
|
def test_with_intro_and_outro_concats_three_segments(
|
||||||
|
self, mock_extract, mock_concat, tmp_path,
|
||||||
|
):
|
||||||
|
from pipeline.shorts_generator import extract_clip_with_template
|
||||||
|
|
||||||
|
intro = tmp_path / "intro.mp4"
|
||||||
|
outro = tmp_path / "outro.mp4"
|
||||||
|
intro.touch()
|
||||||
|
outro.touch()
|
||||||
|
out = tmp_path / "final.mp4"
|
||||||
|
main_tmp = Path(str(out) + ".main.mp4")
|
||||||
|
main_tmp.touch()
|
||||||
|
|
||||||
|
mock_concat.return_value = out
|
||||||
|
|
||||||
|
extract_clip_with_template(
|
||||||
|
input_path=Path("/fake/input.mp4"),
|
||||||
|
output_path=out,
|
||||||
|
start_secs=10.0,
|
||||||
|
end_secs=20.0,
|
||||||
|
vf_filter="scale=1080:-2",
|
||||||
|
intro_path=intro,
|
||||||
|
outro_path=outro,
|
||||||
|
)
|
||||||
|
segments = mock_concat.call_args[1]["segments"]
|
||||||
|
assert len(segments) == 3
|
||||||
|
assert segments[0] == intro
|
||||||
|
assert segments[2] == outro
|
||||||
108
backend/pipeline/test_citation_utils.py
Normal file
108
backend/pipeline/test_citation_utils.py
Normal file
|
|
@ -0,0 +1,108 @@
|
||||||
|
"""Unit tests for citation extraction and validation utilities."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pipeline.citation_utils import extract_citations, validate_citations
|
||||||
|
from pipeline.schemas import BodySection, BodySubSection
|
||||||
|
|
||||||
|
|
||||||
|
# ── extract_citations ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractCitations:
|
||||||
|
def test_single_markers(self):
|
||||||
|
assert extract_citations("This uses reverb [0] and delay [2].") == [0, 2]
|
||||||
|
|
||||||
|
def test_multi_marker(self):
|
||||||
|
assert extract_citations("Combined approach [0,2] works well.") == [0, 2]
|
||||||
|
|
||||||
|
def test_multi_marker_with_spaces(self):
|
||||||
|
assert extract_citations("See [1, 3, 5] for details.") == [1, 3, 5]
|
||||||
|
|
||||||
|
def test_no_citations(self):
|
||||||
|
assert extract_citations("Plain text without citations.") == []
|
||||||
|
|
||||||
|
def test_duplicate_indices_deduplicated(self):
|
||||||
|
assert extract_citations("[1] and again [1] and [1,2]") == [1, 2]
|
||||||
|
|
||||||
|
def test_returns_sorted(self):
|
||||||
|
assert extract_citations("[5] then [1] then [3]") == [1, 3, 5]
|
||||||
|
|
||||||
|
def test_adjacent_markers(self):
|
||||||
|
assert extract_citations("[0][1][2]") == [0, 1, 2]
|
||||||
|
|
||||||
|
def test_does_not_match_non_numeric_brackets(self):
|
||||||
|
assert extract_citations("[abc] and [N] but [7] works") == [7]
|
||||||
|
|
||||||
|
|
||||||
|
# ── validate_citations ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sections(texts: list[str], sub_texts: list[list[str]] | None = None) -> list[BodySection]:
|
||||||
|
"""Helper to build BodySection lists for testing."""
|
||||||
|
sections = []
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
subs = []
|
||||||
|
if sub_texts and i < len(sub_texts):
|
||||||
|
subs = [BodySubSection(heading=f"Sub {j}", content=t) for j, t in enumerate(sub_texts[i])]
|
||||||
|
sections.append(BodySection(heading=f"Section {i}", content=text, subsections=subs))
|
||||||
|
return sections
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateCitations:
|
||||||
|
def test_all_moments_cited(self):
|
||||||
|
sections = _make_sections(["Uses [0] and [1].", "Also [2]."])
|
||||||
|
result = validate_citations(sections, moment_count=3)
|
||||||
|
assert result["valid"] is True
|
||||||
|
assert result["total_citations"] == 3
|
||||||
|
assert result["invalid_indices"] == []
|
||||||
|
assert result["uncited_moments"] == []
|
||||||
|
assert result["coverage_pct"] == 100.0
|
||||||
|
|
||||||
|
def test_out_of_range_index(self):
|
||||||
|
sections = _make_sections(["Reference [0] and [5]."])
|
||||||
|
result = validate_citations(sections, moment_count=3)
|
||||||
|
assert result["valid"] is False
|
||||||
|
assert result["invalid_indices"] == [5]
|
||||||
|
assert result["uncited_moments"] == [1, 2]
|
||||||
|
|
||||||
|
def test_multi_citation_markers(self):
|
||||||
|
sections = _make_sections(["Combined [0,2] technique."])
|
||||||
|
result = validate_citations(sections, moment_count=3)
|
||||||
|
assert result["valid"] is False # moment 1 uncited
|
||||||
|
assert result["total_citations"] == 2
|
||||||
|
assert result["uncited_moments"] == [1]
|
||||||
|
assert result["coverage_pct"] == pytest.approx(66.7, abs=0.1)
|
||||||
|
|
||||||
|
def test_no_citations_at_all(self):
|
||||||
|
sections = _make_sections(["Plain text with no markers."])
|
||||||
|
result = validate_citations(sections, moment_count=2)
|
||||||
|
assert result["valid"] is False
|
||||||
|
assert result["total_citations"] == 0
|
||||||
|
assert result["uncited_moments"] == [0, 1]
|
||||||
|
assert result["coverage_pct"] == 0.0
|
||||||
|
|
||||||
|
def test_empty_sections(self):
|
||||||
|
result = validate_citations([], moment_count=0)
|
||||||
|
assert result["valid"] is True
|
||||||
|
assert result["total_citations"] == 0
|
||||||
|
assert result["coverage_pct"] == 0.0
|
||||||
|
|
||||||
|
def test_subsection_citations_counted(self):
|
||||||
|
sections = _make_sections(
|
||||||
|
["Section text [0]."],
|
||||||
|
sub_texts=[["Subsection cites [1] and [2]."]],
|
||||||
|
)
|
||||||
|
result = validate_citations(sections, moment_count=3)
|
||||||
|
assert result["valid"] is True
|
||||||
|
assert result["total_citations"] == 3
|
||||||
|
|
||||||
|
def test_zero_moment_count_with_citations(self):
|
||||||
|
"""Citations exist but moment_count is 0 — all indices are out of range."""
|
||||||
|
sections = _make_sections(["References [0] and [1]."])
|
||||||
|
result = validate_citations(sections, moment_count=0)
|
||||||
|
assert result["valid"] is False
|
||||||
|
assert result["invalid_indices"] == [0, 1]
|
||||||
|
assert result["coverage_pct"] == 0.0
|
||||||
360
backend/pipeline/test_compose_pipeline.py
Normal file
360
backend/pipeline/test_compose_pipeline.py
Normal file
|
|
@ -0,0 +1,360 @@
|
||||||
|
"""Unit tests for compose pipeline logic in stage5_synthesis.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- _build_compose_user_prompt(): XML structure, offset indices, empty existing, page JSON
|
||||||
|
- Compose-or-create branching: compose triggered vs create fallback
|
||||||
|
- body_sections_format='v2' on persisted pages
|
||||||
|
- TechniquePageVideo insertion via pg_insert with on_conflict_do_nothing
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from collections import namedtuple
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ── Lightweight mock objects ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _MockContentType:
|
||||||
|
"""Mimics ContentType enum with .value."""
|
||||||
|
def __init__(self, value: str) -> None:
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
MockKeyMoment = namedtuple("MockKeyMoment", [
|
||||||
|
"id", "title", "summary", "content_type", "start_time", "end_time",
|
||||||
|
"plugins", "raw_transcript", "technique_page_id", "source_video_id",
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def _moment(
|
||||||
|
title: str = "Test Moment",
|
||||||
|
summary: str = "A moment.",
|
||||||
|
content_type: str = "technique_demo",
|
||||||
|
start_time: float = 0.0,
|
||||||
|
end_time: float = 10.0,
|
||||||
|
plugins: list[str] | None = None,
|
||||||
|
raw_transcript: str | None = "Some transcript text",
|
||||||
|
technique_page_id: uuid.UUID | None = None,
|
||||||
|
source_video_id: uuid.UUID | None = None,
|
||||||
|
) -> MockKeyMoment:
|
||||||
|
return MockKeyMoment(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
title=title,
|
||||||
|
summary=summary,
|
||||||
|
content_type=_MockContentType(content_type),
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
plugins=plugins or [],
|
||||||
|
raw_transcript=raw_transcript or "",
|
||||||
|
technique_page_id=technique_page_id,
|
||||||
|
source_video_id=source_video_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _MockSourceQuality:
|
||||||
|
"""Mimics source_quality enum with .value."""
|
||||||
|
def __init__(self, value: str = "high") -> None:
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
class MockTechniquePage:
|
||||||
|
"""Lightweight stand-in for the ORM TechniquePage."""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
title: str = "Reverb Techniques",
|
||||||
|
slug: str = "reverb-techniques",
|
||||||
|
topic_category: str = "Sound Design",
|
||||||
|
summary: str = "A page about reverb.",
|
||||||
|
body_sections: list | None = None,
|
||||||
|
signal_chains: list | None = None,
|
||||||
|
plugins: list[str] | None = None,
|
||||||
|
source_quality: str = "high",
|
||||||
|
creator_id: uuid.UUID | None = None,
|
||||||
|
body_sections_format: str | None = None,
|
||||||
|
):
|
||||||
|
self.id = uuid.uuid4()
|
||||||
|
self.title = title
|
||||||
|
self.slug = slug
|
||||||
|
self.topic_category = topic_category
|
||||||
|
self.summary = summary
|
||||||
|
self.body_sections = body_sections or [{"heading": "Overview", "content": "Intro text."}]
|
||||||
|
self.signal_chains = signal_chains or []
|
||||||
|
self.plugins = plugins or ["Valhalla VintageVerb"]
|
||||||
|
self.source_quality = _MockSourceQuality(source_quality)
|
||||||
|
self.creator_id = creator_id or uuid.uuid4()
|
||||||
|
self.body_sections_format = body_sections_format
|
||||||
|
|
||||||
|
|
||||||
|
def _cls_info(tags: list[str] | None = None) -> dict:
|
||||||
|
return {"topic_category": "Sound Design", "topic_tags": tags or ["reverb", "delay"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Import the function under test ───────────────────────────────────────────
|
||||||
|
# We need to patch modules before importing stages in some tests.
|
||||||
|
# For _build_compose_user_prompt we can import directly since it's a pure function
|
||||||
|
# that only depends on _build_moments_text.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def build_compose_prompt():
|
||||||
|
"""Import _build_compose_user_prompt from stages."""
|
||||||
|
from pipeline.stages import _build_compose_user_prompt
|
||||||
|
return _build_compose_user_prompt
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests for _build_compose_user_prompt ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildComposeUserPrompt:
|
||||||
|
"""Tests for _build_compose_user_prompt XML structure and offset math."""
|
||||||
|
|
||||||
|
def test_compose_prompt_xml_structure(self, build_compose_prompt):
|
||||||
|
"""Verify output contains all required XML tags."""
|
||||||
|
page = MockTechniquePage()
|
||||||
|
existing = [_moment(title="Existing 1")]
|
||||||
|
new = [(_moment(title="New 1"), _cls_info())]
|
||||||
|
|
||||||
|
result = build_compose_prompt(page, existing, new, "COPYCATT")
|
||||||
|
|
||||||
|
assert "<existing_page>" in result
|
||||||
|
assert "</existing_page>" in result
|
||||||
|
assert "<existing_moments>" in result
|
||||||
|
assert "</existing_moments>" in result
|
||||||
|
assert "<new_moments>" in result
|
||||||
|
assert "</new_moments>" in result
|
||||||
|
assert "<creator>" in result
|
||||||
|
assert "</creator>" in result
|
||||||
|
assert "COPYCATT" in result
|
||||||
|
|
||||||
|
def test_compose_prompt_offset_indices(self, build_compose_prompt):
|
||||||
|
"""With 3 existing + 2 new moments, new moments should use [3] and [4]."""
|
||||||
|
page = MockTechniquePage()
|
||||||
|
existing = [
|
||||||
|
_moment(title=f"Existing {i}") for i in range(3)
|
||||||
|
]
|
||||||
|
new = [
|
||||||
|
(_moment(title=f"New {i}"), _cls_info()) for i in range(2)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = build_compose_prompt(page, existing, new, "COPYCATT")
|
||||||
|
|
||||||
|
# New moments section should have [3] and [4]
|
||||||
|
new_section_start = result.index("<new_moments>")
|
||||||
|
new_section_end = result.index("</new_moments>")
|
||||||
|
new_section = result[new_section_start:new_section_end]
|
||||||
|
|
||||||
|
assert "[3]" in new_section
|
||||||
|
assert "[4]" in new_section
|
||||||
|
# Should NOT have [0], [1], [2] in the new section
|
||||||
|
assert "[0]" not in new_section
|
||||||
|
assert "[1]" not in new_section
|
||||||
|
assert "[2]" not in new_section
|
||||||
|
|
||||||
|
def test_compose_prompt_empty_existing_moments(self, build_compose_prompt):
|
||||||
|
"""0 existing moments → new moments start at [0]."""
|
||||||
|
page = MockTechniquePage()
|
||||||
|
existing = []
|
||||||
|
new = [
|
||||||
|
(_moment(title="New A"), _cls_info()),
|
||||||
|
(_moment(title="New B"), _cls_info()),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = build_compose_prompt(page, existing, new, "COPYCATT")
|
||||||
|
|
||||||
|
new_section_start = result.index("<new_moments>")
|
||||||
|
new_section_end = result.index("</new_moments>")
|
||||||
|
new_section = result[new_section_start:new_section_end]
|
||||||
|
|
||||||
|
assert "[0]" in new_section
|
||||||
|
assert "[1]" in new_section
|
||||||
|
|
||||||
|
def test_compose_prompt_page_json(self, build_compose_prompt):
|
||||||
|
"""Existing page should be serialized as JSON within <existing_page> tags."""
|
||||||
|
page = MockTechniquePage(title="My Page", slug="my-page", topic_category="Mixing")
|
||||||
|
|
||||||
|
result = build_compose_prompt(page, [], [(_moment(), _cls_info())], "Creator")
|
||||||
|
|
||||||
|
page_section_start = result.index("<existing_page>") + len("<existing_page>")
|
||||||
|
page_section_end = result.index("</existing_page>")
|
||||||
|
page_json_str = result[page_section_start:page_section_end].strip()
|
||||||
|
|
||||||
|
page_dict = json.loads(page_json_str)
|
||||||
|
assert page_dict["title"] == "My Page"
|
||||||
|
assert page_dict["slug"] == "my-page"
|
||||||
|
assert page_dict["topic_category"] == "Mixing"
|
||||||
|
assert "summary" in page_dict
|
||||||
|
assert "body_sections" in page_dict
|
||||||
|
|
||||||
|
def test_compose_prompt_new_moment_content(self, build_compose_prompt):
|
||||||
|
"""New moments section includes title, summary, time range, and tags."""
|
||||||
|
page = MockTechniquePage()
|
||||||
|
m = _moment(title="Sidechain Pump", summary="How to create a sidechain pump",
|
||||||
|
start_time=30.0, end_time=45.5, plugins=["FabFilter Pro-C 2"])
|
||||||
|
new = [(m, _cls_info(tags=["compression", "sidechain"]))]
|
||||||
|
|
||||||
|
result = build_compose_prompt(page, [], new, "Creator")
|
||||||
|
|
||||||
|
new_section_start = result.index("<new_moments>")
|
||||||
|
new_section_end = result.index("</new_moments>")
|
||||||
|
new_section = result[new_section_start:new_section_end]
|
||||||
|
|
||||||
|
assert "Sidechain Pump" in new_section
|
||||||
|
assert "How to create a sidechain pump" in new_section
|
||||||
|
assert "30.0s" in new_section
|
||||||
|
assert "45.5s" in new_section
|
||||||
|
assert "FabFilter Pro-C 2" in new_section
|
||||||
|
assert "compression" in new_section
|
||||||
|
assert "sidechain" in new_section
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests for compose-or-create branching ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestComposeOrCreateBranching:
|
||||||
|
"""Tests for the compose-or-create detection and branching in stage5_synthesis.
|
||||||
|
|
||||||
|
Full integration-level mocking of stage5_synthesis is fragile (many DB queries).
|
||||||
|
Instead, we verify:
|
||||||
|
1. The code structure has correct branching (compose_target check → two paths)
|
||||||
|
2. _compose_into_existing calls the LLM with compose prompt and returns parsed result
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_compose_branch_exists_in_source(self):
|
||||||
|
"""Verify stage5 has compose detection → _compose_into_existing call path."""
|
||||||
|
from pathlib import Path
|
||||||
|
src = Path("backend/pipeline/stages.py").read_text()
|
||||||
|
|
||||||
|
# The compose detection block
|
||||||
|
assert "compose_matches = session.execute(" in src
|
||||||
|
assert "compose_target = compose_matches[0] if compose_matches else None" in src
|
||||||
|
|
||||||
|
# The compose branch calls _compose_into_existing
|
||||||
|
assert "if compose_target is not None:" in src
|
||||||
|
assert "_compose_into_existing(" in src
|
||||||
|
|
||||||
|
# The create branch calls _synthesize_chunk
|
||||||
|
assert "elif len(moment_group) <= chunk_size:" in src
|
||||||
|
|
||||||
|
def test_create_branch_when_no_compose_target(self):
|
||||||
|
"""Verify the else/elif branches call _synthesize_chunk, not _compose_into_existing."""
|
||||||
|
from pathlib import Path
|
||||||
|
src = Path("backend/pipeline/stages.py").read_text()
|
||||||
|
|
||||||
|
# Find the compose branch and the create branch — they're mutually exclusive
|
||||||
|
compose_branch_idx = src.index("if compose_target is not None:")
|
||||||
|
create_branch_idx = src.index("elif len(moment_group) <= chunk_size:")
|
||||||
|
|
||||||
|
# The create branch must come after the compose branch (same if/elif chain)
|
||||||
|
assert create_branch_idx > compose_branch_idx
|
||||||
|
|
||||||
|
# _synthesize_chunk should appear in the create branch, not compose
|
||||||
|
create_block = src[create_branch_idx:create_branch_idx + 500]
|
||||||
|
assert "_synthesize_chunk(" in create_block
|
||||||
|
|
||||||
|
@patch("pipeline.stages._safe_parse_llm_response")
|
||||||
|
@patch("pipeline.stages._make_llm_callback", return_value=lambda **kw: None)
|
||||||
|
@patch("pipeline.stages.estimate_max_tokens", return_value=4000)
|
||||||
|
@patch("pipeline.stages._load_prompt", return_value="compose system prompt")
|
||||||
|
def test_compose_into_existing_calls_llm(
|
||||||
|
self, mock_load_prompt, mock_estimate, mock_callback, mock_parse,
|
||||||
|
):
|
||||||
|
"""_compose_into_existing calls LLM with compose prompt and returns parsed result."""
|
||||||
|
from pipeline.schemas import SynthesisResult, SynthesizedPage
|
||||||
|
from pipeline.stages import _compose_into_existing
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.complete.return_value = "raw response"
|
||||||
|
|
||||||
|
synth_page = SynthesizedPage(
|
||||||
|
title="Merged Page", slug="merged-page", topic_category="Sound Design",
|
||||||
|
summary="Merged", body_sections=[], signal_chains=[], plugins=[],
|
||||||
|
source_quality="high", moment_indices=[0, 1],
|
||||||
|
)
|
||||||
|
mock_parse.return_value = SynthesisResult(pages=[synth_page])
|
||||||
|
|
||||||
|
page = MockTechniquePage()
|
||||||
|
existing_moments = [_moment(title="Old Moment")]
|
||||||
|
new_moments = [(_moment(title="New Moment"), _cls_info())]
|
||||||
|
|
||||||
|
result = _compose_into_existing(
|
||||||
|
page, existing_moments, new_moments,
|
||||||
|
"Sound Design", "COPYCATT", "system prompt",
|
||||||
|
mock_llm, None, "text", 8000, str(uuid.uuid4()), None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# LLM was called
|
||||||
|
mock_llm.complete.assert_called_once()
|
||||||
|
# The compose prompt template was loaded
|
||||||
|
mock_load_prompt.assert_called_once()
|
||||||
|
call_args = mock_load_prompt.call_args
|
||||||
|
assert "stage5_compose" in call_args[0][0]
|
||||||
|
|
||||||
|
# Result has the expected page
|
||||||
|
assert len(result.pages) == 1
|
||||||
|
assert result.pages[0].title == "Merged Page"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests for body_sections_format and TechniquePageVideo ────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestBodySectionsFormatAndTracking:
|
||||||
|
"""Tests for body_sections_format='v2' and TechniquePageVideo insertion."""
|
||||||
|
|
||||||
|
def test_body_sections_format_v2_set_on_page(self):
|
||||||
|
"""Verify the persist section sets body_sections_format='v2' on pages."""
|
||||||
|
# Read stages.py source and verify the assignment exists
|
||||||
|
from pathlib import Path
|
||||||
|
stages_src = Path("backend/pipeline/stages.py").read_text()
|
||||||
|
|
||||||
|
# The line `page.body_sections_format = "v2"` must appear in the persist block
|
||||||
|
assert 'page.body_sections_format = "v2"' in stages_src, (
|
||||||
|
"body_sections_format = 'v2' assignment not found in stages.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_technique_page_video_pg_insert(self):
|
||||||
|
"""Verify TechniquePageVideo insertion uses pg_insert with on_conflict_do_nothing."""
|
||||||
|
from pathlib import Path
|
||||||
|
stages_src = Path("backend/pipeline/stages.py").read_text()
|
||||||
|
|
||||||
|
assert "pg_insert(TechniquePageVideo.__table__)" in stages_src, (
|
||||||
|
"pg_insert(TechniquePageVideo.__table__) not found in stages.py"
|
||||||
|
)
|
||||||
|
assert "on_conflict_do_nothing()" in stages_src, (
|
||||||
|
"on_conflict_do_nothing() not found in stages.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_technique_page_video_values(self):
|
||||||
|
"""Verify TechniquePageVideo INSERT includes technique_page_id and source_video_id."""
|
||||||
|
from pathlib import Path
|
||||||
|
stages_src = Path("backend/pipeline/stages.py").read_text()
|
||||||
|
|
||||||
|
# Find the pg_insert block
|
||||||
|
idx = stages_src.index("pg_insert(TechniquePageVideo.__table__)")
|
||||||
|
block = stages_src[idx:idx + 200]
|
||||||
|
|
||||||
|
assert "technique_page_id" in block
|
||||||
|
assert "source_video_id" in block
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests for category case-insensitivity ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCategoryCaseInsensitive:
|
||||||
|
"""Verify the compose detection query uses func.lower for category matching."""
|
||||||
|
|
||||||
|
def test_compose_detection_uses_func_lower(self):
|
||||||
|
"""The compose detection query must use func.lower on both sides."""
|
||||||
|
from pathlib import Path
|
||||||
|
stages_src = Path("backend/pipeline/stages.py").read_text()
|
||||||
|
|
||||||
|
# Find the compose detection block — need enough window to capture the full query
|
||||||
|
idx = stages_src.index("Compose-or-create detection")
|
||||||
|
block = stages_src[idx:idx + 600]
|
||||||
|
|
||||||
|
assert "func.lower(TechniquePage.topic_category)" in block
|
||||||
|
assert "func.lower(category)" in block
|
||||||
830
backend/pipeline/test_harness.py
Normal file
830
backend/pipeline/test_harness.py
Normal file
|
|
@ -0,0 +1,830 @@
|
||||||
|
"""Offline prompt test harness for Chrysopedia synthesis.
|
||||||
|
|
||||||
|
Loads a fixture JSON (exported by export_fixture.py) and a prompt file,
|
||||||
|
calls the LLM, and outputs the synthesized result. No Docker, no database,
|
||||||
|
no Redis, no Celery — just prompt + fixture + LLM endpoint.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m pipeline.test_harness \\
|
||||||
|
--fixture fixtures/real_video_xyz.json \\
|
||||||
|
--prompt prompts/stage5_synthesis.txt \\
|
||||||
|
--output /tmp/result.json
|
||||||
|
|
||||||
|
# Run all categories in a fixture:
|
||||||
|
python -m pipeline.test_harness --fixture fixtures/video.json
|
||||||
|
|
||||||
|
# Run a specific category only:
|
||||||
|
python -m pipeline.test_harness --fixture fixtures/video.json --category "Sound Design"
|
||||||
|
|
||||||
|
Exit codes: 0=success, 1=LLM error, 2=parse error, 3=fixture error
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from collections import Counter, defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from config import get_settings
|
||||||
|
from pipeline.citation_utils import validate_citations
|
||||||
|
from pipeline.llm_client import LLMClient, estimate_max_tokens
|
||||||
|
from pipeline.schemas import SynthesizedPage, SynthesisResult
|
||||||
|
|
||||||
|
|
||||||
|
# ── Lightweight stand-in for KeyMoment ORM model ───────────────────────────
|
||||||
|
|
||||||
|
class _MockContentType:
|
||||||
|
"""Mimics KeyMomentContentType enum with a .value property."""
|
||||||
|
def __init__(self, value: str) -> None:
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
class MockKeyMoment(NamedTuple):
|
||||||
|
"""Lightweight stand-in for the ORM KeyMoment.
|
||||||
|
|
||||||
|
Has the same attributes that _build_moments_text() accesses:
|
||||||
|
title, summary, content_type, start_time, end_time, plugins, raw_transcript.
|
||||||
|
"""
|
||||||
|
title: str
|
||||||
|
summary: str
|
||||||
|
content_type: object # _MockContentType
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
plugins: list[str]
|
||||||
|
raw_transcript: str
|
||||||
|
|
||||||
|
|
||||||
|
def _log(tag: str, msg: str, level: str = "INFO") -> None:
|
||||||
|
"""Write structured log line to stderr."""
|
||||||
|
print(f"[HARNESS] [{level}] {tag}: {msg}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Moment text builder (mirrors stages.py _build_moments_text) ────────────
|
||||||
|
|
||||||
|
def build_moments_text(
|
||||||
|
moment_group: list[tuple[MockKeyMoment, dict]],
|
||||||
|
category: str,
|
||||||
|
) -> tuple[str, set[str]]:
|
||||||
|
"""Build the moments prompt text — matches _build_moments_text in stages.py."""
|
||||||
|
moments_lines = []
|
||||||
|
all_tags: set[str] = set()
|
||||||
|
for i, (m, cls_info) in enumerate(moment_group):
|
||||||
|
tags = cls_info.get("topic_tags", [])
|
||||||
|
all_tags.update(tags)
|
||||||
|
moments_lines.append(
|
||||||
|
f"[{i}] Title: {m.title}\n"
|
||||||
|
f" Summary: {m.summary}\n"
|
||||||
|
f" Content type: {m.content_type.value}\n"
|
||||||
|
f" Time: {m.start_time:.1f}s - {m.end_time:.1f}s\n"
|
||||||
|
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}\n"
|
||||||
|
f" Category: {category}\n"
|
||||||
|
f" Tags: {', '.join(tags) if tags else 'none'}\n"
|
||||||
|
f" Transcript excerpt: {(m.raw_transcript or '')[:300]}"
|
||||||
|
)
|
||||||
|
return "\n\n".join(moments_lines), all_tags
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixture loading ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FixtureData:
|
||||||
|
"""Parsed fixture with moments grouped by category."""
|
||||||
|
creator_name: str
|
||||||
|
video_id: str
|
||||||
|
content_type: str
|
||||||
|
filename: str
|
||||||
|
# Groups: category -> list of (MockKeyMoment, cls_info_dict)
|
||||||
|
groups: dict[str, list[tuple[MockKeyMoment, dict]]]
|
||||||
|
total_moments: int
|
||||||
|
|
||||||
|
|
||||||
|
def load_fixture(path: str) -> FixtureData:
|
||||||
|
"""Load and parse a fixture JSON file into grouped moments."""
|
||||||
|
fixture_path = Path(path)
|
||||||
|
if not fixture_path.exists():
|
||||||
|
raise FileNotFoundError(f"Fixture not found: {path}")
|
||||||
|
|
||||||
|
raw = fixture_path.read_text(encoding="utf-8")
|
||||||
|
size_kb = len(raw.encode("utf-8")) / 1024
|
||||||
|
data = json.loads(raw)
|
||||||
|
|
||||||
|
moments_raw = data.get("moments", [])
|
||||||
|
if not moments_raw:
|
||||||
|
raise ValueError(f"Fixture has no moments: {path}")
|
||||||
|
|
||||||
|
_log("FIXTURE", f"Loading: {path} ({size_kb:.1f} KB, {len(moments_raw)} moments)")
|
||||||
|
|
||||||
|
# Build MockKeyMoment objects and group by category
|
||||||
|
groups: dict[str, list[tuple[MockKeyMoment, dict]]] = defaultdict(list)
|
||||||
|
|
||||||
|
for m in moments_raw:
|
||||||
|
cls = m.get("classification", {})
|
||||||
|
category = cls.get("topic_category", m.get("topic_category", "Uncategorized"))
|
||||||
|
tags = cls.get("topic_tags", m.get("topic_tags", []))
|
||||||
|
|
||||||
|
mock = MockKeyMoment(
|
||||||
|
title=m.get("title", m.get("summary", "")[:80]),
|
||||||
|
summary=m.get("summary", ""),
|
||||||
|
content_type=_MockContentType(m.get("content_type", "technique")),
|
||||||
|
start_time=m.get("start_time", 0.0),
|
||||||
|
end_time=m.get("end_time", 0.0),
|
||||||
|
plugins=m.get("plugins", []),
|
||||||
|
raw_transcript=m.get("raw_transcript", m.get("transcript_excerpt", "")),
|
||||||
|
)
|
||||||
|
cls_info = {"topic_category": category, "topic_tags": tags}
|
||||||
|
groups[category].append((mock, cls_info))
|
||||||
|
|
||||||
|
# Log breakdown
|
||||||
|
cat_counts = {cat: len(moms) for cat, moms in groups.items()}
|
||||||
|
counts = list(cat_counts.values())
|
||||||
|
_log(
|
||||||
|
"FIXTURE",
|
||||||
|
f"Breakdown: {len(groups)} categories, "
|
||||||
|
f"moments per category: min={min(counts)}, max={max(counts)}, "
|
||||||
|
f"avg={sum(counts)/len(counts):.1f}",
|
||||||
|
)
|
||||||
|
for cat, count in sorted(cat_counts.items(), key=lambda x: -x[1]):
|
||||||
|
_log("FIXTURE", f" {cat}: {count} moments")
|
||||||
|
|
||||||
|
return FixtureData(
|
||||||
|
creator_name=data.get("creator_name", "Unknown"),
|
||||||
|
video_id=data.get("video_id", "unknown"),
|
||||||
|
content_type=data.get("content_type", "tutorial"),
|
||||||
|
filename=data.get("filename", "unknown"),
|
||||||
|
groups=dict(groups),
|
||||||
|
total_moments=len(moments_raw),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Synthesis runner ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def run_synthesis(
|
||||||
|
fixture: FixtureData,
|
||||||
|
prompt_path: str,
|
||||||
|
category_filter: str | None = None,
|
||||||
|
model_override: str | None = None,
|
||||||
|
modality: str | None = None,
|
||||||
|
) -> tuple[list[dict], int]:
|
||||||
|
"""Run synthesis on fixture data, returns (pages, exit_code).
|
||||||
|
|
||||||
|
Returns all synthesized pages as dicts plus an exit code.
|
||||||
|
"""
|
||||||
|
# Load prompt
|
||||||
|
prompt_file = Path(prompt_path)
|
||||||
|
if not prompt_file.exists():
|
||||||
|
_log("ERROR", f"Prompt file not found: {prompt_path}", level="ERROR")
|
||||||
|
return [], 3
|
||||||
|
|
||||||
|
system_prompt = prompt_file.read_text(encoding="utf-8")
|
||||||
|
_log("PROMPT", f"Loading: {prompt_path} ({len(system_prompt)} chars)")
|
||||||
|
|
||||||
|
# Setup LLM
|
||||||
|
settings = get_settings()
|
||||||
|
llm = LLMClient(settings)
|
||||||
|
|
||||||
|
stage_model = model_override or settings.llm_stage5_model or settings.llm_model
|
||||||
|
stage_modality = modality or settings.llm_stage5_modality or "thinking"
|
||||||
|
hard_limit = settings.llm_max_tokens_hard_limit
|
||||||
|
|
||||||
|
_log("LLM", f"Model: {stage_model}, modality: {stage_modality}, hard_limit: {hard_limit}")
|
||||||
|
|
||||||
|
# Filter categories if requested
|
||||||
|
categories = fixture.groups
|
||||||
|
if category_filter:
|
||||||
|
if category_filter not in categories:
|
||||||
|
_log("ERROR", f"Category '{category_filter}' not found. Available: {list(categories.keys())}", level="ERROR")
|
||||||
|
return [], 3
|
||||||
|
categories = {category_filter: categories[category_filter]}
|
||||||
|
|
||||||
|
all_pages: list[dict] = []
|
||||||
|
total_prompt_tokens = 0
|
||||||
|
total_completion_tokens = 0
|
||||||
|
total_duration_ms = 0
|
||||||
|
exit_code = 0
|
||||||
|
|
||||||
|
for cat_idx, (category, moment_group) in enumerate(categories.items(), 1):
|
||||||
|
_log("SYNTH", f"Category {cat_idx}/{len(categories)}: '{category}' ({len(moment_group)} moments)")
|
||||||
|
|
||||||
|
# Build user prompt (same format as stages.py _synthesize_chunk)
|
||||||
|
moments_text, all_tags = build_moments_text(moment_group, category)
|
||||||
|
user_prompt = f"<creator>{fixture.creator_name}</creator>\n<moments>\n{moments_text}\n</moments>"
|
||||||
|
|
||||||
|
estimated_tokens = estimate_max_tokens(
|
||||||
|
system_prompt, user_prompt,
|
||||||
|
stage="stage5_synthesis",
|
||||||
|
hard_limit=hard_limit,
|
||||||
|
)
|
||||||
|
_log(
|
||||||
|
"SYNTH",
|
||||||
|
f" Building prompt: {len(moment_group)} moments, "
|
||||||
|
f"max_tokens={estimated_tokens}, tags={sorted(all_tags)[:5]}{'...' if len(all_tags) > 5 else ''}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call LLM
|
||||||
|
call_start = time.monotonic()
|
||||||
|
_log("LLM", f" Calling: model={stage_model}, max_tokens={estimated_tokens}, modality={stage_modality}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw = llm.complete(
|
||||||
|
system_prompt,
|
||||||
|
user_prompt,
|
||||||
|
response_model=SynthesisResult,
|
||||||
|
modality=stage_modality,
|
||||||
|
model_override=stage_model,
|
||||||
|
max_tokens=estimated_tokens,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
_log("ERROR", f" LLM call failed: {exc}", level="ERROR")
|
||||||
|
exit_code = 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
call_duration_ms = int((time.monotonic() - call_start) * 1000)
|
||||||
|
prompt_tokens = getattr(raw, "prompt_tokens", None) or 0
|
||||||
|
completion_tokens = getattr(raw, "completion_tokens", None) or 0
|
||||||
|
finish_reason = getattr(raw, "finish_reason", "unknown")
|
||||||
|
|
||||||
|
total_prompt_tokens += prompt_tokens
|
||||||
|
total_completion_tokens += completion_tokens
|
||||||
|
total_duration_ms += call_duration_ms
|
||||||
|
|
||||||
|
_log(
|
||||||
|
"LLM",
|
||||||
|
f" Response: {prompt_tokens} prompt + {completion_tokens} completion tokens, "
|
||||||
|
f"{call_duration_ms}ms, finish_reason={finish_reason}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if finish_reason == "length":
|
||||||
|
_log(
|
||||||
|
"WARN",
|
||||||
|
" finish_reason=length — output likely truncated! "
|
||||||
|
"Consider reducing fixture size or increasing max_tokens.",
|
||||||
|
level="WARN",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
try:
|
||||||
|
result = SynthesisResult.model_validate_json(str(raw))
|
||||||
|
except (ValidationError, json.JSONDecodeError) as exc:
|
||||||
|
_log("ERROR", f" Parse failed: {exc}", level="ERROR")
|
||||||
|
_log("ERROR", f" Raw response (first 2000 chars): {str(raw)[:2000]}", level="ERROR")
|
||||||
|
exit_code = 2
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Log per-page summary
|
||||||
|
_log("SYNTH", f" Parsed: {len(result.pages)} pages synthesized")
|
||||||
|
total_words = 0
|
||||||
|
for page in result.pages:
|
||||||
|
sections = page.body_sections or []
|
||||||
|
section_count = len(sections)
|
||||||
|
subsection_count = sum(len(s.subsections) for s in sections)
|
||||||
|
word_count = sum(
|
||||||
|
len(s.content.split()) + sum(len(sub.content.split()) for sub in s.subsections)
|
||||||
|
for s in sections
|
||||||
|
)
|
||||||
|
total_words += word_count
|
||||||
|
_log(
|
||||||
|
"PAGE",
|
||||||
|
f" '{page.title}' ({page.slug}): "
|
||||||
|
f"{section_count} sections ({subsection_count} subsections), "
|
||||||
|
f"{word_count} words, "
|
||||||
|
f"{len(page.moment_indices)} moments linked, "
|
||||||
|
f"quality={page.source_quality}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Citation coverage reporting
|
||||||
|
cit = validate_citations(sections, len(page.moment_indices))
|
||||||
|
_log(
|
||||||
|
"CITE",
|
||||||
|
f" Citations: {cit['total_citations']}/{len(page.moment_indices)} moments cited "
|
||||||
|
f"({cit['coverage_pct']}% coverage)"
|
||||||
|
+ (f", invalid indices: {cit['invalid_indices']}" if cit['invalid_indices'] else "")
|
||||||
|
+ (f", uncited: {cit['uncited_moments']}" if cit['uncited_moments'] else ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
all_pages.append(page.model_dump())
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
_log("SUMMARY", f"Total: {len(all_pages)} pages across {len(categories)} categories")
|
||||||
|
_log("SUMMARY", f"Tokens: {total_prompt_tokens} prompt + {total_completion_tokens} completion = {total_prompt_tokens + total_completion_tokens} total")
|
||||||
|
_log("SUMMARY", f"Duration: {total_duration_ms}ms ({total_duration_ms / 1000:.1f}s)")
|
||||||
|
|
||||||
|
return all_pages, exit_code
|
||||||
|
|
||||||
|
|
||||||
|
# ── Compose: merge new moments into existing page ──────────────────────────
|
||||||
|
|
||||||
|
def _count_page_words(page_dict: dict) -> int:
|
||||||
|
"""Count total words in a page's body sections."""
|
||||||
|
return sum(
|
||||||
|
len(s.get("content", "").split())
|
||||||
|
+ sum(len(sub.get("content", "").split()) for sub in s.get("subsections", []))
|
||||||
|
for s in page_dict.get("body_sections", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_compose_prompt(
|
||||||
|
existing_page: dict,
|
||||||
|
existing_moments: list[tuple[MockKeyMoment, dict]],
|
||||||
|
new_moments: list[tuple[MockKeyMoment, dict]],
|
||||||
|
creator_name: str,
|
||||||
|
) -> str:
|
||||||
|
"""Build the user prompt for composition (merging new moments into an existing page).
|
||||||
|
|
||||||
|
Existing moments keep indices [0]-[N-1].
|
||||||
|
New moments get indices [N]-[N+M-1].
|
||||||
|
Uses build_moments_text() for formatting, with index offsets applied for new moments.
|
||||||
|
"""
|
||||||
|
category = existing_page.get("topic_category", "Uncategorized")
|
||||||
|
|
||||||
|
# Format existing moments [0]-[N-1]
|
||||||
|
existing_text, _ = build_moments_text(existing_moments, category)
|
||||||
|
|
||||||
|
# Format new moments with offset indices [N]-[N+M-1]
|
||||||
|
n = len(existing_moments)
|
||||||
|
new_lines = []
|
||||||
|
for i, (m, cls_info) in enumerate(new_moments):
|
||||||
|
tags = cls_info.get("topic_tags", [])
|
||||||
|
new_lines.append(
|
||||||
|
f"[{n + i}] Title: {m.title}\n"
|
||||||
|
f" Summary: {m.summary}\n"
|
||||||
|
f" Content type: {m.content_type.value}\n"
|
||||||
|
f" Time: {m.start_time:.1f}s - {m.end_time:.1f}s\n"
|
||||||
|
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}\n"
|
||||||
|
f" Category: {category}\n"
|
||||||
|
f" Tags: {', '.join(tags) if tags else 'none'}\n"
|
||||||
|
f" Transcript excerpt: {(m.raw_transcript or '')[:300]}"
|
||||||
|
)
|
||||||
|
new_text = "\n\n".join(new_lines)
|
||||||
|
|
||||||
|
page_json = json.dumps(existing_page, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"<existing_page>\n{page_json}\n</existing_page>\n"
|
||||||
|
f"<existing_moments>\n{existing_text}\n</existing_moments>\n"
|
||||||
|
f"<new_moments>\n{new_text}\n</new_moments>\n"
|
||||||
|
f"<creator>{creator_name}</creator>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_compose(
|
||||||
|
existing_page_path: str,
|
||||||
|
existing_fixture_path: str,
|
||||||
|
new_fixture_path: str,
|
||||||
|
prompt_path: str,
|
||||||
|
category_filter: str | None = None,
|
||||||
|
model_override: str | None = None,
|
||||||
|
modality: str | None = None,
|
||||||
|
) -> tuple[list[dict], int]:
|
||||||
|
"""Run composition: merge new fixture moments into an existing page.
|
||||||
|
|
||||||
|
Returns (pages, exit_code) — same shape as run_synthesis().
|
||||||
|
"""
|
||||||
|
# Load existing page JSON
|
||||||
|
existing_page_file = Path(existing_page_path)
|
||||||
|
if not existing_page_file.exists():
|
||||||
|
_log("ERROR", f"Existing page not found: {existing_page_path}", level="ERROR")
|
||||||
|
return [], 3
|
||||||
|
|
||||||
|
try:
|
||||||
|
existing_raw = json.loads(existing_page_file.read_text(encoding="utf-8"))
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
_log("ERROR", f"Invalid JSON in existing page: {exc}", level="ERROR")
|
||||||
|
return [], 3
|
||||||
|
|
||||||
|
# The existing page file might be a harness output (with .pages[]) or a raw SynthesizedPage
|
||||||
|
if "pages" in existing_raw and isinstance(existing_raw["pages"], list):
|
||||||
|
page_dicts = existing_raw["pages"]
|
||||||
|
_log("COMPOSE", f"Loaded harness output with {len(page_dicts)} pages")
|
||||||
|
elif "title" in existing_raw and "body_sections" in existing_raw:
|
||||||
|
page_dicts = [existing_raw]
|
||||||
|
_log("COMPOSE", "Loaded single SynthesizedPage")
|
||||||
|
else:
|
||||||
|
_log("ERROR", "Existing page JSON must be a SynthesizedPage or harness output with 'pages' key", level="ERROR")
|
||||||
|
return [], 3
|
||||||
|
|
||||||
|
# Validate each page against SynthesizedPage
|
||||||
|
validated_pages: list[dict] = []
|
||||||
|
for pd in page_dicts:
|
||||||
|
try:
|
||||||
|
SynthesizedPage.model_validate(pd)
|
||||||
|
validated_pages.append(pd)
|
||||||
|
except ValidationError as exc:
|
||||||
|
_log("WARN", f"Skipping invalid page '{pd.get('title', '?')}': {exc}", level="WARN")
|
||||||
|
|
||||||
|
if not validated_pages:
|
||||||
|
_log("ERROR", "No valid SynthesizedPage found in existing page file", level="ERROR")
|
||||||
|
return [], 3
|
||||||
|
|
||||||
|
# Apply category filter
|
||||||
|
if category_filter:
|
||||||
|
validated_pages = [p for p in validated_pages if p.get("topic_category") == category_filter]
|
||||||
|
if not validated_pages:
|
||||||
|
_log("ERROR", f"No pages match category '{category_filter}'", level="ERROR")
|
||||||
|
return [], 3
|
||||||
|
|
||||||
|
# Load existing moments fixture (the original moments the page was built from)
|
||||||
|
try:
|
||||||
|
existing_fixture = load_fixture(existing_fixture_path)
|
||||||
|
except (FileNotFoundError, ValueError, json.JSONDecodeError) as exc:
|
||||||
|
_log("ERROR", f"Existing fixture error: {exc}", level="ERROR")
|
||||||
|
return [], 3
|
||||||
|
|
||||||
|
# Load new moments fixture
|
||||||
|
try:
|
||||||
|
new_fixture = load_fixture(new_fixture_path)
|
||||||
|
except (FileNotFoundError, ValueError, json.JSONDecodeError) as exc:
|
||||||
|
_log("ERROR", f"New fixture error: {exc}", level="ERROR")
|
||||||
|
return [], 3
|
||||||
|
|
||||||
|
# Load prompt
|
||||||
|
prompt_file = Path(prompt_path)
|
||||||
|
if not prompt_file.exists():
|
||||||
|
_log("ERROR", f"Prompt file not found: {prompt_path}", level="ERROR")
|
||||||
|
return [], 3
|
||||||
|
system_prompt = prompt_file.read_text(encoding="utf-8")
|
||||||
|
_log("PROMPT", f"Loading compose prompt: {prompt_path} ({len(system_prompt)} chars)")
|
||||||
|
|
||||||
|
# Setup LLM
|
||||||
|
settings = get_settings()
|
||||||
|
llm = LLMClient(settings)
|
||||||
|
stage_model = model_override or settings.llm_stage5_model or settings.llm_model
|
||||||
|
stage_modality = modality or settings.llm_stage5_modality or "thinking"
|
||||||
|
hard_limit = settings.llm_max_tokens_hard_limit
|
||||||
|
_log("LLM", f"Model: {stage_model}, modality: {stage_modality}, hard_limit: {hard_limit}")
|
||||||
|
|
||||||
|
all_pages: list[dict] = []
|
||||||
|
exit_code = 0
|
||||||
|
|
||||||
|
for page_idx, existing_page in enumerate(validated_pages, 1):
|
||||||
|
page_category = existing_page.get("topic_category", "Uncategorized")
|
||||||
|
page_title = existing_page.get("title", "Untitled")
|
||||||
|
_log("COMPOSE", f"Page {page_idx}/{len(validated_pages)}: '{page_title}' ({page_category})")
|
||||||
|
|
||||||
|
# Get existing moments for this page's category
|
||||||
|
existing_moments = existing_fixture.groups.get(page_category, [])
|
||||||
|
if not existing_moments:
|
||||||
|
_log("WARN", f" No existing moments found for category '{page_category}' — skipping", level="WARN")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get new moments for this page's category
|
||||||
|
new_moments = new_fixture.groups.get(page_category, [])
|
||||||
|
if not new_moments:
|
||||||
|
_log("WARN", f" No new moments for category '{page_category}' — nothing to compose", level="WARN")
|
||||||
|
all_pages.append(existing_page)
|
||||||
|
continue
|
||||||
|
|
||||||
|
n_existing = len(existing_moments)
|
||||||
|
n_new = len(new_moments)
|
||||||
|
total_moments = n_existing + n_new
|
||||||
|
|
||||||
|
# Before metrics
|
||||||
|
before_words = _count_page_words(existing_page)
|
||||||
|
before_sections = len(existing_page.get("body_sections", []))
|
||||||
|
|
||||||
|
_log(
|
||||||
|
"COMPOSE",
|
||||||
|
f" Existing: {n_existing} moments, {before_sections} sections, {before_words} words | "
|
||||||
|
f"New: {n_new} moments | Total citation space: [0]-[{total_moments - 1}]",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build compose prompt
|
||||||
|
user_prompt = build_compose_prompt(
|
||||||
|
existing_page=existing_page,
|
||||||
|
existing_moments=existing_moments,
|
||||||
|
new_moments=new_moments,
|
||||||
|
creator_name=existing_fixture.creator_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
estimated_tokens = estimate_max_tokens(
|
||||||
|
system_prompt, user_prompt,
|
||||||
|
stage="stage5_synthesis",
|
||||||
|
hard_limit=hard_limit,
|
||||||
|
)
|
||||||
|
_log("COMPOSE", f" Prompt built: {len(user_prompt)} chars, max_tokens={estimated_tokens}")
|
||||||
|
|
||||||
|
# Call LLM
|
||||||
|
call_start = time.monotonic()
|
||||||
|
_log("LLM", f" Calling: model={stage_model}, max_tokens={estimated_tokens}, modality={stage_modality}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw = llm.complete(
|
||||||
|
system_prompt,
|
||||||
|
user_prompt,
|
||||||
|
response_model=SynthesisResult,
|
||||||
|
modality=stage_modality,
|
||||||
|
model_override=stage_model,
|
||||||
|
max_tokens=estimated_tokens,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
_log("ERROR", f" LLM call failed: {exc}", level="ERROR")
|
||||||
|
exit_code = 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
call_duration_ms = int((time.monotonic() - call_start) * 1000)
|
||||||
|
prompt_tokens = getattr(raw, "prompt_tokens", None) or 0
|
||||||
|
completion_tokens = getattr(raw, "completion_tokens", None) or 0
|
||||||
|
finish_reason = getattr(raw, "finish_reason", "unknown")
|
||||||
|
|
||||||
|
_log(
|
||||||
|
"LLM",
|
||||||
|
f" Response: {prompt_tokens} prompt + {completion_tokens} completion tokens, "
|
||||||
|
f"{call_duration_ms}ms, finish_reason={finish_reason}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if finish_reason == "length":
|
||||||
|
_log("WARN", " finish_reason=length — output likely truncated!", level="WARN")
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
try:
|
||||||
|
result = SynthesisResult.model_validate_json(str(raw))
|
||||||
|
except (ValidationError, json.JSONDecodeError) as exc:
|
||||||
|
_log("ERROR", f" Parse failed: {exc}", level="ERROR")
|
||||||
|
_log("ERROR", f" Raw response (first 2000 chars): {str(raw)[:2000]}", level="ERROR")
|
||||||
|
exit_code = 2
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Log compose-specific metrics
|
||||||
|
for page in result.pages:
|
||||||
|
page_dict = page.model_dump()
|
||||||
|
after_words = _count_page_words(page_dict)
|
||||||
|
after_sections = len(page.body_sections or [])
|
||||||
|
|
||||||
|
# Identify new sections (headings not in the original)
|
||||||
|
existing_headings = {s.get("heading", "") for s in existing_page.get("body_sections", [])}
|
||||||
|
new_section_headings = [
|
||||||
|
s.heading for s in (page.body_sections or []) if s.heading not in existing_headings
|
||||||
|
]
|
||||||
|
|
||||||
|
_log(
|
||||||
|
"COMPOSE",
|
||||||
|
f" Result: '{page.title}' — "
|
||||||
|
f"words {before_words}→{after_words} ({after_words - before_words:+d}), "
|
||||||
|
f"sections {before_sections}→{after_sections} ({after_sections - before_sections:+d})"
|
||||||
|
+ (f", new sections: {new_section_headings}" if new_section_headings else ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Citation validation with unified moment count
|
||||||
|
cit = validate_citations(page.body_sections or [], total_moments)
|
||||||
|
_log(
|
||||||
|
"CITE",
|
||||||
|
f" Citations: {cit['total_citations']}/{total_moments} moments cited "
|
||||||
|
f"({cit['coverage_pct']}% coverage)"
|
||||||
|
+ (f", invalid indices: {cit['invalid_indices']}" if cit['invalid_indices'] else "")
|
||||||
|
+ (f", uncited: {cit['uncited_moments']}" if cit['uncited_moments'] else ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
all_pages.append(page_dict)
|
||||||
|
|
||||||
|
_log("SUMMARY", f"Compose complete: {len(all_pages)} pages")
|
||||||
|
return all_pages, exit_code
|
||||||
|
|
||||||
|
|
||||||
|
# ── Promote: deploy a prompt to production ─────────────────────────────────
|
||||||
|
|
||||||
|
_STAGE_PROMPT_MAP = {
|
||||||
|
2: "stage2_segmentation.txt",
|
||||||
|
3: "stage3_extraction.txt",
|
||||||
|
4: "stage4_classification.txt",
|
||||||
|
5: "stage5_synthesis.txt",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def promote_prompt(prompt_path: str, stage: int, reason: str, commit: bool = False) -> int:
|
||||||
|
"""Copy a winning prompt to the canonical path and create a backup.
|
||||||
|
|
||||||
|
The worker reads prompts from disk at runtime — no restart needed.
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
if stage not in _STAGE_PROMPT_MAP:
|
||||||
|
_log("ERROR", f"Invalid stage {stage}. Valid: {sorted(_STAGE_PROMPT_MAP)}", level="ERROR")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
template_name = _STAGE_PROMPT_MAP[stage]
|
||||||
|
canonical = Path(settings.prompts_path) / template_name
|
||||||
|
source = Path(prompt_path)
|
||||||
|
|
||||||
|
if not source.exists():
|
||||||
|
_log("ERROR", f"Source prompt not found: {prompt_path}", level="ERROR")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
new_prompt = source.read_text(encoding="utf-8")
|
||||||
|
new_hash = hashlib.sha256(new_prompt.encode()).hexdigest()[:12]
|
||||||
|
|
||||||
|
# Backup current prompt
|
||||||
|
old_prompt = ""
|
||||||
|
old_hash = "none"
|
||||||
|
if canonical.exists():
|
||||||
|
old_prompt = canonical.read_text(encoding="utf-8")
|
||||||
|
old_hash = hashlib.sha256(old_prompt.encode()).hexdigest()[:12]
|
||||||
|
|
||||||
|
if old_prompt.strip() == new_prompt.strip():
|
||||||
|
_log("PROMOTE", "No change — new prompt is identical to current prompt")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
archive_dir = Path(settings.prompts_path) / "archive"
|
||||||
|
archive_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
ts = time.strftime("%Y%m%d_%H%M%S", time.gmtime())
|
||||||
|
backup = archive_dir / f"{template_name.replace('.txt', '')}_{ts}.txt"
|
||||||
|
shutil.copy2(canonical, backup)
|
||||||
|
_log("PROMOTE", f"Backed up current prompt: {old_hash} -> {backup}")
|
||||||
|
|
||||||
|
# Write new prompt
|
||||||
|
canonical.write_text(new_prompt, encoding="utf-8")
|
||||||
|
|
||||||
|
old_lines = old_prompt.strip().splitlines()
|
||||||
|
new_lines = new_prompt.strip().splitlines()
|
||||||
|
_log("PROMOTE", f"Installed new prompt: {new_hash} ({len(new_prompt)} chars, {len(new_lines)} lines)")
|
||||||
|
_log("PROMOTE", f"Previous: {old_hash} ({len(old_prompt)} chars, {len(old_lines)} lines)")
|
||||||
|
_log("PROMOTE", f"Reason: {reason}")
|
||||||
|
_log("PROMOTE", "Worker reads prompts from disk at runtime — no restart needed")
|
||||||
|
|
||||||
|
if commit:
|
||||||
|
import subprocess
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
["git", "add", str(canonical)],
|
||||||
|
cwd=str(canonical.parent.parent),
|
||||||
|
check=True, capture_output=True,
|
||||||
|
)
|
||||||
|
msg = f"prompt: promote stage{stage} — {reason}"
|
||||||
|
subprocess.run(
|
||||||
|
["git", "commit", "-m", msg],
|
||||||
|
cwd=str(canonical.parent.parent),
|
||||||
|
check=True, capture_output=True,
|
||||||
|
)
|
||||||
|
_log("PROMOTE", f"Git commit created: {msg}")
|
||||||
|
except subprocess.CalledProcessError as exc:
|
||||||
|
_log("PROMOTE", f"Git commit failed: {exc}", level="WARN")
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── CLI ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="pipeline.test_harness",
|
||||||
|
description="Offline prompt test harness for Chrysopedia synthesis",
|
||||||
|
)
|
||||||
|
sub = parser.add_subparsers(dest="command")
|
||||||
|
|
||||||
|
# -- run subcommand (default behavior) --
|
||||||
|
run_parser = sub.add_parser("run", help="Run synthesis against a fixture")
|
||||||
|
run_parser.add_argument("--fixture", "-f", type=str, required=True, help="Fixture JSON file")
|
||||||
|
run_parser.add_argument("--prompt", "-p", type=str, default=None, help="Prompt file (default: stage5_synthesis.txt)")
|
||||||
|
run_parser.add_argument("--output", "-o", type=str, default=None, help="Output file path")
|
||||||
|
run_parser.add_argument("--category", "-c", type=str, default=None, help="Filter to a specific category")
|
||||||
|
run_parser.add_argument("--model", type=str, default=None, help="Override LLM model")
|
||||||
|
run_parser.add_argument("--modality", type=str, default=None, choices=["chat", "thinking"])
|
||||||
|
|
||||||
|
# -- promote subcommand --
|
||||||
|
promo_parser = sub.add_parser("promote", help="Deploy a winning prompt to production")
|
||||||
|
promo_parser.add_argument("--prompt", "-p", type=str, required=True, help="Path to the winning prompt file")
|
||||||
|
promo_parser.add_argument("--stage", "-s", type=int, default=5, help="Stage number (default: 5)")
|
||||||
|
promo_parser.add_argument("--reason", "-r", type=str, required=True, help="Why this prompt is being promoted")
|
||||||
|
promo_parser.add_argument("--commit", action="store_true", help="Also create a git commit")
|
||||||
|
|
||||||
|
# -- compose subcommand --
|
||||||
|
compose_parser = sub.add_parser("compose", help="Merge new moments into an existing page")
|
||||||
|
compose_parser.add_argument("--existing-page", type=str, required=True, help="Existing page JSON (harness output or raw SynthesizedPage)")
|
||||||
|
compose_parser.add_argument("--fixture", "-f", type=str, required=True, help="New moments fixture JSON")
|
||||||
|
compose_parser.add_argument("--existing-fixture", type=str, required=True, help="Original moments fixture JSON (for citation context)")
|
||||||
|
compose_parser.add_argument("--prompt", "-p", type=str, default=None, help="Compose prompt file (default: stage5_compose.txt)")
|
||||||
|
compose_parser.add_argument("--output", "-o", type=str, default=None, help="Output file path")
|
||||||
|
compose_parser.add_argument("--category", "-c", type=str, default=None, help="Filter to a specific category")
|
||||||
|
compose_parser.add_argument("--model", type=str, default=None, help="Override LLM model")
|
||||||
|
compose_parser.add_argument("--modality", type=str, default=None, choices=["chat", "thinking"])
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# If no subcommand, check for --fixture for backward compat
|
||||||
|
if args.command is None:
|
||||||
|
# Support running without subcommand for backward compat
|
||||||
|
parser.print_help()
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if args.command == "promote":
|
||||||
|
return promote_prompt(args.prompt, args.stage, args.reason, args.commit)
|
||||||
|
|
||||||
|
if args.command == "compose":
|
||||||
|
# Resolve default compose prompt
|
||||||
|
prompt_path = args.prompt
|
||||||
|
if prompt_path is None:
|
||||||
|
settings = get_settings()
|
||||||
|
prompt_path = str(Path(settings.prompts_path) / "stage5_compose.txt")
|
||||||
|
|
||||||
|
overall_start = time.monotonic()
|
||||||
|
pages, exit_code = run_compose(
|
||||||
|
existing_page_path=args.existing_page,
|
||||||
|
existing_fixture_path=args.existing_fixture,
|
||||||
|
new_fixture_path=args.fixture,
|
||||||
|
prompt_path=prompt_path,
|
||||||
|
category_filter=args.category,
|
||||||
|
model_override=args.model,
|
||||||
|
modality=args.modality,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not pages and exit_code != 0:
|
||||||
|
return exit_code
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"existing_page_source": args.existing_page,
|
||||||
|
"existing_fixture_source": args.existing_fixture,
|
||||||
|
"new_fixture_source": args.fixture,
|
||||||
|
"prompt_source": prompt_path,
|
||||||
|
"category_filter": args.category,
|
||||||
|
"pages": pages,
|
||||||
|
"metadata": {
|
||||||
|
"page_count": len(pages),
|
||||||
|
"total_words": sum(_count_page_words(p) for p in pages),
|
||||||
|
"elapsed_seconds": round(time.monotonic() - overall_start, 1),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
output_json = json.dumps(output, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
Path(args.output).write_text(output_json, encoding="utf-8")
|
||||||
|
_log("OUTPUT", f"Written to: {args.output} ({len(output_json) / 1024:.1f} KB)")
|
||||||
|
else:
|
||||||
|
print(output_json)
|
||||||
|
_log("OUTPUT", f"Printed to stdout ({len(output_json) / 1024:.1f} KB)")
|
||||||
|
|
||||||
|
total_elapsed = time.monotonic() - overall_start
|
||||||
|
_log("DONE", f"Compose completed in {total_elapsed:.1f}s (exit_code={exit_code})")
|
||||||
|
return exit_code
|
||||||
|
|
||||||
|
# -- run command --
|
||||||
|
prompt_path = args.prompt
|
||||||
|
if prompt_path is None:
|
||||||
|
settings = get_settings()
|
||||||
|
prompt_path = str(Path(settings.prompts_path) / "stage5_synthesis.txt")
|
||||||
|
|
||||||
|
overall_start = time.monotonic()
|
||||||
|
try:
|
||||||
|
fixture = load_fixture(args.fixture)
|
||||||
|
except (FileNotFoundError, ValueError, json.JSONDecodeError) as exc:
|
||||||
|
_log("ERROR", f"Fixture error: {exc}", level="ERROR")
|
||||||
|
return 3
|
||||||
|
|
||||||
|
pages, exit_code = run_synthesis(
|
||||||
|
fixture=fixture,
|
||||||
|
prompt_path=prompt_path,
|
||||||
|
category_filter=args.category,
|
||||||
|
model_override=args.model,
|
||||||
|
modality=args.modality,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not pages and exit_code != 0:
|
||||||
|
return exit_code
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"fixture_source": args.fixture,
|
||||||
|
"prompt_source": prompt_path,
|
||||||
|
"creator_name": fixture.creator_name,
|
||||||
|
"video_id": fixture.video_id,
|
||||||
|
"category_filter": args.category,
|
||||||
|
"pages": pages,
|
||||||
|
"metadata": {
|
||||||
|
"page_count": len(pages),
|
||||||
|
"total_words": sum(
|
||||||
|
sum(
|
||||||
|
len(s.get("content", "").split())
|
||||||
|
+ sum(len(sub.get("content", "").split()) for sub in s.get("subsections", []))
|
||||||
|
for s in p.get("body_sections", [])
|
||||||
|
)
|
||||||
|
for p in pages
|
||||||
|
),
|
||||||
|
"elapsed_seconds": round(time.monotonic() - overall_start, 1),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
output_json = json.dumps(output, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
Path(args.output).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
Path(args.output).write_text(output_json, encoding="utf-8")
|
||||||
|
_log("OUTPUT", f"Written to: {args.output} ({len(output_json) / 1024:.1f} KB)")
|
||||||
|
else:
|
||||||
|
print(output_json)
|
||||||
|
_log("OUTPUT", f"Printed to stdout ({len(output_json) / 1024:.1f} KB)")
|
||||||
|
|
||||||
|
total_elapsed = time.monotonic() - overall_start
|
||||||
|
_log("DONE", f"Completed in {total_elapsed:.1f}s (exit_code={exit_code})")
|
||||||
|
|
||||||
|
return exit_code
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
389
backend/pipeline/test_harness_compose.py
Normal file
389
backend/pipeline/test_harness_compose.py
Normal file
|
|
@ -0,0 +1,389 @@
|
||||||
|
"""Tests for compose-mode prompt building and validation.
|
||||||
|
|
||||||
|
Covers prompt construction, citation re-indexing math, category filtering,
|
||||||
|
and edge cases — no LLM calls required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pipeline.citation_utils import validate_citations
|
||||||
|
from pipeline.schemas import BodySection, BodySubSection, SynthesizedPage
|
||||||
|
from pipeline.test_harness import (
|
||||||
|
MockKeyMoment,
|
||||||
|
_MockContentType,
|
||||||
|
build_compose_prompt,
|
||||||
|
build_moments_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures / helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _moment(
|
||||||
|
title: str = "Test Moment",
|
||||||
|
summary: str = "A moment.",
|
||||||
|
content_type: str = "technique_demo",
|
||||||
|
start_time: float = 0.0,
|
||||||
|
end_time: float = 10.0,
|
||||||
|
plugins: list[str] | None = None,
|
||||||
|
raw_transcript: str | None = "Some transcript text",
|
||||||
|
) -> MockKeyMoment:
|
||||||
|
return MockKeyMoment(
|
||||||
|
title=title,
|
||||||
|
summary=summary,
|
||||||
|
content_type=_MockContentType(content_type),
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
plugins=plugins or [],
|
||||||
|
raw_transcript=raw_transcript or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cls_info(
|
||||||
|
category: str = "Sound Design",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"topic_category": category,
|
||||||
|
"topic_tags": tags or ["reverb", "delay"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_page(
|
||||||
|
title: str = "Reverb Techniques",
|
||||||
|
slug: str = "reverb-techniques",
|
||||||
|
category: str = "Sound Design",
|
||||||
|
sections: list[BodySection] | None = None,
|
||||||
|
moment_indices: list[int] | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Build a SynthesizedPage dict (as it would appear in harness output)."""
|
||||||
|
if sections is None:
|
||||||
|
sections = [
|
||||||
|
BodySection(
|
||||||
|
heading="Overview",
|
||||||
|
content="Reverb is essential [0]. Basics of space [1].",
|
||||||
|
subsections=[
|
||||||
|
BodySubSection(
|
||||||
|
heading="Room Types",
|
||||||
|
content="Rooms vary in character [2].",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
page = SynthesizedPage(
|
||||||
|
title=title,
|
||||||
|
slug=slug,
|
||||||
|
topic_category=category,
|
||||||
|
summary="A page about reverb.",
|
||||||
|
body_sections=sections,
|
||||||
|
moment_indices=moment_indices or [0, 1, 2],
|
||||||
|
)
|
||||||
|
return json.loads(page.model_dump_json())
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestBuildComposePrompt ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildComposePrompt:
|
||||||
|
"""Verify prompt construction for compose mode."""
|
||||||
|
|
||||||
|
def test_prompt_contains_xml_tags(self):
|
||||||
|
"""Existing page + 3 old + 2 new → prompt has all required XML tags."""
|
||||||
|
existing_moments = [(_moment(title=f"Old {i}"), _cls_info()) for i in range(3)]
|
||||||
|
new_moments = [(_moment(title=f"New {i}"), _cls_info()) for i in range(2)]
|
||||||
|
page = _make_page(moment_indices=[0, 1, 2])
|
||||||
|
|
||||||
|
prompt = build_compose_prompt(
|
||||||
|
existing_page=page,
|
||||||
|
existing_moments=existing_moments,
|
||||||
|
new_moments=new_moments,
|
||||||
|
creator_name="DJ Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "<existing_page>" in prompt
|
||||||
|
assert "</existing_page>" in prompt
|
||||||
|
assert "<existing_moments>" in prompt
|
||||||
|
assert "</existing_moments>" in prompt
|
||||||
|
assert "<new_moments>" in prompt
|
||||||
|
assert "</new_moments>" in prompt
|
||||||
|
assert "<creator>" in prompt
|
||||||
|
assert "</creator>" in prompt
|
||||||
|
|
||||||
|
def test_old_moments_indexed_0_to_n(self):
|
||||||
|
"""3 old moments are indexed [0], [1], [2]."""
|
||||||
|
existing_moments = [(_moment(title=f"Old {i}"), _cls_info()) for i in range(3)]
|
||||||
|
new_moments = [(_moment(title=f"New {i}"), _cls_info()) for i in range(2)]
|
||||||
|
page = _make_page()
|
||||||
|
|
||||||
|
prompt = build_compose_prompt(
|
||||||
|
existing_page=page,
|
||||||
|
existing_moments=existing_moments,
|
||||||
|
new_moments=new_moments,
|
||||||
|
creator_name="DJ Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Old moments section uses [0], [1], [2]
|
||||||
|
existing_block = prompt.split("<existing_moments>")[1].split("</existing_moments>")[0]
|
||||||
|
assert "[0] Title:" in existing_block
|
||||||
|
assert "[1] Title:" in existing_block
|
||||||
|
assert "[2] Title:" in existing_block
|
||||||
|
|
||||||
|
def test_new_moments_indexed_n_to_n_plus_m(self):
|
||||||
|
"""2 new moments after 3 old → indexed [3] and [4]."""
|
||||||
|
existing_moments = [(_moment(title=f"Old {i}"), _cls_info()) for i in range(3)]
|
||||||
|
new_moments = [(_moment(title=f"New {i}"), _cls_info()) for i in range(2)]
|
||||||
|
page = _make_page()
|
||||||
|
|
||||||
|
prompt = build_compose_prompt(
|
||||||
|
existing_page=page,
|
||||||
|
existing_moments=existing_moments,
|
||||||
|
new_moments=new_moments,
|
||||||
|
creator_name="DJ Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_block = prompt.split("<new_moments>")[1].split("</new_moments>")[0]
|
||||||
|
assert "[3] Title:" in new_block
|
||||||
|
assert "[4] Title:" in new_block
|
||||||
|
# Should NOT contain [0]-[2] in new moments block
|
||||||
|
assert "[0] Title:" not in new_block
|
||||||
|
|
||||||
|
def test_creator_name_in_prompt(self):
|
||||||
|
page = _make_page()
|
||||||
|
prompt = build_compose_prompt(
|
||||||
|
existing_page=page,
|
||||||
|
existing_moments=[(_moment(), _cls_info())],
|
||||||
|
new_moments=[(_moment(), _cls_info())],
|
||||||
|
creator_name="Keota",
|
||||||
|
)
|
||||||
|
assert "<creator>Keota</creator>" in prompt
|
||||||
|
|
||||||
|
def test_existing_page_json_valid(self):
|
||||||
|
"""Existing page JSON in the prompt is valid and parseable."""
|
||||||
|
page = _make_page()
|
||||||
|
prompt = build_compose_prompt(
|
||||||
|
existing_page=page,
|
||||||
|
existing_moments=[(_moment(), _cls_info())],
|
||||||
|
new_moments=[(_moment(), _cls_info())],
|
||||||
|
creator_name="Test",
|
||||||
|
)
|
||||||
|
page_block = prompt.split("<existing_page>")[1].split("</existing_page>")[0].strip()
|
||||||
|
parsed = json.loads(page_block)
|
||||||
|
assert parsed["title"] == "Reverb Techniques"
|
||||||
|
assert parsed["slug"] == "reverb-techniques"
|
||||||
|
|
||||||
|
def test_moment_format_matches_build_moments_text(self):
|
||||||
|
"""Existing moments format matches build_moments_text output."""
|
||||||
|
moments = [
|
||||||
|
(_moment(title="Delay Basics", plugins=["Valhalla"]), _cls_info(tags=["delay"])),
|
||||||
|
]
|
||||||
|
page = _make_page()
|
||||||
|
|
||||||
|
prompt = build_compose_prompt(
|
||||||
|
existing_page=page,
|
||||||
|
existing_moments=moments,
|
||||||
|
new_moments=[(_moment(), _cls_info())],
|
||||||
|
creator_name="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
# build_moments_text produces the same format for existing moments
|
||||||
|
expected_text, _ = build_moments_text(moments, "Sound Design")
|
||||||
|
existing_block = prompt.split("<existing_moments>")[1].split("</existing_moments>")[0].strip()
|
||||||
|
assert expected_text.strip() == existing_block
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestCitationReindexing ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCitationReindexing:
|
||||||
|
"""Verify citation index math for compose mode."""
|
||||||
|
|
||||||
|
def test_5_old_3_new_valid_range(self):
|
||||||
|
"""5 old + 3 new → valid range is [0]-[7], moment_count=8."""
|
||||||
|
# Build content that references all 8 indices
|
||||||
|
sections = [
|
||||||
|
BodySection(
|
||||||
|
heading="Section",
|
||||||
|
content="Refs [0] [1] [2] [3] [4] [5] [6] [7].",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = validate_citations(sections, moment_count=8)
|
||||||
|
assert result["valid"] is True
|
||||||
|
assert result["total_citations"] == 8
|
||||||
|
assert result["invalid_indices"] == []
|
||||||
|
|
||||||
|
def test_accepts_citations_in_valid_range(self):
|
||||||
|
"""validate_citations with moment_count=8 accepts [0]-[7]."""
|
||||||
|
sections = [
|
||||||
|
BodySection(
|
||||||
|
heading="S1",
|
||||||
|
content="See [0] and [3] and [7].",
|
||||||
|
subsections=[
|
||||||
|
BodySubSection(heading="Sub", content="Also [1] [2] [4] [5] [6].")
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = validate_citations(sections, moment_count=8)
|
||||||
|
assert result["valid"] is True
|
||||||
|
assert result["invalid_indices"] == []
|
||||||
|
|
||||||
|
def test_rejects_out_of_range_citation(self):
|
||||||
|
"""validate_citations with moment_count=8 rejects [8]."""
|
||||||
|
sections = [
|
||||||
|
BodySection(
|
||||||
|
heading="S1",
|
||||||
|
content="Bad ref [8] and valid [0].",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = validate_citations(sections, moment_count=8)
|
||||||
|
assert result["valid"] is False
|
||||||
|
assert 8 in result["invalid_indices"]
|
||||||
|
|
||||||
|
def test_compose_offset_arithmetic(self):
|
||||||
|
"""Verify the offset math: N existing → new moments start at [N]."""
|
||||||
|
n_existing = 5
|
||||||
|
n_new = 3
|
||||||
|
existing = [(_moment(title=f"E{i}"), _cls_info()) for i in range(n_existing)]
|
||||||
|
new = [(_moment(title=f"N{i}"), _cls_info()) for i in range(n_new)]
|
||||||
|
page = _make_page()
|
||||||
|
|
||||||
|
prompt = build_compose_prompt(
|
||||||
|
existing_page=page,
|
||||||
|
existing_moments=existing,
|
||||||
|
new_moments=new,
|
||||||
|
creator_name="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_block = prompt.split("<new_moments>")[1].split("</new_moments>")[0]
|
||||||
|
# First new moment should be [5], last should be [7]
|
||||||
|
assert "[5] Title:" in new_block
|
||||||
|
assert "[6] Title:" in new_block
|
||||||
|
assert "[7] Title:" in new_block
|
||||||
|
assert "[4] Title:" not in new_block # last old moment, not in new block
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestCategoryFiltering ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCategoryFiltering:
|
||||||
|
"""Verify that compose filters moments by category to match existing page."""
|
||||||
|
|
||||||
|
def test_only_matching_category_moments_used(self):
|
||||||
|
"""Moments from category B are excluded when composing a category A page."""
|
||||||
|
page = _make_page(category="Sound Design")
|
||||||
|
existing = [(_moment(title="E0"), _cls_info(category="Sound Design"))]
|
||||||
|
|
||||||
|
# Mix of matching and non-matching new moments
|
||||||
|
new_sound = [(_moment(title="New SD"), _cls_info(category="Sound Design"))]
|
||||||
|
new_mixing = [(_moment(title="New Mix"), _cls_info(category="Mixing"))]
|
||||||
|
|
||||||
|
# build_compose_prompt doesn't filter by category — that's run_compose's job.
|
||||||
|
# But we can verify the prompt only contains what we pass in.
|
||||||
|
prompt = build_compose_prompt(
|
||||||
|
existing_page=page,
|
||||||
|
existing_moments=existing,
|
||||||
|
new_moments=new_sound, # Only Sound Design moments
|
||||||
|
creator_name="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_block = prompt.split("<new_moments>")[1].split("</new_moments>")[0]
|
||||||
|
assert "New SD" in new_block
|
||||||
|
assert "New Mix" not in new_block
|
||||||
|
|
||||||
|
def test_category_from_page_used_in_moments_text(self):
|
||||||
|
"""The page's topic_category is used in the moment formatting."""
|
||||||
|
page = _make_page(category="Mixing")
|
||||||
|
existing = [(_moment(), _cls_info(category="Mixing"))]
|
||||||
|
new = [(_moment(), _cls_info(category="Mixing"))]
|
||||||
|
|
||||||
|
prompt = build_compose_prompt(
|
||||||
|
existing_page=page,
|
||||||
|
existing_moments=existing,
|
||||||
|
new_moments=new,
|
||||||
|
creator_name="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
# The category in the formatted moments comes from the page's topic_category
|
||||||
|
assert "Category: Mixing" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestEdgeCases ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Edge cases for compose prompt construction."""
|
||||||
|
|
||||||
|
def test_empty_new_moments(self):
|
||||||
|
"""Empty new moments → prompt still valid with empty new_moments block."""
|
||||||
|
page = _make_page()
|
||||||
|
existing = [(_moment(), _cls_info())]
|
||||||
|
|
||||||
|
prompt = build_compose_prompt(
|
||||||
|
existing_page=page,
|
||||||
|
existing_moments=existing,
|
||||||
|
new_moments=[],
|
||||||
|
creator_name="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "<new_moments>" in prompt
|
||||||
|
assert "</new_moments>" in prompt
|
||||||
|
# Existing moments still present
|
||||||
|
assert "[0] Title:" in prompt
|
||||||
|
|
||||||
|
def test_single_new_moment_at_offset_n(self):
|
||||||
|
"""Single new moment after 2 existing → indexed [2]."""
|
||||||
|
existing = [(_moment(title=f"E{i}"), _cls_info()) for i in range(2)]
|
||||||
|
new = [(_moment(title="Single New"), _cls_info())]
|
||||||
|
page = _make_page()
|
||||||
|
|
||||||
|
prompt = build_compose_prompt(
|
||||||
|
existing_page=page,
|
||||||
|
existing_moments=existing,
|
||||||
|
new_moments=new,
|
||||||
|
creator_name="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_block = prompt.split("<new_moments>")[1].split("</new_moments>")[0]
|
||||||
|
assert "[2] Title: Single New" in new_block
|
||||||
|
|
||||||
|
def test_existing_page_no_subsections(self):
|
||||||
|
"""Page with sections but no subsections → handled correctly."""
|
||||||
|
sections = [
|
||||||
|
BodySection(heading="Flat Section", content="Content [0]."),
|
||||||
|
]
|
||||||
|
page = _make_page(sections=sections, moment_indices=[0])
|
||||||
|
existing = [(_moment(), _cls_info())]
|
||||||
|
new = [(_moment(title="New One"), _cls_info())]
|
||||||
|
|
||||||
|
prompt = build_compose_prompt(
|
||||||
|
existing_page=page,
|
||||||
|
existing_moments=existing,
|
||||||
|
new_moments=new,
|
||||||
|
creator_name="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
page_block = prompt.split("<existing_page>")[1].split("</existing_page>")[0].strip()
|
||||||
|
parsed = json.loads(page_block)
|
||||||
|
assert len(parsed["body_sections"]) == 1
|
||||||
|
assert parsed["body_sections"][0]["subsections"] == []
|
||||||
|
|
||||||
|
def test_large_offset_indices(self):
|
||||||
|
"""10 existing + 5 new → new moments indexed [10]-[14]."""
|
||||||
|
existing = [(_moment(title=f"E{i}"), _cls_info()) for i in range(10)]
|
||||||
|
new = [(_moment(title=f"N{i}"), _cls_info()) for i in range(5)]
|
||||||
|
page = _make_page()
|
||||||
|
|
||||||
|
prompt = build_compose_prompt(
|
||||||
|
existing_page=page,
|
||||||
|
existing_moments=existing,
|
||||||
|
new_moments=new,
|
||||||
|
creator_name="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_block = prompt.split("<new_moments>")[1].split("</new_moments>")[0]
|
||||||
|
assert "[10] Title:" in new_block
|
||||||
|
assert "[14] Title:" in new_block
|
||||||
|
assert "[9] Title:" not in new_block # last existing, not in new block
|
||||||
213
backend/pipeline/test_harness_v2_format.py
Normal file
213
backend/pipeline/test_harness_v2_format.py
Normal file
|
|
@ -0,0 +1,213 @@
|
||||||
|
"""Tests for test_harness compatibility with v2 body_sections format.
|
||||||
|
|
||||||
|
Validates that word-counting and citation integration work correctly
|
||||||
|
with the list[BodySection] structure (v2) instead of the old dict format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pipeline.citation_utils import validate_citations
|
||||||
|
from pipeline.schemas import BodySection, BodySubSection, SynthesizedPage, SynthesisResult
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_page(
|
||||||
|
body_sections: list[BodySection],
|
||||||
|
moment_indices: list[int] | None = None,
|
||||||
|
title: str = "Test Page",
|
||||||
|
slug: str = "test-page",
|
||||||
|
) -> SynthesizedPage:
|
||||||
|
return SynthesizedPage(
|
||||||
|
title=title,
|
||||||
|
slug=slug,
|
||||||
|
topic_category="Testing",
|
||||||
|
summary="A test page.",
|
||||||
|
body_sections=body_sections,
|
||||||
|
moment_indices=moment_indices or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_words_v2(sections: list[BodySection]) -> int:
|
||||||
|
"""Replicate the word-counting logic from the updated test_harness."""
|
||||||
|
return sum(
|
||||||
|
len(s.content.split()) + sum(len(sub.content.split()) for sub in s.subsections)
|
||||||
|
for s in sections
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_words_metadata(pages_dicts: list[dict]) -> int:
|
||||||
|
"""Replicate the metadata total_words logic (operates on dicts after model_dump)."""
|
||||||
|
return sum(
|
||||||
|
sum(
|
||||||
|
len(s.get("content", "").split())
|
||||||
|
+ sum(len(sub.get("content", "").split()) for sub in s.get("subsections", []))
|
||||||
|
for s in p.get("body_sections", [])
|
||||||
|
)
|
||||||
|
for p in pages_dicts
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Word counting tests ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestWordCounting:
|
||||||
|
def test_flat_sections_no_subsections(self):
|
||||||
|
sections = [
|
||||||
|
BodySection(heading="Intro", content="one two three"),
|
||||||
|
BodySection(heading="Details", content="four five"),
|
||||||
|
]
|
||||||
|
assert _count_words_v2(sections) == 5
|
||||||
|
|
||||||
|
def test_sections_with_subsections(self):
|
||||||
|
sections = [
|
||||||
|
BodySection(
|
||||||
|
heading="Main",
|
||||||
|
content="alpha beta", # 2 words
|
||||||
|
subsections=[
|
||||||
|
BodySubSection(heading="Sub A", content="gamma delta epsilon"), # 3 words
|
||||||
|
BodySubSection(heading="Sub B", content="zeta"), # 1 word
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert _count_words_v2(sections) == 6
|
||||||
|
|
||||||
|
def test_empty_sections_list(self):
|
||||||
|
assert _count_words_v2([]) == 0
|
||||||
|
|
||||||
|
def test_section_with_empty_content(self):
|
||||||
|
sections = [
|
||||||
|
BodySection(heading="Empty", content=""),
|
||||||
|
]
|
||||||
|
# "".split() returns [], len([]) == 0
|
||||||
|
assert _count_words_v2(sections) == 0
|
||||||
|
|
||||||
|
def test_metadata_word_count_matches(self):
|
||||||
|
"""Metadata total_words (from model_dump dicts) matches Pydantic object counting."""
|
||||||
|
sections = [
|
||||||
|
BodySection(
|
||||||
|
heading="H2",
|
||||||
|
content="one two three",
|
||||||
|
subsections=[
|
||||||
|
BodySubSection(heading="H3", content="four five six seven"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
BodySection(heading="Another", content="eight nine"),
|
||||||
|
]
|
||||||
|
page = _make_page(sections, moment_indices=[0, 1])
|
||||||
|
pages_dicts = [page.model_dump()]
|
||||||
|
|
||||||
|
assert _count_words_v2(sections) == 9
|
||||||
|
assert _count_words_metadata(pages_dicts) == 9
|
||||||
|
|
||||||
|
|
||||||
|
# ── Section/subsection counting ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSectionCounting:
|
||||||
|
def test_section_and_subsection_counts(self):
|
||||||
|
sections = [
|
||||||
|
BodySection(heading="A", content="text", subsections=[
|
||||||
|
BodySubSection(heading="A.1", content="sub text"),
|
||||||
|
]),
|
||||||
|
BodySection(heading="B", content="more text"),
|
||||||
|
BodySection(heading="C", content="even more", subsections=[
|
||||||
|
BodySubSection(heading="C.1", content="sub1"),
|
||||||
|
BodySubSection(heading="C.2", content="sub2"),
|
||||||
|
]),
|
||||||
|
]
|
||||||
|
section_count = len(sections)
|
||||||
|
subsection_count = sum(len(s.subsections) for s in sections)
|
||||||
|
assert section_count == 3
|
||||||
|
assert subsection_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ── Citation integration ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCitationIntegration:
|
||||||
|
def test_full_coverage(self):
|
||||||
|
sections = [
|
||||||
|
BodySection(heading="Intro", content="First point [0]. Second point [1]."),
|
||||||
|
BodySection(heading="Details", content="More on [0] and [2]."),
|
||||||
|
]
|
||||||
|
result = validate_citations(sections, moment_count=3)
|
||||||
|
assert result["valid"] is True
|
||||||
|
assert result["coverage_pct"] == 100.0
|
||||||
|
assert result["invalid_indices"] == []
|
||||||
|
assert result["uncited_moments"] == []
|
||||||
|
|
||||||
|
def test_partial_coverage(self):
|
||||||
|
sections = [
|
||||||
|
BodySection(heading="Intro", content="Only cites [0]."),
|
||||||
|
]
|
||||||
|
result = validate_citations(sections, moment_count=3)
|
||||||
|
assert result["valid"] is False
|
||||||
|
assert result["coverage_pct"] == pytest.approx(33.3, abs=0.1)
|
||||||
|
assert result["uncited_moments"] == [1, 2]
|
||||||
|
|
||||||
|
def test_invalid_index(self):
|
||||||
|
sections = [
|
||||||
|
BodySection(heading="Bad", content="Cites [0] and [99]."),
|
||||||
|
]
|
||||||
|
result = validate_citations(sections, moment_count=2)
|
||||||
|
assert result["invalid_indices"] == [99]
|
||||||
|
|
||||||
|
def test_citations_in_subsections(self):
|
||||||
|
sections = [
|
||||||
|
BodySection(
|
||||||
|
heading="Main",
|
||||||
|
content="See [0].",
|
||||||
|
subsections=[
|
||||||
|
BodySubSection(heading="Sub", content="Also [1] and [2]."),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = validate_citations(sections, moment_count=3)
|
||||||
|
assert result["valid"] is True
|
||||||
|
assert result["total_citations"] == 3
|
||||||
|
|
||||||
|
def test_multi_citation_markers(self):
|
||||||
|
sections = [
|
||||||
|
BodySection(heading="X", content="Both sources agree [0,1]."),
|
||||||
|
]
|
||||||
|
result = validate_citations(sections, moment_count=2)
|
||||||
|
assert result["valid"] is True
|
||||||
|
assert result["total_citations"] == 2
|
||||||
|
|
||||||
|
def test_no_sections(self):
|
||||||
|
result = validate_citations([], moment_count=0)
|
||||||
|
assert result["valid"] is True
|
||||||
|
assert result["coverage_pct"] == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── End-to-end: SynthesisResult with v2 body_sections ───────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSynthesisResultV2:
|
||||||
|
def test_round_trip_model_dump(self):
|
||||||
|
"""SynthesisResult with v2 body_sections round-trips through model_dump/validate."""
|
||||||
|
sections = [
|
||||||
|
BodySection(
|
||||||
|
heading="Overview",
|
||||||
|
content="This technique [0] is fundamental.",
|
||||||
|
subsections=[
|
||||||
|
BodySubSection(heading="Key Concept", content="Detail [1]."),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
page = _make_page(sections, moment_indices=[0, 1])
|
||||||
|
result = SynthesisResult(pages=[page])
|
||||||
|
|
||||||
|
dumped = result.model_dump()
|
||||||
|
restored = SynthesisResult.model_validate(dumped)
|
||||||
|
|
||||||
|
assert len(restored.pages) == 1
|
||||||
|
restored_page = restored.pages[0]
|
||||||
|
assert len(restored_page.body_sections) == 1
|
||||||
|
assert restored_page.body_sections[0].heading == "Overview"
|
||||||
|
assert len(restored_page.body_sections[0].subsections) == 1
|
||||||
|
assert restored_page.body_sections_format == "v2"
|
||||||
521
backend/pipeline/test_highlight_scorer.py
Normal file
521
backend/pipeline/test_highlight_scorer.py
Normal file
|
|
@ -0,0 +1,521 @@
|
||||||
|
"""Tests for the highlight scoring engine.
|
||||||
|
|
||||||
|
Verifies heuristic scoring produces sensible orderings and handles
|
||||||
|
edge cases gracefully.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.pipeline.highlight_scorer import (
|
||||||
|
_content_type_weight,
|
||||||
|
_duration_fitness,
|
||||||
|
_pause_density,
|
||||||
|
_plugin_richness,
|
||||||
|
_source_quality_weight,
|
||||||
|
_speaking_pace_fitness,
|
||||||
|
_specificity_density,
|
||||||
|
_speech_rate_variance,
|
||||||
|
_transcript_energy,
|
||||||
|
_video_type_weight,
|
||||||
|
extract_word_timings,
|
||||||
|
score_moment,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixture helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _ideal_moment() -> dict:
|
||||||
|
"""45s technique moment, 3 plugins, specific summary, structured source."""
|
||||||
|
return dict(
|
||||||
|
start_time=10.0,
|
||||||
|
end_time=55.0, # 45s duration
|
||||||
|
content_type="technique",
|
||||||
|
summary=(
|
||||||
|
"Set the compressor threshold to -18 dB with a 4:1 ratio, "
|
||||||
|
"then boost the high shelf at 12 kHz by 3.5 dB using FabFilter Pro-Q 3."
|
||||||
|
),
|
||||||
|
plugins=["FabFilter Pro-Q 3", "SSL G-Bus Compressor", "Valhalla Room"],
|
||||||
|
raw_transcript=(
|
||||||
|
"The trick is to set the threshold low enough. Notice how "
|
||||||
|
"the compressor grabs the transients. Because we want to preserve "
|
||||||
|
"the dynamics, I always back off the ratio. The key is finding "
|
||||||
|
"that sweet spot where it's controlling but not squashing."
|
||||||
|
),
|
||||||
|
source_quality="structured",
|
||||||
|
video_content_type="tutorial",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _mediocre_moment() -> dict:
|
||||||
|
"""90s settings moment, 1 plugin, decent summary, mixed source."""
|
||||||
|
return dict(
|
||||||
|
start_time=120.0,
|
||||||
|
end_time=210.0, # 90s duration
|
||||||
|
content_type="settings",
|
||||||
|
summary="Adjust the EQ settings for the vocal track to get a clearer sound.",
|
||||||
|
plugins=["FabFilter Pro-Q 3"],
|
||||||
|
raw_transcript=(
|
||||||
|
"So here we're just going to adjust this. I think it sounds "
|
||||||
|
"better when we cut some of the low end. Let me show you what "
|
||||||
|
"I mean. Yeah, that's better."
|
||||||
|
),
|
||||||
|
source_quality="mixed",
|
||||||
|
video_content_type="breakdown",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _poor_moment() -> dict:
|
||||||
|
"""300s reasoning moment, 0 plugins, vague summary, unstructured source."""
|
||||||
|
return dict(
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=300.0, # 300s duration → zero for duration_fitness
|
||||||
|
content_type="reasoning",
|
||||||
|
summary="General discussion about mixing philosophy and approach.",
|
||||||
|
plugins=[],
|
||||||
|
raw_transcript=(
|
||||||
|
"I think mixing is really about taste. Everyone has their own "
|
||||||
|
"approach. Some people like it loud, some people like it quiet. "
|
||||||
|
"There's no right or wrong way to do it really."
|
||||||
|
),
|
||||||
|
source_quality="unstructured",
|
||||||
|
video_content_type="livestream",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_word_timings(
|
||||||
|
start: float = 0.0,
|
||||||
|
count: int = 40,
|
||||||
|
wps: float = 4.0,
|
||||||
|
pause_every: int | None = None,
|
||||||
|
pause_duration: float = 0.8,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Generate synthetic word-timing dicts for testing.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
start : float
|
||||||
|
Start time in seconds.
|
||||||
|
count : int
|
||||||
|
Number of words to generate.
|
||||||
|
wps : float
|
||||||
|
Words per second (base rate).
|
||||||
|
pause_every : int | None
|
||||||
|
Insert a pause every N words. None = no pauses.
|
||||||
|
pause_duration : float
|
||||||
|
Duration of each pause in seconds.
|
||||||
|
"""
|
||||||
|
timings = []
|
||||||
|
t = start
|
||||||
|
word_dur = 1.0 / wps * 0.7 # 70% speaking, 30% normal gap
|
||||||
|
gap = 1.0 / wps * 0.3
|
||||||
|
|
||||||
|
for i in range(count):
|
||||||
|
timings.append({"word": f"word{i}", "start": t, "end": t + word_dur})
|
||||||
|
t += word_dur + gap
|
||||||
|
if pause_every and (i + 1) % pause_every == 0:
|
||||||
|
t += pause_duration
|
||||||
|
return timings
|
||||||
|
|
||||||
|
|
||||||
|
def _make_transcript_segments(word_timings: list[dict], words_per_segment: int = 10) -> list[dict]:
|
||||||
|
"""Group word timings into transcript segments for extract_word_timings tests."""
|
||||||
|
segments = []
|
||||||
|
for i in range(0, len(word_timings), words_per_segment):
|
||||||
|
chunk = word_timings[i : i + words_per_segment]
|
||||||
|
segments.append({"words": chunk})
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestScoreMoment:
|
||||||
|
def test_ideal_moment_scores_high(self):
|
||||||
|
result = score_moment(**_ideal_moment())
|
||||||
|
assert result["score"] > 0.7, f"Ideal moment scored {result['score']}, expected > 0.7"
|
||||||
|
|
||||||
|
def test_poor_moment_scores_low(self):
|
||||||
|
result = score_moment(**_poor_moment())
|
||||||
|
assert result["score"] < 0.4, f"Poor moment scored {result['score']}, expected < 0.4"
|
||||||
|
|
||||||
|
def test_ordering_is_sensible(self):
|
||||||
|
ideal = score_moment(**_ideal_moment())
|
||||||
|
mediocre = score_moment(**_mediocre_moment())
|
||||||
|
poor = score_moment(**_poor_moment())
|
||||||
|
|
||||||
|
assert ideal["score"] > mediocre["score"] > poor["score"], (
|
||||||
|
f"Expected ideal ({ideal['score']:.3f}) > "
|
||||||
|
f"mediocre ({mediocre['score']:.3f}) > "
|
||||||
|
f"poor ({poor['score']:.3f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_score_bounds(self):
|
||||||
|
"""All scores in [0.0, 1.0] for edge cases."""
|
||||||
|
edge_cases = [
|
||||||
|
dict(start_time=0, end_time=0, summary="", plugins=None, raw_transcript=None),
|
||||||
|
dict(start_time=0, end_time=500, summary=None, plugins=[], raw_transcript=""),
|
||||||
|
dict(start_time=0, end_time=45, summary="x" * 10000, plugins=["a"] * 100),
|
||||||
|
dict(start_time=100, end_time=100), # zero duration
|
||||||
|
]
|
||||||
|
for kwargs in edge_cases:
|
||||||
|
result = score_moment(**kwargs)
|
||||||
|
assert 0.0 <= result["score"] <= 1.0, f"Score {result['score']} out of bounds for {kwargs}"
|
||||||
|
for dim, val in result["score_breakdown"].items():
|
||||||
|
assert 0.0 <= val <= 1.0, f"{dim}={val} out of bounds for {kwargs}"
|
||||||
|
|
||||||
|
def test_missing_optional_fields(self):
|
||||||
|
"""None raw_transcript and None plugins don't crash."""
|
||||||
|
result = score_moment(
|
||||||
|
start_time=10.0,
|
||||||
|
end_time=55.0,
|
||||||
|
content_type="technique",
|
||||||
|
summary="A summary.",
|
||||||
|
plugins=None,
|
||||||
|
raw_transcript=None,
|
||||||
|
source_quality=None,
|
||||||
|
video_content_type=None,
|
||||||
|
)
|
||||||
|
assert 0.0 <= result["score"] <= 1.0
|
||||||
|
assert result["duration_secs"] == 45.0
|
||||||
|
assert len(result["score_breakdown"]) == 10
|
||||||
|
|
||||||
|
def test_returns_duration_secs(self):
|
||||||
|
result = score_moment(start_time=10.0, end_time=55.0)
|
||||||
|
assert result["duration_secs"] == 45.0
|
||||||
|
|
||||||
|
def test_breakdown_has_ten_dimensions(self):
|
||||||
|
result = score_moment(**_ideal_moment())
|
||||||
|
assert len(result["score_breakdown"]) == 10
|
||||||
|
expected_keys = {
|
||||||
|
"duration_score", "content_density_score", "technique_relevance_score",
|
||||||
|
"plugin_diversity_score", "engagement_proxy_score", "position_score",
|
||||||
|
"uniqueness_score", "speech_rate_variance_score", "pause_density_score",
|
||||||
|
"speaking_pace_score",
|
||||||
|
}
|
||||||
|
assert set(result["score_breakdown"].keys()) == expected_keys
|
||||||
|
|
||||||
|
def test_without_word_timings_audio_dims_are_neutral(self):
|
||||||
|
"""When word_timings is None, audio proxy dimensions score 0.5."""
|
||||||
|
result = score_moment(start_time=10.0, end_time=55.0)
|
||||||
|
bd = result["score_breakdown"]
|
||||||
|
assert bd["speech_rate_variance_score"] == 0.5
|
||||||
|
assert bd["pause_density_score"] == 0.5
|
||||||
|
assert bd["speaking_pace_score"] == 0.5
|
||||||
|
|
||||||
|
def test_with_word_timings_changes_score(self):
|
||||||
|
"""Providing word_timings should shift the composite score vs without."""
|
||||||
|
base = _ideal_moment()
|
||||||
|
without = score_moment(**base)
|
||||||
|
# Add word timings at a good teaching pace (~4 WPS) with some pauses
|
||||||
|
timings = _make_word_timings(start=10.0, count=120, wps=4.0, pause_every=15)
|
||||||
|
with_timings = score_moment(**base, word_timings=timings)
|
||||||
|
# Scores should differ since audio dims are no longer neutral
|
||||||
|
assert with_timings["score"] != without["score"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestDurationFitness:
|
||||||
|
def test_bell_curve_peak(self):
|
||||||
|
"""45s scores higher than 10s, 10s scores higher than 400s."""
|
||||||
|
assert _duration_fitness(45) > _duration_fitness(10)
|
||||||
|
assert _duration_fitness(10) > _duration_fitness(400)
|
||||||
|
|
||||||
|
def test_sweet_spot(self):
|
||||||
|
assert _duration_fitness(30) == 1.0
|
||||||
|
assert _duration_fitness(45) == 1.0
|
||||||
|
assert _duration_fitness(60) == 1.0
|
||||||
|
|
||||||
|
def test_zero_at_extremes(self):
|
||||||
|
assert _duration_fitness(0) == 0.0
|
||||||
|
assert _duration_fitness(300) == 0.0
|
||||||
|
assert _duration_fitness(500) == 0.0
|
||||||
|
|
||||||
|
def test_negative_duration(self):
|
||||||
|
assert _duration_fitness(-10) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestContentTypeWeight:
|
||||||
|
def test_technique_highest(self):
|
||||||
|
assert _content_type_weight("technique") == 1.0
|
||||||
|
|
||||||
|
def test_reasoning_lowest_known(self):
|
||||||
|
assert _content_type_weight("reasoning") == 0.4
|
||||||
|
|
||||||
|
def test_unknown_gets_default(self):
|
||||||
|
assert _content_type_weight("unknown") == 0.5
|
||||||
|
assert _content_type_weight(None) == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpecificityDensity:
|
||||||
|
def test_specific_summary_scores_high(self):
|
||||||
|
summary = "Set threshold to -18 dB with 4:1 ratio, boost 12 kHz by 3.5 dB"
|
||||||
|
score = _specificity_density(summary)
|
||||||
|
assert score > 0.5
|
||||||
|
|
||||||
|
def test_vague_summary_scores_low(self):
|
||||||
|
score = _specificity_density("General discussion about mixing philosophy.")
|
||||||
|
assert score < 0.3
|
||||||
|
|
||||||
|
def test_empty_returns_zero(self):
|
||||||
|
assert _specificity_density("") == 0.0
|
||||||
|
assert _specificity_density(None) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestPluginRichness:
|
||||||
|
def test_three_plugins_maxes_out(self):
|
||||||
|
assert _plugin_richness(["a", "b", "c"]) == 1.0
|
||||||
|
|
||||||
|
def test_more_than_three_capped(self):
|
||||||
|
assert _plugin_richness(["a", "b", "c", "d"]) == 1.0
|
||||||
|
|
||||||
|
def test_empty(self):
|
||||||
|
assert _plugin_richness([]) == 0.0
|
||||||
|
assert _plugin_richness(None) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranscriptEnergy:
|
||||||
|
def test_teaching_phrases_score_high(self):
|
||||||
|
transcript = (
|
||||||
|
"The trick is to notice how the compressor behaves. "
|
||||||
|
"Because we want dynamics, I always set it gently. The key is balance."
|
||||||
|
)
|
||||||
|
score = _transcript_energy(transcript)
|
||||||
|
assert score > 0.5
|
||||||
|
|
||||||
|
def test_bland_transcript_scores_low(self):
|
||||||
|
transcript = "And then we adjust this slider here. Okay that sounds fine."
|
||||||
|
score = _transcript_energy(transcript)
|
||||||
|
assert score < 0.3
|
||||||
|
|
||||||
|
def test_empty(self):
|
||||||
|
assert _transcript_energy("") == 0.0
|
||||||
|
assert _transcript_energy(None) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSourceQualityWeight:
|
||||||
|
def test_structured_highest(self):
|
||||||
|
assert _source_quality_weight("structured") == 1.0
|
||||||
|
|
||||||
|
def test_none_default(self):
|
||||||
|
assert _source_quality_weight(None) == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class TestVideoTypeWeight:
|
||||||
|
def test_tutorial_highest(self):
|
||||||
|
assert _video_type_weight("tutorial") == 1.0
|
||||||
|
|
||||||
|
def test_short_form_lowest(self):
|
||||||
|
assert _video_type_weight("short_form") == 0.3
|
||||||
|
|
||||||
|
def test_none_default(self):
|
||||||
|
assert _video_type_weight(None) == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
# ── Audio proxy function tests ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractWordTimings:
|
||||||
|
def test_filters_by_time_window(self):
|
||||||
|
words = _make_word_timings(start=0.0, count=40, wps=4.0)
|
||||||
|
segments = _make_transcript_segments(words)
|
||||||
|
# Extract window 2.0–5.0s
|
||||||
|
result = extract_word_timings(segments, start_time=2.0, end_time=5.0)
|
||||||
|
for w in result:
|
||||||
|
assert 2.0 <= w["start"] <= 5.0
|
||||||
|
|
||||||
|
def test_returns_all_when_window_covers_entire_range(self):
|
||||||
|
words = _make_word_timings(start=0.0, count=20, wps=4.0)
|
||||||
|
segments = _make_transcript_segments(words)
|
||||||
|
result = extract_word_timings(segments, start_time=0.0, end_time=100.0)
|
||||||
|
assert len(result) == 20
|
||||||
|
|
||||||
|
def test_empty_transcript_data(self):
|
||||||
|
assert extract_word_timings([], start_time=0.0, end_time=10.0) == []
|
||||||
|
|
||||||
|
def test_no_words_in_window(self):
|
||||||
|
words = _make_word_timings(start=0.0, count=10, wps=4.0)
|
||||||
|
segments = _make_transcript_segments(words)
|
||||||
|
# Window far beyond the word timings
|
||||||
|
result = extract_word_timings(segments, start_time=100.0, end_time=200.0)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_segments_without_words_key(self):
|
||||||
|
"""Segments missing 'words' are skipped gracefully."""
|
||||||
|
segments = [{"text": "hello"}, {"words": [{"start": 1.0, "end": 1.2, "word": "a"}]}]
|
||||||
|
result = extract_word_timings(segments, start_time=0.0, end_time=10.0)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_words_without_start_are_skipped(self):
|
||||||
|
segments = [{"words": [{"end": 1.2, "word": "a"}, {"start": 2.0, "end": 2.2, "word": "b"}]}]
|
||||||
|
result = extract_word_timings(segments, start_time=0.0, end_time=10.0)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["word"] == "b"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpeechRateVariance:
|
||||||
|
def test_none_returns_neutral(self):
|
||||||
|
assert _speech_rate_variance(None) == 0.5
|
||||||
|
|
||||||
|
def test_too_few_words_returns_neutral(self):
|
||||||
|
timings = _make_word_timings(count=3, wps=4.0)
|
||||||
|
assert _speech_rate_variance(timings) == 0.5
|
||||||
|
|
||||||
|
def test_short_span_returns_neutral(self):
|
||||||
|
"""Words spanning <5s should return neutral."""
|
||||||
|
timings = _make_word_timings(count=10, wps=4.0, start=0.0)
|
||||||
|
# 10 words at 4 WPS = 2.5s span → too short
|
||||||
|
assert _speech_rate_variance(timings) == 0.5
|
||||||
|
|
||||||
|
def test_uniform_pace_scores_low(self):
|
||||||
|
"""Steady 4 WPS for 30s → low variance."""
|
||||||
|
timings = _make_word_timings(start=0.0, count=120, wps=4.0)
|
||||||
|
score = _speech_rate_variance(timings)
|
||||||
|
assert score < 0.4, f"Uniform pace scored {score}, expected < 0.4"
|
||||||
|
|
||||||
|
def test_varied_pace_scores_higher(self):
|
||||||
|
"""Alternating fast/slow sections → higher variance."""
|
||||||
|
timings = []
|
||||||
|
t = 0.0
|
||||||
|
# Fast section: 6 WPS for 10s
|
||||||
|
for i in range(60):
|
||||||
|
dur = 0.12
|
||||||
|
timings.append({"word": f"w{i}", "start": t, "end": t + dur})
|
||||||
|
t += 1.0 / 6.0
|
||||||
|
# Slow section: 2 WPS for 10s
|
||||||
|
for i in range(20):
|
||||||
|
dur = 0.3
|
||||||
|
timings.append({"word": f"w{60+i}", "start": t, "end": t + dur})
|
||||||
|
t += 0.5
|
||||||
|
score = _speech_rate_variance(timings)
|
||||||
|
uniform_score = _speech_rate_variance(
|
||||||
|
_make_word_timings(start=0.0, count=80, wps=4.0)
|
||||||
|
)
|
||||||
|
assert score > uniform_score, (
|
||||||
|
f"Varied pace ({score:.3f}) should be > uniform ({uniform_score:.3f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_score_bounded(self):
|
||||||
|
timings = _make_word_timings(start=0.0, count=200, wps=4.0)
|
||||||
|
score = _speech_rate_variance(timings)
|
||||||
|
assert 0.0 <= score <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestPauseDensity:
|
||||||
|
def test_none_returns_neutral(self):
|
||||||
|
assert _pause_density(None) == 0.5
|
||||||
|
|
||||||
|
def test_single_word_returns_neutral(self):
|
||||||
|
assert _pause_density([{"start": 0.0, "end": 0.2}]) == 0.5
|
||||||
|
|
||||||
|
def test_no_pauses_scores_zero(self):
|
||||||
|
"""Continuous speech with no gaps >0.5s → 0."""
|
||||||
|
timings = _make_word_timings(start=0.0, count=60, wps=4.0)
|
||||||
|
score = _pause_density(timings)
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
def test_frequent_pauses_scores_high(self):
|
||||||
|
"""Pauses every 5 words → high density."""
|
||||||
|
timings = _make_word_timings(start=0.0, count=60, wps=4.0, pause_every=5, pause_duration=0.8)
|
||||||
|
score = _pause_density(timings)
|
||||||
|
assert score > 0.5, f"Frequent pauses scored {score}, expected > 0.5"
|
||||||
|
|
||||||
|
def test_long_pauses_weighted_more(self):
|
||||||
|
"""One 1.5s pause should score higher than one 0.6s pause in a longer segment."""
|
||||||
|
# Build timings with one long pause at midpoint — 60 words for longer duration
|
||||||
|
long_pause = []
|
||||||
|
t = 0.0
|
||||||
|
for i in range(60):
|
||||||
|
long_pause.append({"word": f"w{i}", "start": t, "end": t + 0.15})
|
||||||
|
t += 0.25
|
||||||
|
if i == 29:
|
||||||
|
t += 1.5 # long pause >1.0s
|
||||||
|
# Build timings with one short pause — same word count
|
||||||
|
short_pause = []
|
||||||
|
t = 0.0
|
||||||
|
for i in range(60):
|
||||||
|
short_pause.append({"word": f"w{i}", "start": t, "end": t + 0.15})
|
||||||
|
t += 0.25
|
||||||
|
if i == 29:
|
||||||
|
t += 0.6 # short pause >0.5s but <1.0s
|
||||||
|
assert _pause_density(long_pause) > _pause_density(short_pause)
|
||||||
|
|
||||||
|
def test_score_bounded(self):
|
||||||
|
timings = _make_word_timings(start=0.0, count=60, wps=4.0, pause_every=3, pause_duration=1.5)
|
||||||
|
score = _pause_density(timings)
|
||||||
|
assert 0.0 <= score <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpeakingPaceFitness:
|
||||||
|
def test_none_returns_neutral(self):
|
||||||
|
assert _speaking_pace_fitness(None) == 0.5
|
||||||
|
|
||||||
|
def test_single_word_returns_neutral(self):
|
||||||
|
assert _speaking_pace_fitness([{"start": 0.0, "end": 0.2}]) == 0.5
|
||||||
|
|
||||||
|
def test_optimal_pace_scores_high(self):
|
||||||
|
"""4 WPS (optimal teaching pace) → 1.0."""
|
||||||
|
timings = _make_word_timings(start=0.0, count=40, wps=4.0)
|
||||||
|
score = _speaking_pace_fitness(timings)
|
||||||
|
assert score == 1.0, f"4 WPS scored {score}, expected 1.0"
|
||||||
|
|
||||||
|
def test_three_wps_is_sweet_spot_edge(self):
|
||||||
|
timings = _make_word_timings(start=0.0, count=30, wps=3.0)
|
||||||
|
score = _speaking_pace_fitness(timings)
|
||||||
|
assert score == 1.0
|
||||||
|
|
||||||
|
def test_five_wps_is_sweet_spot_edge(self):
|
||||||
|
timings = _make_word_timings(start=0.0, count=50, wps=5.0)
|
||||||
|
score = _speaking_pace_fitness(timings)
|
||||||
|
assert score > 0.95, f"5 WPS scored {score}, expected near 1.0"
|
||||||
|
|
||||||
|
def test_too_slow_scores_lower(self):
|
||||||
|
"""1.5 WPS → below sweet spot."""
|
||||||
|
timings = _make_word_timings(start=0.0, count=15, wps=1.5)
|
||||||
|
score = _speaking_pace_fitness(timings)
|
||||||
|
assert 0.4 < score < 0.6, f"1.5 WPS scored {score}, expected ~0.5"
|
||||||
|
|
||||||
|
def test_too_fast_scores_lower(self):
|
||||||
|
"""8 WPS → above sweet spot."""
|
||||||
|
timings = _make_word_timings(start=0.0, count=80, wps=8.0)
|
||||||
|
score = _speaking_pace_fitness(timings)
|
||||||
|
assert 0.0 < score < 1.0
|
||||||
|
|
||||||
|
def test_very_fast_scores_zero(self):
|
||||||
|
"""10+ WPS → 0."""
|
||||||
|
timings = _make_word_timings(start=0.0, count=110, wps=11.0)
|
||||||
|
score = _speaking_pace_fitness(timings)
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
def test_zero_wps_scores_zero(self):
|
||||||
|
"""Very short duration → neutral."""
|
||||||
|
timings = [{"start": 0.0, "end": 0.01}, {"start": 0.005, "end": 0.015}]
|
||||||
|
score = _speaking_pace_fitness(timings)
|
||||||
|
# Duration ~0.015s → too short → 0.5 (neutral)
|
||||||
|
assert score == 0.5
|
||||||
|
|
||||||
|
def test_score_bounded(self):
|
||||||
|
for wps in [0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 10.0]:
|
||||||
|
timings = _make_word_timings(start=0.0, count=max(10, int(wps * 10)), wps=wps)
|
||||||
|
score = _speaking_pace_fitness(timings)
|
||||||
|
assert 0.0 <= score <= 1.0, f"WPS {wps} scored {score} out of bounds"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackwardCompatibility:
|
||||||
|
"""Ensure the weight rebalancing doesn't break existing relative orderings."""
|
||||||
|
|
||||||
|
def test_ideal_still_beats_poor(self):
|
||||||
|
ideal = score_moment(**_ideal_moment())
|
||||||
|
poor = score_moment(**_poor_moment())
|
||||||
|
assert ideal["score"] > poor["score"]
|
||||||
|
|
||||||
|
def test_ideal_still_above_threshold(self):
|
||||||
|
result = score_moment(**_ideal_moment())
|
||||||
|
assert result["score"] > 0.6, f"Ideal scored {result['score']}, expected > 0.6"
|
||||||
|
|
||||||
|
def test_poor_still_below_threshold(self):
|
||||||
|
result = score_moment(**_poor_moment())
|
||||||
|
assert result["score"] < 0.45, f"Poor scored {result['score']}, expected < 0.45"
|
||||||
|
|
||||||
|
def test_weights_sum_to_one(self):
|
||||||
|
from backend.pipeline.highlight_scorer import _WEIGHTS
|
||||||
|
assert abs(sum(_WEIGHTS.values()) - 1.0) < 1e-9
|
||||||
328
backend/pipeline/test_section_embedding.py
Normal file
328
backend/pipeline/test_section_embedding.py
Normal file
|
|
@ -0,0 +1,328 @@
|
||||||
|
"""Unit tests for per-section embedding in stage 6.
|
||||||
|
|
||||||
|
Tests _slugify_heading, section embed text construction, delete-before-upsert
|
||||||
|
ordering, v1 page skipping, upsert payload correctness, and deterministic UUIDs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# ── slugify tests ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
from pipeline.stages import _slugify_heading
|
||||||
|
|
||||||
|
|
||||||
|
class TestSlugifyHeading:
|
||||||
|
"""Verify _slugify_heading matches frontend TableOfContents.tsx slugify."""
|
||||||
|
|
||||||
|
def test_simple_heading(self):
|
||||||
|
assert _slugify_heading("Grain Position Control") == "grain-position-control"
|
||||||
|
|
||||||
|
def test_ampersand_and_special_chars(self):
|
||||||
|
# Consecutive non-alphanumeric chars collapse to a single hyphen
|
||||||
|
assert _slugify_heading("LFO Routing & Modulation") == "lfo-routing-modulation"
|
||||||
|
|
||||||
|
def test_leading_trailing_special(self):
|
||||||
|
assert _slugify_heading(" —Hello World! ") == "hello-world"
|
||||||
|
|
||||||
|
def test_numbers_preserved(self):
|
||||||
|
assert _slugify_heading("Step 1: Setup") == "step-1-setup"
|
||||||
|
|
||||||
|
def test_empty_string(self):
|
||||||
|
assert _slugify_heading("") == ""
|
||||||
|
|
||||||
|
def test_only_special_chars(self):
|
||||||
|
assert _slugify_heading("!@#$%") == ""
|
||||||
|
|
||||||
|
def test_unicode_stripped(self):
|
||||||
|
assert _slugify_heading("Café Sounds") == "caf-sounds"
|
||||||
|
|
||||||
|
def test_multiple_hyphens_collapse(self):
|
||||||
|
assert _slugify_heading("A -- B --- C") == "a-b-c"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Deterministic UUID tests ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_QDRANT_NAMESPACE = uuid.UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890")
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeterministicUUIDs:
|
||||||
|
"""Verify same page+section always produces the same point ID."""
|
||||||
|
|
||||||
|
def test_same_input_same_uuid(self):
|
||||||
|
id1 = str(uuid.uuid5(_QDRANT_NAMESPACE, "ts:page-abc:grain-position-control"))
|
||||||
|
id2 = str(uuid.uuid5(_QDRANT_NAMESPACE, "ts:page-abc:grain-position-control"))
|
||||||
|
assert id1 == id2
|
||||||
|
|
||||||
|
def test_different_section_different_uuid(self):
|
||||||
|
id1 = str(uuid.uuid5(_QDRANT_NAMESPACE, "ts:page-abc:section-a"))
|
||||||
|
id2 = str(uuid.uuid5(_QDRANT_NAMESPACE, "ts:page-abc:section-b"))
|
||||||
|
assert id1 != id2
|
||||||
|
|
||||||
|
|
||||||
|
# ── QdrantManager section methods ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestQdrantManagerSections:
|
||||||
|
"""Test upsert_technique_sections and delete_sections_by_page_id."""
|
||||||
|
|
||||||
|
def _make_manager(self):
|
||||||
|
"""Create a QdrantManager with a mocked client."""
|
||||||
|
with patch("pipeline.qdrant_client.QdrantClient") as MockClient:
|
||||||
|
mock_client = MockClient.return_value
|
||||||
|
from pipeline.qdrant_client import QdrantManager
|
||||||
|
settings = MagicMock()
|
||||||
|
settings.qdrant_url = "http://localhost:6333"
|
||||||
|
settings.qdrant_collection = "test_collection"
|
||||||
|
settings.embedding_dimensions = 768
|
||||||
|
mgr = QdrantManager(settings)
|
||||||
|
mgr._client = mock_client
|
||||||
|
return mgr, mock_client
|
||||||
|
|
||||||
|
def test_upsert_builds_correct_payloads(self):
|
||||||
|
mgr, mock_client = self._make_manager()
|
||||||
|
sections = [
|
||||||
|
{
|
||||||
|
"page_id": "p1",
|
||||||
|
"creator_id": "c1",
|
||||||
|
"creator_name": "Keota",
|
||||||
|
"title": "Granular Synthesis",
|
||||||
|
"slug": "granular-synthesis",
|
||||||
|
"section_heading": "Grain Position Control",
|
||||||
|
"section_anchor": "grain-position-control",
|
||||||
|
"topic_category": "Sound Design",
|
||||||
|
"topic_tags": ["granular", "synthesis"],
|
||||||
|
"summary": "Control the grain position parameter.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
vectors = [[0.1] * 768]
|
||||||
|
|
||||||
|
mgr.upsert_technique_sections(sections, vectors)
|
||||||
|
|
||||||
|
# Verify upsert was called
|
||||||
|
assert mock_client.upsert.called
|
||||||
|
points = mock_client.upsert.call_args[1]["points"]
|
||||||
|
assert len(points) == 1
|
||||||
|
|
||||||
|
payload = points[0].payload
|
||||||
|
assert payload["type"] == "technique_section"
|
||||||
|
assert payload["page_id"] == "p1"
|
||||||
|
assert payload["section_heading"] == "Grain Position Control"
|
||||||
|
assert payload["section_anchor"] == "grain-position-control"
|
||||||
|
assert payload["slug"] == "granular-synthesis"
|
||||||
|
|
||||||
|
# Verify deterministic UUID
|
||||||
|
expected_id = str(uuid.uuid5(_QDRANT_NAMESPACE, "ts:p1:grain-position-control"))
|
||||||
|
assert points[0].id == expected_id
|
||||||
|
|
||||||
|
def test_upsert_count_mismatch_skips(self):
|
||||||
|
mgr, mock_client = self._make_manager()
|
||||||
|
mgr.upsert_technique_sections([{"page_id": "p1"}], [[0.1], [0.2]])
|
||||||
|
assert not mock_client.upsert.called
|
||||||
|
|
||||||
|
def test_upsert_empty_list_skips(self):
|
||||||
|
mgr, mock_client = self._make_manager()
|
||||||
|
mgr.upsert_technique_sections([], [])
|
||||||
|
assert not mock_client.upsert.called
|
||||||
|
|
||||||
|
def test_summary_truncated_to_200_chars(self):
|
||||||
|
mgr, mock_client = self._make_manager()
|
||||||
|
long_summary = "x" * 500
|
||||||
|
sections = [{
|
||||||
|
"page_id": "p1", "section_heading": "H", "section_anchor": "h",
|
||||||
|
"summary": long_summary,
|
||||||
|
}]
|
||||||
|
vectors = [[0.1] * 768]
|
||||||
|
mgr.upsert_technique_sections(sections, vectors)
|
||||||
|
payload = mock_client.upsert.call_args[1]["points"][0].payload
|
||||||
|
assert len(payload["summary"]) == 200
|
||||||
|
|
||||||
|
def test_delete_sections_by_page_id(self):
|
||||||
|
mgr, mock_client = self._make_manager()
|
||||||
|
mgr.delete_sections_by_page_id("p1")
|
||||||
|
assert mock_client.delete.called
|
||||||
|
filter_arg = mock_client.delete.call_args[1]["points_selector"]
|
||||||
|
# Verify filter has both page_id and type conditions
|
||||||
|
must_conditions = filter_arg.must
|
||||||
|
assert len(must_conditions) == 2
|
||||||
|
keys = {c.key for c in must_conditions}
|
||||||
|
assert keys == {"page_id", "type"}
|
||||||
|
|
||||||
|
def test_delete_sections_logs_on_failure(self):
|
||||||
|
mgr, mock_client = self._make_manager()
|
||||||
|
mock_client.delete.side_effect = Exception("connection refused")
|
||||||
|
# Should not raise
|
||||||
|
mgr.delete_sections_by_page_id("p1")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stage 6 section embedding logic ─────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestStage6SectionEmbedding:
|
||||||
|
"""Test the section embedding block within stage6_embed_and_index.
|
||||||
|
|
||||||
|
Uses mocked DB, embedding client, and QdrantManager to verify:
|
||||||
|
- v2 pages produce section points
|
||||||
|
- v1 pages are skipped
|
||||||
|
- delete is called before upsert
|
||||||
|
- embed text includes creator/page/section context
|
||||||
|
- sections with empty headings are skipped
|
||||||
|
- subsection content is included in embed text
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _make_page(self, page_id="p1", creator_id="c1", format_="v2",
|
||||||
|
body_sections=None, title="Granular Synthesis",
|
||||||
|
slug="granular-synthesis"):
|
||||||
|
"""Create a mock TechniquePage-like object."""
|
||||||
|
page = MagicMock()
|
||||||
|
page.id = page_id
|
||||||
|
page.creator_id = creator_id
|
||||||
|
page.body_sections_format = format_
|
||||||
|
page.body_sections = body_sections
|
||||||
|
page.title = title
|
||||||
|
page.slug = slug
|
||||||
|
page.topic_category = "Sound Design"
|
||||||
|
page.topic_tags = ["granular"]
|
||||||
|
page.summary = "Page summary"
|
||||||
|
return page
|
||||||
|
|
||||||
|
def test_v1_page_produces_zero_sections(self):
|
||||||
|
"""Pages with body_sections_format != 'v2' should be skipped."""
|
||||||
|
page = self._make_page(format_="v1", body_sections=[
|
||||||
|
{"heading": "Section A", "content": "Content A"},
|
||||||
|
])
|
||||||
|
v2_pages = [p for p in [page] if getattr(p, "body_sections_format", "v1") == "v2"]
|
||||||
|
assert len(v2_pages) == 0
|
||||||
|
|
||||||
|
def test_v2_page_none_body_sections(self):
|
||||||
|
"""Page with body_sections=None → skipped (not a list)."""
|
||||||
|
page = self._make_page(format_="v2", body_sections=None)
|
||||||
|
v2_pages = [p for p in [page] if getattr(p, "body_sections_format", "v1") == "v2"]
|
||||||
|
assert len(v2_pages) == 1
|
||||||
|
# body_sections is None → not a list → skipped in the loop
|
||||||
|
assert not isinstance(page.body_sections, list)
|
||||||
|
|
||||||
|
def test_section_empty_heading_skipped(self):
|
||||||
|
"""Sections with empty heading should be skipped."""
|
||||||
|
page = self._make_page(body_sections=[
|
||||||
|
{"heading": "", "content": "Orphan content"},
|
||||||
|
{"heading": "Valid", "content": "Real content"},
|
||||||
|
])
|
||||||
|
sections_with_heading = [
|
||||||
|
s for s in page.body_sections
|
||||||
|
if isinstance(s, dict) and s.get("heading", "").strip()
|
||||||
|
]
|
||||||
|
assert len(sections_with_heading) == 1
|
||||||
|
assert sections_with_heading[0]["heading"] == "Valid"
|
||||||
|
|
||||||
|
def test_subsection_content_included_in_embed_text(self):
|
||||||
|
"""Section with subsections should include subsection content."""
|
||||||
|
section = {
|
||||||
|
"heading": "Grain Position Control",
|
||||||
|
"content": "Main content",
|
||||||
|
"subsections": [
|
||||||
|
{"heading": "Fine Tuning", "content": "Fine tune the position."},
|
||||||
|
{"heading": "Automation", "content": "Automate grain pos."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Reproduce the embed text construction from stage 6
|
||||||
|
creator_name = "Keota"
|
||||||
|
page_title = "Granular Synthesis"
|
||||||
|
heading = section["heading"]
|
||||||
|
section_content = section.get("content", "")
|
||||||
|
subsection_parts = []
|
||||||
|
for sub in section.get("subsections", []):
|
||||||
|
if isinstance(sub, dict):
|
||||||
|
sub_heading = sub.get("heading", "")
|
||||||
|
sub_content = sub.get("content", "")
|
||||||
|
if sub_heading:
|
||||||
|
subsection_parts.append(f"{sub_heading}: {sub_content}")
|
||||||
|
elif sub_content:
|
||||||
|
subsection_parts.append(sub_content)
|
||||||
|
|
||||||
|
embed_text = (
|
||||||
|
f"{creator_name} {page_title} — {heading}: "
|
||||||
|
f"{section_content} {' '.join(subsection_parts)}"
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
assert "Fine Tuning: Fine tune the position." in embed_text
|
||||||
|
assert "Automation: Automate grain pos." in embed_text
|
||||||
|
assert "Keota Granular Synthesis" in embed_text
|
||||||
|
|
||||||
|
def test_subsection_no_direct_content(self):
|
||||||
|
"""Section with subsections but no direct content still embeds subsection text."""
|
||||||
|
section = {
|
||||||
|
"heading": "Advanced Techniques",
|
||||||
|
"content": "",
|
||||||
|
"subsections": [
|
||||||
|
{"heading": "Sub A", "content": "Content A"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
heading = section["heading"]
|
||||||
|
section_content = section.get("content", "")
|
||||||
|
subsection_parts = []
|
||||||
|
for sub in section.get("subsections", []):
|
||||||
|
if isinstance(sub, dict):
|
||||||
|
sub_heading = sub.get("heading", "")
|
||||||
|
sub_content = sub.get("content", "")
|
||||||
|
if sub_heading:
|
||||||
|
subsection_parts.append(f"{sub_heading}: {sub_content}")
|
||||||
|
elif sub_content:
|
||||||
|
subsection_parts.append(sub_content)
|
||||||
|
|
||||||
|
embed_text = (
|
||||||
|
f"Creator Page — {heading}: "
|
||||||
|
f"{section_content} {' '.join(subsection_parts)}"
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
assert "Sub A: Content A" in embed_text
|
||||||
|
|
||||||
|
def test_delete_called_before_upsert_ordering(self):
|
||||||
|
"""Verify delete_sections_by_page_id is called before upsert_technique_sections."""
|
||||||
|
call_order = []
|
||||||
|
mock_qdrant = MagicMock()
|
||||||
|
mock_qdrant.delete_sections_by_page_id.side_effect = lambda pid: call_order.append(("delete", pid))
|
||||||
|
mock_qdrant.upsert_technique_sections.side_effect = lambda s, v: call_order.append(("upsert", len(s)))
|
||||||
|
|
||||||
|
mock_embed = MagicMock()
|
||||||
|
mock_embed.embed.return_value = [[0.1] * 768] # One vector
|
||||||
|
|
||||||
|
page = self._make_page(body_sections=[
|
||||||
|
{"heading": "Section A", "content": "Content A"},
|
||||||
|
])
|
||||||
|
|
||||||
|
creator_map = {str(page.creator_id): "TestCreator"}
|
||||||
|
v2_pages = [page]
|
||||||
|
page_id_str = str(page.id)
|
||||||
|
|
||||||
|
# Simulate the section embedding block
|
||||||
|
for p in v2_pages:
|
||||||
|
body_sections = p.body_sections
|
||||||
|
if not isinstance(body_sections, list):
|
||||||
|
continue
|
||||||
|
creator_name = creator_map.get(str(p.creator_id), "")
|
||||||
|
mock_qdrant.delete_sections_by_page_id(str(p.id))
|
||||||
|
|
||||||
|
section_texts = []
|
||||||
|
section_dicts = []
|
||||||
|
for section in body_sections:
|
||||||
|
if not isinstance(section, dict):
|
||||||
|
continue
|
||||||
|
heading = section.get("heading", "")
|
||||||
|
if not heading or not heading.strip():
|
||||||
|
continue
|
||||||
|
section_anchor = _slugify_heading(heading)
|
||||||
|
section_texts.append(f"{creator_name} {p.title} — {heading}")
|
||||||
|
section_dicts.append({"page_id": str(p.id), "section_anchor": section_anchor})
|
||||||
|
|
||||||
|
if section_texts:
|
||||||
|
vectors = mock_embed.embed(section_texts)
|
||||||
|
if vectors:
|
||||||
|
mock_qdrant.upsert_technique_sections(section_dicts, vectors)
|
||||||
|
|
||||||
|
assert call_order[0][0] == "delete"
|
||||||
|
assert call_order[1][0] == "upsert"
|
||||||
3
backend/pytest.ini
Normal file
3
backend/pytest.ini
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
[pytest]
|
||||||
|
asyncio_mode = auto
|
||||||
|
testpaths = tests
|
||||||
116
backend/rate_limiter.py
Normal file
116
backend/rate_limiter.py
Normal file
|
|
@ -0,0 +1,116 @@
|
||||||
|
"""Redis sliding-window rate limiter using sorted sets.
|
||||||
|
|
||||||
|
Each rate limit key is a Redis sorted set where members are unique
|
||||||
|
request identifiers (timestamps with microseconds) and scores are
|
||||||
|
Unix timestamps. On each check, expired entries are pruned, the
|
||||||
|
current request is added, and the count determines whether the
|
||||||
|
request is allowed.
|
||||||
|
|
||||||
|
Fail-open: If Redis is unavailable, requests are allowed through
|
||||||
|
with a WARNING log.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
logger = logging.getLogger("chrysopedia.rate_limiter")
|
||||||
|
|
||||||
|
_KEY_PREFIX = "chrysopedia:ratelimit"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RateLimitResult:
|
||||||
|
"""Result of a rate limit check."""
|
||||||
|
|
||||||
|
allowed: bool
|
||||||
|
remaining: int
|
||||||
|
retry_after: int # seconds until the window slides enough to allow a request; 0 if allowed
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""Sliding-window rate limiter backed by Redis sorted sets.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
limiter = RateLimiter(redis)
|
||||||
|
result = await limiter.check_rate_limit("user:abc123", limit=30, window_seconds=3600)
|
||||||
|
if not result.allowed:
|
||||||
|
return 429, result.retry_after
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, redis: aioredis.Redis) -> None:
|
||||||
|
self._redis = redis
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def key(scope: str, identifier: str) -> str:
|
||||||
|
"""Build a namespaced Redis key for a rate limit bucket."""
|
||||||
|
return f"{_KEY_PREFIX}:{scope}:{identifier}"
|
||||||
|
|
||||||
|
async def check_rate_limit(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
limit: int,
|
||||||
|
window_seconds: int = 3600,
|
||||||
|
) -> RateLimitResult:
|
||||||
|
"""Check whether a request is within the rate limit.
|
||||||
|
|
||||||
|
Uses a sorted set where:
|
||||||
|
- ZREMRANGEBYSCORE prunes entries older than the window
|
||||||
|
- ZCARD counts current entries
|
||||||
|
- ZADD adds the current request if under limit
|
||||||
|
|
||||||
|
Returns a RateLimitResult with allowed/remaining/retry_after.
|
||||||
|
On Redis errors, fails open (allowed=True).
|
||||||
|
"""
|
||||||
|
now = time.time()
|
||||||
|
window_start = now - window_seconds
|
||||||
|
|
||||||
|
try:
|
||||||
|
pipe = self._redis.pipeline(transaction=True)
|
||||||
|
# Remove expired entries
|
||||||
|
pipe.zremrangebyscore(key, "-inf", window_start)
|
||||||
|
# Count remaining entries
|
||||||
|
pipe.zcard(key)
|
||||||
|
results = await pipe.execute()
|
||||||
|
|
||||||
|
current_count: int = results[1]
|
||||||
|
|
||||||
|
if current_count >= limit:
|
||||||
|
# Over limit — calculate retry_after from oldest entry
|
||||||
|
oldest = await self._redis.zrange(key, 0, 0, withscores=True)
|
||||||
|
if oldest:
|
||||||
|
oldest_score = oldest[0][1]
|
||||||
|
retry_after = int(oldest_score + window_seconds - now) + 1
|
||||||
|
retry_after = max(retry_after, 1)
|
||||||
|
else:
|
||||||
|
retry_after = window_seconds
|
||||||
|
|
||||||
|
return RateLimitResult(
|
||||||
|
allowed=False,
|
||||||
|
remaining=0,
|
||||||
|
retry_after=retry_after,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Under limit — add this request
|
||||||
|
member = f"{now}:{id(key)}" # unique member per call
|
||||||
|
await self._redis.zadd(key, {member: now})
|
||||||
|
# Set TTL on the key so it auto-expires after the window
|
||||||
|
await self._redis.expire(key, window_seconds + 60)
|
||||||
|
|
||||||
|
remaining = limit - current_count - 1
|
||||||
|
return RateLimitResult(
|
||||||
|
allowed=True,
|
||||||
|
remaining=max(remaining, 0),
|
||||||
|
retry_after=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"rate_limit_redis_error key=%s — failing open", key, exc_info=True
|
||||||
|
)
|
||||||
|
return RateLimitResult(allowed=True, remaining=limit, retry_after=0)
|
||||||
15
backend/redis_client.py
Normal file
15
backend/redis_client.py
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
"""Async Redis client helper for Chrysopedia."""
|
||||||
|
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
from config import get_settings
|
||||||
|
|
||||||
|
|
||||||
|
async def get_redis() -> aioredis.Redis:
|
||||||
|
"""Return an async Redis client from the configured URL.
|
||||||
|
|
||||||
|
Callers should close the connection when done, or use it
|
||||||
|
as a short-lived client within a request handler.
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
return aioredis.from_url(settings.redis_url, decode_responses=True)
|
||||||
23
backend/requirements.txt
Normal file
23
backend/requirements.txt
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
fastapi>=0.115.0,<1.0
|
||||||
|
uvicorn[standard]>=0.32.0,<1.0
|
||||||
|
sqlalchemy[asyncio]>=2.0,<3.0
|
||||||
|
asyncpg>=0.30.0,<1.0
|
||||||
|
alembic>=1.14.0,<2.0
|
||||||
|
pydantic>=2.0,<3.0
|
||||||
|
pydantic-settings>=2.0,<3.0
|
||||||
|
celery[redis]>=5.4.0,<6.0
|
||||||
|
redis>=5.0,<6.0
|
||||||
|
python-dotenv>=1.0,<2.0
|
||||||
|
python-multipart>=0.0.9,<1.0
|
||||||
|
httpx>=0.27.0,<1.0
|
||||||
|
openai>=1.0,<2.0
|
||||||
|
qdrant-client>=1.9,<2.0
|
||||||
|
pyyaml>=6.0,<7.0
|
||||||
|
psycopg2-binary>=2.9,<3.0
|
||||||
|
watchdog>=4.0,<5.0
|
||||||
|
PyJWT>=2.8,<3.0
|
||||||
|
bcrypt>=4.0,<6.0
|
||||||
|
minio>=7.2,<8.0
|
||||||
|
# Test dependencies
|
||||||
|
pytest>=8.0,<10.0
|
||||||
|
pytest-asyncio>=0.24,<1.0
|
||||||
1
backend/routers/__init__.py
Normal file
1
backend/routers/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Chrysopedia API routers package."""
|
||||||
417
backend/routers/admin.py
Normal file
417
backend/routers/admin.py
Normal file
|
|
@ -0,0 +1,417 @@
|
||||||
|
"""Admin router — user management, impersonation, and usage analytics."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Annotated
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import aliased
|
||||||
|
|
||||||
|
from auth import (
|
||||||
|
create_impersonation_token,
|
||||||
|
decode_access_token,
|
||||||
|
get_current_user,
|
||||||
|
require_role,
|
||||||
|
)
|
||||||
|
from database import get_session
|
||||||
|
from models import ChatUsageLog, ImpersonationLog, User, UserRole
|
||||||
|
|
||||||
|
logger = logging.getLogger("chrysopedia.admin")
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||||
|
|
||||||
|
_require_admin = require_role(UserRole.admin)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Schemas ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class UserListItem(BaseModel):
|
||||||
|
id: str
|
||||||
|
email: str
|
||||||
|
display_name: str
|
||||||
|
role: str
|
||||||
|
creator_id: str | None
|
||||||
|
is_active: bool
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class ImpersonateResponse(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
target_user: UserListItem
|
||||||
|
|
||||||
|
|
||||||
|
class StopImpersonateResponse(BaseModel):
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class StartImpersonationRequest(BaseModel):
|
||||||
|
write_mode: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ImpersonationLogItem(BaseModel):
|
||||||
|
id: str
|
||||||
|
admin_name: str
|
||||||
|
target_name: str
|
||||||
|
action: str
|
||||||
|
write_mode: bool
|
||||||
|
ip_address: str | None
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _client_ip(request: Request) -> str | None:
|
||||||
|
"""Best-effort client IP from X-Forwarded-For or direct connection."""
|
||||||
|
forwarded = request.headers.get("x-forwarded-for")
|
||||||
|
if forwarded:
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
if request.client:
|
||||||
|
return request.client.host
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Endpoints ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/users", response_model=list[UserListItem])
|
||||||
|
async def list_users(
|
||||||
|
_admin: Annotated[User, Depends(_require_admin)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""List all users. Admin only."""
|
||||||
|
result = await session.execute(
|
||||||
|
select(User).order_by(User.display_name)
|
||||||
|
)
|
||||||
|
users = result.scalars().all()
|
||||||
|
return [
|
||||||
|
UserListItem(
|
||||||
|
id=str(u.id),
|
||||||
|
email=u.email,
|
||||||
|
display_name=u.display_name,
|
||||||
|
role=u.role.value,
|
||||||
|
creator_id=str(u.creator_id) if u.creator_id else None,
|
||||||
|
is_active=u.is_active,
|
||||||
|
)
|
||||||
|
for u in users
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/impersonate/{user_id}", response_model=ImpersonateResponse)
|
||||||
|
async def start_impersonation(
|
||||||
|
user_id: UUID,
|
||||||
|
request: Request,
|
||||||
|
admin: Annotated[User, Depends(_require_admin)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
body: StartImpersonationRequest | None = None,
|
||||||
|
):
|
||||||
|
"""Start impersonating a user. Admin only. Returns a scoped JWT."""
|
||||||
|
if body is None:
|
||||||
|
body = StartImpersonationRequest()
|
||||||
|
|
||||||
|
# Cannot impersonate yourself
|
||||||
|
if admin.id == user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Cannot impersonate yourself",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load target user
|
||||||
|
result = await session.execute(select(User).where(User.id == user_id))
|
||||||
|
target = result.scalar_one_or_none()
|
||||||
|
if target is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Target user not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create impersonation token
|
||||||
|
token = create_impersonation_token(
|
||||||
|
admin_user_id=admin.id,
|
||||||
|
target_user_id=target.id,
|
||||||
|
target_role=target.role.value,
|
||||||
|
write_mode=body.write_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audit log
|
||||||
|
session.add(ImpersonationLog(
|
||||||
|
admin_user_id=admin.id,
|
||||||
|
target_user_id=target.id,
|
||||||
|
action="start",
|
||||||
|
write_mode=body.write_mode,
|
||||||
|
ip_address=_client_ip(request),
|
||||||
|
))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Impersonation started: admin=%s target=%s write_mode=%s",
|
||||||
|
admin.id, target.id, body.write_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImpersonateResponse(
|
||||||
|
access_token=token,
|
||||||
|
target_user=UserListItem(
|
||||||
|
id=str(target.id),
|
||||||
|
email=target.email,
|
||||||
|
display_name=target.display_name,
|
||||||
|
role=target.role.value,
|
||||||
|
creator_id=str(target.creator_id) if target.creator_id else None,
|
||||||
|
is_active=target.is_active,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/impersonate/stop", response_model=StopImpersonateResponse)
|
||||||
|
async def stop_impersonation(
|
||||||
|
request: Request,
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""Stop impersonation. Requires a valid impersonation token."""
|
||||||
|
admin_id = getattr(current_user, "_impersonating_admin_id", None)
|
||||||
|
if admin_id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Not currently impersonating",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audit log
|
||||||
|
session.add(ImpersonationLog(
|
||||||
|
admin_user_id=admin_id,
|
||||||
|
target_user_id=current_user.id,
|
||||||
|
action="stop",
|
||||||
|
ip_address=_client_ip(request),
|
||||||
|
))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Impersonation stopped: admin=%s target=%s",
|
||||||
|
admin_id, current_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return StopImpersonateResponse(message="Impersonation ended")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/impersonation-log", response_model=list[ImpersonationLogItem])
|
||||||
|
async def get_impersonation_log(
|
||||||
|
_admin: Annotated[User, Depends(_require_admin)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
page_size: int = Query(50, ge=1, le=200),
|
||||||
|
):
|
||||||
|
"""Paginated impersonation audit log. Admin only."""
|
||||||
|
AdminUser = aliased(User, name="admin_user")
|
||||||
|
TargetUser = aliased(User, name="target_user")
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(ImpersonationLog, AdminUser.display_name, TargetUser.display_name)
|
||||||
|
.join(AdminUser, ImpersonationLog.admin_user_id == AdminUser.id)
|
||||||
|
.join(TargetUser, ImpersonationLog.target_user_id == TargetUser.id)
|
||||||
|
.order_by(ImpersonationLog.created_at.desc())
|
||||||
|
.offset((page - 1) * page_size)
|
||||||
|
.limit(page_size)
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
rows = result.all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
ImpersonationLogItem(
|
||||||
|
id=str(log.id),
|
||||||
|
admin_name=admin_name,
|
||||||
|
target_name=target_name,
|
||||||
|
action=log.action,
|
||||||
|
write_mode=log.write_mode,
|
||||||
|
ip_address=log.ip_address,
|
||||||
|
created_at=log.created_at,
|
||||||
|
)
|
||||||
|
for log, admin_name, target_name in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/creators/{slug}/extract-profile")
|
||||||
|
async def extract_creator_profile(
|
||||||
|
slug: str,
|
||||||
|
_admin: Annotated[User, Depends(_require_admin)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""Queue personality profile extraction for a creator. Admin only."""
|
||||||
|
from models import Creator
|
||||||
|
|
||||||
|
result = await session.execute(
|
||||||
|
select(Creator).where(Creator.slug == slug)
|
||||||
|
)
|
||||||
|
creator = result.scalar_one_or_none()
|
||||||
|
if creator is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Creator not found: {slug}",
|
||||||
|
)
|
||||||
|
|
||||||
|
from pipeline.stages import extract_personality_profile
|
||||||
|
extract_personality_profile.delay(str(creator.id))
|
||||||
|
|
||||||
|
logger.info("Queued personality extraction for creator=%s (%s)", slug, creator.id)
|
||||||
|
return {"status": "queued", "creator_id": str(creator.id)}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Usage Analytics ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _PeriodStats(BaseModel):
|
||||||
|
request_count: int
|
||||||
|
total_tokens: int
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class _CreatorUsage(BaseModel):
|
||||||
|
creator_slug: str
|
||||||
|
request_count: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class _UserUsage(BaseModel):
|
||||||
|
identifier: str # display_name or IP
|
||||||
|
request_count: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class _DailyCount(BaseModel):
|
||||||
|
date: str # ISO date YYYY-MM-DD
|
||||||
|
request_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class UsageStatsResponse(BaseModel):
|
||||||
|
today: _PeriodStats
|
||||||
|
week: _PeriodStats
|
||||||
|
month: _PeriodStats
|
||||||
|
top_creators: list[_CreatorUsage]
|
||||||
|
top_users: list[_UserUsage]
|
||||||
|
daily_counts: list[_DailyCount]
|
||||||
|
|
||||||
|
|
||||||
|
async def _period_stats(
|
||||||
|
session: AsyncSession, since: datetime,
|
||||||
|
) -> _PeriodStats:
|
||||||
|
"""Aggregate token stats for chat usage since a given timestamp."""
|
||||||
|
stmt = select(
|
||||||
|
func.count().label("cnt"),
|
||||||
|
func.coalesce(func.sum(ChatUsageLog.total_tokens), 0).label("total"),
|
||||||
|
func.coalesce(func.sum(ChatUsageLog.prompt_tokens), 0).label("prompt"),
|
||||||
|
func.coalesce(func.sum(ChatUsageLog.completion_tokens), 0).label("completion"),
|
||||||
|
).where(ChatUsageLog.created_at >= since)
|
||||||
|
row = (await session.execute(stmt)).one()
|
||||||
|
return _PeriodStats(
|
||||||
|
request_count=row.cnt,
|
||||||
|
total_tokens=row.total,
|
||||||
|
prompt_tokens=row.prompt,
|
||||||
|
completion_tokens=row.completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/usage", response_model=UsageStatsResponse)
|
||||||
|
async def get_usage_stats(
|
||||||
|
_admin: Annotated[User, Depends(_require_admin)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""Aggregated chat usage statistics. Admin only."""
|
||||||
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
week_start = today_start - timedelta(days=today_start.weekday()) # Monday
|
||||||
|
month_start = today_start.replace(day=1)
|
||||||
|
|
||||||
|
today = await _period_stats(session, today_start)
|
||||||
|
week = await _period_stats(session, week_start)
|
||||||
|
month = await _period_stats(session, month_start)
|
||||||
|
|
||||||
|
# Top 10 creators by total tokens (this month)
|
||||||
|
creator_stmt = (
|
||||||
|
select(
|
||||||
|
ChatUsageLog.creator_slug,
|
||||||
|
func.count().label("cnt"),
|
||||||
|
func.coalesce(func.sum(ChatUsageLog.total_tokens), 0).label("total"),
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
ChatUsageLog.created_at >= month_start,
|
||||||
|
ChatUsageLog.creator_slug.isnot(None),
|
||||||
|
)
|
||||||
|
.group_by(ChatUsageLog.creator_slug)
|
||||||
|
.order_by(func.sum(ChatUsageLog.total_tokens).desc())
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
creator_rows = (await session.execute(creator_stmt)).all()
|
||||||
|
top_creators = [
|
||||||
|
_CreatorUsage(creator_slug=r.creator_slug, request_count=r.cnt, total_tokens=r.total)
|
||||||
|
for r in creator_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
# Top 10 users by request count (this month)
|
||||||
|
# Join with users table to get display_name; fall back to IP for anonymous
|
||||||
|
user_stmt = (
|
||||||
|
select(
|
||||||
|
ChatUsageLog.user_id,
|
||||||
|
ChatUsageLog.client_ip,
|
||||||
|
func.count().label("cnt"),
|
||||||
|
func.coalesce(func.sum(ChatUsageLog.total_tokens), 0).label("total"),
|
||||||
|
)
|
||||||
|
.where(ChatUsageLog.created_at >= month_start)
|
||||||
|
.group_by(ChatUsageLog.user_id, ChatUsageLog.client_ip)
|
||||||
|
.order_by(func.count().desc())
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
user_rows = (await session.execute(user_stmt)).all()
|
||||||
|
|
||||||
|
# Resolve user display names
|
||||||
|
user_ids = [r.user_id for r in user_rows if r.user_id is not None]
|
||||||
|
name_map: dict[str, str] = {}
|
||||||
|
if user_ids:
|
||||||
|
name_result = await session.execute(
|
||||||
|
select(User.id, User.display_name).where(User.id.in_(user_ids))
|
||||||
|
)
|
||||||
|
for uid, name in name_result.all():
|
||||||
|
name_map[str(uid)] = name
|
||||||
|
|
||||||
|
top_users = [
|
||||||
|
_UserUsage(
|
||||||
|
identifier=name_map.get(str(r.user_id), r.client_ip or "anonymous")
|
||||||
|
if r.user_id
|
||||||
|
else (r.client_ip or "anonymous"),
|
||||||
|
request_count=r.cnt,
|
||||||
|
total_tokens=r.total,
|
||||||
|
)
|
||||||
|
for r in user_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
# Daily request counts for last 7 days
|
||||||
|
seven_days_ago = today_start - timedelta(days=6)
|
||||||
|
day_col = func.date_trunc("day", ChatUsageLog.created_at).label("day")
|
||||||
|
daily_stmt = (
|
||||||
|
select(day_col, func.count().label("cnt"))
|
||||||
|
.where(ChatUsageLog.created_at >= seven_days_ago)
|
||||||
|
.group_by(day_col)
|
||||||
|
.order_by(day_col)
|
||||||
|
)
|
||||||
|
daily_rows = (await session.execute(daily_stmt)).all()
|
||||||
|
daily_counts = [
|
||||||
|
_DailyCount(date=r.day.strftime("%Y-%m-%d"), request_count=r.cnt)
|
||||||
|
for r in daily_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
return UsageStatsResponse(
|
||||||
|
today=today,
|
||||||
|
week=week,
|
||||||
|
month=month,
|
||||||
|
top_creators=top_creators,
|
||||||
|
top_users=top_users,
|
||||||
|
daily_counts=daily_counts,
|
||||||
|
)
|
||||||
189
backend/routers/auth.py
Normal file
189
backend/routers/auth.py
Normal file
|
|
@ -0,0 +1,189 @@
|
||||||
|
"""Auth router — registration, login, profile management."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from auth import (
|
||||||
|
create_access_token,
|
||||||
|
get_current_user,
|
||||||
|
hash_password,
|
||||||
|
reject_impersonation,
|
||||||
|
verify_password,
|
||||||
|
)
|
||||||
|
from database import get_session
|
||||||
|
from models import Creator, InviteCode, User
|
||||||
|
from schemas import (
|
||||||
|
LoginRequest,
|
||||||
|
RegisterRequest,
|
||||||
|
TokenResponse,
|
||||||
|
UpdateProfileRequest,
|
||||||
|
UserResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("chrysopedia.auth")
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Registration ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register(
|
||||||
|
body: RegisterRequest,
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""Register a new user with a valid invite code."""
|
||||||
|
# 1. Validate invite code
|
||||||
|
result = await session.execute(
|
||||||
|
select(InviteCode).where(InviteCode.code == body.invite_code)
|
||||||
|
)
|
||||||
|
invite = result.scalar_one_or_none()
|
||||||
|
if invite is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid invite code")
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
if invite.expires_at is not None and invite.expires_at < now:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invite code has expired")
|
||||||
|
|
||||||
|
if invite.uses_remaining <= 0:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invite code exhausted")
|
||||||
|
|
||||||
|
# 2. Check email uniqueness
|
||||||
|
existing = await session.execute(select(User).where(User.email == body.email))
|
||||||
|
if existing.scalar_one_or_none() is not None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already registered")
|
||||||
|
|
||||||
|
# 3. Optionally resolve creator_id from slug
|
||||||
|
creator_id = None
|
||||||
|
if body.creator_slug:
|
||||||
|
creator_result = await session.execute(
|
||||||
|
select(Creator).where(Creator.slug == body.creator_slug)
|
||||||
|
)
|
||||||
|
creator = creator_result.scalar_one_or_none()
|
||||||
|
if creator is not None:
|
||||||
|
creator_id = creator.id
|
||||||
|
|
||||||
|
# 4. Create user
|
||||||
|
user = User(
|
||||||
|
email=body.email,
|
||||||
|
hashed_password=hash_password(body.password),
|
||||||
|
display_name=body.display_name,
|
||||||
|
creator_id=creator_id,
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
|
||||||
|
# 5. Decrement invite code uses
|
||||||
|
invite.uses_remaining -= 1
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
|
||||||
|
logger.info("User registered: %s (email=%s)", user.id, user.email)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
# ── Login ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=TokenResponse)
|
||||||
|
async def login(
|
||||||
|
body: LoginRequest,
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""Authenticate with email + password, return JWT."""
|
||||||
|
result = await session.execute(select(User).where(User.email == body.email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if user is None or not verify_password(body.password, user.hashed_password):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid email or password",
|
||||||
|
)
|
||||||
|
|
||||||
|
token = create_access_token(user.id, user.role.value)
|
||||||
|
logger.info("User logged in: %s", user.id)
|
||||||
|
return TokenResponse(access_token=token)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Profile ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
async def get_profile(
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
):
|
||||||
|
"""Return the current user's profile."""
|
||||||
|
resp = UserResponse.model_validate(current_user)
|
||||||
|
admin_id = getattr(current_user, "_impersonating_admin_id", None)
|
||||||
|
if admin_id is not None:
|
||||||
|
resp.impersonating = True
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me", response_model=UserResponse)
|
||||||
|
async def update_profile(
|
||||||
|
body: UpdateProfileRequest,
|
||||||
|
current_user: Annotated[User, Depends(reject_impersonation)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""Update the current user's display name and/or password."""
|
||||||
|
if body.display_name is not None:
|
||||||
|
current_user.display_name = body.display_name
|
||||||
|
|
||||||
|
if body.new_password is not None:
|
||||||
|
if body.current_password is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Current password required to set new password",
|
||||||
|
)
|
||||||
|
if not verify_password(body.current_password, current_user.hashed_password):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Current password is incorrect",
|
||||||
|
)
|
||||||
|
current_user.hashed_password = hash_password(body.new_password)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(current_user)
|
||||||
|
|
||||||
|
logger.info("Profile updated: %s", current_user.id)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
# ── Seed ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def seed_invite_codes(session: AsyncSession) -> None:
|
||||||
|
"""Create default invite code if none exist. Call from lifespan or CLI."""
|
||||||
|
result = await session.execute(select(InviteCode))
|
||||||
|
if result.scalar_one_or_none() is None:
|
||||||
|
session.add(InviteCode(
|
||||||
|
code="CHRYSOPEDIA-ALPHA-2026",
|
||||||
|
uses_remaining=100,
|
||||||
|
))
|
||||||
|
await session.commit()
|
||||||
|
logger.info("Seeded default invite code: CHRYSOPEDIA-ALPHA-2026")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Onboarding ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/onboarding-complete", response_model=UserResponse)
|
||||||
|
async def complete_onboarding(
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""Mark the current user's onboarding as completed."""
|
||||||
|
current_user.onboarding_completed = True
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(current_user)
|
||||||
|
logger.info("Onboarding completed: %s", current_user.id)
|
||||||
|
return UserResponse.model_validate(current_user)
|
||||||
145
backend/routers/chat.py
Normal file
145
backend/routers/chat.py
Normal file
|
|
@ -0,0 +1,145 @@
|
||||||
|
"""Chat endpoint — POST /api/v1/chat with SSE streaming response.
|
||||||
|
|
||||||
|
Accepts a query and optional creator filter, returns a Server-Sent Events
|
||||||
|
stream with sources, token, done, and error events.
|
||||||
|
|
||||||
|
Rate limiting: per-user (authenticated), per-IP (anonymous), and per-creator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from auth import get_optional_user
|
||||||
|
from chat_service import ChatService
|
||||||
|
from config import Settings, get_settings
|
||||||
|
from database import get_session
|
||||||
|
from models import User
|
||||||
|
from rate_limiter import RateLimiter
|
||||||
|
from redis_client import get_redis
|
||||||
|
|
||||||
|
logger = logging.getLogger("chrysopedia.chat.router")
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
"""Request body for the chat endpoint."""
|
||||||
|
|
||||||
|
query: str = Field(..., min_length=1, max_length=1000)
|
||||||
|
creator: str | None = None
|
||||||
|
conversation_id: str | None = None
|
||||||
|
personality_weight: float = Field(default=0.0, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client_ip(request: Request) -> str:
|
||||||
|
"""Extract client IP, preferring X-Forwarded-For behind a reverse proxy."""
|
||||||
|
forwarded = request.headers.get("x-forwarded-for")
|
||||||
|
if forwarded:
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
return request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=None)
|
||||||
|
async def chat(
|
||||||
|
body: ChatRequest,
|
||||||
|
request: Request,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
settings: Settings = Depends(get_settings),
|
||||||
|
user: User | None = Depends(get_optional_user),
|
||||||
|
):
|
||||||
|
"""Stream a chat response as Server-Sent Events.
|
||||||
|
|
||||||
|
Rate limits are checked before processing:
|
||||||
|
- Authenticated users: ``rate_limit_user_per_hour`` requests/hour
|
||||||
|
- Anonymous (IP-based): ``rate_limit_ip_per_hour`` requests/hour
|
||||||
|
- Per-creator (if creator filter set): ``rate_limit_creator_per_hour`` requests/hour
|
||||||
|
|
||||||
|
SSE protocol:
|
||||||
|
- ``event: sources`` — citation metadata array (sent first)
|
||||||
|
- ``event: token`` — streamed text chunk (repeated)
|
||||||
|
- ``event: done`` — completion metadata with cascade_tier, conversation_id
|
||||||
|
- ``event: error`` — error message (on failure)
|
||||||
|
"""
|
||||||
|
client_ip = _get_client_ip(request)
|
||||||
|
user_id = user.id if user else None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"chat_request query=%r creator=%r cid=%r weight=%.2f user=%s ip=%s",
|
||||||
|
body.query, body.creator, body.conversation_id,
|
||||||
|
body.personality_weight, user_id, client_ip,
|
||||||
|
)
|
||||||
|
|
||||||
|
redis = await get_redis()
|
||||||
|
|
||||||
|
# ── Rate limiting ───────────────────────────────────────────────────
|
||||||
|
limiter = RateLimiter(redis)
|
||||||
|
|
||||||
|
# User-based limit (authenticated) or IP-based limit (anonymous)
|
||||||
|
if user_id:
|
||||||
|
identity_key = RateLimiter.key("user", str(user_id))
|
||||||
|
identity_limit = settings.rate_limit_user_per_hour
|
||||||
|
else:
|
||||||
|
identity_key = RateLimiter.key("ip", client_ip)
|
||||||
|
identity_limit = settings.rate_limit_ip_per_hour
|
||||||
|
|
||||||
|
result = await limiter.check_rate_limit(identity_key, identity_limit, window_seconds=3600)
|
||||||
|
if not result.allowed:
|
||||||
|
scope = "user" if user_id else "ip"
|
||||||
|
logger.warning(
|
||||||
|
"rate_limit_exceeded scope=%s key=%s remaining=%d retry_after=%d",
|
||||||
|
scope, identity_key, result.remaining, result.retry_after,
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"error": "Rate limit exceeded",
|
||||||
|
"retry_after": result.retry_after,
|
||||||
|
},
|
||||||
|
headers={"Retry-After": str(result.retry_after)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Per-creator limit (if creator filter is provided)
|
||||||
|
if body.creator:
|
||||||
|
creator_key = RateLimiter.key("creator", body.creator)
|
||||||
|
creator_result = await limiter.check_rate_limit(
|
||||||
|
creator_key, settings.rate_limit_creator_per_hour, window_seconds=3600,
|
||||||
|
)
|
||||||
|
if not creator_result.allowed:
|
||||||
|
logger.warning(
|
||||||
|
"rate_limit_exceeded scope=creator key=%s retry_after=%d",
|
||||||
|
creator_key, creator_result.retry_after,
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"error": "Creator rate limit exceeded",
|
||||||
|
"retry_after": creator_result.retry_after,
|
||||||
|
},
|
||||||
|
headers={"Retry-After": str(creator_result.retry_after)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Stream response ─────────────────────────────────────────────────
|
||||||
|
service = ChatService(settings, redis=redis)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
service.stream_response(
|
||||||
|
query=body.query,
|
||||||
|
db=db,
|
||||||
|
creator=body.creator,
|
||||||
|
conversation_id=body.conversation_id,
|
||||||
|
personality_weight=body.personality_weight,
|
||||||
|
user_id=user_id,
|
||||||
|
client_ip=client_ip,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
322
backend/routers/consent.py
Normal file
322
backend/routers/consent.py
Normal file
|
|
@ -0,0 +1,322 @@
|
||||||
|
"""Consent router — per-video consent toggles with versioned audit trail.
|
||||||
|
|
||||||
|
Creator endpoints (ownership-gated):
|
||||||
|
GET /consent/videos List consent for the current creator's videos
|
||||||
|
GET /consent/videos/{video_id} Single video consent status
|
||||||
|
PUT /consent/videos/{video_id} Upsert consent (partial update, audit logged)
|
||||||
|
GET /consent/videos/{video_id}/history Audit trail for a video
|
||||||
|
|
||||||
|
Admin endpoint:
|
||||||
|
GET /consent/admin/summary Aggregate consent flag counts
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from auth import get_current_user, reject_impersonation, require_role
|
||||||
|
from database import get_session
|
||||||
|
from models import (
|
||||||
|
ConsentAuditLog,
|
||||||
|
ConsentField,
|
||||||
|
SourceVideo,
|
||||||
|
User,
|
||||||
|
UserRole,
|
||||||
|
VideoConsent,
|
||||||
|
)
|
||||||
|
from schemas import (
|
||||||
|
ConsentAuditEntry,
|
||||||
|
ConsentListResponse,
|
||||||
|
ConsentSummary,
|
||||||
|
VideoConsentRead,
|
||||||
|
VideoConsentUpdate,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("chrysopedia.consent")
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/consent", tags=["consent"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _verify_video_ownership(
|
||||||
|
video_id: uuid.UUID,
|
||||||
|
user: User,
|
||||||
|
session: AsyncSession,
|
||||||
|
) -> SourceVideo:
|
||||||
|
"""Load a SourceVideo and verify the user owns it (or is admin).
|
||||||
|
|
||||||
|
Returns the SourceVideo on success.
|
||||||
|
Raises 403 if user has no creator_id or doesn't own the video.
|
||||||
|
Raises 404 if video doesn't exist.
|
||||||
|
"""
|
||||||
|
result = await session.execute(
|
||||||
|
select(SourceVideo).where(SourceVideo.id == video_id)
|
||||||
|
)
|
||||||
|
video = result.scalar_one_or_none()
|
||||||
|
if video is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Video not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Admin bypasses ownership check
|
||||||
|
if user.role == UserRole.admin:
|
||||||
|
return video
|
||||||
|
|
||||||
|
if user.creator_id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="User is not linked to a creator profile",
|
||||||
|
)
|
||||||
|
|
||||||
|
if video.creator_id != user.creator_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="You do not own this video",
|
||||||
|
)
|
||||||
|
|
||||||
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def _consent_to_read(consent: VideoConsent, filename: str) -> VideoConsentRead:
|
||||||
|
"""Map a VideoConsent ORM instance to the read schema."""
|
||||||
|
return VideoConsentRead(
|
||||||
|
source_video_id=consent.source_video_id,
|
||||||
|
video_filename=filename,
|
||||||
|
creator_id=consent.creator_id,
|
||||||
|
kb_inclusion=consent.kb_inclusion,
|
||||||
|
training_usage=consent.training_usage,
|
||||||
|
public_display=consent.public_display,
|
||||||
|
updated_at=consent.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Endpoints ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/videos", response_model=ConsentListResponse)
|
||||||
|
async def list_video_consents(
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
offset: Annotated[int, Query(ge=0)] = 0,
|
||||||
|
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||||
|
):
|
||||||
|
"""List consent records for the current creator's videos."""
|
||||||
|
if current_user.creator_id is None and current_user.role != UserRole.admin:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="User is not linked to a creator profile",
|
||||||
|
)
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(VideoConsent)
|
||||||
|
.join(SourceVideo, VideoConsent.source_video_id == SourceVideo.id)
|
||||||
|
.options(selectinload(VideoConsent.source_video))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-admin sees only their own videos
|
||||||
|
if current_user.role != UserRole.admin:
|
||||||
|
stmt = stmt.where(VideoConsent.creator_id == current_user.creator_id)
|
||||||
|
|
||||||
|
# Count
|
||||||
|
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||||
|
total = (await session.execute(count_stmt)).scalar() or 0
|
||||||
|
|
||||||
|
# Fetch page
|
||||||
|
stmt = stmt.order_by(VideoConsent.updated_at.desc())
|
||||||
|
stmt = stmt.offset(offset).limit(limit)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
consents = result.scalars().all()
|
||||||
|
|
||||||
|
items = [
|
||||||
|
_consent_to_read(c, c.source_video.filename) for c in consents
|
||||||
|
]
|
||||||
|
return ConsentListResponse(items=items, total=total)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/videos/{video_id}", response_model=VideoConsentRead)
|
||||||
|
async def get_video_consent(
|
||||||
|
video_id: uuid.UUID,
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""Get consent status for a single video."""
|
||||||
|
video = await _verify_video_ownership(video_id, current_user, session)
|
||||||
|
|
||||||
|
result = await session.execute(
|
||||||
|
select(VideoConsent).where(VideoConsent.source_video_id == video_id)
|
||||||
|
)
|
||||||
|
consent = result.scalar_one_or_none()
|
||||||
|
if consent is None:
|
||||||
|
# No consent record yet — return defaults
|
||||||
|
return VideoConsentRead(
|
||||||
|
source_video_id=video_id,
|
||||||
|
video_filename=video.filename,
|
||||||
|
creator_id=video.creator_id,
|
||||||
|
kb_inclusion=False,
|
||||||
|
training_usage=False,
|
||||||
|
public_display=True,
|
||||||
|
updated_at=video.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _consent_to_read(consent, video.filename)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/videos/{video_id}", response_model=VideoConsentRead)
|
||||||
|
async def update_video_consent(
|
||||||
|
video_id: uuid.UUID,
|
||||||
|
body: VideoConsentUpdate,
|
||||||
|
current_user: Annotated[User, Depends(reject_impersonation)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
request: Request,
|
||||||
|
):
|
||||||
|
"""Upsert consent for a video. Only non-None fields are changed.
|
||||||
|
|
||||||
|
Creates audit log entries for each changed field with incrementing
|
||||||
|
version numbers.
|
||||||
|
"""
|
||||||
|
video = await _verify_video_ownership(video_id, current_user, session)
|
||||||
|
|
||||||
|
# Load or create consent record
|
||||||
|
result = await session.execute(
|
||||||
|
select(VideoConsent).where(VideoConsent.source_video_id == video_id)
|
||||||
|
)
|
||||||
|
consent = result.scalar_one_or_none()
|
||||||
|
is_new = consent is None
|
||||||
|
|
||||||
|
if is_new:
|
||||||
|
consent = VideoConsent(
|
||||||
|
source_video_id=video_id,
|
||||||
|
creator_id=video.creator_id,
|
||||||
|
updated_by=current_user.id,
|
||||||
|
)
|
||||||
|
session.add(consent)
|
||||||
|
await session.flush() # get consent.id for audit entries
|
||||||
|
|
||||||
|
# Determine the next version number
|
||||||
|
max_version_result = await session.execute(
|
||||||
|
select(func.coalesce(func.max(ConsentAuditLog.version), 0)).where(
|
||||||
|
ConsentAuditLog.video_consent_id == consent.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
next_version = (max_version_result.scalar() or 0) + 1
|
||||||
|
|
||||||
|
# Collect client IP for audit
|
||||||
|
client_ip = request.client.host if request.client else None
|
||||||
|
|
||||||
|
# Apply changes and build audit entries
|
||||||
|
fields_changed: list[str] = []
|
||||||
|
update_data = body.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
for field_name, new_value in update_data.items():
|
||||||
|
# Validate field name against the enum
|
||||||
|
try:
|
||||||
|
ConsentField(field_name)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
old_value = getattr(consent, field_name)
|
||||||
|
|
||||||
|
# Skip if no actual change
|
||||||
|
if old_value == new_value:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update the consent record
|
||||||
|
setattr(consent, field_name, new_value)
|
||||||
|
fields_changed.append(field_name)
|
||||||
|
|
||||||
|
# Create audit entry
|
||||||
|
audit_entry = ConsentAuditLog(
|
||||||
|
video_consent_id=consent.id,
|
||||||
|
version=next_version,
|
||||||
|
field_name=field_name,
|
||||||
|
old_value=old_value if not is_new else None,
|
||||||
|
new_value=new_value,
|
||||||
|
changed_by=current_user.id,
|
||||||
|
ip_address=client_ip,
|
||||||
|
)
|
||||||
|
session.add(audit_entry)
|
||||||
|
next_version += 1
|
||||||
|
|
||||||
|
if fields_changed:
|
||||||
|
consent.updated_by = current_user.id
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(consent)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Consent updated: video_id=%s fields_changed=%s user=%s",
|
||||||
|
video_id,
|
||||||
|
fields_changed,
|
||||||
|
current_user.id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No actual changes — still commit if we created a new record
|
||||||
|
if is_new:
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(consent)
|
||||||
|
else:
|
||||||
|
# Nothing changed, no audit entries
|
||||||
|
pass
|
||||||
|
|
||||||
|
return _consent_to_read(consent, video.filename)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/videos/{video_id}/history", response_model=list[ConsentAuditEntry])
|
||||||
|
async def get_consent_history(
|
||||||
|
video_id: uuid.UUID,
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""Get the audit trail for a video's consent changes."""
|
||||||
|
await _verify_video_ownership(video_id, current_user, session)
|
||||||
|
|
||||||
|
# Find the consent record
|
||||||
|
result = await session.execute(
|
||||||
|
select(VideoConsent).where(VideoConsent.source_video_id == video_id)
|
||||||
|
)
|
||||||
|
consent = result.scalar_one_or_none()
|
||||||
|
if consent is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Fetch audit entries ordered by version
|
||||||
|
audit_result = await session.execute(
|
||||||
|
select(ConsentAuditLog)
|
||||||
|
.where(ConsentAuditLog.video_consent_id == consent.id)
|
||||||
|
.order_by(ConsentAuditLog.version.asc())
|
||||||
|
)
|
||||||
|
return audit_result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/admin/summary",
|
||||||
|
response_model=ConsentSummary,
|
||||||
|
dependencies=[Depends(require_role(UserRole.admin))],
|
||||||
|
)
|
||||||
|
async def consent_admin_summary(
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""Aggregate consent flag counts across all videos (admin only)."""
|
||||||
|
result = await session.execute(
|
||||||
|
select(
|
||||||
|
func.count().label("total"),
|
||||||
|
func.sum(VideoConsent.kb_inclusion.cast(int)).label("kb"),
|
||||||
|
func.sum(VideoConsent.training_usage.cast(int)).label("tu"),
|
||||||
|
func.sum(VideoConsent.public_display.cast(int)).label("pd"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.one()
|
||||||
|
return ConsentSummary(
|
||||||
|
total_videos=row.total or 0,
|
||||||
|
kb_inclusion_granted=row.kb or 0,
|
||||||
|
training_usage_granted=row.tu or 0,
|
||||||
|
public_display_granted=row.pd or 0,
|
||||||
|
)
|
||||||
172
backend/routers/creator_chapters.py
Normal file
172
backend/routers/creator_chapters.py
Normal file
|
|
@ -0,0 +1,172 @@
|
||||||
|
"""Creator chapter management endpoints — review, edit, reorder, approve chapters.
|
||||||
|
|
||||||
|
Auth-guarded endpoints for creators to manage auto-detected chapters for
|
||||||
|
their videos before publication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from auth import get_current_user
|
||||||
|
from database import get_session
|
||||||
|
from models import ChapterStatus, KeyMoment, SourceVideo, User
|
||||||
|
from schemas import (
|
||||||
|
ChapterBulkApproveRequest,
|
||||||
|
ChapterMarkerRead,
|
||||||
|
ChapterReorderRequest,
|
||||||
|
ChapterUpdate,
|
||||||
|
ChaptersResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("chrysopedia.creator_chapters")
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/creator", tags=["creator-chapters"])
|
||||||
|
|
||||||
|
|
||||||
|
async def _verify_creator_owns_video(
|
||||||
|
current_user: User,
|
||||||
|
video_id: uuid.UUID,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Verify the user is a creator and owns the specified video."""
|
||||||
|
if current_user.creator_id is None:
|
||||||
|
raise HTTPException(status_code=403, detail="No creator profile linked")
|
||||||
|
|
||||||
|
video = (await db.execute(
|
||||||
|
select(SourceVideo).where(
|
||||||
|
SourceVideo.id == video_id,
|
||||||
|
SourceVideo.creator_id == current_user.creator_id,
|
||||||
|
)
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if video is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Video not found or not owned by you")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{video_id}/chapters", response_model=ChaptersResponse)
|
||||||
|
async def get_creator_chapters(
|
||||||
|
video_id: uuid.UUID,
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> ChaptersResponse:
|
||||||
|
"""Return all chapters for a creator's video (all statuses)."""
|
||||||
|
await _verify_creator_owns_video(current_user, video_id, db)
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(KeyMoment)
|
||||||
|
.where(KeyMoment.source_video_id == video_id)
|
||||||
|
.order_by(KeyMoment.sort_order, KeyMoment.start_time)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
moments = result.scalars().all()
|
||||||
|
logger.debug("Creator chapters for %s: %d", video_id, len(moments))
|
||||||
|
return ChaptersResponse(
|
||||||
|
video_id=video_id,
|
||||||
|
chapters=[ChapterMarkerRead.model_validate(m) for m in moments],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/chapters/{chapter_id}", response_model=ChapterMarkerRead)
|
||||||
|
async def update_chapter(
|
||||||
|
chapter_id: uuid.UUID,
|
||||||
|
body: ChapterUpdate,
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> ChapterMarkerRead:
|
||||||
|
"""Update a single chapter (title, times, status)."""
|
||||||
|
if current_user.creator_id is None:
|
||||||
|
raise HTTPException(status_code=403, detail="No creator profile linked")
|
||||||
|
|
||||||
|
# Fetch the chapter and verify ownership via the video
|
||||||
|
chapter = (await db.execute(
|
||||||
|
select(KeyMoment).where(KeyMoment.id == chapter_id)
|
||||||
|
)).scalar_one_or_none()
|
||||||
|
if chapter is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Chapter not found")
|
||||||
|
|
||||||
|
await _verify_creator_owns_video(current_user, chapter.source_video_id, db)
|
||||||
|
|
||||||
|
# Apply partial updates
|
||||||
|
update_data = body.model_dump(exclude_unset=True)
|
||||||
|
if "chapter_status" in update_data:
|
||||||
|
update_data["chapter_status"] = ChapterStatus(update_data["chapter_status"])
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(chapter, field, value)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(chapter)
|
||||||
|
logger.info("Updated chapter %s: %s", chapter_id, list(update_data.keys()))
|
||||||
|
return ChapterMarkerRead.model_validate(chapter)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{video_id}/chapters/reorder", response_model=ChaptersResponse)
|
||||||
|
async def reorder_chapters(
|
||||||
|
video_id: uuid.UUID,
|
||||||
|
body: ChapterReorderRequest,
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> ChaptersResponse:
|
||||||
|
"""Reorder chapters for a video by setting sort_order values."""
|
||||||
|
await _verify_creator_owns_video(current_user, video_id, db)
|
||||||
|
|
||||||
|
for item in body.chapters:
|
||||||
|
await db.execute(
|
||||||
|
update(KeyMoment)
|
||||||
|
.where(KeyMoment.id == item.id, KeyMoment.source_video_id == video_id)
|
||||||
|
.values(sort_order=item.sort_order)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# Return updated list
|
||||||
|
stmt = (
|
||||||
|
select(KeyMoment)
|
||||||
|
.where(KeyMoment.source_video_id == video_id)
|
||||||
|
.order_by(KeyMoment.sort_order, KeyMoment.start_time)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
moments = result.scalars().all()
|
||||||
|
logger.info("Reordered %d chapters for video %s", len(body.chapters), video_id)
|
||||||
|
return ChaptersResponse(
|
||||||
|
video_id=video_id,
|
||||||
|
chapters=[ChapterMarkerRead.model_validate(m) for m in moments],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{video_id}/chapters/approve", response_model=ChaptersResponse)
|
||||||
|
async def bulk_approve_chapters(
|
||||||
|
video_id: uuid.UUID,
|
||||||
|
body: ChapterBulkApproveRequest,
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> ChaptersResponse:
|
||||||
|
"""Bulk-approve chapters by ID list."""
|
||||||
|
await _verify_creator_owns_video(current_user, video_id, db)
|
||||||
|
|
||||||
|
if body.chapter_ids:
|
||||||
|
await db.execute(
|
||||||
|
update(KeyMoment)
|
||||||
|
.where(
|
||||||
|
KeyMoment.id.in_(body.chapter_ids),
|
||||||
|
KeyMoment.source_video_id == video_id,
|
||||||
|
)
|
||||||
|
.values(chapter_status=ChapterStatus.approved)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
logger.info("Bulk-approved %d chapters for video %s", len(body.chapter_ids), video_id)
|
||||||
|
|
||||||
|
# Return updated list
|
||||||
|
stmt = (
|
||||||
|
select(KeyMoment)
|
||||||
|
.where(KeyMoment.source_video_id == video_id)
|
||||||
|
.order_by(KeyMoment.sort_order, KeyMoment.start_time)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
moments = result.scalars().all()
|
||||||
|
return ChaptersResponse(
|
||||||
|
video_id=video_id,
|
||||||
|
chapters=[ChapterMarkerRead.model_validate(m) for m in moments],
|
||||||
|
)
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue