feat: Truncation detection, batched classification, and pipeline auto-resume
Three resilience improvements to the pipeline: 1. LLMResponse(str) subclass carries finish_reason metadata from the LLM. _safe_parse_llm_response detects truncation (finish=length) and raises LLMTruncationError instead of wastefully retrying with a JSON nudge that makes the prompt even longer. 2. Stage 4 classification now batches moments (20 per call) instead of sending all moments in a single LLM call. Prevents context window overflow for videos with many moments. Batch results are merged with reindexed moment_index values. 3. run_pipeline auto-resumes from the last completed stage on error/retry instead of always restarting from stage 2. Queries pipeline_events for the most recent run to find completed stages. clean_reprocess trigger still forces a full restart. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
5984129e25
commit
e80094dc05
2 changed files with 285 additions and 81 deletions
|
|
@ -28,6 +28,38 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
# ── LLM Response wrapper ─────────────────────────────────────────────────────
|
||||
|
||||
class LLMResponse(str):
|
||||
"""String subclass that carries LLM response metadata.
|
||||
|
||||
Backward-compatible with all code that treats the response as a plain
|
||||
string, but callers that know about it can inspect finish_reason and
|
||||
the truncated property.
|
||||
"""
|
||||
finish_reason: str | None
|
||||
prompt_tokens: int | None
|
||||
completion_tokens: int | None
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
text: str,
|
||||
finish_reason: str | None = None,
|
||||
prompt_tokens: int | None = None,
|
||||
completion_tokens: int | None = None,
|
||||
):
|
||||
obj = super().__new__(cls, text)
|
||||
obj.finish_reason = finish_reason
|
||||
obj.prompt_tokens = prompt_tokens
|
||||
obj.completion_tokens = completion_tokens
|
||||
return obj
|
||||
|
||||
@property
|
||||
def truncated(self) -> bool:
|
||||
"""True if the model hit its token limit before finishing."""
|
||||
return self.finish_reason == "length"
|
||||
|
||||
# ── Think-tag stripping ──────────────────────────────────────────────────────
|
||||
|
||||
_THINK_PATTERN = re.compile(r"<think>.*?</think>", re.DOTALL)
|
||||
|
|
@ -149,7 +181,7 @@ class LLMClient:
|
|||
model_override: str | None = None,
|
||||
on_complete: "Callable | None" = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> str:
|
||||
) -> "LLMResponse":
|
||||
"""Send a chat completion request, falling back on connection/timeout errors.
|
||||
|
||||
Parameters
|
||||
|
|
@ -174,8 +206,8 @@ class LLMClient:
|
|||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Raw completion text from the model (think tags stripped if thinking).
|
||||
LLMResponse
|
||||
Raw completion text (str subclass) with finish_reason metadata.
|
||||
"""
|
||||
kwargs: dict = {}
|
||||
effective_system = system_prompt
|
||||
|
|
@ -230,6 +262,7 @@ class LLMClient:
|
|||
)
|
||||
if modality == "thinking":
|
||||
raw = strip_think_tags(raw)
|
||||
finish = response.choices[0].finish_reason if response.choices else None
|
||||
if on_complete is not None:
|
||||
try:
|
||||
on_complete(
|
||||
|
|
@ -238,11 +271,16 @@ class LLMClient:
|
|||
completion_tokens=usage.completion_tokens if usage else None,
|
||||
total_tokens=usage.total_tokens if usage else None,
|
||||
content=raw,
|
||||
finish_reason=response.choices[0].finish_reason if response.choices else None,
|
||||
finish_reason=finish,
|
||||
)
|
||||
except Exception as cb_exc:
|
||||
logger.warning("on_complete callback failed: %s", cb_exc)
|
||||
return raw
|
||||
return LLMResponse(
|
||||
raw,
|
||||
finish_reason=finish,
|
||||
prompt_tokens=usage.prompt_tokens if usage else None,
|
||||
completion_tokens=usage.completion_tokens if usage else None,
|
||||
)
|
||||
|
||||
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
||||
logger.warning(
|
||||
|
|
@ -270,6 +308,7 @@ class LLMClient:
|
|||
)
|
||||
if modality == "thinking":
|
||||
raw = strip_think_tags(raw)
|
||||
finish = response.choices[0].finish_reason if response.choices else None
|
||||
if on_complete is not None:
|
||||
try:
|
||||
on_complete(
|
||||
|
|
@ -278,12 +317,17 @@ class LLMClient:
|
|||
completion_tokens=usage.completion_tokens if usage else None,
|
||||
total_tokens=usage.total_tokens if usage else None,
|
||||
content=raw,
|
||||
finish_reason=response.choices[0].finish_reason if response.choices else None,
|
||||
finish_reason=finish,
|
||||
is_fallback=True,
|
||||
)
|
||||
except Exception as cb_exc:
|
||||
logger.warning("on_complete callback failed: %s", cb_exc)
|
||||
return raw
|
||||
return LLMResponse(
|
||||
raw,
|
||||
finish_reason=finish,
|
||||
prompt_tokens=usage.prompt_tokens if usage else None,
|
||||
completion_tokens=usage.completion_tokens if usage else None,
|
||||
)
|
||||
|
||||
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
|
||||
logger.error(
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ from models import (
|
|||
TranscriptSegment,
|
||||
)
|
||||
from pipeline.embedding_client import EmbeddingClient
|
||||
from pipeline.llm_client import LLMClient, estimate_max_tokens
|
||||
from pipeline.llm_client import LLMClient, LLMResponse, estimate_max_tokens
|
||||
from pipeline.qdrant_client import QdrantManager
|
||||
from pipeline.schemas import (
|
||||
ClassificationResult,
|
||||
|
|
@ -49,6 +49,11 @@ from worker import celery_app
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMTruncationError(RuntimeError):
|
||||
"""Raised when the LLM response was truncated (finish_reason=length)."""
|
||||
pass
|
||||
|
||||
|
||||
# ── Error status helper ──────────────────────────────────────────────────────
|
||||
|
||||
def _set_error_status(video_id: str, stage_name: str, error: Exception) -> None:
|
||||
|
|
@ -256,7 +261,7 @@ def _format_taxonomy_for_prompt(tags_data: dict) -> str:
|
|||
|
||||
|
||||
def _safe_parse_llm_response(
|
||||
raw: str,
|
||||
raw,
|
||||
model_cls,
|
||||
llm: LLMClient,
|
||||
system_prompt: str,
|
||||
|
|
@ -265,14 +270,37 @@ def _safe_parse_llm_response(
|
|||
model_override: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
):
|
||||
"""Parse LLM response with one retry on failure.
|
||||
"""Parse LLM response with truncation detection and one retry on failure.
|
||||
|
||||
On malformed response: log the raw text, retry once with a JSON nudge,
|
||||
then raise on second failure.
|
||||
If the response was truncated (finish_reason=length), raises
|
||||
LLMTruncationError immediately — retrying with a JSON nudge would only
|
||||
make things worse by adding tokens to an already-too-large prompt.
|
||||
|
||||
For non-truncation parse failures: retry once with a JSON nudge, then
|
||||
raise on second failure.
|
||||
"""
|
||||
# Check for truncation before attempting parse
|
||||
is_truncated = isinstance(raw, LLMResponse) and raw.truncated
|
||||
if is_truncated:
|
||||
logger.warning(
|
||||
"LLM response truncated (finish=length) for %s. "
|
||||
"prompt_tokens=%s, completion_tokens=%s. Will not retry with nudge.",
|
||||
model_cls.__name__,
|
||||
getattr(raw, "prompt_tokens", "?"),
|
||||
getattr(raw, "completion_tokens", "?"),
|
||||
)
|
||||
|
||||
try:
|
||||
return llm.parse_response(raw, model_cls)
|
||||
except (ValidationError, ValueError, json.JSONDecodeError) as exc:
|
||||
if is_truncated:
|
||||
raise LLMTruncationError(
|
||||
f"LLM output truncated for {model_cls.__name__}: "
|
||||
f"prompt_tokens={getattr(raw, 'prompt_tokens', '?')}, "
|
||||
f"completion_tokens={getattr(raw, 'completion_tokens', '?')}. "
|
||||
f"Response too large for model context window."
|
||||
) from exc
|
||||
|
||||
logger.warning(
|
||||
"First parse attempt failed for %s (%s). Retrying with JSON nudge. "
|
||||
"Raw response (first 500 chars): %.500s",
|
||||
|
|
@ -479,6 +507,66 @@ def stage3_extraction(self, video_id: str, run_id: str | None = None) -> str:
|
|||
|
||||
# ── Stage 4: Classification ─────────────────────────────────────────────────
|
||||
|
||||
# Maximum moments per classification batch. Keeps each LLM call well within
|
||||
# context window limits. Batches are classified independently and merged.
|
||||
_STAGE4_BATCH_SIZE = 20
|
||||
|
||||
|
||||
def _classify_moment_batch(
|
||||
moments_batch: list,
|
||||
batch_offset: int,
|
||||
taxonomy_text: str,
|
||||
system_prompt: str,
|
||||
llm: LLMClient,
|
||||
model_override: str | None,
|
||||
modality: str,
|
||||
hard_limit: int,
|
||||
video_id: str,
|
||||
run_id: str | None,
|
||||
) -> ClassificationResult:
|
||||
"""Classify a single batch of moments. Raises on failure."""
|
||||
moments_lines = []
|
||||
for i, m in enumerate(moments_batch):
|
||||
moments_lines.append(
|
||||
f"[{i}] Title: {m.title}\n"
|
||||
f" Summary: {m.summary}\n"
|
||||
f" Content type: {m.content_type.value}\n"
|
||||
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}"
|
||||
)
|
||||
moments_text = "\n\n".join(moments_lines)
|
||||
|
||||
user_prompt = (
|
||||
f"<taxonomy>\n{taxonomy_text}\n</taxonomy>\n\n"
|
||||
f"<moments>\n{moments_text}\n</moments>"
|
||||
)
|
||||
|
||||
max_tokens = estimate_max_tokens(
|
||||
system_prompt, user_prompt,
|
||||
stage="stage4_classification", hard_limit=hard_limit,
|
||||
)
|
||||
batch_label = f"batch {batch_offset // _STAGE4_BATCH_SIZE + 1} (moments {batch_offset}-{batch_offset + len(moments_batch) - 1})"
|
||||
logger.info(
|
||||
"Stage 4 classifying %s, max_tokens=%d",
|
||||
batch_label, max_tokens,
|
||||
)
|
||||
raw = llm.complete(
|
||||
system_prompt, user_prompt,
|
||||
response_model=ClassificationResult,
|
||||
on_complete=_make_llm_callback(
|
||||
video_id, "stage4_classification",
|
||||
system_prompt=system_prompt, user_prompt=user_prompt,
|
||||
run_id=run_id, context_label=batch_label,
|
||||
),
|
||||
modality=modality, model_override=model_override,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return _safe_parse_llm_response(
|
||||
raw, ClassificationResult, llm, system_prompt, user_prompt,
|
||||
modality=modality, model_override=model_override,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
||||
def stage4_classification(self, video_id: str, run_id: str | None = None) -> str:
|
||||
"""Classify key moments against the canonical tag taxonomy.
|
||||
|
|
@ -487,6 +575,9 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
|
|||
canonical taxonomy, and stores classification results in Redis for
|
||||
stage 5 consumption. Updates content_type if the classifier overrides it.
|
||||
|
||||
For large moment sets, automatically batches into groups of
|
||||
_STAGE4_BATCH_SIZE to stay within model context window limits.
|
||||
|
||||
Stage 4 does NOT change processing_status.
|
||||
|
||||
Returns the video_id for chain compatibility.
|
||||
|
|
@ -510,7 +601,6 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
|
|||
|
||||
if not moments:
|
||||
logger.info("Stage 4: No moments found for video_id=%s, skipping.", video_id)
|
||||
# Store empty classification data
|
||||
_store_classification_data(video_id, [])
|
||||
return video_id
|
||||
|
||||
|
|
@ -518,38 +608,29 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
|
|||
tags_data = _load_canonical_tags()
|
||||
taxonomy_text = _format_taxonomy_for_prompt(tags_data)
|
||||
|
||||
# Build moments text for the LLM
|
||||
moments_lines = []
|
||||
for i, m in enumerate(moments):
|
||||
moments_lines.append(
|
||||
f"[{i}] Title: {m.title}\n"
|
||||
f" Summary: {m.summary}\n"
|
||||
f" Content type: {m.content_type.value}\n"
|
||||
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}"
|
||||
)
|
||||
moments_text = "\n\n".join(moments_lines)
|
||||
|
||||
system_prompt = _load_prompt("stage4_classification.txt")
|
||||
user_prompt = (
|
||||
f"<taxonomy>\n{taxonomy_text}\n</taxonomy>\n\n"
|
||||
f"<moments>\n{moments_text}\n</moments>"
|
||||
)
|
||||
|
||||
llm = _get_llm_client()
|
||||
model_override, modality = _get_stage_config(4)
|
||||
hard_limit = get_settings().llm_max_tokens_hard_limit
|
||||
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage4_classification", hard_limit=hard_limit)
|
||||
logger.info("Stage 4 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
|
||||
raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult, on_complete=_make_llm_callback(video_id, "stage4_classification", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id),
|
||||
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
||||
result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt,
|
||||
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
||||
|
||||
# Batch moments for classification
|
||||
all_classifications = []
|
||||
for batch_start in range(0, len(moments), _STAGE4_BATCH_SIZE):
|
||||
batch = moments[batch_start:batch_start + _STAGE4_BATCH_SIZE]
|
||||
result = _classify_moment_batch(
|
||||
batch, batch_start, taxonomy_text, system_prompt,
|
||||
llm, model_override, modality, hard_limit,
|
||||
video_id, run_id,
|
||||
)
|
||||
# Reindex: batch uses 0-based indices, remap to global indices
|
||||
for cls in result.classifications:
|
||||
cls.moment_index += batch_start
|
||||
all_classifications.extend(result.classifications)
|
||||
|
||||
# Apply content_type overrides and prepare classification data for stage 5
|
||||
classification_data = []
|
||||
moment_ids = [str(m.id) for m in moments]
|
||||
|
||||
for cls in result.classifications:
|
||||
for cls in all_classifications:
|
||||
if 0 <= cls.moment_index < len(moments):
|
||||
moment = moments[cls.moment_index]
|
||||
|
||||
|
|
@ -572,10 +653,12 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
|
|||
_store_classification_data(video_id, classification_data)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
num_batches = (len(moments) + _STAGE4_BATCH_SIZE - 1) // _STAGE4_BATCH_SIZE
|
||||
_emit_event(video_id, "stage4_classification", "complete", run_id=run_id)
|
||||
logger.info(
|
||||
"Stage 4 (classification) completed for video_id=%s in %.1fs — %d moments classified",
|
||||
video_id, elapsed, len(classification_data),
|
||||
"Stage 4 (classification) completed for video_id=%s in %.1fs — "
|
||||
"%d moments classified in %d batch(es)",
|
||||
video_id, elapsed, len(classification_data), num_batches,
|
||||
)
|
||||
return video_id
|
||||
|
||||
|
|
@ -1110,6 +1193,77 @@ def _load_prior_pages(video_id: str) -> list[str]:
|
|||
return []
|
||||
return json.loads(raw)
|
||||
|
||||
|
||||
# ── Stage completion detection for auto-resume ───────────────────────────────
|
||||
|
||||
# Ordered list of pipeline stages for resumability logic
|
||||
_PIPELINE_STAGES = [
|
||||
"stage2_segmentation",
|
||||
"stage3_extraction",
|
||||
"stage4_classification",
|
||||
"stage5_synthesis",
|
||||
"stage6_embed_and_index",
|
||||
]
|
||||
|
||||
_STAGE_TASKS = {
|
||||
"stage2_segmentation": stage2_segmentation,
|
||||
"stage3_extraction": stage3_extraction,
|
||||
"stage4_classification": stage4_classification,
|
||||
"stage5_synthesis": stage5_synthesis,
|
||||
"stage6_embed_and_index": stage6_embed_and_index,
|
||||
}
|
||||
|
||||
|
||||
def _get_last_completed_stage(video_id: str) -> str | None:
|
||||
"""Find the last stage that completed successfully for this video.
|
||||
|
||||
Queries pipeline_events for the most recent run, looking for 'complete'
|
||||
events. Returns the stage name (e.g. 'stage3_extraction') or None if
|
||||
no stages have completed.
|
||||
"""
|
||||
session = _get_sync_session()
|
||||
try:
|
||||
# Find the most recent run for this video
|
||||
from models import PipelineRun
|
||||
latest_run = session.execute(
|
||||
select(PipelineRun)
|
||||
.where(PipelineRun.video_id == video_id)
|
||||
.order_by(PipelineRun.started_at.desc())
|
||||
.limit(1)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if latest_run is None:
|
||||
return None
|
||||
|
||||
# Get all 'complete' events from that run
|
||||
completed_events = session.execute(
|
||||
select(PipelineEvent.stage)
|
||||
.where(
|
||||
PipelineEvent.run_id == str(latest_run.id),
|
||||
PipelineEvent.event_type == "complete",
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
completed_set = set(completed_events)
|
||||
|
||||
# Walk backwards through the ordered stages to find the last completed one
|
||||
last_completed = None
|
||||
for stage_name in _PIPELINE_STAGES:
|
||||
if stage_name in completed_set:
|
||||
last_completed = stage_name
|
||||
else:
|
||||
break # Stop at first gap — stages must be sequential
|
||||
|
||||
if last_completed:
|
||||
logger.info(
|
||||
"Auto-resume: video_id=%s last completed stage=%s (run_id=%s)",
|
||||
video_id, last_completed, latest_run.id,
|
||||
)
|
||||
return last_completed
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
# ── Orchestrator ─────────────────────────────────────────────────────────────
|
||||
|
||||
@celery_app.task
|
||||
|
|
@ -1175,13 +1329,13 @@ def _finish_run(run_id: str, status: str, error_stage: str | None = None) -> Non
|
|||
|
||||
@celery_app.task
|
||||
def run_pipeline(video_id: str, trigger: str = "manual") -> str:
|
||||
"""Orchestrate the full pipeline (stages 2-5) with resumability.
|
||||
"""Orchestrate the full pipeline (stages 2-6) with auto-resume.
|
||||
|
||||
Checks the current processing_status of the video and chains only the
|
||||
stages that still need to run. For example:
|
||||
- queued → stages 2, 3, 4, 5
|
||||
- processing/error → re-run full pipeline
|
||||
- complete → no-op
|
||||
For error/processing status, queries pipeline_events to find the last
|
||||
stage that completed successfully and resumes from the next stage.
|
||||
This avoids re-running expensive LLM stages that already succeeded.
|
||||
|
||||
For clean_reprocess trigger, always starts from stage 2.
|
||||
|
||||
Returns the video_id.
|
||||
"""
|
||||
|
|
@ -1204,6 +1358,13 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str:
|
|||
finally:
|
||||
session.close()
|
||||
|
||||
if status == ProcessingStatus.complete:
|
||||
logger.info(
|
||||
"run_pipeline: video_id=%s already at status=%s, nothing to do.",
|
||||
video_id, status.value,
|
||||
)
|
||||
return video_id
|
||||
|
||||
# Snapshot prior technique pages before pipeline resets key_moments
|
||||
_snapshot_prior_pages(video_id)
|
||||
|
||||
|
|
@ -1211,40 +1372,39 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str:
|
|||
run_id = _create_run(video_id, trigger)
|
||||
logger.info("run_pipeline: created run_id=%s for video_id=%s (trigger=%s)", run_id, video_id, trigger)
|
||||
|
||||
# Build the chain based on current status
|
||||
stages = []
|
||||
if status in (ProcessingStatus.not_started, ProcessingStatus.queued):
|
||||
stages = [
|
||||
stage2_segmentation.s(video_id, run_id=run_id),
|
||||
stage3_extraction.s(run_id=run_id), # receives video_id from previous
|
||||
stage4_classification.s(run_id=run_id),
|
||||
stage5_synthesis.s(run_id=run_id),
|
||||
stage6_embed_and_index.s(run_id=run_id),
|
||||
]
|
||||
elif status == ProcessingStatus.processing:
|
||||
stages = [
|
||||
stage2_segmentation.s(video_id, run_id=run_id),
|
||||
stage3_extraction.s(run_id=run_id),
|
||||
stage4_classification.s(run_id=run_id),
|
||||
stage5_synthesis.s(run_id=run_id),
|
||||
stage6_embed_and_index.s(run_id=run_id),
|
||||
]
|
||||
elif status == ProcessingStatus.error:
|
||||
stages = [
|
||||
stage2_segmentation.s(video_id, run_id=run_id),
|
||||
stage3_extraction.s(run_id=run_id),
|
||||
stage4_classification.s(run_id=run_id),
|
||||
stage5_synthesis.s(run_id=run_id),
|
||||
stage6_embed_and_index.s(run_id=run_id),
|
||||
]
|
||||
elif status == ProcessingStatus.complete:
|
||||
logger.info(
|
||||
"run_pipeline: video_id=%s already at status=%s, nothing to do.",
|
||||
video_id, status.value,
|
||||
)
|
||||
return video_id
|
||||
# Determine which stages to run
|
||||
resume_from_idx = 0 # Default: start from stage 2
|
||||
|
||||
if stages:
|
||||
if trigger != "clean_reprocess" and status in (ProcessingStatus.processing, ProcessingStatus.error):
|
||||
# Try to resume from where we left off
|
||||
last_completed = _get_last_completed_stage(video_id)
|
||||
if last_completed and last_completed in _PIPELINE_STAGES:
|
||||
completed_idx = _PIPELINE_STAGES.index(last_completed)
|
||||
resume_from_idx = completed_idx + 1
|
||||
if resume_from_idx >= len(_PIPELINE_STAGES):
|
||||
logger.info(
|
||||
"run_pipeline: all stages already completed for video_id=%s",
|
||||
video_id,
|
||||
)
|
||||
return video_id
|
||||
|
||||
stages_to_run = _PIPELINE_STAGES[resume_from_idx:]
|
||||
logger.info(
|
||||
"run_pipeline: video_id=%s will run stages: %s (resume_from_idx=%d)",
|
||||
video_id, stages_to_run, resume_from_idx,
|
||||
)
|
||||
|
||||
# Build the Celery chain — first stage gets video_id as arg,
|
||||
# subsequent stages receive it from the previous stage's return value
|
||||
celery_sigs = []
|
||||
for i, stage_name in enumerate(stages_to_run):
|
||||
task_func = _STAGE_TASKS[stage_name]
|
||||
if i == 0:
|
||||
celery_sigs.append(task_func.s(video_id, run_id=run_id))
|
||||
else:
|
||||
celery_sigs.append(task_func.s(run_id=run_id))
|
||||
|
||||
if celery_sigs:
|
||||
# Mark as processing before dispatching
|
||||
session = _get_sync_session()
|
||||
try:
|
||||
|
|
@ -1256,12 +1416,12 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str:
|
|||
finally:
|
||||
session.close()
|
||||
|
||||
pipeline = celery_chain(*stages)
|
||||
pipeline = celery_chain(*celery_sigs)
|
||||
error_cb = mark_pipeline_error.s(video_id, run_id=run_id)
|
||||
pipeline.apply_async(link_error=error_cb)
|
||||
logger.info(
|
||||
"run_pipeline: dispatched %d stages for video_id=%s (run_id=%s)",
|
||||
len(stages), video_id, run_id,
|
||||
"run_pipeline: dispatched %d stages for video_id=%s (run_id=%s, starting at %s)",
|
||||
len(celery_sigs), video_id, run_id, stages_to_run[0],
|
||||
)
|
||||
|
||||
return video_id
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue