From f59718f8c7f5cc64d21e3019c4156185c13cff82 Mon Sep 17 00:00:00 2001 From: jlightner Date: Sun, 29 Mar 2026 22:39:04 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20Created=20sync=20EmbeddingClient,=20Qdr?= =?UTF-8?q?antManager=20with=20idempotent=20colle=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - "backend/pipeline/embedding_client.py" - "backend/pipeline/qdrant_client.py" - "backend/pipeline/stages.py" GSD-Task: S03/T03 --- backend/pipeline/embedding_client.py | 88 +++++++++++++ backend/pipeline/qdrant_client.py | 184 +++++++++++++++++++++++++++ backend/pipeline/stages.py | 135 ++++++++++++++++++++ 3 files changed, 407 insertions(+) create mode 100644 backend/pipeline/embedding_client.py create mode 100644 backend/pipeline/qdrant_client.py diff --git a/backend/pipeline/embedding_client.py b/backend/pipeline/embedding_client.py new file mode 100644 index 0000000..16c0921 --- /dev/null +++ b/backend/pipeline/embedding_client.py @@ -0,0 +1,88 @@ +"""Synchronous embedding client using the OpenAI-compatible /v1/embeddings API. + +Uses ``openai.OpenAI`` (sync) since Celery tasks run synchronously. +Handles connection failures gracefully — embedding is non-blocking for the pipeline. +""" + +from __future__ import annotations + +import logging + +import openai + +from config import Settings + +logger = logging.getLogger(__name__) + + +class EmbeddingClient: + """Sync embedding client backed by an OpenAI-compatible /v1/embeddings endpoint.""" + + def __init__(self, settings: Settings) -> None: + self.settings = settings + self._client = openai.OpenAI( + base_url=settings.embedding_api_url, + api_key=settings.llm_api_key, + ) + + def embed(self, texts: list[str]) -> list[list[float]]: + """Generate embedding vectors for a batch of texts. + + Parameters + ---------- + texts: + List of strings to embed. + + Returns + ------- + list[list[float]] + Embedding vectors. Returns empty list on connection/timeout errors + so the pipeline can continue without embeddings. + """ + if not texts: + return [] + + try: + response = self._client.embeddings.create( + model=self.settings.embedding_model, + input=texts, + ) + except (openai.APIConnectionError, openai.APITimeoutError) as exc: + logger.warning( + "Embedding API unavailable (%s: %s). Skipping %d texts.", + type(exc).__name__, + exc, + len(texts), + ) + return [] + except openai.APIError as exc: + logger.warning( + "Embedding API error (%s: %s). Skipping %d texts.", + type(exc).__name__, + exc, + len(texts), + ) + return [] + + vectors = [item.embedding for item in response.data] + + # Validate dimensions + expected_dim = self.settings.embedding_dimensions + for i, vec in enumerate(vectors): + if len(vec) != expected_dim: + logger.warning( + "Embedding dimension mismatch at index %d: expected %d, got %d. " + "Returning empty list.", + i, + expected_dim, + len(vec), + ) + return [] + + logger.info( + "Generated %d embeddings (dim=%d) using model=%s", + len(vectors), + expected_dim, + self.settings.embedding_model, + ) + return vectors diff --git a/backend/pipeline/qdrant_client.py b/backend/pipeline/qdrant_client.py new file mode 100644 index 0000000..db512fd --- /dev/null +++ b/backend/pipeline/qdrant_client.py @@ -0,0 +1,184 @@ +"""Qdrant vector database manager for collection lifecycle and point upserts. + +Handles collection creation (idempotent) and batch upserts for technique pages +and key moments. Connection failures are non-blocking — the pipeline continues +without search indexing. +""" + +from __future__ import annotations + +import logging +import uuid + +from qdrant_client import QdrantClient +from qdrant_client.http import exceptions as qdrant_exceptions +from qdrant_client.models import Distance, PointStruct, VectorParams + +from config import Settings + +logger = logging.getLogger(__name__) + + +class QdrantManager: + """Manages a Qdrant collection for Chrysopedia technique-page and key-moment vectors.""" + + def __init__(self, settings: Settings) -> None: + self.settings = settings + self._client = QdrantClient(url=settings.qdrant_url) + self._collection = settings.qdrant_collection + + # ── Collection management ──────────────────────────────────────────── + + def ensure_collection(self) -> None: + """Create the collection if it does not already exist. + + Uses cosine distance and the configured embedding dimensions. + """ + try: + if self._client.collection_exists(self._collection): + logger.info("Qdrant collection '%s' already exists.", self._collection) + return + + self._client.create_collection( + collection_name=self._collection, + vectors_config=VectorParams( + size=self.settings.embedding_dimensions, + distance=Distance.COSINE, + ), + ) + logger.info( + "Created Qdrant collection '%s' (dim=%d, cosine).", + self._collection, + self.settings.embedding_dimensions, + ) + except qdrant_exceptions.UnexpectedResponse as exc: + logger.warning( + "Qdrant error during ensure_collection (%s). Skipping.", + exc, + ) + except Exception as exc: + logger.warning( + "Qdrant connection failed during ensure_collection (%s: %s). Skipping.", + type(exc).__name__, + exc, + ) + + # ── Low-level upsert ───────────────────────────────────────────────── + + def upsert_points(self, points: list[PointStruct]) -> None: + """Upsert a batch of pre-built PointStruct objects.""" + if not points: + return + try: + self._client.upsert( + collection_name=self._collection, + points=points, + ) + logger.info( + "Upserted %d points to Qdrant collection '%s'.", + len(points), + self._collection, + ) + except qdrant_exceptions.UnexpectedResponse as exc: + logger.warning( + "Qdrant upsert failed (%s). %d points skipped.", + exc, + len(points), + ) + except Exception as exc: + logger.warning( + "Qdrant upsert connection error (%s: %s). %d points skipped.", + type(exc).__name__, + exc, + len(points), + ) + + # ── High-level upserts ─────────────────────────────────────────────── + + def upsert_technique_pages( + self, + pages: list[dict], + vectors: list[list[float]], + ) -> None: + """Build and upsert PointStructs for technique pages. + + Each page dict must contain: + page_id, creator_id, title, topic_category, topic_tags, summary + + Parameters + ---------- + pages: + Metadata dicts, one per technique page. + vectors: + Corresponding embedding vectors (same order as pages). + """ + if len(pages) != len(vectors): + logger.warning( + "Technique-page count (%d) != vector count (%d). Skipping upsert.", + len(pages), + len(vectors), + ) + return + + points = [] + for page, vector in zip(pages, vectors): + point = PointStruct( + id=str(uuid.uuid4()), + vector=vector, + payload={ + "type": "technique_page", + "page_id": page["page_id"], + "creator_id": page["creator_id"], + "title": page["title"], + "topic_category": page["topic_category"], + "topic_tags": page.get("topic_tags") or [], + "summary": page.get("summary") or "", + }, + ) + points.append(point) + + self.upsert_points(points) + + def upsert_key_moments( + self, + moments: list[dict], + vectors: list[list[float]], + ) -> None: + """Build and upsert PointStructs for key moments. + + Each moment dict must contain: + moment_id, source_video_id, title, start_time, end_time, content_type + + Parameters + ---------- + moments: + Metadata dicts, one per key moment. + vectors: + Corresponding embedding vectors (same order as moments). + """ + if len(moments) != len(vectors): + logger.warning( + "Key-moment count (%d) != vector count (%d). Skipping upsert.", + len(moments), + len(vectors), + ) + return + + points = [] + for moment, vector in zip(moments, vectors): + point = PointStruct( + id=str(uuid.uuid4()), + vector=vector, + payload={ + "type": "key_moment", + "moment_id": moment["moment_id"], + "source_video_id": moment["source_video_id"], + "title": moment["title"], + "start_time": moment["start_time"], + "end_time": moment["end_time"], + "content_type": moment["content_type"], + }, + ) + points.append(point) + + self.upsert_points(points) diff --git a/backend/pipeline/stages.py b/backend/pipeline/stages.py index 68ddfdc..a2893b0 100644 --- a/backend/pipeline/stages.py +++ b/backend/pipeline/stages.py @@ -30,7 +30,9 @@ from models import ( TechniquePage, TranscriptSegment, ) +from pipeline.embedding_client import EmbeddingClient from pipeline.llm_client import LLMClient +from pipeline.qdrant_client import QdrantManager from pipeline.schemas import ( ClassificationResult, ExtractionResult, @@ -577,6 +579,137 @@ def stage5_synthesis(self, video_id: str) -> str: session.close() +# ── Stage 6: Embed & Index ─────────────────────────────────────────────────── + +@celery_app.task(bind=True, max_retries=0) +def stage6_embed_and_index(self, video_id: str) -> str: + """Generate embeddings for technique pages and key moments, then upsert to Qdrant. + + This is a non-blocking side-effect stage — failures are logged but do not + fail the pipeline. Embeddings can be regenerated later. Does NOT update + processing_status. + + Returns the video_id for chain compatibility. + """ + start = time.monotonic() + logger.info("Stage 6 (embed & index) starting for video_id=%s", video_id) + + settings = get_settings() + session = _get_sync_session() + try: + # Load technique pages created for this video's moments + moments = ( + session.execute( + select(KeyMoment) + .where(KeyMoment.source_video_id == video_id) + .order_by(KeyMoment.start_time) + ) + .scalars() + .all() + ) + + # Get unique technique page IDs from moments + page_ids = {m.technique_page_id for m in moments if m.technique_page_id is not None} + pages = [] + if page_ids: + pages = ( + session.execute( + select(TechniquePage).where(TechniquePage.id.in_(page_ids)) + ) + .scalars() + .all() + ) + + if not moments and not pages: + logger.info("Stage 6: No moments or pages for video_id=%s, skipping.", video_id) + return video_id + + embed_client = EmbeddingClient(settings) + qdrant = QdrantManager(settings) + + # Ensure collection exists before upserting + qdrant.ensure_collection() + + # ── Embed & upsert technique pages ─────────────────────────────── + if pages: + page_texts = [] + page_dicts = [] + for p in pages: + text = f"{p.title} {p.summary or ''} {p.topic_category or ''}" + page_texts.append(text.strip()) + page_dicts.append({ + "page_id": str(p.id), + "creator_id": str(p.creator_id), + "title": p.title, + "topic_category": p.topic_category or "", + "topic_tags": p.topic_tags or [], + "summary": p.summary or "", + }) + + page_vectors = embed_client.embed(page_texts) + if page_vectors: + qdrant.upsert_technique_pages(page_dicts, page_vectors) + logger.info( + "Stage 6: Upserted %d technique page vectors for video_id=%s", + len(page_vectors), video_id, + ) + else: + logger.warning( + "Stage 6: Embedding returned empty for %d technique pages (video_id=%s). " + "Skipping page upsert.", + len(page_texts), video_id, + ) + + # ── Embed & upsert key moments ─────────────────────────────────── + if moments: + moment_texts = [] + moment_dicts = [] + for m in moments: + text = f"{m.title} {m.summary or ''}" + moment_texts.append(text.strip()) + moment_dicts.append({ + "moment_id": str(m.id), + "source_video_id": str(m.source_video_id), + "title": m.title, + "start_time": m.start_time, + "end_time": m.end_time, + "content_type": m.content_type.value, + }) + + moment_vectors = embed_client.embed(moment_texts) + if moment_vectors: + qdrant.upsert_key_moments(moment_dicts, moment_vectors) + logger.info( + "Stage 6: Upserted %d key moment vectors for video_id=%s", + len(moment_vectors), video_id, + ) + else: + logger.warning( + "Stage 6: Embedding returned empty for %d key moments (video_id=%s). " + "Skipping moment upsert.", + len(moment_texts), video_id, + ) + + elapsed = time.monotonic() - start + logger.info( + "Stage 6 (embed & index) completed for video_id=%s in %.1fs — " + "%d pages, %d moments processed", + video_id, elapsed, len(pages), len(moments), + ) + return video_id + + except Exception as exc: + # Non-blocking: log error but don't fail the pipeline + logger.error( + "Stage 6 failed for video_id=%s: %s. " + "Pipeline continues — embeddings can be regenerated later.", + video_id, exc, + ) + return video_id + finally: + session.close() + + # ── Orchestrator ───────────────────────────────────────────────────────────── @celery_app.task @@ -618,11 +751,13 @@ def run_pipeline(video_id: str) -> str: stage3_extraction.s(), # receives video_id from previous stage4_classification.s(), stage5_synthesis.s(), + stage6_embed_and_index.s(), ] elif status == ProcessingStatus.extracted: stages = [ stage4_classification.s(video_id), stage5_synthesis.s(), + stage6_embed_and_index.s(), ] elif status in (ProcessingStatus.reviewed, ProcessingStatus.published): logger.info(