feat: Truncation detection, batched classification, and pipeline auto-resume
Three resilience improvements to the pipeline: 1. LLMResponse(str) subclass carries finish_reason metadata from the LLM. _safe_parse_llm_response detects truncation (finish=length) and raises LLMTruncationError instead of wastefully retrying with a JSON nudge that makes the prompt even longer. 2. Stage 4 classification now batches moments (20 per call) instead of sending all moments in a single LLM call. Prevents context window overflow for videos with many moments. Batch results are merged with reindexed moment_index values. 3. run_pipeline auto-resumes from the last completed stage on error/retry instead of always restarting from stage 2. Queries pipeline_events for the most recent run to find completed stages. clean_reprocess trigger still forces a full restart. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
5984129e25
commit
e80094dc05
2 changed files with 285 additions and 81 deletions
|
|
@ -28,6 +28,38 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM Response wrapper ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class LLMResponse(str):
|
||||||
|
"""String subclass that carries LLM response metadata.
|
||||||
|
|
||||||
|
Backward-compatible with all code that treats the response as a plain
|
||||||
|
string, but callers that know about it can inspect finish_reason and
|
||||||
|
the truncated property.
|
||||||
|
"""
|
||||||
|
finish_reason: str | None
|
||||||
|
prompt_tokens: int | None
|
||||||
|
completion_tokens: int | None
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
text: str,
|
||||||
|
finish_reason: str | None = None,
|
||||||
|
prompt_tokens: int | None = None,
|
||||||
|
completion_tokens: int | None = None,
|
||||||
|
):
|
||||||
|
obj = super().__new__(cls, text)
|
||||||
|
obj.finish_reason = finish_reason
|
||||||
|
obj.prompt_tokens = prompt_tokens
|
||||||
|
obj.completion_tokens = completion_tokens
|
||||||
|
return obj
|
||||||
|
|
||||||
|
@property
|
||||||
|
def truncated(self) -> bool:
|
||||||
|
"""True if the model hit its token limit before finishing."""
|
||||||
|
return self.finish_reason == "length"
|
||||||
|
|
||||||
# ── Think-tag stripping ──────────────────────────────────────────────────────
|
# ── Think-tag stripping ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
_THINK_PATTERN = re.compile(r"<think>.*?</think>", re.DOTALL)
|
_THINK_PATTERN = re.compile(r"<think>.*?</think>", re.DOTALL)
|
||||||
|
|
@ -149,7 +181,7 @@ class LLMClient:
|
||||||
model_override: str | None = None,
|
model_override: str | None = None,
|
||||||
on_complete: "Callable | None" = None,
|
on_complete: "Callable | None" = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
) -> str:
|
) -> "LLMResponse":
|
||||||
"""Send a chat completion request, falling back on connection/timeout errors.
|
"""Send a chat completion request, falling back on connection/timeout errors.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
|
@ -174,8 +206,8 @@ class LLMClient:
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
str
|
LLMResponse
|
||||||
Raw completion text from the model (think tags stripped if thinking).
|
Raw completion text (str subclass) with finish_reason metadata.
|
||||||
"""
|
"""
|
||||||
kwargs: dict = {}
|
kwargs: dict = {}
|
||||||
effective_system = system_prompt
|
effective_system = system_prompt
|
||||||
|
|
@ -230,6 +262,7 @@ class LLMClient:
|
||||||
)
|
)
|
||||||
if modality == "thinking":
|
if modality == "thinking":
|
||||||
raw = strip_think_tags(raw)
|
raw = strip_think_tags(raw)
|
||||||
|
finish = response.choices[0].finish_reason if response.choices else None
|
||||||
if on_complete is not None:
|
if on_complete is not None:
|
||||||
try:
|
try:
|
||||||
on_complete(
|
on_complete(
|
||||||
|
|
@ -238,11 +271,16 @@ class LLMClient:
|
||||||
completion_tokens=usage.completion_tokens if usage else None,
|
completion_tokens=usage.completion_tokens if usage else None,
|
||||||
total_tokens=usage.total_tokens if usage else None,
|
total_tokens=usage.total_tokens if usage else None,
|
||||||
content=raw,
|
content=raw,
|
||||||
finish_reason=response.choices[0].finish_reason if response.choices else None,
|
finish_reason=finish,
|
||||||
)
|
)
|
||||||
except Exception as cb_exc:
|
except Exception as cb_exc:
|
||||||
logger.warning("on_complete callback failed: %s", cb_exc)
|
logger.warning("on_complete callback failed: %s", cb_exc)
|
||||||
return raw
|
return LLMResponse(
|
||||||
|
raw,
|
||||||
|
finish_reason=finish,
|
||||||
|
prompt_tokens=usage.prompt_tokens if usage else None,
|
||||||
|
completion_tokens=usage.completion_tokens if usage else None,
|
||||||
|
)
|
||||||
|
|
||||||
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -270,6 +308,7 @@ class LLMClient:
|
||||||
)
|
)
|
||||||
if modality == "thinking":
|
if modality == "thinking":
|
||||||
raw = strip_think_tags(raw)
|
raw = strip_think_tags(raw)
|
||||||
|
finish = response.choices[0].finish_reason if response.choices else None
|
||||||
if on_complete is not None:
|
if on_complete is not None:
|
||||||
try:
|
try:
|
||||||
on_complete(
|
on_complete(
|
||||||
|
|
@ -278,12 +317,17 @@ class LLMClient:
|
||||||
completion_tokens=usage.completion_tokens if usage else None,
|
completion_tokens=usage.completion_tokens if usage else None,
|
||||||
total_tokens=usage.total_tokens if usage else None,
|
total_tokens=usage.total_tokens if usage else None,
|
||||||
content=raw,
|
content=raw,
|
||||||
finish_reason=response.choices[0].finish_reason if response.choices else None,
|
finish_reason=finish,
|
||||||
is_fallback=True,
|
is_fallback=True,
|
||||||
)
|
)
|
||||||
except Exception as cb_exc:
|
except Exception as cb_exc:
|
||||||
logger.warning("on_complete callback failed: %s", cb_exc)
|
logger.warning("on_complete callback failed: %s", cb_exc)
|
||||||
return raw
|
return LLMResponse(
|
||||||
|
raw,
|
||||||
|
finish_reason=finish,
|
||||||
|
prompt_tokens=usage.prompt_tokens if usage else None,
|
||||||
|
completion_tokens=usage.completion_tokens if usage else None,
|
||||||
|
)
|
||||||
|
|
||||||
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
|
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ from models import (
|
||||||
TranscriptSegment,
|
TranscriptSegment,
|
||||||
)
|
)
|
||||||
from pipeline.embedding_client import EmbeddingClient
|
from pipeline.embedding_client import EmbeddingClient
|
||||||
from pipeline.llm_client import LLMClient, estimate_max_tokens
|
from pipeline.llm_client import LLMClient, LLMResponse, estimate_max_tokens
|
||||||
from pipeline.qdrant_client import QdrantManager
|
from pipeline.qdrant_client import QdrantManager
|
||||||
from pipeline.schemas import (
|
from pipeline.schemas import (
|
||||||
ClassificationResult,
|
ClassificationResult,
|
||||||
|
|
@ -49,6 +49,11 @@ from worker import celery_app
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMTruncationError(RuntimeError):
|
||||||
|
"""Raised when the LLM response was truncated (finish_reason=length)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# ── Error status helper ──────────────────────────────────────────────────────
|
# ── Error status helper ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
def _set_error_status(video_id: str, stage_name: str, error: Exception) -> None:
|
def _set_error_status(video_id: str, stage_name: str, error: Exception) -> None:
|
||||||
|
|
@ -256,7 +261,7 @@ def _format_taxonomy_for_prompt(tags_data: dict) -> str:
|
||||||
|
|
||||||
|
|
||||||
def _safe_parse_llm_response(
|
def _safe_parse_llm_response(
|
||||||
raw: str,
|
raw,
|
||||||
model_cls,
|
model_cls,
|
||||||
llm: LLMClient,
|
llm: LLMClient,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
|
|
@ -265,14 +270,37 @@ def _safe_parse_llm_response(
|
||||||
model_override: str | None = None,
|
model_override: str | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
):
|
):
|
||||||
"""Parse LLM response with one retry on failure.
|
"""Parse LLM response with truncation detection and one retry on failure.
|
||||||
|
|
||||||
On malformed response: log the raw text, retry once with a JSON nudge,
|
If the response was truncated (finish_reason=length), raises
|
||||||
then raise on second failure.
|
LLMTruncationError immediately — retrying with a JSON nudge would only
|
||||||
|
make things worse by adding tokens to an already-too-large prompt.
|
||||||
|
|
||||||
|
For non-truncation parse failures: retry once with a JSON nudge, then
|
||||||
|
raise on second failure.
|
||||||
"""
|
"""
|
||||||
|
# Check for truncation before attempting parse
|
||||||
|
is_truncated = isinstance(raw, LLMResponse) and raw.truncated
|
||||||
|
if is_truncated:
|
||||||
|
logger.warning(
|
||||||
|
"LLM response truncated (finish=length) for %s. "
|
||||||
|
"prompt_tokens=%s, completion_tokens=%s. Will not retry with nudge.",
|
||||||
|
model_cls.__name__,
|
||||||
|
getattr(raw, "prompt_tokens", "?"),
|
||||||
|
getattr(raw, "completion_tokens", "?"),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return llm.parse_response(raw, model_cls)
|
return llm.parse_response(raw, model_cls)
|
||||||
except (ValidationError, ValueError, json.JSONDecodeError) as exc:
|
except (ValidationError, ValueError, json.JSONDecodeError) as exc:
|
||||||
|
if is_truncated:
|
||||||
|
raise LLMTruncationError(
|
||||||
|
f"LLM output truncated for {model_cls.__name__}: "
|
||||||
|
f"prompt_tokens={getattr(raw, 'prompt_tokens', '?')}, "
|
||||||
|
f"completion_tokens={getattr(raw, 'completion_tokens', '?')}. "
|
||||||
|
f"Response too large for model context window."
|
||||||
|
) from exc
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"First parse attempt failed for %s (%s). Retrying with JSON nudge. "
|
"First parse attempt failed for %s (%s). Retrying with JSON nudge. "
|
||||||
"Raw response (first 500 chars): %.500s",
|
"Raw response (first 500 chars): %.500s",
|
||||||
|
|
@ -479,6 +507,66 @@ def stage3_extraction(self, video_id: str, run_id: str | None = None) -> str:
|
||||||
|
|
||||||
# ── Stage 4: Classification ─────────────────────────────────────────────────
|
# ── Stage 4: Classification ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Maximum moments per classification batch. Keeps each LLM call well within
|
||||||
|
# context window limits. Batches are classified independently and merged.
|
||||||
|
_STAGE4_BATCH_SIZE = 20
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_moment_batch(
|
||||||
|
moments_batch: list,
|
||||||
|
batch_offset: int,
|
||||||
|
taxonomy_text: str,
|
||||||
|
system_prompt: str,
|
||||||
|
llm: LLMClient,
|
||||||
|
model_override: str | None,
|
||||||
|
modality: str,
|
||||||
|
hard_limit: int,
|
||||||
|
video_id: str,
|
||||||
|
run_id: str | None,
|
||||||
|
) -> ClassificationResult:
|
||||||
|
"""Classify a single batch of moments. Raises on failure."""
|
||||||
|
moments_lines = []
|
||||||
|
for i, m in enumerate(moments_batch):
|
||||||
|
moments_lines.append(
|
||||||
|
f"[{i}] Title: {m.title}\n"
|
||||||
|
f" Summary: {m.summary}\n"
|
||||||
|
f" Content type: {m.content_type.value}\n"
|
||||||
|
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}"
|
||||||
|
)
|
||||||
|
moments_text = "\n\n".join(moments_lines)
|
||||||
|
|
||||||
|
user_prompt = (
|
||||||
|
f"<taxonomy>\n{taxonomy_text}\n</taxonomy>\n\n"
|
||||||
|
f"<moments>\n{moments_text}\n</moments>"
|
||||||
|
)
|
||||||
|
|
||||||
|
max_tokens = estimate_max_tokens(
|
||||||
|
system_prompt, user_prompt,
|
||||||
|
stage="stage4_classification", hard_limit=hard_limit,
|
||||||
|
)
|
||||||
|
batch_label = f"batch {batch_offset // _STAGE4_BATCH_SIZE + 1} (moments {batch_offset}-{batch_offset + len(moments_batch) - 1})"
|
||||||
|
logger.info(
|
||||||
|
"Stage 4 classifying %s, max_tokens=%d",
|
||||||
|
batch_label, max_tokens,
|
||||||
|
)
|
||||||
|
raw = llm.complete(
|
||||||
|
system_prompt, user_prompt,
|
||||||
|
response_model=ClassificationResult,
|
||||||
|
on_complete=_make_llm_callback(
|
||||||
|
video_id, "stage4_classification",
|
||||||
|
system_prompt=system_prompt, user_prompt=user_prompt,
|
||||||
|
run_id=run_id, context_label=batch_label,
|
||||||
|
),
|
||||||
|
modality=modality, model_override=model_override,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
return _safe_parse_llm_response(
|
||||||
|
raw, ClassificationResult, llm, system_prompt, user_prompt,
|
||||||
|
modality=modality, model_override=model_override,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
||||||
def stage4_classification(self, video_id: str, run_id: str | None = None) -> str:
|
def stage4_classification(self, video_id: str, run_id: str | None = None) -> str:
|
||||||
"""Classify key moments against the canonical tag taxonomy.
|
"""Classify key moments against the canonical tag taxonomy.
|
||||||
|
|
@ -487,6 +575,9 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
|
||||||
canonical taxonomy, and stores classification results in Redis for
|
canonical taxonomy, and stores classification results in Redis for
|
||||||
stage 5 consumption. Updates content_type if the classifier overrides it.
|
stage 5 consumption. Updates content_type if the classifier overrides it.
|
||||||
|
|
||||||
|
For large moment sets, automatically batches into groups of
|
||||||
|
_STAGE4_BATCH_SIZE to stay within model context window limits.
|
||||||
|
|
||||||
Stage 4 does NOT change processing_status.
|
Stage 4 does NOT change processing_status.
|
||||||
|
|
||||||
Returns the video_id for chain compatibility.
|
Returns the video_id for chain compatibility.
|
||||||
|
|
@ -510,7 +601,6 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
|
||||||
|
|
||||||
if not moments:
|
if not moments:
|
||||||
logger.info("Stage 4: No moments found for video_id=%s, skipping.", video_id)
|
logger.info("Stage 4: No moments found for video_id=%s, skipping.", video_id)
|
||||||
# Store empty classification data
|
|
||||||
_store_classification_data(video_id, [])
|
_store_classification_data(video_id, [])
|
||||||
return video_id
|
return video_id
|
||||||
|
|
||||||
|
|
@ -518,38 +608,29 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
|
||||||
tags_data = _load_canonical_tags()
|
tags_data = _load_canonical_tags()
|
||||||
taxonomy_text = _format_taxonomy_for_prompt(tags_data)
|
taxonomy_text = _format_taxonomy_for_prompt(tags_data)
|
||||||
|
|
||||||
# Build moments text for the LLM
|
|
||||||
moments_lines = []
|
|
||||||
for i, m in enumerate(moments):
|
|
||||||
moments_lines.append(
|
|
||||||
f"[{i}] Title: {m.title}\n"
|
|
||||||
f" Summary: {m.summary}\n"
|
|
||||||
f" Content type: {m.content_type.value}\n"
|
|
||||||
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}"
|
|
||||||
)
|
|
||||||
moments_text = "\n\n".join(moments_lines)
|
|
||||||
|
|
||||||
system_prompt = _load_prompt("stage4_classification.txt")
|
system_prompt = _load_prompt("stage4_classification.txt")
|
||||||
user_prompt = (
|
|
||||||
f"<taxonomy>\n{taxonomy_text}\n</taxonomy>\n\n"
|
|
||||||
f"<moments>\n{moments_text}\n</moments>"
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = _get_llm_client()
|
llm = _get_llm_client()
|
||||||
model_override, modality = _get_stage_config(4)
|
model_override, modality = _get_stage_config(4)
|
||||||
hard_limit = get_settings().llm_max_tokens_hard_limit
|
hard_limit = get_settings().llm_max_tokens_hard_limit
|
||||||
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage4_classification", hard_limit=hard_limit)
|
|
||||||
logger.info("Stage 4 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
|
# Batch moments for classification
|
||||||
raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult, on_complete=_make_llm_callback(video_id, "stage4_classification", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id),
|
all_classifications = []
|
||||||
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
for batch_start in range(0, len(moments), _STAGE4_BATCH_SIZE):
|
||||||
result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt,
|
batch = moments[batch_start:batch_start + _STAGE4_BATCH_SIZE]
|
||||||
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
result = _classify_moment_batch(
|
||||||
|
batch, batch_start, taxonomy_text, system_prompt,
|
||||||
|
llm, model_override, modality, hard_limit,
|
||||||
|
video_id, run_id,
|
||||||
|
)
|
||||||
|
# Reindex: batch uses 0-based indices, remap to global indices
|
||||||
|
for cls in result.classifications:
|
||||||
|
cls.moment_index += batch_start
|
||||||
|
all_classifications.extend(result.classifications)
|
||||||
|
|
||||||
# Apply content_type overrides and prepare classification data for stage 5
|
# Apply content_type overrides and prepare classification data for stage 5
|
||||||
classification_data = []
|
classification_data = []
|
||||||
moment_ids = [str(m.id) for m in moments]
|
|
||||||
|
|
||||||
for cls in result.classifications:
|
for cls in all_classifications:
|
||||||
if 0 <= cls.moment_index < len(moments):
|
if 0 <= cls.moment_index < len(moments):
|
||||||
moment = moments[cls.moment_index]
|
moment = moments[cls.moment_index]
|
||||||
|
|
||||||
|
|
@ -572,10 +653,12 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
|
||||||
_store_classification_data(video_id, classification_data)
|
_store_classification_data(video_id, classification_data)
|
||||||
|
|
||||||
elapsed = time.monotonic() - start
|
elapsed = time.monotonic() - start
|
||||||
|
num_batches = (len(moments) + _STAGE4_BATCH_SIZE - 1) // _STAGE4_BATCH_SIZE
|
||||||
_emit_event(video_id, "stage4_classification", "complete", run_id=run_id)
|
_emit_event(video_id, "stage4_classification", "complete", run_id=run_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Stage 4 (classification) completed for video_id=%s in %.1fs — %d moments classified",
|
"Stage 4 (classification) completed for video_id=%s in %.1fs — "
|
||||||
video_id, elapsed, len(classification_data),
|
"%d moments classified in %d batch(es)",
|
||||||
|
video_id, elapsed, len(classification_data), num_batches,
|
||||||
)
|
)
|
||||||
return video_id
|
return video_id
|
||||||
|
|
||||||
|
|
@ -1110,6 +1193,77 @@ def _load_prior_pages(video_id: str) -> list[str]:
|
||||||
return []
|
return []
|
||||||
return json.loads(raw)
|
return json.loads(raw)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stage completion detection for auto-resume ───────────────────────────────
|
||||||
|
|
||||||
|
# Ordered list of pipeline stages for resumability logic
|
||||||
|
_PIPELINE_STAGES = [
|
||||||
|
"stage2_segmentation",
|
||||||
|
"stage3_extraction",
|
||||||
|
"stage4_classification",
|
||||||
|
"stage5_synthesis",
|
||||||
|
"stage6_embed_and_index",
|
||||||
|
]
|
||||||
|
|
||||||
|
_STAGE_TASKS = {
|
||||||
|
"stage2_segmentation": stage2_segmentation,
|
||||||
|
"stage3_extraction": stage3_extraction,
|
||||||
|
"stage4_classification": stage4_classification,
|
||||||
|
"stage5_synthesis": stage5_synthesis,
|
||||||
|
"stage6_embed_and_index": stage6_embed_and_index,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_last_completed_stage(video_id: str) -> str | None:
|
||||||
|
"""Find the last stage that completed successfully for this video.
|
||||||
|
|
||||||
|
Queries pipeline_events for the most recent run, looking for 'complete'
|
||||||
|
events. Returns the stage name (e.g. 'stage3_extraction') or None if
|
||||||
|
no stages have completed.
|
||||||
|
"""
|
||||||
|
session = _get_sync_session()
|
||||||
|
try:
|
||||||
|
# Find the most recent run for this video
|
||||||
|
from models import PipelineRun
|
||||||
|
latest_run = session.execute(
|
||||||
|
select(PipelineRun)
|
||||||
|
.where(PipelineRun.video_id == video_id)
|
||||||
|
.order_by(PipelineRun.started_at.desc())
|
||||||
|
.limit(1)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if latest_run is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get all 'complete' events from that run
|
||||||
|
completed_events = session.execute(
|
||||||
|
select(PipelineEvent.stage)
|
||||||
|
.where(
|
||||||
|
PipelineEvent.run_id == str(latest_run.id),
|
||||||
|
PipelineEvent.event_type == "complete",
|
||||||
|
)
|
||||||
|
).scalars().all()
|
||||||
|
|
||||||
|
completed_set = set(completed_events)
|
||||||
|
|
||||||
|
# Walk backwards through the ordered stages to find the last completed one
|
||||||
|
last_completed = None
|
||||||
|
for stage_name in _PIPELINE_STAGES:
|
||||||
|
if stage_name in completed_set:
|
||||||
|
last_completed = stage_name
|
||||||
|
else:
|
||||||
|
break # Stop at first gap — stages must be sequential
|
||||||
|
|
||||||
|
if last_completed:
|
||||||
|
logger.info(
|
||||||
|
"Auto-resume: video_id=%s last completed stage=%s (run_id=%s)",
|
||||||
|
video_id, last_completed, latest_run.id,
|
||||||
|
)
|
||||||
|
return last_completed
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
# ── Orchestrator ─────────────────────────────────────────────────────────────
|
# ── Orchestrator ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@celery_app.task
|
@celery_app.task
|
||||||
|
|
@ -1175,13 +1329,13 @@ def _finish_run(run_id: str, status: str, error_stage: str | None = None) -> Non
|
||||||
|
|
||||||
@celery_app.task
|
@celery_app.task
|
||||||
def run_pipeline(video_id: str, trigger: str = "manual") -> str:
|
def run_pipeline(video_id: str, trigger: str = "manual") -> str:
|
||||||
"""Orchestrate the full pipeline (stages 2-5) with resumability.
|
"""Orchestrate the full pipeline (stages 2-6) with auto-resume.
|
||||||
|
|
||||||
Checks the current processing_status of the video and chains only the
|
For error/processing status, queries pipeline_events to find the last
|
||||||
stages that still need to run. For example:
|
stage that completed successfully and resumes from the next stage.
|
||||||
- queued → stages 2, 3, 4, 5
|
This avoids re-running expensive LLM stages that already succeeded.
|
||||||
- processing/error → re-run full pipeline
|
|
||||||
- complete → no-op
|
For clean_reprocess trigger, always starts from stage 2.
|
||||||
|
|
||||||
Returns the video_id.
|
Returns the video_id.
|
||||||
"""
|
"""
|
||||||
|
|
@ -1204,6 +1358,13 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str:
|
||||||
finally:
|
finally:
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
if status == ProcessingStatus.complete:
|
||||||
|
logger.info(
|
||||||
|
"run_pipeline: video_id=%s already at status=%s, nothing to do.",
|
||||||
|
video_id, status.value,
|
||||||
|
)
|
||||||
|
return video_id
|
||||||
|
|
||||||
# Snapshot prior technique pages before pipeline resets key_moments
|
# Snapshot prior technique pages before pipeline resets key_moments
|
||||||
_snapshot_prior_pages(video_id)
|
_snapshot_prior_pages(video_id)
|
||||||
|
|
||||||
|
|
@ -1211,40 +1372,39 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str:
|
||||||
run_id = _create_run(video_id, trigger)
|
run_id = _create_run(video_id, trigger)
|
||||||
logger.info("run_pipeline: created run_id=%s for video_id=%s (trigger=%s)", run_id, video_id, trigger)
|
logger.info("run_pipeline: created run_id=%s for video_id=%s (trigger=%s)", run_id, video_id, trigger)
|
||||||
|
|
||||||
# Build the chain based on current status
|
# Determine which stages to run
|
||||||
stages = []
|
resume_from_idx = 0 # Default: start from stage 2
|
||||||
if status in (ProcessingStatus.not_started, ProcessingStatus.queued):
|
|
||||||
stages = [
|
|
||||||
stage2_segmentation.s(video_id, run_id=run_id),
|
|
||||||
stage3_extraction.s(run_id=run_id), # receives video_id from previous
|
|
||||||
stage4_classification.s(run_id=run_id),
|
|
||||||
stage5_synthesis.s(run_id=run_id),
|
|
||||||
stage6_embed_and_index.s(run_id=run_id),
|
|
||||||
]
|
|
||||||
elif status == ProcessingStatus.processing:
|
|
||||||
stages = [
|
|
||||||
stage2_segmentation.s(video_id, run_id=run_id),
|
|
||||||
stage3_extraction.s(run_id=run_id),
|
|
||||||
stage4_classification.s(run_id=run_id),
|
|
||||||
stage5_synthesis.s(run_id=run_id),
|
|
||||||
stage6_embed_and_index.s(run_id=run_id),
|
|
||||||
]
|
|
||||||
elif status == ProcessingStatus.error:
|
|
||||||
stages = [
|
|
||||||
stage2_segmentation.s(video_id, run_id=run_id),
|
|
||||||
stage3_extraction.s(run_id=run_id),
|
|
||||||
stage4_classification.s(run_id=run_id),
|
|
||||||
stage5_synthesis.s(run_id=run_id),
|
|
||||||
stage6_embed_and_index.s(run_id=run_id),
|
|
||||||
]
|
|
||||||
elif status == ProcessingStatus.complete:
|
|
||||||
logger.info(
|
|
||||||
"run_pipeline: video_id=%s already at status=%s, nothing to do.",
|
|
||||||
video_id, status.value,
|
|
||||||
)
|
|
||||||
return video_id
|
|
||||||
|
|
||||||
if stages:
|
if trigger != "clean_reprocess" and status in (ProcessingStatus.processing, ProcessingStatus.error):
|
||||||
|
# Try to resume from where we left off
|
||||||
|
last_completed = _get_last_completed_stage(video_id)
|
||||||
|
if last_completed and last_completed in _PIPELINE_STAGES:
|
||||||
|
completed_idx = _PIPELINE_STAGES.index(last_completed)
|
||||||
|
resume_from_idx = completed_idx + 1
|
||||||
|
if resume_from_idx >= len(_PIPELINE_STAGES):
|
||||||
|
logger.info(
|
||||||
|
"run_pipeline: all stages already completed for video_id=%s",
|
||||||
|
video_id,
|
||||||
|
)
|
||||||
|
return video_id
|
||||||
|
|
||||||
|
stages_to_run = _PIPELINE_STAGES[resume_from_idx:]
|
||||||
|
logger.info(
|
||||||
|
"run_pipeline: video_id=%s will run stages: %s (resume_from_idx=%d)",
|
||||||
|
video_id, stages_to_run, resume_from_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build the Celery chain — first stage gets video_id as arg,
|
||||||
|
# subsequent stages receive it from the previous stage's return value
|
||||||
|
celery_sigs = []
|
||||||
|
for i, stage_name in enumerate(stages_to_run):
|
||||||
|
task_func = _STAGE_TASKS[stage_name]
|
||||||
|
if i == 0:
|
||||||
|
celery_sigs.append(task_func.s(video_id, run_id=run_id))
|
||||||
|
else:
|
||||||
|
celery_sigs.append(task_func.s(run_id=run_id))
|
||||||
|
|
||||||
|
if celery_sigs:
|
||||||
# Mark as processing before dispatching
|
# Mark as processing before dispatching
|
||||||
session = _get_sync_session()
|
session = _get_sync_session()
|
||||||
try:
|
try:
|
||||||
|
|
@ -1256,12 +1416,12 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str:
|
||||||
finally:
|
finally:
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
pipeline = celery_chain(*stages)
|
pipeline = celery_chain(*celery_sigs)
|
||||||
error_cb = mark_pipeline_error.s(video_id, run_id=run_id)
|
error_cb = mark_pipeline_error.s(video_id, run_id=run_id)
|
||||||
pipeline.apply_async(link_error=error_cb)
|
pipeline.apply_async(link_error=error_cb)
|
||||||
logger.info(
|
logger.info(
|
||||||
"run_pipeline: dispatched %d stages for video_id=%s (run_id=%s)",
|
"run_pipeline: dispatched %d stages for video_id=%s (run_id=%s, starting at %s)",
|
||||||
len(stages), video_id, run_id,
|
len(celery_sigs), video_id, run_id, stages_to_run[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
return video_id
|
return video_id
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue