feat: Dynamic token estimation for per-stage max_tokens
- Add estimate_tokens() and estimate_max_tokens() to llm_client with stage-specific output ratios (0.3x segmentation, 1.2x extraction, 0.15x classification, 1.5x synthesis) - Add max_tokens override parameter to LLMClient.complete() - Wire all 4 pipeline stages to estimate max_tokens from actual prompt content with 20% buffer and 2048 floor - Add LLM_MAX_TOKENS_HARD_LIMIT=32768 config (dynamic estimator ceiling) - Log token estimates alongside every LLM request Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
df33d15360
commit
c6c15defee
3 changed files with 287 additions and 21 deletions
|
|
@ -43,8 +43,9 @@ class Settings(BaseSettings):
|
|||
llm_stage5_model: str | None = "fyn-llm-agent-think" # synthesis — reasoning
|
||||
llm_stage5_modality: str = "thinking"
|
||||
|
||||
# Max tokens for LLM responses (OpenWebUI defaults to 1000 which truncates pipeline JSON)
|
||||
llm_max_tokens: int = 65536
|
||||
# Dynamic token estimation — each stage calculates max_tokens from input size
|
||||
llm_max_tokens_hard_limit: int = 32768 # Hard ceiling for dynamic estimator
|
||||
llm_max_tokens: int = 65536 # Fallback when no estimate is provided
|
||||
|
||||
# Embedding endpoint
|
||||
embedding_api_url: str = "http://localhost:11434/v1"
|
||||
|
|
|
|||
|
|
@ -14,7 +14,10 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import TypeVar
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
import openai
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -50,6 +53,77 @@ def strip_think_tags(text: str) -> str:
|
|||
return cleaned.strip()
|
||||
|
||||
|
||||
|
||||
# ── Token estimation ─────────────────────────────────────────────────────────
|
||||
|
||||
# Stage-specific output multipliers: estimated output tokens as a ratio of input tokens.
|
||||
# These are empirically tuned based on observed pipeline behavior.
|
||||
_STAGE_OUTPUT_RATIOS: dict[str, float] = {
|
||||
"stage2_segmentation": 0.3, # Compact topic groups — much smaller than input
|
||||
"stage3_extraction": 1.2, # Detailed moments with summaries — can exceed input
|
||||
"stage4_classification": 0.15, # Index + category + tags per moment — very compact
|
||||
"stage5_synthesis": 1.5, # Full prose technique pages — heaviest output
|
||||
}
|
||||
|
||||
# Minimum floor so we never send a trivially small max_tokens
|
||||
_MIN_MAX_TOKENS = 2048
|
||||
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Estimate token count from text using a chars-per-token heuristic.
|
||||
|
||||
Uses 3.5 chars/token which is conservative for English + JSON markup.
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
return max(1, int(len(text) / 3.5))
|
||||
|
||||
|
||||
def estimate_max_tokens(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
stage: str | None = None,
|
||||
hard_limit: int = 32768,
|
||||
) -> int:
|
||||
"""Estimate the max_tokens parameter for an LLM call.
|
||||
|
||||
Calculates expected output size based on input size and stage-specific
|
||||
multipliers. The result is clamped between _MIN_MAX_TOKENS and hard_limit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
system_prompt:
|
||||
The system prompt text.
|
||||
user_prompt:
|
||||
The user prompt text (transcript, moments, etc.).
|
||||
stage:
|
||||
Pipeline stage name (e.g. "stage3_extraction"). If None or unknown,
|
||||
uses a default 1.0x multiplier.
|
||||
hard_limit:
|
||||
Absolute ceiling — never exceed this value.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Estimated max_tokens value to pass to the LLM API.
|
||||
"""
|
||||
input_tokens = estimate_tokens(system_prompt) + estimate_tokens(user_prompt)
|
||||
ratio = _STAGE_OUTPUT_RATIOS.get(stage or "", 1.0)
|
||||
estimated_output = int(input_tokens * ratio)
|
||||
|
||||
# Add a 20% buffer for JSON overhead and variability
|
||||
estimated_output = int(estimated_output * 1.2)
|
||||
|
||||
# Clamp to [_MIN_MAX_TOKENS, hard_limit]
|
||||
result = max(_MIN_MAX_TOKENS, min(estimated_output, hard_limit))
|
||||
|
||||
logger.info(
|
||||
"Token estimate: input≈%d, stage=%s, ratio=%.2f, estimated_output=%d, max_tokens=%d (hard_limit=%d)",
|
||||
input_tokens, stage or "default", ratio, estimated_output, result, hard_limit,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""Sync LLM client that tries a primary endpoint and falls back on failure."""
|
||||
|
||||
|
|
@ -73,6 +147,8 @@ class LLMClient:
|
|||
response_model: type[BaseModel] | None = None,
|
||||
modality: str = "chat",
|
||||
model_override: str | None = None,
|
||||
on_complete: "Callable | None" = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> str:
|
||||
"""Send a chat completion request, falling back on connection/timeout errors.
|
||||
|
||||
|
|
@ -92,6 +168,9 @@ class LLMClient:
|
|||
model_override:
|
||||
Model name to use instead of the default. If None, uses the
|
||||
configured default for the endpoint.
|
||||
max_tokens:
|
||||
Override for max_tokens on this call. If None, falls back to
|
||||
the configured ``llm_max_tokens`` from settings.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -123,12 +202,14 @@ class LLMClient:
|
|||
|
||||
primary_model = model_override or self.settings.llm_model
|
||||
fallback_model = self.settings.llm_fallback_model
|
||||
effective_max_tokens = max_tokens if max_tokens is not None else self.settings.llm_max_tokens
|
||||
|
||||
logger.info(
|
||||
"LLM request: model=%s, modality=%s, response_model=%s",
|
||||
"LLM request: model=%s, modality=%s, response_model=%s, max_tokens=%d",
|
||||
primary_model,
|
||||
modality,
|
||||
response_model.__name__ if response_model else None,
|
||||
effective_max_tokens,
|
||||
)
|
||||
|
||||
# --- Try primary endpoint ---
|
||||
|
|
@ -136,7 +217,7 @@ class LLMClient:
|
|||
response = self._primary.chat.completions.create(
|
||||
model=primary_model,
|
||||
messages=messages,
|
||||
max_tokens=self.settings.llm_max_tokens,
|
||||
max_tokens=effective_max_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
raw = response.choices[0].message.content or ""
|
||||
|
|
@ -149,6 +230,18 @@ class LLMClient:
|
|||
)
|
||||
if modality == "thinking":
|
||||
raw = strip_think_tags(raw)
|
||||
if on_complete is not None:
|
||||
try:
|
||||
on_complete(
|
||||
model=primary_model,
|
||||
prompt_tokens=usage.prompt_tokens if usage else None,
|
||||
completion_tokens=usage.completion_tokens if usage else None,
|
||||
total_tokens=usage.total_tokens if usage else None,
|
||||
content=raw,
|
||||
finish_reason=response.choices[0].finish_reason if response.choices else None,
|
||||
)
|
||||
except Exception as cb_exc:
|
||||
logger.warning("on_complete callback failed: %s", cb_exc)
|
||||
return raw
|
||||
|
||||
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
||||
|
|
@ -164,7 +257,7 @@ class LLMClient:
|
|||
response = self._fallback.chat.completions.create(
|
||||
model=fallback_model,
|
||||
messages=messages,
|
||||
max_tokens=self.settings.llm_max_tokens,
|
||||
max_tokens=effective_max_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
raw = response.choices[0].message.content or ""
|
||||
|
|
@ -177,6 +270,19 @@ class LLMClient:
|
|||
)
|
||||
if modality == "thinking":
|
||||
raw = strip_think_tags(raw)
|
||||
if on_complete is not None:
|
||||
try:
|
||||
on_complete(
|
||||
model=fallback_model,
|
||||
prompt_tokens=usage.prompt_tokens if usage else None,
|
||||
completion_tokens=usage.completion_tokens if usage else None,
|
||||
total_tokens=usage.total_tokens if usage else None,
|
||||
content=raw,
|
||||
finish_reason=response.choices[0].finish_reason if response.choices else None,
|
||||
is_fallback=True,
|
||||
)
|
||||
except Exception as cb_exc:
|
||||
logger.warning("on_complete callback failed: %s", cb_exc)
|
||||
return raw
|
||||
|
||||
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from config import get_settings
|
|||
from models import (
|
||||
KeyMoment,
|
||||
KeyMomentContentType,
|
||||
PipelineEvent,
|
||||
ProcessingStatus,
|
||||
SourceVideo,
|
||||
TechniquePage,
|
||||
|
|
@ -33,7 +34,7 @@ from models import (
|
|||
TranscriptSegment,
|
||||
)
|
||||
from pipeline.embedding_client import EmbeddingClient
|
||||
from pipeline.llm_client import LLMClient
|
||||
from pipeline.llm_client import LLMClient, estimate_max_tokens
|
||||
from pipeline.qdrant_client import QdrantManager
|
||||
from pipeline.schemas import (
|
||||
ClassificationResult,
|
||||
|
|
@ -45,6 +46,68 @@ from worker import celery_app
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Pipeline event persistence ───────────────────────────────────────────────
|
||||
|
||||
def _emit_event(
|
||||
video_id: str,
|
||||
stage: str,
|
||||
event_type: str,
|
||||
*,
|
||||
prompt_tokens: int | None = None,
|
||||
completion_tokens: int | None = None,
|
||||
total_tokens: int | None = None,
|
||||
model: str | None = None,
|
||||
duration_ms: int | None = None,
|
||||
payload: dict | None = None,
|
||||
) -> None:
|
||||
"""Persist a pipeline event to the DB. Best-effort -- failures logged, not raised."""
|
||||
try:
|
||||
session = _get_sync_session()
|
||||
try:
|
||||
event = PipelineEvent(
|
||||
video_id=video_id,
|
||||
stage=stage,
|
||||
event_type=event_type,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model=model,
|
||||
duration_ms=duration_ms,
|
||||
payload=payload,
|
||||
)
|
||||
session.add(event)
|
||||
session.commit()
|
||||
finally:
|
||||
session.close()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to emit pipeline event: %s", exc)
|
||||
|
||||
|
||||
def _make_llm_callback(video_id: str, stage: str):
|
||||
"""Create an on_complete callback for LLMClient that emits llm_call events."""
|
||||
def callback(*, model=None, prompt_tokens=None, completion_tokens=None,
|
||||
total_tokens=None, content=None, finish_reason=None,
|
||||
is_fallback=False, **_kwargs):
|
||||
# Truncate content for storage — keep first 2000 chars for debugging
|
||||
truncated = content[:2000] if content and len(content) > 2000 else content
|
||||
_emit_event(
|
||||
video_id=video_id,
|
||||
stage=stage,
|
||||
event_type="llm_call",
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
payload={
|
||||
"content_preview": truncated,
|
||||
"content_length": len(content) if content else 0,
|
||||
"finish_reason": finish_reason,
|
||||
"is_fallback": is_fallback,
|
||||
},
|
||||
)
|
||||
return callback
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
_engine = None
|
||||
|
|
@ -175,6 +238,7 @@ def stage2_segmentation(self, video_id: str) -> str:
|
|||
"""
|
||||
start = time.monotonic()
|
||||
logger.info("Stage 2 (segmentation) starting for video_id=%s", video_id)
|
||||
_emit_event(video_id, "stage2_segmentation", "start")
|
||||
|
||||
session = _get_sync_session()
|
||||
try:
|
||||
|
|
@ -207,9 +271,11 @@ def stage2_segmentation(self, video_id: str) -> str:
|
|||
|
||||
llm = _get_llm_client()
|
||||
model_override, modality = _get_stage_config(2)
|
||||
logger.info("Stage 2 using model=%s, modality=%s", model_override or "default", modality)
|
||||
raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult,
|
||||
modality=modality, model_override=model_override)
|
||||
hard_limit = get_settings().llm_max_tokens_hard_limit
|
||||
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage2_segmentation", hard_limit=hard_limit)
|
||||
logger.info("Stage 2 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
|
||||
raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult, on_complete=_make_llm_callback(video_id, "stage2_segmentation"),
|
||||
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
||||
result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt,
|
||||
modality=modality, model_override=model_override)
|
||||
|
||||
|
|
@ -222,6 +288,7 @@ def stage2_segmentation(self, video_id: str) -> str:
|
|||
|
||||
session.commit()
|
||||
elapsed = time.monotonic() - start
|
||||
_emit_event(video_id, "stage2_segmentation", "complete")
|
||||
logger.info(
|
||||
"Stage 2 (segmentation) completed for video_id=%s in %.1fs — %d topic groups found",
|
||||
video_id, elapsed, len(result.segments),
|
||||
|
|
@ -232,6 +299,7 @@ def stage2_segmentation(self, video_id: str) -> str:
|
|||
raise # Don't retry missing prompt files
|
||||
except Exception as exc:
|
||||
session.rollback()
|
||||
_emit_event(video_id, "stage2_segmentation", "error", payload={"error": str(exc)})
|
||||
logger.error("Stage 2 failed for video_id=%s: %s", video_id, exc)
|
||||
raise self.retry(exc=exc)
|
||||
finally:
|
||||
|
|
@ -251,6 +319,7 @@ def stage3_extraction(self, video_id: str) -> str:
|
|||
"""
|
||||
start = time.monotonic()
|
||||
logger.info("Stage 3 (extraction) starting for video_id=%s", video_id)
|
||||
_emit_event(video_id, "stage3_extraction", "start")
|
||||
|
||||
session = _get_sync_session()
|
||||
try:
|
||||
|
|
@ -278,6 +347,7 @@ def stage3_extraction(self, video_id: str) -> str:
|
|||
system_prompt = _load_prompt("stage3_extraction.txt")
|
||||
llm = _get_llm_client()
|
||||
model_override, modality = _get_stage_config(3)
|
||||
hard_limit = get_settings().llm_max_tokens_hard_limit
|
||||
logger.info("Stage 3 using model=%s, modality=%s", model_override or "default", modality)
|
||||
total_moments = 0
|
||||
|
||||
|
|
@ -295,8 +365,9 @@ def stage3_extraction(self, video_id: str) -> str:
|
|||
f"<segment>\n{segment_text}\n</segment>"
|
||||
)
|
||||
|
||||
raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult,
|
||||
modality=modality, model_override=model_override)
|
||||
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage3_extraction", hard_limit=hard_limit)
|
||||
raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult, on_complete=_make_llm_callback(video_id, "stage3_extraction"),
|
||||
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
||||
result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt,
|
||||
modality=modality, model_override=model_override)
|
||||
|
||||
|
|
@ -329,6 +400,7 @@ def stage3_extraction(self, video_id: str) -> str:
|
|||
|
||||
session.commit()
|
||||
elapsed = time.monotonic() - start
|
||||
_emit_event(video_id, "stage3_extraction", "complete")
|
||||
logger.info(
|
||||
"Stage 3 (extraction) completed for video_id=%s in %.1fs — %d moments created",
|
||||
video_id, elapsed, total_moments,
|
||||
|
|
@ -339,6 +411,7 @@ def stage3_extraction(self, video_id: str) -> str:
|
|||
raise
|
||||
except Exception as exc:
|
||||
session.rollback()
|
||||
_emit_event(video_id, "stage3_extraction", "error", payload={"error": str(exc)})
|
||||
logger.error("Stage 3 failed for video_id=%s: %s", video_id, exc)
|
||||
raise self.retry(exc=exc)
|
||||
finally:
|
||||
|
|
@ -361,6 +434,7 @@ def stage4_classification(self, video_id: str) -> str:
|
|||
"""
|
||||
start = time.monotonic()
|
||||
logger.info("Stage 4 (classification) starting for video_id=%s", video_id)
|
||||
_emit_event(video_id, "stage4_classification", "start")
|
||||
|
||||
session = _get_sync_session()
|
||||
try:
|
||||
|
|
@ -404,9 +478,11 @@ def stage4_classification(self, video_id: str) -> str:
|
|||
|
||||
llm = _get_llm_client()
|
||||
model_override, modality = _get_stage_config(4)
|
||||
logger.info("Stage 4 using model=%s, modality=%s", model_override or "default", modality)
|
||||
raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult,
|
||||
modality=modality, model_override=model_override)
|
||||
hard_limit = get_settings().llm_max_tokens_hard_limit
|
||||
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage4_classification", hard_limit=hard_limit)
|
||||
logger.info("Stage 4 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
|
||||
raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult, on_complete=_make_llm_callback(video_id, "stage4_classification"),
|
||||
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
||||
result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt,
|
||||
modality=modality, model_override=model_override)
|
||||
|
||||
|
|
@ -437,6 +513,7 @@ def stage4_classification(self, video_id: str) -> str:
|
|||
_store_classification_data(video_id, classification_data)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
_emit_event(video_id, "stage4_classification", "complete")
|
||||
logger.info(
|
||||
"Stage 4 (classification) completed for video_id=%s in %.1fs — %d moments classified",
|
||||
video_id, elapsed, len(classification_data),
|
||||
|
|
@ -447,6 +524,7 @@ def stage4_classification(self, video_id: str) -> str:
|
|||
raise
|
||||
except Exception as exc:
|
||||
session.rollback()
|
||||
_emit_event(video_id, "stage4_classification", "error", payload={"error": str(exc)})
|
||||
logger.error("Stage 4 failed for video_id=%s: %s", video_id, exc)
|
||||
raise self.retry(exc=exc)
|
||||
finally:
|
||||
|
|
@ -539,6 +617,7 @@ def stage5_synthesis(self, video_id: str) -> str:
|
|||
"""
|
||||
start = time.monotonic()
|
||||
logger.info("Stage 5 (synthesis) starting for video_id=%s", video_id)
|
||||
_emit_event(video_id, "stage5_synthesis", "start")
|
||||
|
||||
settings = get_settings()
|
||||
session = _get_sync_session()
|
||||
|
|
@ -576,6 +655,7 @@ def stage5_synthesis(self, video_id: str) -> str:
|
|||
system_prompt = _load_prompt("stage5_synthesis.txt")
|
||||
llm = _get_llm_client()
|
||||
model_override, modality = _get_stage_config(5)
|
||||
hard_limit = get_settings().llm_max_tokens_hard_limit
|
||||
logger.info("Stage 5 using model=%s, modality=%s", model_override or "default", modality)
|
||||
pages_created = 0
|
||||
|
||||
|
|
@ -600,17 +680,39 @@ def stage5_synthesis(self, video_id: str) -> str:
|
|||
|
||||
user_prompt = f"<moments>\n{moments_text}\n</moments>"
|
||||
|
||||
raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult,
|
||||
modality=modality, model_override=model_override)
|
||||
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage5_synthesis", hard_limit=hard_limit)
|
||||
raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult, on_complete=_make_llm_callback(video_id, "stage5_synthesis"),
|
||||
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
||||
result = _safe_parse_llm_response(raw, SynthesisResult, llm, system_prompt, user_prompt,
|
||||
modality=modality, model_override=model_override)
|
||||
|
||||
# Load prior pages from this video (snapshot taken before pipeline reset)
|
||||
prior_page_ids = _load_prior_pages(video_id)
|
||||
|
||||
# Create/update TechniquePage rows
|
||||
for page_data in result.pages:
|
||||
# Check if page with this slug already exists
|
||||
existing = session.execute(
|
||||
select(TechniquePage).where(TechniquePage.slug == page_data.slug)
|
||||
).scalar_one_or_none()
|
||||
existing = None
|
||||
|
||||
# First: check prior pages from this video by creator + category
|
||||
if prior_page_ids:
|
||||
existing = session.execute(
|
||||
select(TechniquePage).where(
|
||||
TechniquePage.id.in_(prior_page_ids),
|
||||
TechniquePage.creator_id == video.creator_id,
|
||||
TechniquePage.topic_category == (page_data.topic_category or category),
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if existing:
|
||||
logger.info(
|
||||
"Stage 5: Matched prior page '%s' (id=%s) by creator+category for video_id=%s",
|
||||
existing.slug, existing.id, video_id,
|
||||
)
|
||||
|
||||
# Fallback: check by slug (handles cross-video dedup)
|
||||
if existing is None:
|
||||
existing = session.execute(
|
||||
select(TechniquePage).where(TechniquePage.slug == page_data.slug)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# Snapshot existing content before overwriting
|
||||
|
|
@ -690,6 +792,7 @@ def stage5_synthesis(self, video_id: str) -> str:
|
|||
|
||||
session.commit()
|
||||
elapsed = time.monotonic() - start
|
||||
_emit_event(video_id, "stage5_synthesis", "complete")
|
||||
logger.info(
|
||||
"Stage 5 (synthesis) completed for video_id=%s in %.1fs — %d pages created/updated",
|
||||
video_id, elapsed, pages_created,
|
||||
|
|
@ -700,6 +803,7 @@ def stage5_synthesis(self, video_id: str) -> str:
|
|||
raise
|
||||
except Exception as exc:
|
||||
session.rollback()
|
||||
_emit_event(video_id, "stage5_synthesis", "error", payload={"error": str(exc)})
|
||||
logger.error("Stage 5 failed for video_id=%s: %s", video_id, exc)
|
||||
raise self.retry(exc=exc)
|
||||
finally:
|
||||
|
|
@ -837,6 +941,58 @@ def stage6_embed_and_index(self, video_id: str) -> str:
|
|||
session.close()
|
||||
|
||||
|
||||
|
||||
|
||||
def _snapshot_prior_pages(video_id: str) -> None:
|
||||
"""Save existing technique_page_ids linked to this video before pipeline resets them.
|
||||
|
||||
When a video is reprocessed, stage 3 deletes and recreates key_moments,
|
||||
breaking the link to technique pages. This snapshots the page IDs to Redis
|
||||
so stage 5 can find and update prior pages instead of creating duplicates.
|
||||
"""
|
||||
import redis
|
||||
|
||||
session = _get_sync_session()
|
||||
try:
|
||||
# Find technique pages linked via this video's key moments
|
||||
rows = session.execute(
|
||||
select(KeyMoment.technique_page_id)
|
||||
.where(
|
||||
KeyMoment.source_video_id == video_id,
|
||||
KeyMoment.technique_page_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
).scalars().all()
|
||||
|
||||
page_ids = [str(pid) for pid in rows]
|
||||
|
||||
if page_ids:
|
||||
settings = get_settings()
|
||||
r = redis.Redis.from_url(settings.redis_url)
|
||||
key = f"chrysopedia:prior_pages:{video_id}"
|
||||
r.set(key, json.dumps(page_ids), ex=86400)
|
||||
logger.info(
|
||||
"Snapshot %d prior technique pages for video_id=%s: %s",
|
||||
len(page_ids), video_id, page_ids,
|
||||
)
|
||||
else:
|
||||
logger.info("No prior technique pages for video_id=%s", video_id)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def _load_prior_pages(video_id: str) -> list[str]:
|
||||
"""Load prior technique page IDs from Redis."""
|
||||
import redis
|
||||
|
||||
settings = get_settings()
|
||||
r = redis.Redis.from_url(settings.redis_url)
|
||||
key = f"chrysopedia:prior_pages:{video_id}"
|
||||
raw = r.get(key)
|
||||
if raw is None:
|
||||
return []
|
||||
return json.loads(raw)
|
||||
|
||||
# ── Orchestrator ─────────────────────────────────────────────────────────────
|
||||
|
||||
@celery_app.task
|
||||
|
|
@ -870,6 +1026,9 @@ def run_pipeline(video_id: str) -> str:
|
|||
finally:
|
||||
session.close()
|
||||
|
||||
# Snapshot prior technique pages before pipeline resets key_moments
|
||||
_snapshot_prior_pages(video_id)
|
||||
|
||||
# Build the chain based on current status
|
||||
stages = []
|
||||
if status in (ProcessingStatus.pending, ProcessingStatus.transcribed):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue