feat(M001): Desire Economy
Completed slices: - S01: Desire Embedding & Clustering - S02: Fulfillment Flow & Frontend Branch: milestone/M001
This commit is contained in:
parent
a5f0c0e093
commit
5936ab167e
19 changed files with 2612 additions and 19 deletions
28
Makefile
Normal file
28
Makefile
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
# Fractafrag — Docker Compose monorepo
|
||||
# Common development commands
|
||||
|
||||
.PHONY: up down build logs test api-shell worker-shell db-shell
|
||||
|
||||
up:
|
||||
docker compose up -d
|
||||
|
||||
down:
|
||||
docker compose down
|
||||
|
||||
build:
|
||||
docker compose build
|
||||
|
||||
logs:
|
||||
docker compose logs -f
|
||||
|
||||
test:
|
||||
docker compose exec api python -m pytest tests/ -v
|
||||
|
||||
api-shell:
|
||||
docker compose exec api bash
|
||||
|
||||
worker-shell:
|
||||
docker compose exec worker bash
|
||||
|
||||
db-shell:
|
||||
docker compose exec postgres psql -U fracta -d fractafrag
|
||||
6
services/api/=0.20.0
Normal file
6
services/api/=0.20.0
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
Defaulting to user installation because normal site-packages is not writeable
|
||||
Collecting aiosqlite
|
||||
Using cached aiosqlite-0.22.1-py3-none-any.whl.metadata (4.3 kB)
|
||||
Using cached aiosqlite-0.22.1-py3-none-any.whl (17 kB)
|
||||
Installing collected packages: aiosqlite
|
||||
Successfully installed aiosqlite-0.22.1
|
||||
|
|
@ -3,10 +3,10 @@
|
|||
from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import User, Desire
|
||||
from app.models import User, Desire, Shader
|
||||
from app.schemas import DesireCreate, DesirePublic
|
||||
from app.middleware.auth import get_current_user, require_tier
|
||||
|
||||
|
|
@ -29,7 +29,24 @@ async def list_desires(
|
|||
|
||||
query = query.order_by(Desire.heat_score.desc()).limit(limit).offset(offset)
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
desires = list(result.scalars().all())
|
||||
|
||||
# Batch-annotate cluster_count to avoid N+1 queries
|
||||
desire_ids = [d.id for d in desires]
|
||||
if desire_ids:
|
||||
cluster_query = text("""
|
||||
SELECT dc1.desire_id, COUNT(dc2.desire_id) as cluster_count
|
||||
FROM desire_clusters dc1
|
||||
JOIN desire_clusters dc2 ON dc1.cluster_id = dc2.cluster_id
|
||||
WHERE dc1.desire_id = ANY(:desire_ids)
|
||||
GROUP BY dc1.desire_id
|
||||
""")
|
||||
cluster_result = await db.execute(cluster_query, {"desire_ids": desire_ids})
|
||||
cluster_counts = {row[0]: row[1] for row in cluster_result}
|
||||
for d in desires:
|
||||
d.cluster_count = cluster_counts.get(d.id, 0)
|
||||
|
||||
return desires
|
||||
|
||||
|
||||
@router.get("/{desire_id}", response_model=DesirePublic)
|
||||
|
|
@ -38,6 +55,18 @@ async def get_desire(desire_id: UUID, db: AsyncSession = Depends(get_db)):
|
|||
desire = result.scalar_one_or_none()
|
||||
if not desire:
|
||||
raise HTTPException(status_code=404, detail="Desire not found")
|
||||
|
||||
# Annotate cluster_count for single desire
|
||||
cluster_query = text("""
|
||||
SELECT COUNT(dc2.desire_id) as cluster_count
|
||||
FROM desire_clusters dc1
|
||||
JOIN desire_clusters dc2 ON dc1.cluster_id = dc2.cluster_id
|
||||
WHERE dc1.desire_id = :desire_id
|
||||
""")
|
||||
cluster_result = await db.execute(cluster_query, {"desire_id": desire_id})
|
||||
row = cluster_result.first()
|
||||
desire.cluster_count = row[0] if row else 0
|
||||
|
||||
return desire
|
||||
|
||||
|
||||
|
|
@ -55,9 +84,9 @@ async def create_desire(
|
|||
db.add(desire)
|
||||
await db.flush()
|
||||
|
||||
# TODO: Embed prompt text (Track G)
|
||||
# TODO: Check similarity clustering (Track G)
|
||||
# TODO: Enqueue process_desire worker job (Track G)
|
||||
# Fire-and-forget: enqueue embedding + clustering worker task
|
||||
from app.worker import process_desire
|
||||
process_desire.delay(str(desire.id))
|
||||
|
||||
return desire
|
||||
|
||||
|
|
@ -76,6 +105,13 @@ async def fulfill_desire(
|
|||
if desire.status != "open":
|
||||
raise HTTPException(status_code=400, detail="Desire is not open")
|
||||
|
||||
# Validate shader exists and is published
|
||||
shader = (await db.execute(select(Shader).where(Shader.id == shader_id))).scalar_one_or_none()
|
||||
if not shader:
|
||||
raise HTTPException(status_code=404, detail="Shader not found")
|
||||
if shader.status != "published":
|
||||
raise HTTPException(status_code=400, detail="Shader must be published to fulfill a desire")
|
||||
|
||||
from datetime import datetime, timezone
|
||||
desire.status = "fulfilled"
|
||||
desire.fulfilled_by_shader = shader_id
|
||||
|
|
|
|||
|
|
@ -184,6 +184,7 @@ class DesirePublic(BaseModel):
|
|||
tip_amount_cents: int
|
||||
status: str
|
||||
heat_score: float
|
||||
cluster_count: int = 0
|
||||
fulfilled_by_shader: Optional[UUID]
|
||||
fulfilled_at: Optional[datetime]
|
||||
created_at: datetime
|
||||
|
|
|
|||
406
services/api/app/services/clustering.py
Normal file
406
services/api/app/services/clustering.py
Normal file
|
|
@ -0,0 +1,406 @@
|
|||
"""Clustering service — pgvector cosine similarity search and heat calculation.
|
||||
|
||||
Groups desires into clusters based on prompt embedding similarity.
|
||||
Uses pgvector's <=> cosine distance operator for nearest-neighbor search.
|
||||
Heat scores scale linearly with cluster size (more demand → more visible).
|
||||
|
||||
Provides both async (for FastAPI) and sync (for Celery worker) variants.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid as uuid_mod
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.models import DesireCluster
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def find_nearest_cluster(
|
||||
embedding: list[float],
|
||||
db: AsyncSession,
|
||||
threshold: float = 0.82,
|
||||
) -> tuple[UUID | None, float]:
|
||||
"""Find the nearest existing desire cluster for an embedding vector.
|
||||
|
||||
Uses pgvector cosine distance (<=> operator). A threshold of 0.82 means
|
||||
cosine_similarity >= 0.82, i.e., cosine_distance <= 0.18.
|
||||
|
||||
Returns:
|
||||
(cluster_id, similarity) if a match is found within threshold,
|
||||
(None, 0.0) if no match exists.
|
||||
"""
|
||||
max_distance = 1.0 - threshold
|
||||
|
||||
# Raw SQL for pgvector cosine distance — SQLAlchemy ORM doesn't natively
|
||||
# support the <=> operator without extra configuration.
|
||||
query = text("""
|
||||
SELECT d.id AS desire_id,
|
||||
(d.prompt_embedding <=> :embedding) AS distance
|
||||
FROM desires d
|
||||
WHERE d.prompt_embedding IS NOT NULL
|
||||
AND (d.prompt_embedding <=> :embedding) <= :max_distance
|
||||
ORDER BY distance ASC
|
||||
LIMIT 1
|
||||
""")
|
||||
|
||||
result = await db.execute(
|
||||
query,
|
||||
{"embedding": str(embedding), "max_distance": max_distance},
|
||||
)
|
||||
row = result.first()
|
||||
|
||||
if row is None:
|
||||
logger.debug("No nearby cluster found (threshold=%.2f)", threshold)
|
||||
return (None, 0.0)
|
||||
|
||||
# Found a nearby desire — look up its cluster membership
|
||||
matched_desire_id = row.desire_id
|
||||
similarity = 1.0 - row.distance
|
||||
|
||||
cluster_query = text("""
|
||||
SELECT cluster_id FROM desire_clusters
|
||||
WHERE desire_id = :desire_id
|
||||
LIMIT 1
|
||||
""")
|
||||
cluster_result = await db.execute(
|
||||
cluster_query, {"desire_id": matched_desire_id}
|
||||
)
|
||||
cluster_row = cluster_result.first()
|
||||
|
||||
if cluster_row is None:
|
||||
# Nearby desire exists but isn't in a cluster — shouldn't normally
|
||||
# happen but handle gracefully by treating as no match.
|
||||
logger.warning(
|
||||
"Desire %s is nearby (sim=%.3f) but has no cluster assignment",
|
||||
matched_desire_id,
|
||||
similarity,
|
||||
)
|
||||
return (None, 0.0)
|
||||
|
||||
logger.info(
|
||||
"Found nearby cluster %s via desire %s (similarity=%.3f)",
|
||||
cluster_row.cluster_id,
|
||||
matched_desire_id,
|
||||
similarity,
|
||||
)
|
||||
return (cluster_row.cluster_id, similarity)
|
||||
|
||||
|
||||
async def create_cluster(desire_id: UUID, db: AsyncSession) -> UUID:
|
||||
"""Create a new single-member cluster for a desire.
|
||||
|
||||
Returns the new cluster_id.
|
||||
"""
|
||||
cluster_id = uuid_mod.uuid4()
|
||||
entry = DesireCluster(
|
||||
cluster_id=cluster_id,
|
||||
desire_id=desire_id,
|
||||
similarity=1.0,
|
||||
)
|
||||
db.add(entry)
|
||||
await db.flush()
|
||||
logger.info("Created new cluster %s for desire %s", cluster_id, desire_id)
|
||||
return cluster_id
|
||||
|
||||
|
||||
async def add_to_cluster(
|
||||
cluster_id: UUID,
|
||||
desire_id: UUID,
|
||||
similarity: float,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Add a desire to an existing cluster.
|
||||
|
||||
Uses INSERT ... ON CONFLICT DO NOTHING to safely handle re-processing
|
||||
(idempotent — won't duplicate if the desire is already in the cluster).
|
||||
"""
|
||||
insert_query = text("""
|
||||
INSERT INTO desire_clusters (cluster_id, desire_id, similarity)
|
||||
VALUES (:cluster_id, :desire_id, :similarity)
|
||||
ON CONFLICT (cluster_id, desire_id) DO NOTHING
|
||||
""")
|
||||
await db.execute(insert_query, {
|
||||
"cluster_id": cluster_id,
|
||||
"desire_id": desire_id,
|
||||
"similarity": similarity,
|
||||
})
|
||||
await db.flush()
|
||||
logger.info(
|
||||
"Added desire %s to cluster %s (similarity=%.3f)",
|
||||
desire_id,
|
||||
cluster_id,
|
||||
similarity,
|
||||
)
|
||||
|
||||
|
||||
async def recalculate_heat(cluster_id: UUID, db: AsyncSession) -> float:
|
||||
"""Recalculate heat scores for all desires in a cluster.
|
||||
|
||||
Heat = cluster_size (linear scaling). A 3-member cluster means each
|
||||
desire in it gets heat_score = 3.0.
|
||||
|
||||
Returns the new heat score.
|
||||
"""
|
||||
# Count members in this cluster
|
||||
count_query = text("""
|
||||
SELECT COUNT(*) AS cnt FROM desire_clusters
|
||||
WHERE cluster_id = :cluster_id
|
||||
""")
|
||||
result = await db.execute(count_query, {"cluster_id": cluster_id})
|
||||
cluster_size = result.scalar_one()
|
||||
|
||||
heat_score = float(cluster_size)
|
||||
|
||||
# Update all desires in the cluster
|
||||
update_query = text("""
|
||||
UPDATE desires SET heat_score = :heat
|
||||
WHERE id IN (
|
||||
SELECT desire_id FROM desire_clusters
|
||||
WHERE cluster_id = :cluster_id
|
||||
)
|
||||
""")
|
||||
await db.execute(update_query, {
|
||||
"heat": heat_score,
|
||||
"cluster_id": cluster_id,
|
||||
})
|
||||
await db.flush()
|
||||
|
||||
logger.info(
|
||||
"Recalculated heat for cluster %s: size=%d, heat_score=%.1f",
|
||||
cluster_id,
|
||||
cluster_size,
|
||||
heat_score,
|
||||
)
|
||||
return heat_score
|
||||
|
||||
|
||||
async def cluster_desire(
|
||||
desire_id: UUID,
|
||||
embedding: list[float],
|
||||
db: AsyncSession,
|
||||
) -> dict:
|
||||
"""Orchestrate clustering for a single desire.
|
||||
|
||||
Flow: find_nearest_cluster → add_to_cluster + recalculate_heat (if match)
|
||||
or create_cluster (if no match).
|
||||
|
||||
Returns an observability dict:
|
||||
{
|
||||
"cluster_id": UUID,
|
||||
"is_new": bool,
|
||||
"heat_score": float,
|
||||
}
|
||||
"""
|
||||
cluster_id, similarity = await find_nearest_cluster(embedding, db)
|
||||
|
||||
if cluster_id is not None:
|
||||
# Join existing cluster
|
||||
await add_to_cluster(cluster_id, desire_id, similarity, db)
|
||||
heat_score = await recalculate_heat(cluster_id, db)
|
||||
logger.info(
|
||||
"Desire %s joined cluster %s (similarity=%.3f, heat=%.1f)",
|
||||
desire_id,
|
||||
cluster_id,
|
||||
similarity,
|
||||
heat_score,
|
||||
)
|
||||
return {
|
||||
"cluster_id": cluster_id,
|
||||
"is_new": False,
|
||||
"heat_score": heat_score,
|
||||
}
|
||||
else:
|
||||
# Create new single-member cluster
|
||||
cluster_id = await create_cluster(desire_id, db)
|
||||
logger.info(
|
||||
"Desire %s started new cluster %s (heat=1.0)",
|
||||
desire_id,
|
||||
cluster_id,
|
||||
)
|
||||
return {
|
||||
"cluster_id": cluster_id,
|
||||
"is_new": True,
|
||||
"heat_score": 1.0,
|
||||
}
|
||||
|
||||
|
||||
# ── Synchronous variants (for Celery worker context) ─────────────────────
|
||||
|
||||
|
||||
def find_nearest_cluster_sync(
|
||||
embedding: list[float],
|
||||
session: Session,
|
||||
threshold: float = 0.82,
|
||||
) -> tuple[UUID | None, float]:
|
||||
"""Sync variant of find_nearest_cluster for Celery worker context."""
|
||||
max_distance = 1.0 - threshold
|
||||
|
||||
query = text("""
|
||||
SELECT d.id AS desire_id,
|
||||
(d.prompt_embedding <=> :embedding) AS distance
|
||||
FROM desires d
|
||||
WHERE d.prompt_embedding IS NOT NULL
|
||||
AND (d.prompt_embedding <=> :embedding) <= :max_distance
|
||||
ORDER BY distance ASC
|
||||
LIMIT 1
|
||||
""")
|
||||
|
||||
result = session.execute(
|
||||
query,
|
||||
{"embedding": str(embedding), "max_distance": max_distance},
|
||||
)
|
||||
row = result.first()
|
||||
|
||||
if row is None:
|
||||
logger.debug("No nearby cluster found (threshold=%.2f)", threshold)
|
||||
return (None, 0.0)
|
||||
|
||||
matched_desire_id = row.desire_id
|
||||
similarity = 1.0 - row.distance
|
||||
|
||||
cluster_query = text("""
|
||||
SELECT cluster_id FROM desire_clusters
|
||||
WHERE desire_id = :desire_id
|
||||
LIMIT 1
|
||||
""")
|
||||
cluster_result = session.execute(
|
||||
cluster_query, {"desire_id": matched_desire_id}
|
||||
)
|
||||
cluster_row = cluster_result.first()
|
||||
|
||||
if cluster_row is None:
|
||||
logger.warning(
|
||||
"Desire %s is nearby (sim=%.3f) but has no cluster assignment",
|
||||
matched_desire_id,
|
||||
similarity,
|
||||
)
|
||||
return (None, 0.0)
|
||||
|
||||
logger.info(
|
||||
"Found nearby cluster %s via desire %s (similarity=%.3f)",
|
||||
cluster_row.cluster_id,
|
||||
matched_desire_id,
|
||||
similarity,
|
||||
)
|
||||
return (cluster_row.cluster_id, similarity)
|
||||
|
||||
|
||||
def create_cluster_sync(desire_id: UUID, session: Session) -> UUID:
|
||||
"""Sync variant of create_cluster for Celery worker context."""
|
||||
cluster_id = uuid_mod.uuid4()
|
||||
entry = DesireCluster(
|
||||
cluster_id=cluster_id,
|
||||
desire_id=desire_id,
|
||||
similarity=1.0,
|
||||
)
|
||||
session.add(entry)
|
||||
session.flush()
|
||||
logger.info("Created new cluster %s for desire %s", cluster_id, desire_id)
|
||||
return cluster_id
|
||||
|
||||
|
||||
def add_to_cluster_sync(
|
||||
cluster_id: UUID,
|
||||
desire_id: UUID,
|
||||
similarity: float,
|
||||
session: Session,
|
||||
) -> None:
|
||||
"""Sync variant of add_to_cluster for Celery worker context."""
|
||||
insert_query = text("""
|
||||
INSERT INTO desire_clusters (cluster_id, desire_id, similarity)
|
||||
VALUES (:cluster_id, :desire_id, :similarity)
|
||||
ON CONFLICT (cluster_id, desire_id) DO NOTHING
|
||||
""")
|
||||
session.execute(insert_query, {
|
||||
"cluster_id": cluster_id,
|
||||
"desire_id": desire_id,
|
||||
"similarity": similarity,
|
||||
})
|
||||
session.flush()
|
||||
logger.info(
|
||||
"Added desire %s to cluster %s (similarity=%.3f)",
|
||||
desire_id,
|
||||
cluster_id,
|
||||
similarity,
|
||||
)
|
||||
|
||||
|
||||
def recalculate_heat_sync(cluster_id: UUID, session: Session) -> float:
|
||||
"""Sync variant of recalculate_heat for Celery worker context."""
|
||||
count_query = text("""
|
||||
SELECT COUNT(*) AS cnt FROM desire_clusters
|
||||
WHERE cluster_id = :cluster_id
|
||||
""")
|
||||
result = session.execute(count_query, {"cluster_id": cluster_id})
|
||||
cluster_size = result.scalar_one()
|
||||
|
||||
heat_score = float(cluster_size)
|
||||
|
||||
update_query = text("""
|
||||
UPDATE desires SET heat_score = :heat
|
||||
WHERE id IN (
|
||||
SELECT desire_id FROM desire_clusters
|
||||
WHERE cluster_id = :cluster_id
|
||||
)
|
||||
""")
|
||||
session.execute(update_query, {
|
||||
"heat": heat_score,
|
||||
"cluster_id": cluster_id,
|
||||
})
|
||||
session.flush()
|
||||
|
||||
logger.info(
|
||||
"Recalculated heat for cluster %s: size=%d, heat_score=%.1f",
|
||||
cluster_id,
|
||||
cluster_size,
|
||||
heat_score,
|
||||
)
|
||||
return heat_score
|
||||
|
||||
|
||||
def cluster_desire_sync(
|
||||
desire_id: UUID,
|
||||
embedding: list[float],
|
||||
session: Session,
|
||||
) -> dict:
|
||||
"""Sync orchestrator for clustering a single desire (Celery worker context).
|
||||
|
||||
Same flow as async cluster_desire but uses synchronous Session.
|
||||
|
||||
Returns:
|
||||
{"cluster_id": UUID, "is_new": bool, "heat_score": float}
|
||||
"""
|
||||
cluster_id, similarity = find_nearest_cluster_sync(embedding, session)
|
||||
|
||||
if cluster_id is not None:
|
||||
add_to_cluster_sync(cluster_id, desire_id, similarity, session)
|
||||
heat_score = recalculate_heat_sync(cluster_id, session)
|
||||
logger.info(
|
||||
"Desire %s joined cluster %s (similarity=%.3f, heat=%.1f)",
|
||||
desire_id,
|
||||
cluster_id,
|
||||
similarity,
|
||||
heat_score,
|
||||
)
|
||||
return {
|
||||
"cluster_id": cluster_id,
|
||||
"is_new": False,
|
||||
"heat_score": heat_score,
|
||||
}
|
||||
else:
|
||||
cluster_id = create_cluster_sync(desire_id, session)
|
||||
logger.info(
|
||||
"Desire %s started new cluster %s (heat=1.0)",
|
||||
desire_id,
|
||||
cluster_id,
|
||||
)
|
||||
return {
|
||||
"cluster_id": cluster_id,
|
||||
"is_new": True,
|
||||
"heat_score": 1.0,
|
||||
}
|
||||
291
services/api/app/services/embedding.py
Normal file
291
services/api/app/services/embedding.py
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
"""Text embedding service using TF-IDF + TruncatedSVD.
|
||||
|
||||
Converts desire prompt text into 512-dimensional dense vectors suitable
|
||||
for pgvector cosine similarity search. Pre-seeded with shader/visual-art
|
||||
domain vocabulary so the model produces meaningful vectors even from a
|
||||
single short text input.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Shader / visual-art domain seed corpus.
|
||||
# Gives the TF-IDF model a vocabulary foundation so it can produce
|
||||
# meaningful vectors from short creative text descriptions.
|
||||
# Target: 500+ unique TF-IDF features (unigrams + bigrams) to support
|
||||
# near-512 SVD components without heavy padding.
|
||||
_SEED_CORPUS: list[str] = [
|
||||
"particle system fluid simulation dynamics motion",
|
||||
"raymarching signed distance field sdf shapes volumes",
|
||||
"procedural noise fractal pattern generation recursive",
|
||||
"color palette gradient blend interpolation smooth",
|
||||
"audio reactive frequency spectrum visualization beat",
|
||||
"ragdoll physics dark moody atmosphere heavy",
|
||||
"kaleidoscope symmetry rotation mirror reflection",
|
||||
"voronoi cellular texture organic growth biological",
|
||||
"bloom glow post processing effect luminance",
|
||||
"retro pixel art scanlines crt monitor vintage",
|
||||
"geometry morphing vertex displacement deformation mesh",
|
||||
"wave propagation ripple interference oscillation",
|
||||
"fire smoke volumetric rendering density fog",
|
||||
"crystal refraction caustics light transparency",
|
||||
"terrain heightmap erosion landscape mountain valley",
|
||||
"shader glitch distortion databend corruption artifact",
|
||||
"feedback loop recursive transformation iteration",
|
||||
"physics collision rigid body dynamics impulse",
|
||||
"abstract minimal geometric composition shape",
|
||||
"aurora borealis atmospheric optical phenomenon sky",
|
||||
"underwater caustics god rays depth ocean",
|
||||
"cyberpunk neon wireframe grid futuristic urban",
|
||||
"organic growth branching lsystem tree root",
|
||||
"mandelbrot julia set fractal zoom iteration",
|
||||
"cloth simulation soft body drape fabric textile",
|
||||
"dissolve transition threshold mask binary cutoff",
|
||||
"chromatic aberration lens distortion optical warp",
|
||||
"shadow mapping ambient occlusion darkness depth",
|
||||
"motion blur temporal accumulation streak trail",
|
||||
"boids flocking swarm emergence collective behavior",
|
||||
"reaction diffusion turing pattern spots stripes",
|
||||
"perlin simplex worley noise texture procedural",
|
||||
"voxel rendering isometric cube block pixel",
|
||||
"gravity orbital celestial mechanics planet orbit",
|
||||
"psychedelic trippy color shift hue rotation",
|
||||
"spiral fibonacci golden ratio mathematical curve",
|
||||
"explosion debris shatter fragment destruction impact",
|
||||
"rain snow weather precipitation droplet splash",
|
||||
"electric lightning bolt plasma energy discharge",
|
||||
"tunnel infinite corridor perspective vanishing point",
|
||||
"metaball blob isosurface marching cubes implicit",
|
||||
"starfield galaxy nebula cosmic space stellar",
|
||||
"shadow puppet silhouette outline contour edge",
|
||||
"mosaic tessellation tile pattern hexagonal grid",
|
||||
"hologram iridescent spectrum rainbow interference",
|
||||
"ink watercolor paint brush stroke artistic",
|
||||
"sand dune desert wind erosion granular",
|
||||
"ice frost frozen crystal snowflake cold",
|
||||
"magma lava volcanic molten heat flow",
|
||||
"laser beam scanning projection line vector",
|
||||
"DNA helix molecular biology strand protein",
|
||||
"circuit board electronic trace signal digital",
|
||||
"camouflage pattern dithering halftone screening dots",
|
||||
"waveform synthesizer oscillator modulation frequency",
|
||||
"topographic contour map elevation isoline level",
|
||||
"origami fold paper crease geometric angular",
|
||||
"stained glass window colorful segmented panels",
|
||||
"smoke ring vortex toroidal turbulence curl",
|
||||
"pendulum harmonic oscillation swing periodic cycle",
|
||||
"cloud formation cumulus atmospheric convection wispy",
|
||||
"ripple pond surface tension concentric circular",
|
||||
"decay rust corrosion entropy degradation aging",
|
||||
"fiber optic strand luminous filament glow",
|
||||
"prism dispersion spectral separation wavelength band",
|
||||
"radar sonar ping pulse echo scanning",
|
||||
"compass rose navigation cardinal directional symbol",
|
||||
"clock mechanism gear cog rotation mechanical",
|
||||
"barcode matrix encoding data stripe identification",
|
||||
"fingerprint unique biometric whorl ridge pattern",
|
||||
"maze labyrinth path algorithm recursive backtrack",
|
||||
"chess checkerboard alternating square pattern grid",
|
||||
"domino cascade chain sequential trigger falling",
|
||||
"balloon inflation expansion pressure sphere elastic",
|
||||
"ribbon flowing fabric curve spline bezier",
|
||||
"confetti celebration scatter random distribution joyful",
|
||||
"ember spark ignition tiny particle hot",
|
||||
"bubble foam soap iridescent sphere surface tension",
|
||||
"whirlpool maelstrom spinning vortex fluid drain",
|
||||
"mirage shimmer heat haze atmospheric refraction",
|
||||
"echo reverberation delay repetition fading diminish",
|
||||
"pulse heartbeat rhythmic expanding ring concentric",
|
||||
"weave interlocking thread textile warp weft",
|
||||
"honeycomb hexagonal efficient packing natural structure",
|
||||
"coral reef branching organic marine growth colony",
|
||||
"mushroom spore cap stem fungal network mycelium",
|
||||
"neuron synapse network brain signal impulse dendrite",
|
||||
"constellation star connect dot line celestial chart",
|
||||
"seismograph earthquake wave amplitude vibration tremor",
|
||||
"aurora curtain charged particle magnetic solar wind",
|
||||
"tidal wave surge ocean force gravitational pull",
|
||||
"sandstorm particle erosion wind desert visibility",
|
||||
"volcanic eruption ash plume pyroclastic flow magma",
|
||||
]
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
"""Produces 512-dim L2-normalized vectors from text using TF-IDF + SVD."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._vectorizer = TfidfVectorizer(
|
||||
ngram_range=(1, 2),
|
||||
max_features=10_000,
|
||||
stop_words="english",
|
||||
)
|
||||
self._svd = TruncatedSVD(n_components=512, random_state=42)
|
||||
self._corpus: list[str] = list(_SEED_CORPUS)
|
||||
self._fitted: bool = False
|
||||
|
||||
def _fit_if_needed(self) -> None:
|
||||
"""Fit the vectorizer and SVD on the seed corpus if not yet fitted."""
|
||||
if self._fitted:
|
||||
return
|
||||
tfidf_matrix = self._vectorizer.fit_transform(self._corpus)
|
||||
|
||||
# SVD n_components must be < number of features in the TF-IDF matrix.
|
||||
# If the corpus is too small, reduce SVD components temporarily.
|
||||
n_features = tfidf_matrix.shape[1]
|
||||
if n_features < self._svd.n_components:
|
||||
logger.warning(
|
||||
"TF-IDF produced only %d features, reducing SVD components "
|
||||
"from %d to %d",
|
||||
n_features,
|
||||
self._svd.n_components,
|
||||
n_features - 1,
|
||||
)
|
||||
self._svd = TruncatedSVD(
|
||||
n_components=n_features - 1, random_state=42
|
||||
)
|
||||
|
||||
self._svd.fit(tfidf_matrix)
|
||||
self._fitted = True
|
||||
|
||||
def embed_text(self, text: str) -> list[float]:
|
||||
"""Embed a single text into a 512-dim L2-normalized vector.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
List of 512 floats (L2-normalized).
|
||||
|
||||
Raises:
|
||||
ValueError: If text is empty or whitespace-only.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
raise ValueError("Cannot embed empty or whitespace-only text")
|
||||
|
||||
start = time.perf_counter()
|
||||
self._fit_if_needed()
|
||||
|
||||
tfidf = self._vectorizer.transform([text])
|
||||
svd_vec = self._svd.transform(tfidf)[0]
|
||||
|
||||
# L2 normalize — handle zero vectors from OOV text
|
||||
norm = np.linalg.norm(svd_vec)
|
||||
if norm > 1e-10:
|
||||
svd_vec = svd_vec / norm
|
||||
else:
|
||||
# Text produced no recognized vocabulary — generate a
|
||||
# deterministic low-magnitude vector from the text hash
|
||||
# so the vector is non-zero but won't cluster with anything.
|
||||
logger.warning(
|
||||
"Text produced zero TF-IDF vector (no recognized vocabulary): "
|
||||
"'%s'",
|
||||
text[:80],
|
||||
)
|
||||
rng = np.random.RandomState(hash(text) % (2**31))
|
||||
svd_vec = rng.randn(len(svd_vec))
|
||||
svd_vec = svd_vec / np.linalg.norm(svd_vec)
|
||||
|
||||
# Pad to 512 if SVD produced fewer components
|
||||
if len(svd_vec) < 512:
|
||||
padded = np.zeros(512)
|
||||
padded[: len(svd_vec)] = svd_vec
|
||||
# Re-normalize after padding
|
||||
pad_norm = np.linalg.norm(padded)
|
||||
if pad_norm > 0:
|
||||
padded = padded / pad_norm
|
||||
svd_vec = padded
|
||||
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
logger.info(
|
||||
"Embedded text (%d chars) → %d-dim vector in %.1fms",
|
||||
len(text),
|
||||
len(svd_vec),
|
||||
elapsed_ms,
|
||||
)
|
||||
|
||||
return svd_vec.tolist()
|
||||
|
||||
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed multiple texts at once.
|
||||
|
||||
Args:
|
||||
texts: List of input texts.
|
||||
|
||||
Returns:
|
||||
List of 512-dim L2-normalized float lists.
|
||||
|
||||
Raises:
|
||||
ValueError: If any text is empty or whitespace-only.
|
||||
"""
|
||||
for i, text in enumerate(texts):
|
||||
if not text or not text.strip():
|
||||
raise ValueError(
|
||||
f"Cannot embed empty or whitespace-only text at index {i}"
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
self._fit_if_needed()
|
||||
|
||||
tfidf = self._vectorizer.transform(texts)
|
||||
svd_vecs = self._svd.transform(tfidf)
|
||||
|
||||
results: list[list[float]] = []
|
||||
for idx, vec in enumerate(svd_vecs):
|
||||
norm = np.linalg.norm(vec)
|
||||
if norm > 1e-10:
|
||||
vec = vec / norm
|
||||
else:
|
||||
logger.warning(
|
||||
"Text at index %d produced zero TF-IDF vector: '%s'",
|
||||
idx,
|
||||
texts[idx][:80],
|
||||
)
|
||||
rng = np.random.RandomState(hash(texts[idx]) % (2**31))
|
||||
vec = rng.randn(len(vec))
|
||||
vec = vec / np.linalg.norm(vec)
|
||||
|
||||
if len(vec) < 512:
|
||||
padded = np.zeros(512)
|
||||
padded[: len(vec)] = vec
|
||||
pad_norm = np.linalg.norm(padded)
|
||||
if pad_norm > 0:
|
||||
padded = padded / pad_norm
|
||||
vec = padded
|
||||
|
||||
results.append(vec.tolist())
|
||||
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
logger.info(
|
||||
"Batch-embedded %d texts → %d-dim vectors in %.1fms",
|
||||
len(texts),
|
||||
512,
|
||||
elapsed_ms,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
embedding_service = EmbeddingService()
|
||||
|
||||
|
||||
def embed_text(text: str) -> list[float]:
|
||||
"""Embed a single text into a 512-dim normalized vector.
|
||||
|
||||
Convenience wrapper around the module-level EmbeddingService singleton.
|
||||
"""
|
||||
return embedding_service.embed_text(text)
|
||||
|
||||
|
||||
def embed_batch(texts: list[str]) -> list[list[float]]:
|
||||
"""Embed multiple texts into 512-dim normalized vectors.
|
||||
|
||||
Convenience wrapper around the module-level EmbeddingService singleton.
|
||||
"""
|
||||
return embedding_service.embed_batch(texts)
|
||||
|
|
@ -1,8 +1,15 @@
|
|||
"""Fractafrag — Celery worker configuration."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid as uuid_mod
|
||||
|
||||
from celery import Celery
|
||||
import os
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
redis_url = os.environ.get("REDIS_URL", "redis://redis:6379/0")
|
||||
|
||||
celery_app = Celery(
|
||||
|
|
@ -28,6 +35,19 @@ celery_app.conf.update(
|
|||
celery_app.autodiscover_tasks(["app.worker"])
|
||||
|
||||
|
||||
# ── Sync DB session factory for worker tasks ──────────────
|
||||
|
||||
def _get_sync_session_factory():
|
||||
"""Lazy-init sync session factory using settings.database_url_sync."""
|
||||
from app.config import get_settings
|
||||
settings = get_settings()
|
||||
engine = create_engine(settings.database_url_sync, pool_pre_ping=True)
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Task Definitions ──────────────────────────────────────
|
||||
|
||||
@celery_app.task(name="render_shader", bind=True, max_retries=2)
|
||||
|
|
@ -48,11 +68,77 @@ def embed_shader(self, shader_id: str):
|
|||
pass
|
||||
|
||||
|
||||
@celery_app.task(name="process_desire", bind=True)
|
||||
@celery_app.task(name="process_desire", bind=True, max_retries=3)
|
||||
def process_desire(self, desire_id: str):
|
||||
"""Process a new desire: embed, cluster, optionally auto-fulfill. (Track G)"""
|
||||
# TODO: Implement in Track G
|
||||
pass
|
||||
"""Process a new desire: embed text, store embedding, cluster, update heat.
|
||||
|
||||
Flow:
|
||||
1. Load desire from DB by id
|
||||
2. Embed prompt_text → 512-dim vector
|
||||
3. Store embedding on desire row
|
||||
4. Run sync clustering (find nearest or create new cluster)
|
||||
5. Commit all changes
|
||||
|
||||
On transient DB errors, retries up to 3 times with 30s backoff.
|
||||
On success, logs desire_id, cluster_id, heat_score, and elapsed_ms.
|
||||
On failure, desire keeps prompt_embedding=NULL and heat_score=1.0.
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
desire_uuid = uuid_mod.UUID(desire_id)
|
||||
|
||||
SessionFactory = _get_sync_session_factory()
|
||||
session = SessionFactory()
|
||||
|
||||
try:
|
||||
from app.models.models import Desire
|
||||
from app.services.embedding import embed_text
|
||||
from app.services.clustering import cluster_desire_sync
|
||||
|
||||
# 1. Load desire
|
||||
desire = session.get(Desire, desire_uuid)
|
||||
if desire is None:
|
||||
logger.warning(
|
||||
"process_desire: desire %s not found, skipping", desire_id
|
||||
)
|
||||
return
|
||||
|
||||
# 2. Embed prompt text
|
||||
embedding = embed_text(desire.prompt_text)
|
||||
|
||||
# 3. Store embedding on desire
|
||||
desire.prompt_embedding = embedding
|
||||
session.flush()
|
||||
|
||||
# 4. Cluster
|
||||
cluster_result = cluster_desire_sync(desire.id, embedding, session)
|
||||
|
||||
# 5. Commit
|
||||
session.commit()
|
||||
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
logger.info(
|
||||
"process_desire completed: desire_id=%s cluster_id=%s "
|
||||
"is_new=%s heat_score=%.1f elapsed_ms=%.1f",
|
||||
desire_id,
|
||||
cluster_result["cluster_id"],
|
||||
cluster_result["is_new"],
|
||||
cluster_result["heat_score"],
|
||||
elapsed_ms,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
session.rollback()
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
logger.error(
|
||||
"process_desire failed: desire_id=%s error=%s elapsed_ms=%.1f",
|
||||
desire_id,
|
||||
str(exc),
|
||||
elapsed_ms,
|
||||
)
|
||||
raise self.retry(exc=exc, countdown=30)
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
@celery_app.task(name="ai_generate", bind=True, max_retries=3)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ dependencies = [
|
|||
"python-multipart>=0.0.12",
|
||||
"stripe>=11.0.0",
|
||||
"numpy>=2.0.0",
|
||||
"scikit-learn>=1.4",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
@ -32,4 +33,5 @@ dev = [
|
|||
"pytest-asyncio>=0.24.0",
|
||||
"httpx>=0.28.0",
|
||||
"ruff>=0.8.0",
|
||||
"aiosqlite>=0.20.0",
|
||||
]
|
||||
|
|
|
|||
1
services/api/tests/__init__.py
Normal file
1
services/api/tests/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Tests package for fractafrag-api."""
|
||||
190
services/api/tests/conftest.py
Normal file
190
services/api/tests/conftest.py
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
"""Pytest configuration and shared fixtures for fractafrag-api tests.
|
||||
|
||||
Integration test infrastructure:
|
||||
- Async SQLite in-memory database (via aiosqlite)
|
||||
- FastAPI test client with dependency overrides
|
||||
- Auth dependency overrides (mock pro-tier user)
|
||||
- Celery worker mock (process_desire.delay → no-op)
|
||||
|
||||
Environment variables are set BEFORE any app.* imports to ensure
|
||||
get_settings() picks up test values (database.py calls get_settings()
|
||||
at module scope with @lru_cache).
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# ── 1. sys.path setup ─────────────────────────────────────
|
||||
_api_root = str(Path(__file__).resolve().parent.parent)
|
||||
if _api_root not in sys.path:
|
||||
sys.path.insert(0, _api_root)
|
||||
|
||||
# ── 2. Set env vars BEFORE any app.* imports ──────────────
|
||||
# We do NOT override DATABASE_URL — the module-level engine in database.py
|
||||
# uses pool_size/max_overflow which are PostgreSQL-specific. The default
|
||||
# PostgreSQL URL creates an engine that never actually connects (no queries
|
||||
# hit it). Our integration tests override get_db with a test SQLite session.
|
||||
# We only set dummy values for env vars that cause validation failures.
|
||||
os.environ.setdefault("JWT_SECRET", "test-secret")
|
||||
os.environ.setdefault("REDIS_URL", "redis://localhost:6379/0")
|
||||
os.environ.setdefault("BYOK_MASTER_KEY", "test-master-key-0123456789abcdef")
|
||||
|
||||
# ── 3. Now safe to import app modules ─────────────────────
|
||||
import pytest # noqa: E402
|
||||
import pytest_asyncio # noqa: E402
|
||||
from httpx import ASGITransport, AsyncClient # noqa: E402
|
||||
from sqlalchemy import event, text # noqa: E402
|
||||
from sqlalchemy.ext.asyncio import ( # noqa: E402
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.ext.compiler import compiles # noqa: E402
|
||||
from pgvector.sqlalchemy import Vector # noqa: E402
|
||||
from sqlalchemy.dialects.postgresql import UUID as PG_UUID, JSONB, ARRAY # noqa: E402
|
||||
|
||||
from app.database import Base, get_db # noqa: E402
|
||||
from app.main import app # noqa: E402
|
||||
from app.middleware.auth import get_current_user, require_tier # noqa: E402
|
||||
from app.models.models import User # noqa: E402
|
||||
|
||||
|
||||
# ── 4. SQLite type compilation overrides ──────────────────
|
||||
# pgvector Vector, PostgreSQL UUID, JSONB, and ARRAY don't exist in SQLite.
|
||||
# Register custom compilation rules so create_all() works.
|
||||
|
||||
@compiles(Vector, "sqlite")
|
||||
def _compile_vector_sqlite(type_, compiler, **kw):
|
||||
"""Render pgvector Vector as TEXT in SQLite."""
|
||||
return "TEXT"
|
||||
|
||||
|
||||
# Override PostgreSQL UUID to TEXT for SQLite
|
||||
@compiles(PG_UUID, "sqlite")
|
||||
def _compile_pg_uuid_sqlite(type_, compiler, **kw):
|
||||
"""Render PostgreSQL UUID as TEXT in SQLite (standard UUID is fine, dialect-specific isn't)."""
|
||||
return "TEXT"
|
||||
|
||||
|
||||
# Override JSONB to TEXT for SQLite
|
||||
@compiles(JSONB, "sqlite")
|
||||
def _compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
"""Render JSONB as TEXT in SQLite."""
|
||||
return "TEXT"
|
||||
|
||||
|
||||
# Override ARRAY to TEXT for SQLite
|
||||
@compiles(ARRAY, "sqlite")
|
||||
def _compile_array_sqlite(type_, compiler, **kw):
|
||||
"""Render PostgreSQL ARRAY as TEXT in SQLite."""
|
||||
return "TEXT"
|
||||
|
||||
|
||||
# Register Python uuid.UUID as a SQLite adapter so raw text() queries
|
||||
# can bind UUID parameters without "type 'UUID' is not supported" errors.
|
||||
# IMPORTANT: Use .hex (no hyphens) to match SQLAlchemy's UUID storage format in SQLite.
|
||||
# Also register list adapter so ARRAY columns (compiled as TEXT in SQLite)
|
||||
# can bind Python lists without "type 'list' is not supported" errors.
|
||||
import json as _json # noqa: E402
|
||||
import sqlite3 # noqa: E402
|
||||
sqlite3.register_adapter(uuid.UUID, lambda u: u.hex)
|
||||
sqlite3.register_adapter(list, lambda lst: _json.dumps(lst))
|
||||
|
||||
|
||||
# ── 5. Test database engine and session fixtures ──────────
|
||||
|
||||
# Shared test user ID — consistent across all integration tests
|
||||
TEST_USER_ID = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def db_engine():
|
||||
"""Create an async SQLite engine and all tables. Session-scoped."""
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite://",
|
||||
echo=False,
|
||||
# SQLite doesn't support pool_size/max_overflow
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(db_engine):
|
||||
"""Yield a fresh AsyncSession per test. Rolls back after each test for isolation."""
|
||||
session_factory = async_sessionmaker(
|
||||
db_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with session_factory() as session:
|
||||
# Start a nested transaction so we can roll back after the test
|
||||
async with session.begin():
|
||||
yield session
|
||||
# Rollback ensures test isolation — no committed state leaks between tests
|
||||
await session.rollback()
|
||||
|
||||
|
||||
# ── 6. Mock user fixture ─────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def test_user():
|
||||
"""Return a mock User object for auth dependency overrides."""
|
||||
user = MagicMock(spec=User)
|
||||
user.id = TEST_USER_ID
|
||||
user.username = "testuser"
|
||||
user.email = "testuser@test.com"
|
||||
user.role = "user"
|
||||
user.subscription_tier = "pro"
|
||||
user.is_system = False
|
||||
user.trust_tier = "standard"
|
||||
return user
|
||||
|
||||
|
||||
# ── 7. FastAPI test client fixture ────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(db_session, test_user):
|
||||
"""Async HTTP client wired to the FastAPI app with dependency overrides.
|
||||
|
||||
Overrides:
|
||||
- get_db → yields the test db_session
|
||||
- get_current_user → returns test_user (pro tier)
|
||||
- require_tier → returns test_user unconditionally (bypasses tier check)
|
||||
- process_desire.delay → no-op (prevents Celery/Redis connection)
|
||||
"""
|
||||
|
||||
# Override get_db to yield test session
|
||||
async def _override_get_db():
|
||||
yield db_session
|
||||
|
||||
# Override get_current_user to return mock user
|
||||
async def _override_get_current_user():
|
||||
return test_user
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
app.dependency_overrides[get_current_user] = _override_get_current_user
|
||||
|
||||
# require_tier is a factory that returns inner functions depending on
|
||||
# get_current_user. Since we override get_current_user to return a pro-tier
|
||||
# user, the tier check inside require_tier will pass naturally.
|
||||
# We still need to mock process_desire to prevent Celery/Redis connection.
|
||||
with patch("app.worker.process_desire") as mock_task:
|
||||
# process_desire.delay() should be a no-op
|
||||
mock_task.delay = MagicMock(return_value=None)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
# Clean up dependency overrides after test
|
||||
app.dependency_overrides.clear()
|
||||
366
services/api/tests/test_clustering.py
Normal file
366
services/api/tests/test_clustering.py
Normal file
|
|
@ -0,0 +1,366 @@
|
|||
"""Unit tests for the clustering service.
|
||||
|
||||
Tests use mocked async DB sessions to isolate clustering logic from
|
||||
pgvector and database concerns. Synthetic 512-dim vectors verify the
|
||||
service's orchestration, heat calculation, and threshold behavior.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.models import DesireCluster
|
||||
from app.services.clustering import (
|
||||
add_to_cluster,
|
||||
cluster_desire,
|
||||
create_cluster,
|
||||
find_nearest_cluster,
|
||||
recalculate_heat,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_embedding(dim: int = 512) -> list[float]:
|
||||
"""Create a synthetic embedding vector for testing."""
|
||||
import numpy as np
|
||||
rng = np.random.default_rng(42)
|
||||
vec = rng.standard_normal(dim)
|
||||
vec = vec / np.linalg.norm(vec)
|
||||
return vec.tolist()
|
||||
|
||||
|
||||
def _mock_result_row(**kwargs):
|
||||
"""Create a mock DB result row with named attributes."""
|
||||
row = MagicMock()
|
||||
for key, value in kwargs.items():
|
||||
setattr(row, key, value)
|
||||
return row
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: cluster_desire orchestration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestClusterDesireOrchestration:
|
||||
"""Test the main cluster_desire orchestrator with mocked sub-functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.clustering.find_nearest_cluster", new_callable=AsyncMock)
|
||||
@patch("app.services.clustering.create_cluster", new_callable=AsyncMock)
|
||||
async def test_new_desire_creates_own_cluster(
|
||||
self, mock_create, mock_find
|
||||
) -> None:
|
||||
"""When no nearby cluster exists, create a new one."""
|
||||
new_cluster_id = uuid.uuid4()
|
||||
desire_id = uuid.uuid4()
|
||||
embedding = _make_embedding()
|
||||
|
||||
mock_find.return_value = (None, 0.0)
|
||||
mock_create.return_value = new_cluster_id
|
||||
|
||||
db = AsyncMock()
|
||||
result = await cluster_desire(desire_id, embedding, db)
|
||||
|
||||
mock_find.assert_awaited_once_with(embedding, db)
|
||||
mock_create.assert_awaited_once_with(desire_id, db)
|
||||
assert result["is_new"] is True
|
||||
assert result["cluster_id"] == new_cluster_id
|
||||
assert result["heat_score"] == 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.clustering.find_nearest_cluster", new_callable=AsyncMock)
|
||||
@patch("app.services.clustering.add_to_cluster", new_callable=AsyncMock)
|
||||
@patch("app.services.clustering.recalculate_heat", new_callable=AsyncMock)
|
||||
async def test_similar_desire_joins_existing_cluster(
|
||||
self, mock_recalc, mock_add, mock_find
|
||||
) -> None:
|
||||
"""When a nearby cluster is found, join it and recalculate heat."""
|
||||
existing_cluster_id = uuid.uuid4()
|
||||
desire_id = uuid.uuid4()
|
||||
embedding = _make_embedding()
|
||||
similarity = 0.92
|
||||
|
||||
mock_find.return_value = (existing_cluster_id, similarity)
|
||||
mock_recalc.return_value = 3.0
|
||||
|
||||
db = AsyncMock()
|
||||
result = await cluster_desire(desire_id, embedding, db)
|
||||
|
||||
mock_find.assert_awaited_once_with(embedding, db)
|
||||
mock_add.assert_awaited_once_with(
|
||||
existing_cluster_id, desire_id, similarity, db
|
||||
)
|
||||
mock_recalc.assert_awaited_once_with(existing_cluster_id, db)
|
||||
assert result["is_new"] is False
|
||||
assert result["cluster_id"] == existing_cluster_id
|
||||
assert result["heat_score"] == 3.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.clustering.find_nearest_cluster", new_callable=AsyncMock)
|
||||
@patch("app.services.clustering.create_cluster", new_callable=AsyncMock)
|
||||
async def test_cluster_desire_returns_observability_dict(
|
||||
self, mock_create, mock_find
|
||||
) -> None:
|
||||
"""Returned dict always has cluster_id, is_new, heat_score."""
|
||||
cluster_id = uuid.uuid4()
|
||||
mock_find.return_value = (None, 0.0)
|
||||
mock_create.return_value = cluster_id
|
||||
|
||||
db = AsyncMock()
|
||||
result = await cluster_desire(uuid.uuid4(), _make_embedding(), db)
|
||||
|
||||
assert "cluster_id" in result
|
||||
assert "is_new" in result
|
||||
assert "heat_score" in result
|
||||
assert isinstance(result["cluster_id"], uuid.UUID)
|
||||
assert isinstance(result["is_new"], bool)
|
||||
assert isinstance(result["heat_score"], float)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: recalculate_heat
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRecalculateHeat:
|
||||
"""Test heat score recalculation with mocked DB results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heat_scales_with_cluster_size(self) -> None:
|
||||
"""Heat score should equal cluster size (linear scaling)."""
|
||||
cluster_id = uuid.uuid4()
|
||||
db = AsyncMock()
|
||||
|
||||
# First call: COUNT(*) returns 3
|
||||
count_result = MagicMock()
|
||||
count_result.scalar_one.return_value = 3
|
||||
|
||||
# Second call: UPDATE (no return value needed)
|
||||
update_result = MagicMock()
|
||||
|
||||
db.execute = AsyncMock(side_effect=[count_result, update_result])
|
||||
|
||||
heat = await recalculate_heat(cluster_id, db)
|
||||
|
||||
assert heat == 3.0
|
||||
assert db.execute.await_count == 2
|
||||
assert db.flush.await_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heat_for_single_member_cluster(self) -> None:
|
||||
"""A single-member cluster should have heat_score = 1.0."""
|
||||
cluster_id = uuid.uuid4()
|
||||
db = AsyncMock()
|
||||
|
||||
count_result = MagicMock()
|
||||
count_result.scalar_one.return_value = 1
|
||||
update_result = MagicMock()
|
||||
|
||||
db.execute = AsyncMock(side_effect=[count_result, update_result])
|
||||
|
||||
heat = await recalculate_heat(cluster_id, db)
|
||||
|
||||
assert heat == 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heat_for_large_cluster(self) -> None:
|
||||
"""Heat scales to large cluster sizes."""
|
||||
cluster_id = uuid.uuid4()
|
||||
db = AsyncMock()
|
||||
|
||||
count_result = MagicMock()
|
||||
count_result.scalar_one.return_value = 15
|
||||
update_result = MagicMock()
|
||||
|
||||
db.execute = AsyncMock(side_effect=[count_result, update_result])
|
||||
|
||||
heat = await recalculate_heat(cluster_id, db)
|
||||
|
||||
assert heat == 15.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: find_nearest_cluster
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFindNearestCluster:
|
||||
"""Test pgvector distance query with mocked DB results."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_db_returns_none(self) -> None:
|
||||
"""No desires with embeddings → no cluster match."""
|
||||
db = AsyncMock()
|
||||
|
||||
# Query returns no rows
|
||||
empty_result = MagicMock()
|
||||
empty_result.first.return_value = None
|
||||
db.execute = AsyncMock(return_value=empty_result)
|
||||
|
||||
cluster_id, similarity = await find_nearest_cluster(
|
||||
_make_embedding(), db
|
||||
)
|
||||
|
||||
assert cluster_id is None
|
||||
assert similarity == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_found_with_cluster(self) -> None:
|
||||
"""A desire within threshold that has a cluster → returns cluster."""
|
||||
desire_id = uuid.uuid4()
|
||||
cluster_id = uuid.uuid4()
|
||||
db = AsyncMock()
|
||||
|
||||
# First query: find nearest desire (distance = 0.08 → similarity = 0.92)
|
||||
desire_row = _mock_result_row(desire_id=desire_id, distance=0.08)
|
||||
desire_result = MagicMock()
|
||||
desire_result.first.return_value = desire_row
|
||||
|
||||
# Second query: cluster lookup
|
||||
cluster_row = _mock_result_row(cluster_id=cluster_id)
|
||||
cluster_result = MagicMock()
|
||||
cluster_result.first.return_value = cluster_row
|
||||
|
||||
db.execute = AsyncMock(side_effect=[desire_result, cluster_result])
|
||||
|
||||
found_id, sim = await find_nearest_cluster(_make_embedding(), db)
|
||||
|
||||
assert found_id == cluster_id
|
||||
assert abs(sim - 0.92) < 1e-6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_found_without_cluster(self) -> None:
|
||||
"""A nearby desire that has no cluster entry → returns None."""
|
||||
desire_id = uuid.uuid4()
|
||||
db = AsyncMock()
|
||||
|
||||
# First query: find nearest desire
|
||||
desire_row = _mock_result_row(desire_id=desire_id, distance=0.10)
|
||||
desire_result = MagicMock()
|
||||
desire_result.first.return_value = desire_row
|
||||
|
||||
# Second query: cluster lookup returns nothing
|
||||
cluster_result = MagicMock()
|
||||
cluster_result.first.return_value = None
|
||||
|
||||
db.execute = AsyncMock(side_effect=[desire_result, cluster_result])
|
||||
|
||||
found_id, sim = await find_nearest_cluster(_make_embedding(), db)
|
||||
|
||||
assert found_id is None
|
||||
assert sim == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_threshold_boundary_at_0_82(self) -> None:
|
||||
"""Threshold of 0.82 means max distance of 0.18.
|
||||
|
||||
A desire at exactly distance=0.18 (similarity=0.82) should be
|
||||
returned by the SQL query (distance <= 0.18).
|
||||
"""
|
||||
desire_id = uuid.uuid4()
|
||||
cluster_id = uuid.uuid4()
|
||||
db = AsyncMock()
|
||||
|
||||
# Exactly at boundary: distance = 0.18 → similarity = 0.82
|
||||
desire_row = _mock_result_row(desire_id=desire_id, distance=0.18)
|
||||
desire_result = MagicMock()
|
||||
desire_result.first.return_value = desire_row
|
||||
|
||||
cluster_row = _mock_result_row(cluster_id=cluster_id)
|
||||
cluster_result = MagicMock()
|
||||
cluster_result.first.return_value = cluster_row
|
||||
|
||||
db.execute = AsyncMock(side_effect=[desire_result, cluster_result])
|
||||
|
||||
found_id, sim = await find_nearest_cluster(
|
||||
_make_embedding(), db, threshold=0.82
|
||||
)
|
||||
|
||||
assert found_id == cluster_id
|
||||
assert abs(sim - 0.82) < 1e-6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_below_threshold_returns_none(self) -> None:
|
||||
"""A desire beyond the distance threshold is not returned by SQL.
|
||||
|
||||
With threshold=0.82 (max_distance=0.18), a desire at distance=0.19
|
||||
(similarity=0.81) would be filtered out by the WHERE clause.
|
||||
The mock simulates this by returning no rows.
|
||||
"""
|
||||
db = AsyncMock()
|
||||
|
||||
# SQL filters it out → no rows
|
||||
empty_result = MagicMock()
|
||||
empty_result.first.return_value = None
|
||||
db.execute = AsyncMock(return_value=empty_result)
|
||||
|
||||
found_id, sim = await find_nearest_cluster(
|
||||
_make_embedding(), db, threshold=0.82
|
||||
)
|
||||
|
||||
assert found_id is None
|
||||
assert sim == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: create_cluster
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCreateCluster:
|
||||
"""Test cluster creation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_cluster_returns_uuid(self) -> None:
|
||||
"""New cluster gets a valid UUID."""
|
||||
db = AsyncMock()
|
||||
db.add = MagicMock() # Session.add() is synchronous
|
||||
desire_id = uuid.uuid4()
|
||||
|
||||
cluster_id = await create_cluster(desire_id, db)
|
||||
|
||||
assert isinstance(cluster_id, uuid.UUID)
|
||||
db.add.assert_called_once()
|
||||
db.flush.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_cluster_adds_desire_cluster_row(self) -> None:
|
||||
"""The DesireCluster row has similarity=1.0 (self-reference)."""
|
||||
db = AsyncMock()
|
||||
db.add = MagicMock() # Session.add() is synchronous
|
||||
desire_id = uuid.uuid4()
|
||||
|
||||
cluster_id = await create_cluster(desire_id, db)
|
||||
|
||||
added_obj = db.add.call_args[0][0]
|
||||
assert isinstance(added_obj, DesireCluster)
|
||||
assert added_obj.cluster_id == cluster_id
|
||||
assert added_obj.desire_id == desire_id
|
||||
assert added_obj.similarity == 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: add_to_cluster
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAddToCluster:
|
||||
"""Test adding a desire to an existing cluster."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_to_cluster_executes_insert(self) -> None:
|
||||
"""Insert is executed and flushed."""
|
||||
db = AsyncMock()
|
||||
cluster_id = uuid.uuid4()
|
||||
desire_id = uuid.uuid4()
|
||||
|
||||
await add_to_cluster(cluster_id, desire_id, 0.91, db)
|
||||
|
||||
db.execute.assert_awaited_once()
|
||||
db.flush.assert_awaited()
|
||||
|
||||
# Verify the parameters passed to execute
|
||||
call_kwargs = db.execute.call_args[0][1]
|
||||
assert call_kwargs["cluster_id"] == cluster_id
|
||||
assert call_kwargs["desire_id"] == desire_id
|
||||
assert call_kwargs["similarity"] == 0.91
|
||||
250
services/api/tests/test_desire_pipeline.py
Normal file
250
services/api/tests/test_desire_pipeline.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
"""Pipeline integration tests — embed → cluster → heat.
|
||||
|
||||
Proves the full desire processing pipeline works end-to-end by:
|
||||
1. Verifying similar texts produce embeddings with cosine similarity above
|
||||
the clustering threshold (0.82)
|
||||
2. Verifying dissimilar texts stay below the clustering threshold
|
||||
3. Validating heat calculation logic for clustered desires
|
||||
4. Checking that the router and worker are wired correctly (static assertions)
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from app.services.embedding import embed_text
|
||||
|
||||
|
||||
def cosine_sim(a: list[float], b: list[float]) -> float:
|
||||
"""Cosine similarity between two L2-normalized vectors (= dot product)."""
|
||||
return float(np.dot(a, b))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Embedding pipeline: similar texts cluster, dissimilar texts don't
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSimilarDesiresClustering:
|
||||
"""Verify that similar desire texts produce clusterable embeddings."""
|
||||
|
||||
SIMILAR_TEXTS = [
|
||||
"ragdoll physics dark moody slow motion",
|
||||
"dark physics ragdoll slow motion moody",
|
||||
"slow motion ragdoll dark physics moody",
|
||||
]
|
||||
|
||||
def test_similar_desires_produce_clusterable_embeddings(self) -> None:
|
||||
"""All pairwise cosine similarities among similar texts exceed 0.82."""
|
||||
embeddings = [embed_text(t) for t in self.SIMILAR_TEXTS]
|
||||
|
||||
for i in range(len(embeddings)):
|
||||
for j in range(i + 1, len(embeddings)):
|
||||
sim = cosine_sim(embeddings[i], embeddings[j])
|
||||
assert sim > 0.82, (
|
||||
f"Texts [{i}] and [{j}] should cluster (sim > 0.82), "
|
||||
f"got {sim:.4f}:\n"
|
||||
f" [{i}] '{self.SIMILAR_TEXTS[i]}'\n"
|
||||
f" [{j}] '{self.SIMILAR_TEXTS[j]}'"
|
||||
)
|
||||
|
||||
def test_dissimilar_desire_does_not_cluster(self) -> None:
|
||||
"""A dissimilar text has cosine similarity < 0.82 with all similar texts."""
|
||||
dissimilar = embed_text("bright colorful kaleidoscope flowers rainbow")
|
||||
similar_embeddings = [embed_text(t) for t in self.SIMILAR_TEXTS]
|
||||
|
||||
for i, emb in enumerate(similar_embeddings):
|
||||
sim = cosine_sim(dissimilar, emb)
|
||||
assert sim < 0.82, (
|
||||
f"Dissimilar text should NOT cluster with text [{i}] "
|
||||
f"(sim < 0.82), got {sim:.4f}:\n"
|
||||
f" dissimilar: 'bright colorful kaleidoscope flowers rainbow'\n"
|
||||
f" similar[{i}]: '{self.SIMILAR_TEXTS[i]}'"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Heat calculation logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPipelineHeatCalculation:
|
||||
"""Verify heat score calculation matches cluster size."""
|
||||
|
||||
def test_pipeline_heat_calculation_logic(self) -> None:
|
||||
"""A cluster of 3 desires should produce heat_score = 3.0 for each member.
|
||||
|
||||
This tests the recalculate_heat_sync logic by simulating its
|
||||
DB interaction pattern with mocks.
|
||||
"""
|
||||
from app.services.clustering import recalculate_heat_sync
|
||||
|
||||
cluster_id = uuid.uuid4()
|
||||
session = MagicMock()
|
||||
|
||||
# Mock COUNT(*) returning 3 members
|
||||
count_result = MagicMock()
|
||||
count_result.scalar_one.return_value = 3
|
||||
|
||||
# Mock UPDATE (no meaningful return)
|
||||
update_result = MagicMock()
|
||||
|
||||
session.execute = MagicMock(side_effect=[count_result, update_result])
|
||||
|
||||
heat = recalculate_heat_sync(cluster_id, session)
|
||||
|
||||
assert heat == 3.0
|
||||
assert session.execute.call_count == 2
|
||||
assert session.flush.call_count >= 1
|
||||
|
||||
def test_single_member_cluster_has_heat_1(self) -> None:
|
||||
"""A new single-member cluster should have heat_score = 1.0."""
|
||||
from app.services.clustering import recalculate_heat_sync
|
||||
|
||||
cluster_id = uuid.uuid4()
|
||||
session = MagicMock()
|
||||
|
||||
count_result = MagicMock()
|
||||
count_result.scalar_one.return_value = 1
|
||||
update_result = MagicMock()
|
||||
|
||||
session.execute = MagicMock(side_effect=[count_result, update_result])
|
||||
|
||||
heat = recalculate_heat_sync(cluster_id, session)
|
||||
|
||||
assert heat == 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync clustering orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSyncClusteringOrchestrator:
|
||||
"""Test cluster_desire_sync orchestration with mocked sub-functions."""
|
||||
|
||||
def test_new_desire_creates_cluster(self) -> None:
|
||||
"""When no nearby cluster exists, creates a new one."""
|
||||
from unittest.mock import patch
|
||||
from app.services.clustering import cluster_desire_sync
|
||||
|
||||
desire_id = uuid.uuid4()
|
||||
embedding = embed_text("ragdoll physics dark moody slow")
|
||||
new_cluster_id = uuid.uuid4()
|
||||
session = MagicMock()
|
||||
|
||||
with patch("app.services.clustering.find_nearest_cluster_sync") as mock_find, \
|
||||
patch("app.services.clustering.create_cluster_sync") as mock_create:
|
||||
mock_find.return_value = (None, 0.0)
|
||||
mock_create.return_value = new_cluster_id
|
||||
|
||||
result = cluster_desire_sync(desire_id, embedding, session)
|
||||
|
||||
assert result["is_new"] is True
|
||||
assert result["cluster_id"] == new_cluster_id
|
||||
assert result["heat_score"] == 1.0
|
||||
mock_find.assert_called_once_with(embedding, session)
|
||||
mock_create.assert_called_once_with(desire_id, session)
|
||||
|
||||
def test_similar_desire_joins_existing_cluster(self) -> None:
|
||||
"""When a nearby cluster is found, joins it and recalculates heat."""
|
||||
from unittest.mock import patch
|
||||
from app.services.clustering import cluster_desire_sync
|
||||
|
||||
desire_id = uuid.uuid4()
|
||||
embedding = embed_text("ragdoll physics dark moody slow")
|
||||
existing_cluster_id = uuid.uuid4()
|
||||
session = MagicMock()
|
||||
|
||||
with patch("app.services.clustering.find_nearest_cluster_sync") as mock_find, \
|
||||
patch("app.services.clustering.add_to_cluster_sync") as mock_add, \
|
||||
patch("app.services.clustering.recalculate_heat_sync") as mock_recalc:
|
||||
mock_find.return_value = (existing_cluster_id, 0.91)
|
||||
mock_recalc.return_value = 3.0
|
||||
|
||||
result = cluster_desire_sync(desire_id, embedding, session)
|
||||
|
||||
assert result["is_new"] is False
|
||||
assert result["cluster_id"] == existing_cluster_id
|
||||
assert result["heat_score"] == 3.0
|
||||
mock_add.assert_called_once_with(
|
||||
existing_cluster_id, desire_id, 0.91, session
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wiring checks: router + worker are connected
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWiring:
|
||||
"""Static assertions that the router and worker are properly wired."""
|
||||
|
||||
def test_router_has_worker_enqueue(self) -> None:
|
||||
"""desires.py contains process_desire.delay — fire-and-forget enqueue."""
|
||||
desires_path = (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "app"
|
||||
/ "routers"
|
||||
/ "desires.py"
|
||||
)
|
||||
source = desires_path.read_text()
|
||||
assert "process_desire.delay" in source, (
|
||||
"Router should call process_desire.delay() to enqueue worker task"
|
||||
)
|
||||
|
||||
def test_worker_task_is_implemented(self) -> None:
|
||||
"""process_desire task body is not just 'pass' — has real implementation.
|
||||
|
||||
Reads the worker source file directly to avoid importing celery
|
||||
(which may not be installed in the test environment).
|
||||
"""
|
||||
worker_path = (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "app"
|
||||
/ "worker"
|
||||
/ "__init__.py"
|
||||
)
|
||||
source = worker_path.read_text()
|
||||
|
||||
# Should contain key implementation markers
|
||||
assert "embed_text" in source, (
|
||||
"Worker should call embed_text to embed desire prompt"
|
||||
)
|
||||
assert "cluster_desire_sync" in source, (
|
||||
"Worker should call cluster_desire_sync to cluster the desire"
|
||||
)
|
||||
assert "session.commit" in source, (
|
||||
"Worker should commit the DB transaction"
|
||||
)
|
||||
|
||||
def test_worker_has_structured_logging(self) -> None:
|
||||
"""process_desire task includes structured logging of key fields."""
|
||||
worker_path = (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "app"
|
||||
/ "worker"
|
||||
/ "__init__.py"
|
||||
)
|
||||
source = worker_path.read_text()
|
||||
|
||||
assert "desire_id" in source, "Should log desire_id"
|
||||
assert "cluster_id" in source, "Should log cluster_id"
|
||||
assert "heat_score" in source, "Should log heat_score"
|
||||
assert "elapsed_ms" in source, "Should log elapsed_ms"
|
||||
|
||||
def test_worker_has_error_handling_with_retry(self) -> None:
|
||||
"""process_desire catches exceptions and retries."""
|
||||
worker_path = (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "app"
|
||||
/ "worker"
|
||||
/ "__init__.py"
|
||||
)
|
||||
source = worker_path.read_text()
|
||||
|
||||
assert "self.retry" in source, (
|
||||
"Worker should use self.retry for transient error handling"
|
||||
)
|
||||
assert "session.rollback" in source, (
|
||||
"Worker should rollback on error before retrying"
|
||||
)
|
||||
137
services/api/tests/test_embedding.py
Normal file
137
services/api/tests/test_embedding.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
"""Unit tests for the text embedding service.
|
||||
|
||||
Validates that TF-IDF + TruncatedSVD produces 512-dim L2-normalized vectors
|
||||
with meaningful cosine similarity for shader/visual-art domain text.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from app.services.embedding import EmbeddingService, embed_text
|
||||
|
||||
|
||||
def cosine_sim(a: list[float], b: list[float]) -> float:
|
||||
"""Compute cosine similarity between two vectors.
|
||||
|
||||
Since our vectors are already L2-normalized, this is just the dot product.
|
||||
"""
|
||||
return float(np.dot(a, b))
|
||||
|
||||
|
||||
class TestEmbedDimension:
|
||||
"""Verify output vector dimensions."""
|
||||
|
||||
def test_embed_produces_512_dim_vector(self) -> None:
|
||||
result = embed_text("particle system fluid simulation")
|
||||
assert len(result) == 512, f"Expected 512 dims, got {len(result)}"
|
||||
|
||||
def test_embed_returns_list_of_floats(self) -> None:
|
||||
result = embed_text("fractal noise pattern")
|
||||
assert isinstance(result, list)
|
||||
assert all(isinstance(x, float) for x in result)
|
||||
|
||||
|
||||
class TestNormalization:
|
||||
"""Verify L2 normalization of output vectors."""
|
||||
|
||||
def test_embed_vectors_are_normalized(self) -> None:
|
||||
result = embed_text("raymarching distance field shapes")
|
||||
norm = np.linalg.norm(result)
|
||||
assert abs(norm - 1.0) < 1e-6, f"Expected norm ≈ 1.0, got {norm}"
|
||||
|
||||
def test_various_inputs_all_normalized(self) -> None:
|
||||
texts = [
|
||||
"short",
|
||||
"a much longer description of a complex visual effect with many words",
|
||||
"ragdoll physics dark moody atmosphere simulation",
|
||||
]
|
||||
for text in texts:
|
||||
result = embed_text(text)
|
||||
norm = np.linalg.norm(result)
|
||||
assert abs(norm - 1.0) < 1e-6, (
|
||||
f"Norm for '{text}' = {norm}, expected ≈ 1.0"
|
||||
)
|
||||
|
||||
|
||||
class TestSimilarity:
|
||||
"""Verify semantic similarity properties of the embeddings."""
|
||||
|
||||
def test_similar_texts_have_high_cosine_similarity(self) -> None:
|
||||
a = embed_text("ragdoll physics dark and slow")
|
||||
b = embed_text("dark physics simulation ragdoll")
|
||||
sim = cosine_sim(a, b)
|
||||
assert sim > 0.8, (
|
||||
f"Similar texts should have >0.8 cosine sim, got {sim:.4f}"
|
||||
)
|
||||
|
||||
def test_dissimilar_texts_have_low_cosine_similarity(self) -> None:
|
||||
a = embed_text("ragdoll physics dark")
|
||||
b = embed_text("bright colorful kaleidoscope flowers")
|
||||
sim = cosine_sim(a, b)
|
||||
assert sim < 0.5, (
|
||||
f"Dissimilar texts should have <0.5 cosine sim, got {sim:.4f}"
|
||||
)
|
||||
|
||||
def test_identical_texts_have_perfect_similarity(self) -> None:
|
||||
text = "procedural noise fractal generation"
|
||||
a = embed_text(text)
|
||||
b = embed_text(text)
|
||||
sim = cosine_sim(a, b)
|
||||
assert sim > 0.999, (
|
||||
f"Identical texts should have ~1.0 cosine sim, got {sim:.4f}"
|
||||
)
|
||||
|
||||
|
||||
class TestBatch:
|
||||
"""Verify batch embedding matches individual embeddings."""
|
||||
|
||||
def test_embed_batch_matches_individual(self) -> None:
|
||||
texts = [
|
||||
"particle system fluid",
|
||||
"ragdoll physics dark moody",
|
||||
"kaleidoscope symmetry rotation",
|
||||
]
|
||||
|
||||
# Fresh service to ensure deterministic results
|
||||
service = EmbeddingService()
|
||||
individual = [service.embed_text(t) for t in texts]
|
||||
|
||||
# Reset and do batch
|
||||
service2 = EmbeddingService()
|
||||
batched = service2.embed_batch(texts)
|
||||
|
||||
assert len(batched) == len(individual)
|
||||
for i, (ind, bat) in enumerate(zip(individual, batched)):
|
||||
sim = cosine_sim(ind, bat)
|
||||
assert sim > 0.999, (
|
||||
f"Batch result {i} doesn't match individual: sim={sim:.6f}"
|
||||
)
|
||||
|
||||
def test_batch_dimensions(self) -> None:
|
||||
texts = ["fire smoke volumetric", "crystal refraction light"]
|
||||
results = EmbeddingService().embed_batch(texts)
|
||||
assert len(results) == 2
|
||||
for vec in results:
|
||||
assert len(vec) == 512
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Verify clear error messages on invalid input."""
|
||||
|
||||
def test_empty_string_raises_valueerror(self) -> None:
|
||||
with pytest.raises(ValueError, match="empty or whitespace"):
|
||||
embed_text("")
|
||||
|
||||
def test_whitespace_only_raises_valueerror(self) -> None:
|
||||
with pytest.raises(ValueError, match="empty or whitespace"):
|
||||
embed_text(" \n\t ")
|
||||
|
||||
def test_batch_with_empty_string_raises_valueerror(self) -> None:
|
||||
service = EmbeddingService()
|
||||
with pytest.raises(ValueError, match="empty or whitespace"):
|
||||
service.embed_batch(["valid text", ""])
|
||||
|
||||
def test_batch_with_whitespace_raises_valueerror(self) -> None:
|
||||
service = EmbeddingService()
|
||||
with pytest.raises(ValueError, match="empty or whitespace"):
|
||||
service.embed_batch([" ", "valid text"])
|
||||
294
services/api/tests/test_fulfillment.py
Normal file
294
services/api/tests/test_fulfillment.py
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
"""Tests for desire fulfillment endpoint and cluster_count annotation.
|
||||
|
||||
Covers:
|
||||
- fulfill_desire endpoint: happy path, not-found, not-open, shader validation
|
||||
(tested via source assertions since FastAPI isn't in the test environment)
|
||||
- cluster_count annotation: batch query pattern, single desire query
|
||||
- Schema field: cluster_count exists in DesirePublic
|
||||
|
||||
Approach: Per K005, router functions can't be imported without FastAPI installed.
|
||||
We verify correctness through:
|
||||
1. Source-level structure assertions (endpoint wiring, imports, validation logic)
|
||||
2. Isolated logic unit tests (annotation loop, status transitions)
|
||||
3. Schema field verification via Pydantic model introspection
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _router_source() -> str:
|
||||
"""Read the desires router source code."""
|
||||
return (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "app"
|
||||
/ "routers"
|
||||
/ "desires.py"
|
||||
).read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def _schema_source() -> str:
|
||||
"""Read the schemas source code."""
|
||||
return (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "app"
|
||||
/ "schemas"
|
||||
/ "schemas.py"
|
||||
).read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def _make_mock_desire(
|
||||
*,
|
||||
desire_id=None,
|
||||
status="open",
|
||||
heat_score=1.0,
|
||||
):
|
||||
"""Create a mock object simulating a Desire ORM instance."""
|
||||
d = MagicMock()
|
||||
d.id = desire_id or uuid.uuid4()
|
||||
d.status = status
|
||||
d.heat_score = heat_score
|
||||
d.cluster_count = 0 # default before annotation
|
||||
return d
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fulfill_desire — happy path structure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFulfillHappyPath:
|
||||
"""Verify the fulfill endpoint's happy-path logic via source analysis."""
|
||||
|
||||
def test_fulfill_sets_status_to_fulfilled(self):
|
||||
"""The endpoint sets desire.status = 'fulfilled' on success."""
|
||||
source = _router_source()
|
||||
assert 'desire.status = "fulfilled"' in source
|
||||
|
||||
def test_fulfill_sets_fulfilled_by_shader(self):
|
||||
"""The endpoint sets desire.fulfilled_by_shader = shader_id."""
|
||||
source = _router_source()
|
||||
assert "desire.fulfilled_by_shader = shader_id" in source
|
||||
|
||||
def test_fulfill_sets_fulfilled_at_timestamp(self):
|
||||
"""The endpoint sets desire.fulfilled_at to current UTC time."""
|
||||
source = _router_source()
|
||||
assert "desire.fulfilled_at" in source
|
||||
assert "datetime.now(timezone.utc)" in source
|
||||
|
||||
def test_fulfill_returns_status_response(self):
|
||||
"""The endpoint returns a dict with status, desire_id, shader_id."""
|
||||
source = _router_source()
|
||||
assert '"status": "fulfilled"' in source
|
||||
assert '"desire_id"' in source
|
||||
assert '"shader_id"' in source
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fulfill_desire — error paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFulfillDesireNotFound:
|
||||
"""404 when desire doesn't exist."""
|
||||
|
||||
def test_desire_not_found_raises_404(self):
|
||||
source = _router_source()
|
||||
# After desire lookup, checks scalar_one_or_none result
|
||||
assert "Desire not found" in source
|
||||
|
||||
|
||||
class TestFulfillDesireNotOpen:
|
||||
"""400 when desire is not in 'open' status."""
|
||||
|
||||
def test_desire_not_open_check_exists(self):
|
||||
source = _router_source()
|
||||
assert 'desire.status != "open"' in source
|
||||
|
||||
def test_desire_not_open_error_message(self):
|
||||
source = _router_source()
|
||||
assert "Desire is not open" in source
|
||||
|
||||
|
||||
class TestFulfillShaderNotFound:
|
||||
"""404 when shader_id doesn't match any shader."""
|
||||
|
||||
def test_shader_lookup_exists(self):
|
||||
source = _router_source()
|
||||
assert "select(Shader).where(Shader.id == shader_id)" in source
|
||||
|
||||
def test_shader_not_found_raises_404(self):
|
||||
source = _router_source()
|
||||
assert "Shader not found" in source
|
||||
|
||||
|
||||
class TestFulfillShaderNotPublished:
|
||||
"""400 when shader status is not 'published'."""
|
||||
|
||||
def test_shader_status_validation(self):
|
||||
source = _router_source()
|
||||
assert 'shader.status != "published"' in source
|
||||
|
||||
def test_shader_not_published_error_message(self):
|
||||
source = _router_source()
|
||||
assert "Shader must be published to fulfill a desire" in source
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cluster_count annotation — logic unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestClusterCountAnnotation:
|
||||
"""Verify cluster_count annotation logic patterns."""
|
||||
|
||||
def test_list_desires_has_batch_cluster_query(self):
|
||||
"""list_desires uses a batch query with ANY(:desire_ids)."""
|
||||
source = _router_source()
|
||||
assert "ANY(:desire_ids)" in source
|
||||
assert "desire_clusters dc1" in source
|
||||
assert "desire_clusters dc2" in source
|
||||
|
||||
def test_list_desires_avoids_n_plus_1(self):
|
||||
"""Cluster counts are fetched in a single batch, not per-desire."""
|
||||
source = _router_source()
|
||||
# The pattern: build dict from batch query, then loop to annotate
|
||||
assert "cluster_counts = {" in source
|
||||
assert "cluster_counts.get(d.id, 0)" in source
|
||||
|
||||
def test_list_desires_skips_cluster_query_when_empty(self):
|
||||
"""When no desires are returned, cluster query is skipped."""
|
||||
source = _router_source()
|
||||
assert "if desire_ids:" in source
|
||||
|
||||
def test_get_desire_annotates_single_cluster_count(self):
|
||||
"""get_desire runs a cluster count query for the single desire."""
|
||||
source = _router_source()
|
||||
# Should have a cluster query scoped to a single desire_id
|
||||
assert "WHERE dc1.desire_id = :desire_id" in source
|
||||
|
||||
def test_annotation_loop_sets_default_zero(self):
|
||||
"""Desires without cluster entries default to cluster_count = 0."""
|
||||
source = _router_source()
|
||||
assert "cluster_counts.get(d.id, 0)" in source
|
||||
|
||||
def test_annotation_loop_logic(self):
|
||||
"""Unit test: the annotation loop correctly maps cluster counts to desires."""
|
||||
# Simulate the annotation loop from list_desires
|
||||
d1 = _make_mock_desire()
|
||||
d2 = _make_mock_desire()
|
||||
d3 = _make_mock_desire()
|
||||
desires = [d1, d2, d3]
|
||||
|
||||
# Simulate cluster query result: d1 has 3 in cluster, d3 has 2
|
||||
cluster_counts = {d1.id: 3, d3.id: 2}
|
||||
|
||||
# This is the exact logic from the router
|
||||
for d in desires:
|
||||
d.cluster_count = cluster_counts.get(d.id, 0)
|
||||
|
||||
assert d1.cluster_count == 3
|
||||
assert d2.cluster_count == 0 # not in any cluster
|
||||
assert d3.cluster_count == 2
|
||||
|
||||
def test_get_desire_cluster_count_fallback(self):
|
||||
"""get_desire sets cluster_count=0 when no cluster row exists."""
|
||||
source = _router_source()
|
||||
# The router checks `row[0] if row else 0`
|
||||
assert "row[0] if row else 0" in source
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema field verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDesirePublicSchema:
|
||||
"""Verify DesirePublic schema has the cluster_count field."""
|
||||
|
||||
def test_cluster_count_field_in_schema_source(self):
|
||||
"""DesirePublic schema source contains cluster_count field."""
|
||||
source = _schema_source()
|
||||
assert "cluster_count" in source
|
||||
|
||||
def test_cluster_count_default_zero(self):
|
||||
"""cluster_count defaults to 0 in the schema."""
|
||||
source = _schema_source()
|
||||
assert "cluster_count: int = 0" in source
|
||||
|
||||
def test_schema_from_attributes_enabled(self):
|
||||
"""DesirePublic uses from_attributes=True for ORM compatibility."""
|
||||
source = _schema_source()
|
||||
# Find the DesirePublic class section
|
||||
desire_public_idx = source.index("class DesirePublic")
|
||||
desire_public_section = source[desire_public_idx:desire_public_idx + 200]
|
||||
assert "from_attributes=True" in desire_public_section
|
||||
|
||||
def test_cluster_count_pydantic_model(self):
|
||||
"""DesirePublic schema has cluster_count as an int field with default 0."""
|
||||
source = _schema_source()
|
||||
# Find the DesirePublic class and verify cluster_count is between
|
||||
# heat_score and fulfilled_by_shader (correct field ordering)
|
||||
desire_idx = source.index("class DesirePublic")
|
||||
desire_section = source[desire_idx:desire_idx + 500]
|
||||
heat_pos = desire_section.index("heat_score")
|
||||
cluster_pos = desire_section.index("cluster_count")
|
||||
fulfilled_pos = desire_section.index("fulfilled_by_shader")
|
||||
assert heat_pos < cluster_pos < fulfilled_pos, (
|
||||
"cluster_count should be between heat_score and fulfilled_by_shader"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wiring assertions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFulfillmentWiring:
|
||||
"""Structural assertions that the router is properly wired."""
|
||||
|
||||
def test_router_imports_shader_model(self):
|
||||
"""desires.py imports Shader for shader validation."""
|
||||
source = _router_source()
|
||||
assert "Shader" in source.split("\n")[8] # near top imports
|
||||
|
||||
def test_router_imports_text_from_sqlalchemy(self):
|
||||
"""desires.py imports text from sqlalchemy for raw SQL."""
|
||||
source = _router_source()
|
||||
assert "from sqlalchemy import" in source
|
||||
assert "text" in source
|
||||
|
||||
def test_fulfill_endpoint_requires_auth(self):
|
||||
"""fulfill_desire uses get_current_user dependency."""
|
||||
source = _router_source()
|
||||
# Find the fulfill_desire function
|
||||
fulfill_idx = source.index("async def fulfill_desire")
|
||||
fulfill_section = source[fulfill_idx:fulfill_idx + 500]
|
||||
assert "get_current_user" in fulfill_section
|
||||
|
||||
def test_fulfill_endpoint_takes_shader_id_param(self):
|
||||
"""fulfill_desire accepts shader_id as a query parameter."""
|
||||
source = _router_source()
|
||||
fulfill_idx = source.index("async def fulfill_desire")
|
||||
fulfill_section = source[fulfill_idx:fulfill_idx + 300]
|
||||
assert "shader_id" in fulfill_section
|
||||
|
||||
def test_list_desires_returns_desire_public(self):
|
||||
"""list_desires endpoint uses DesirePublic response model."""
|
||||
source = _router_source()
|
||||
assert "response_model=list[DesirePublic]" in source
|
||||
|
||||
def test_get_desire_returns_desire_public(self):
|
||||
"""get_desire endpoint uses DesirePublic response model."""
|
||||
source = _router_source()
|
||||
# Find the get_desire endpoint specifically
|
||||
lines = source.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
if "async def get_desire" in line:
|
||||
# Check the decorator line above
|
||||
decorator_line = lines[i - 1]
|
||||
assert "response_model=DesirePublic" in decorator_line
|
||||
break
|
||||
412
services/api/tests/test_integration.py
Normal file
412
services/api/tests/test_integration.py
Normal file
|
|
@ -0,0 +1,412 @@
|
|||
"""Integration tests — end-to-end acceptance scenarios through FastAPI.
|
||||
|
||||
Uses async SQLite test database, real FastAPI endpoint handlers,
|
||||
and dependency overrides for auth and Celery worker.
|
||||
|
||||
Test classes:
|
||||
TestInfrastructureSmoke — proves test infra works (T01)
|
||||
TestClusteringScenario — clustering + heat elevation via API (T02)
|
||||
TestFulfillmentScenario — desire fulfillment lifecycle (T02)
|
||||
TestMCPFieldPassthrough — MCP tool field passthrough (T02, source-level)
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy import select, update
|
||||
|
||||
# ── Smoke Test: proves integration infrastructure works ───
|
||||
|
||||
|
||||
class TestInfrastructureSmoke:
|
||||
"""Verify that the integration test infrastructure (DB, client, auth, Celery mock) works."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_read_desire(self, client: AsyncClient):
|
||||
"""POST a desire, then GET it back — proves DB, serialization, auth override, and Celery mock."""
|
||||
# Create a desire
|
||||
response = await client.post(
|
||||
"/api/v1/desires",
|
||||
json={"prompt_text": "glowing neon wireframe city"},
|
||||
)
|
||||
assert response.status_code == 201, f"Expected 201, got {response.status_code}: {response.text}"
|
||||
data = response.json()
|
||||
desire_id = data["id"]
|
||||
assert data["prompt_text"] == "glowing neon wireframe city"
|
||||
assert data["status"] == "open"
|
||||
|
||||
# Read it back
|
||||
response = await client.get(f"/api/v1/desires/{desire_id}")
|
||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}"
|
||||
data = response.json()
|
||||
assert data["id"] == desire_id
|
||||
assert data["prompt_text"] == "glowing neon wireframe city"
|
||||
assert data["heat_score"] == 1.0
|
||||
assert data["cluster_count"] == 0
|
||||
|
||||
|
||||
# ── Clustering Scenario ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestClusteringScenario:
|
||||
"""Prove that clustered desires have elevated heat and cluster_count via the API.
|
||||
|
||||
Strategy: POST desires through the API, then directly insert DesireCluster
|
||||
rows and update heat_score in the test DB (simulating what the Celery worker
|
||||
pipeline does). Verify via GET /api/v1/desires/{id} that the API returns
|
||||
correct heat_score and cluster_count.
|
||||
|
||||
Note: list_desires uses PostgreSQL ANY(:desire_ids) which doesn't work in
|
||||
SQLite, so we verify via individual GET requests instead.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_similar_desires_cluster_and_elevate_heat(
|
||||
self, client: AsyncClient, db_session
|
||||
):
|
||||
"""Create 3 desires, cluster them, elevate heat, verify API returns correct data."""
|
||||
from app.models.models import Desire, DesireCluster
|
||||
|
||||
# Create 3 desires via API
|
||||
desire_ids = []
|
||||
prompts = [
|
||||
"neon fractal explosion in deep space",
|
||||
"colorful fractal burst cosmic background",
|
||||
"glowing fractal nova against dark stars",
|
||||
]
|
||||
for prompt in prompts:
|
||||
resp = await client.post(
|
||||
"/api/v1/desires", json={"prompt_text": prompt}
|
||||
)
|
||||
assert resp.status_code == 201, f"Create failed: {resp.text}"
|
||||
desire_ids.append(resp.json()["id"])
|
||||
|
||||
# Simulate clustering: insert DesireCluster rows linking all 3 to one cluster
|
||||
cluster_id = uuid.uuid4()
|
||||
for did in desire_ids:
|
||||
dc = DesireCluster(
|
||||
cluster_id=cluster_id,
|
||||
desire_id=uuid.UUID(did),
|
||||
similarity=0.88,
|
||||
)
|
||||
db_session.add(dc)
|
||||
|
||||
# Simulate heat recalculation: update heat_score on all 3 desires
|
||||
for did in desire_ids:
|
||||
await db_session.execute(
|
||||
update(Desire)
|
||||
.where(Desire.id == uuid.UUID(did))
|
||||
.values(heat_score=3.0)
|
||||
)
|
||||
await db_session.flush()
|
||||
|
||||
# Verify each desire via GET shows correct heat_score and cluster_count
|
||||
for did in desire_ids:
|
||||
resp = await client.get(f"/api/v1/desires/{did}")
|
||||
assert resp.status_code == 200, f"GET {did} failed: {resp.text}"
|
||||
data = resp.json()
|
||||
assert data["heat_score"] >= 3.0, (
|
||||
f"Desire {did} heat_score={data['heat_score']}, expected >= 3.0"
|
||||
)
|
||||
assert data["cluster_count"] >= 3, (
|
||||
f"Desire {did} cluster_count={data['cluster_count']}, expected >= 3"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lone_desire_has_default_heat(self, client: AsyncClient):
|
||||
"""A single desire without clustering has heat_score=1.0 and cluster_count=0."""
|
||||
resp = await client.post(
|
||||
"/api/v1/desires",
|
||||
json={"prompt_text": "unique standalone art concept"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
desire_id = resp.json()["id"]
|
||||
|
||||
resp = await client.get(f"/api/v1/desires/{desire_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["heat_score"] == 1.0, f"Expected heat_score=1.0, got {data['heat_score']}"
|
||||
assert data["cluster_count"] == 0, f"Expected cluster_count=0, got {data['cluster_count']}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_desires_sorted_by_heat_descending(
|
||||
self, client: AsyncClient, db_session
|
||||
):
|
||||
"""When fetching desires, high-heat desires appear before low-heat ones.
|
||||
|
||||
Uses individual GET since list_desires relies on PostgreSQL ANY().
|
||||
Verifies the ordering guarantee via direct heat_score comparison.
|
||||
"""
|
||||
from app.models.models import Desire, DesireCluster
|
||||
|
||||
# Create a "hot" desire and cluster it
|
||||
hot_resp = await client.post(
|
||||
"/api/v1/desires",
|
||||
json={"prompt_text": "blazing hot fractal vortex"},
|
||||
)
|
||||
assert hot_resp.status_code == 201
|
||||
hot_id = hot_resp.json()["id"]
|
||||
|
||||
# Simulate clustering for hot desire
|
||||
cluster_id = uuid.uuid4()
|
||||
dc = DesireCluster(
|
||||
cluster_id=cluster_id,
|
||||
desire_id=uuid.UUID(hot_id),
|
||||
similarity=0.90,
|
||||
)
|
||||
db_session.add(dc)
|
||||
await db_session.execute(
|
||||
update(Desire)
|
||||
.where(Desire.id == uuid.UUID(hot_id))
|
||||
.values(heat_score=5.0)
|
||||
)
|
||||
await db_session.flush()
|
||||
|
||||
# Create a "cold" desire (no clustering)
|
||||
cold_resp = await client.post(
|
||||
"/api/v1/desires",
|
||||
json={"prompt_text": "calm minimal zen garden"},
|
||||
)
|
||||
assert cold_resp.status_code == 201
|
||||
cold_id = cold_resp.json()["id"]
|
||||
|
||||
# Verify hot desire has higher heat than cold
|
||||
hot_data = (await client.get(f"/api/v1/desires/{hot_id}")).json()
|
||||
cold_data = (await client.get(f"/api/v1/desires/{cold_id}")).json()
|
||||
assert hot_data["heat_score"] > cold_data["heat_score"], (
|
||||
f"Hot ({hot_data['heat_score']}) should be > Cold ({cold_data['heat_score']})"
|
||||
)
|
||||
|
||||
|
||||
# ── Fulfillment Scenario ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestFulfillmentScenario:
|
||||
"""Prove desire fulfillment transitions status and links to a shader."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fulfill_desire_transitions_status(
|
||||
self, client: AsyncClient, db_session
|
||||
):
|
||||
"""Create desire, insert published shader, fulfill, verify status transition."""
|
||||
from app.models.models import Shader
|
||||
|
||||
# Create desire
|
||||
resp = await client.post(
|
||||
"/api/v1/desires",
|
||||
json={"prompt_text": "ethereal particle waterfall"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
desire_id = resp.json()["id"]
|
||||
|
||||
# Insert a published shader directly in test DB
|
||||
shader_id = uuid.uuid4()
|
||||
shader = Shader(
|
||||
id=shader_id,
|
||||
title="Particle Waterfall",
|
||||
glsl_code="void mainImage(out vec4 c, in vec2 f) { c = vec4(0); }",
|
||||
status="published",
|
||||
author_id=None,
|
||||
)
|
||||
db_session.add(shader)
|
||||
await db_session.flush()
|
||||
|
||||
# Fulfill the desire
|
||||
resp = await client.post(
|
||||
f"/api/v1/desires/{desire_id}/fulfill",
|
||||
params={"shader_id": str(shader_id)},
|
||||
)
|
||||
assert resp.status_code == 200, f"Fulfill failed: {resp.text}"
|
||||
data = resp.json()
|
||||
assert data["status"] == "fulfilled"
|
||||
assert data["desire_id"] == desire_id
|
||||
assert data["shader_id"] == str(shader_id)
|
||||
|
||||
# Verify read-back shows fulfilled status and linked shader
|
||||
resp = await client.get(f"/api/v1/desires/{desire_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "fulfilled"
|
||||
assert data["fulfilled_by_shader"] == str(shader_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fulfill_requires_published_shader(
|
||||
self, client: AsyncClient, db_session
|
||||
):
|
||||
"""Fulfilling with a draft shader returns 400."""
|
||||
from app.models.models import Shader
|
||||
|
||||
# Create desire
|
||||
resp = await client.post(
|
||||
"/api/v1/desires",
|
||||
json={"prompt_text": "glitch art mosaic pattern"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
desire_id = resp.json()["id"]
|
||||
|
||||
# Insert a draft shader
|
||||
shader_id = uuid.uuid4()
|
||||
shader = Shader(
|
||||
id=shader_id,
|
||||
title="Draft Mosaic",
|
||||
glsl_code="void mainImage(out vec4 c, in vec2 f) { c = vec4(1); }",
|
||||
status="draft",
|
||||
author_id=None,
|
||||
)
|
||||
db_session.add(shader)
|
||||
await db_session.flush()
|
||||
|
||||
# Attempt fulfill — should fail
|
||||
resp = await client.post(
|
||||
f"/api/v1/desires/{desire_id}/fulfill",
|
||||
params={"shader_id": str(shader_id)},
|
||||
)
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}: {resp.text}"
|
||||
assert "published" in resp.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fulfill_already_fulfilled_returns_400(
|
||||
self, client: AsyncClient, db_session
|
||||
):
|
||||
"""Fulfilling an already-fulfilled desire returns 400."""
|
||||
from app.models.models import Shader
|
||||
|
||||
# Create desire
|
||||
resp = await client.post(
|
||||
"/api/v1/desires",
|
||||
json={"prompt_text": "recursive mirror tunnel"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
desire_id = resp.json()["id"]
|
||||
|
||||
# Insert published shader
|
||||
shader_id = uuid.uuid4()
|
||||
shader = Shader(
|
||||
id=shader_id,
|
||||
title="Mirror Tunnel",
|
||||
glsl_code="void mainImage(out vec4 c, in vec2 f) { c = vec4(0.5); }",
|
||||
status="published",
|
||||
author_id=None,
|
||||
)
|
||||
db_session.add(shader)
|
||||
await db_session.flush()
|
||||
|
||||
# First fulfill — should succeed
|
||||
resp = await client.post(
|
||||
f"/api/v1/desires/{desire_id}/fulfill",
|
||||
params={"shader_id": str(shader_id)},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Second fulfill — should fail
|
||||
resp = await client.post(
|
||||
f"/api/v1/desires/{desire_id}/fulfill",
|
||||
params={"shader_id": str(shader_id)},
|
||||
)
|
||||
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}: {resp.text}"
|
||||
assert "not open" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
# ── MCP Field Passthrough (source-level) ─────────────────────
|
||||
|
||||
|
||||
class TestMCPFieldPassthrough:
|
||||
"""Verify MCP server tools pass through all required fields via source inspection.
|
||||
|
||||
The MCP server runs as a separate process and can't be tested through
|
||||
FastAPI TestClient. These tests verify the source code structure to ensure
|
||||
field passthrough is correct.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _read_mcp_server_source(cls) -> str:
|
||||
"""Read the MCP server source file."""
|
||||
# From services/api/tests/ → up 3 to services/ → mcp/server.py
|
||||
mcp_path = Path(__file__).resolve().parent.parent.parent / "mcp" / "server.py"
|
||||
assert mcp_path.exists(), f"MCP server.py not found at {mcp_path}"
|
||||
return mcp_path.read_text()
|
||||
|
||||
def test_get_desire_queue_includes_cluster_fields(self):
|
||||
"""get_desire_queue maps cluster_count, heat_score, style_hints, fulfilled_by_shader."""
|
||||
source = self._read_mcp_server_source()
|
||||
|
||||
# Verify get_desire_queue function exists
|
||||
assert "async def get_desire_queue" in source, "get_desire_queue function not found"
|
||||
|
||||
# Extract the function body (from def to next @mcp or end)
|
||||
fn_start = source.index("async def get_desire_queue")
|
||||
# Find next top-level decorator or end of file
|
||||
next_decorator = source.find("\n@mcp.", fn_start + 1)
|
||||
if next_decorator == -1:
|
||||
fn_body = source[fn_start:]
|
||||
else:
|
||||
fn_body = source[fn_start:next_decorator]
|
||||
|
||||
required_fields = ["cluster_count", "heat_score", "style_hints", "fulfilled_by_shader"]
|
||||
for field in required_fields:
|
||||
assert field in fn_body, (
|
||||
f"get_desire_queue missing field '{field}' in response mapping"
|
||||
)
|
||||
|
||||
def test_fulfill_desire_tool_exists(self):
|
||||
"""fulfill_desire function exists and uses api_post_with_params."""
|
||||
source = self._read_mcp_server_source()
|
||||
|
||||
assert "async def fulfill_desire" in source, "fulfill_desire function not found"
|
||||
|
||||
# Extract function body
|
||||
fn_start = source.index("async def fulfill_desire")
|
||||
next_decorator = source.find("\n@mcp.", fn_start + 1)
|
||||
if next_decorator == -1:
|
||||
fn_body = source[fn_start:]
|
||||
else:
|
||||
fn_body = source[fn_start:next_decorator]
|
||||
|
||||
assert "api_post_with_params" in fn_body, (
|
||||
"fulfill_desire should call api_post_with_params"
|
||||
)
|
||||
|
||||
def test_fulfill_desire_returns_structured_response(self):
|
||||
"""fulfill_desire returns JSON with status, desire_id, shader_id."""
|
||||
source = self._read_mcp_server_source()
|
||||
|
||||
fn_start = source.index("async def fulfill_desire")
|
||||
next_decorator = source.find("\n@mcp.", fn_start + 1)
|
||||
if next_decorator == -1:
|
||||
fn_body = source[fn_start:]
|
||||
else:
|
||||
fn_body = source[fn_start:next_decorator]
|
||||
|
||||
# Check the success-path return contains the required fields
|
||||
required_keys = ['"status"', '"desire_id"', '"shader_id"']
|
||||
for key in required_keys:
|
||||
assert key in fn_body, (
|
||||
f"fulfill_desire response missing key {key}"
|
||||
)
|
||||
|
||||
def test_submit_shader_accepts_fulfills_desire_id(self):
|
||||
"""submit_shader accepts fulfills_desire_id parameter and passes it to the API."""
|
||||
source = self._read_mcp_server_source()
|
||||
|
||||
assert "async def submit_shader" in source, "submit_shader function not found"
|
||||
|
||||
fn_start = source.index("async def submit_shader")
|
||||
next_decorator = source.find("\n@mcp.", fn_start + 1)
|
||||
if next_decorator == -1:
|
||||
fn_body = source[fn_start:]
|
||||
else:
|
||||
fn_body = source[fn_start:next_decorator]
|
||||
|
||||
# Verify parameter exists in function signature
|
||||
assert "fulfills_desire_id" in fn_body, (
|
||||
"submit_shader should accept fulfills_desire_id parameter"
|
||||
)
|
||||
# Verify it's passed to the payload
|
||||
assert 'payload["fulfills_desire_id"]' in fn_body or \
|
||||
'"fulfills_desire_id"' in fn_body, (
|
||||
"submit_shader should include fulfills_desire_id in the API payload"
|
||||
)
|
||||
|
|
@ -46,6 +46,11 @@ export default function Bounties() {
|
|||
<span className="flex items-center gap-1">
|
||||
🔥 Heat: {desire.heat_score.toFixed(1)}
|
||||
</span>
|
||||
{desire.cluster_count > 1 && (
|
||||
<span className="text-purple-400">
|
||||
👥 {desire.cluster_count} similar
|
||||
</span>
|
||||
)}
|
||||
{desire.tip_amount_cents > 0 && (
|
||||
<span className="text-green-400">
|
||||
💰 ${(desire.tip_amount_cents / 100).toFixed(2)} tip
|
||||
|
|
|
|||
|
|
@ -49,6 +49,11 @@ export default function BountyDetail() {
|
|||
<h1 className="text-xl font-bold">{desire.prompt_text}</h1>
|
||||
<div className="flex items-center gap-3 mt-3 text-sm text-gray-500">
|
||||
<span>🔥 Heat: {desire.heat_score.toFixed(1)}</span>
|
||||
{desire.cluster_count > 1 && (
|
||||
<span className="text-purple-400">
|
||||
👥 {desire.cluster_count} similar
|
||||
</span>
|
||||
)}
|
||||
{desire.tip_amount_cents > 0 && (
|
||||
<span className="text-green-400">
|
||||
💰 ${(desire.tip_amount_cents / 100).toFixed(2)} bounty
|
||||
|
|
@ -77,7 +82,7 @@ export default function BountyDetail() {
|
|||
|
||||
{desire.status === 'open' && (
|
||||
<div className="mt-6 pt-4 border-t border-surface-3">
|
||||
<Link to="/editor" className="btn-primary">
|
||||
<Link to={`/editor?fulfill=${desire.id}`} className="btn-primary">
|
||||
Fulfill this Desire →
|
||||
</Link>
|
||||
<p className="text-xs text-gray-500 mt-2">
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
*/
|
||||
|
||||
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import { useParams, useNavigate } from 'react-router-dom';
|
||||
import { useParams, useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { useQuery } from '@tanstack/react-query';
|
||||
import api from '@/lib/api';
|
||||
import { useAuthStore } from '@/stores/auth';
|
||||
|
|
@ -38,8 +38,13 @@ void mainImage(out vec4 fragColor, in vec2 fragCoord) {
|
|||
export default function Editor() {
|
||||
const { id } = useParams<{ id: string }>();
|
||||
const navigate = useNavigate();
|
||||
const [searchParams] = useSearchParams();
|
||||
const { isAuthenticated, user } = useAuthStore();
|
||||
|
||||
// Fulfillment context — read once from URL, persist in ref so it survives navigation
|
||||
const fulfillId = searchParams.get('fulfill');
|
||||
const fulfillDesireId = useRef(fulfillId);
|
||||
|
||||
const [code, setCode] = useState(DEFAULT_SHADER);
|
||||
const [liveCode, setLiveCode] = useState(DEFAULT_SHADER);
|
||||
const [title, setTitle] = useState('Untitled Shader');
|
||||
|
|
@ -70,6 +75,16 @@ export default function Editor() {
|
|||
enabled: !!id,
|
||||
});
|
||||
|
||||
// Fetch desire context when fulfilling
|
||||
const { data: fulfillDesire } = useQuery({
|
||||
queryKey: ['desire', fulfillDesireId.current],
|
||||
queryFn: async () => {
|
||||
const { data } = await api.get(`/desires/${fulfillDesireId.current}`);
|
||||
return data;
|
||||
},
|
||||
enabled: !!fulfillDesireId.current,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (existingShader) {
|
||||
setCode(existingShader.glsl_code);
|
||||
|
|
@ -160,7 +175,10 @@ export default function Editor() {
|
|||
}
|
||||
} else {
|
||||
// Create new shader
|
||||
const { data } = await api.post('/shaders', payload);
|
||||
const { data } = await api.post('/shaders', {
|
||||
...payload,
|
||||
fulfills_desire_id: fulfillDesireId.current || undefined,
|
||||
});
|
||||
if (publishStatus === 'published') {
|
||||
navigate(`/shader/${data.id}`);
|
||||
} else {
|
||||
|
|
@ -278,6 +296,17 @@ export default function Editor() {
|
|||
</div>
|
||||
)}
|
||||
|
||||
{/* Desire fulfillment context banner */}
|
||||
{fulfillDesire && (
|
||||
<div className="px-4 py-3 bg-amber-600/10 border-b border-amber-600/20 flex items-center gap-3">
|
||||
<span className="text-amber-400 text-sm font-medium">🎯 Fulfilling desire:</span>
|
||||
<span className="text-gray-300 text-sm flex-1">{fulfillDesire.prompt_text}</span>
|
||||
{fulfillDesire.style_hints && (
|
||||
<span className="text-xs text-gray-500">Style hints available</span>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Split pane: editor + drag handle + preview */}
|
||||
<div ref={containerRef} className="flex-1 flex min-h-0">
|
||||
{/* Code editor */}
|
||||
|
|
|
|||
|
|
@ -41,6 +41,14 @@ async def api_post(path: str, data: dict):
|
|||
return resp.json()
|
||||
|
||||
|
||||
async def api_post_with_params(path: str, params: dict):
|
||||
"""POST with query parameters (not JSON body). Used for endpoints like fulfill."""
|
||||
async with httpx.AsyncClient(base_url=API_BASE, timeout=15.0) as client:
|
||||
resp = await client.post(f"/api/v1{path}", params=params, headers=INTERNAL_AUTH)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def api_put(path: str, data: dict):
|
||||
async with httpx.AsyncClient(base_url=API_BASE, timeout=15.0) as client:
|
||||
resp = await client.put(f"/api/v1{path}", json=data, headers=INTERNAL_AUTH)
|
||||
|
|
@ -121,7 +129,8 @@ async def get_shader_version_code(shader_id: str, version_number: int) -> str:
|
|||
|
||||
@mcp.tool()
|
||||
async def submit_shader(title: str, glsl_code: str, description: str = "", tags: str = "",
|
||||
shader_type: str = "2d", status: str = "published") -> str:
|
||||
shader_type: str = "2d", status: str = "published",
|
||||
fulfills_desire_id: str = "") -> str:
|
||||
"""Submit a new GLSL shader to Fractafrag.
|
||||
|
||||
Shader format: void mainImage(out vec4 fragColor, in vec2 fragCoord)
|
||||
|
|
@ -134,11 +143,15 @@ async def submit_shader(title: str, glsl_code: str, description: str = "", tags:
|
|||
tags: Comma-separated tags (e.g. "fractal,noise,colorful")
|
||||
shader_type: 2d, 3d, or audio-reactive
|
||||
status: "published" to go live, "draft" to save privately
|
||||
fulfills_desire_id: Optional UUID of a desire this shader fulfills
|
||||
"""
|
||||
tag_list = [t.strip() for t in tags.split(",") if t.strip()] if tags else []
|
||||
result = await api_post("/shaders", {"title": title, "glsl_code": glsl_code,
|
||||
payload = {"title": title, "glsl_code": glsl_code,
|
||||
"description": description, "tags": tag_list,
|
||||
"shader_type": shader_type, "status": status})
|
||||
"shader_type": shader_type, "status": status}
|
||||
if fulfills_desire_id:
|
||||
payload["fulfills_desire_id"] = fulfills_desire_id
|
||||
result = await api_post("/shaders", payload)
|
||||
return json.dumps({"id": result["id"], "title": result["title"],
|
||||
"status": result.get("status"), "current_version": result.get("current_version", 1),
|
||||
"message": f"Shader '{result['title']}' created.", "url": f"/shader/{result['id']}"})
|
||||
|
|
@ -210,7 +223,11 @@ async def get_similar_shaders(shader_id: str, limit: int = 10) -> str:
|
|||
|
||||
@mcp.tool()
|
||||
async def get_desire_queue(min_heat: float = 0, limit: int = 10) -> str:
|
||||
"""Get open shader desires/bounties. These are community requests.
|
||||
"""Get open shader desires/bounties with cluster context and style hints.
|
||||
|
||||
Returns community requests ranked by heat. Use cluster_count to identify
|
||||
high-demand desires (many similar requests). Use style_hints to understand
|
||||
the visual direction requested.
|
||||
|
||||
Args:
|
||||
min_heat: Minimum heat score (higher = more demand)
|
||||
|
|
@ -220,11 +237,42 @@ async def get_desire_queue(min_heat: float = 0, limit: int = 10) -> str:
|
|||
return json.dumps({"count": len(desires),
|
||||
"desires": [{"id": d["id"], "prompt_text": d["prompt_text"],
|
||||
"heat_score": d.get("heat_score", 0),
|
||||
"cluster_count": d.get("cluster_count", 0),
|
||||
"style_hints": d.get("style_hints"),
|
||||
"tip_amount_cents": d.get("tip_amount_cents", 0),
|
||||
"status": d.get("status")}
|
||||
"status": d.get("status"),
|
||||
"fulfilled_by_shader": d.get("fulfilled_by_shader")}
|
||||
for d in desires]})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def fulfill_desire(desire_id: str, shader_id: str) -> str:
|
||||
"""Mark a desire as fulfilled by linking it to a published shader.
|
||||
|
||||
The shader must be published. The desire must be open.
|
||||
Use get_desire_queue to find open desires, then submit_shader or
|
||||
use an existing shader ID to fulfill one.
|
||||
|
||||
Args:
|
||||
desire_id: UUID of the desire to fulfill
|
||||
shader_id: UUID of the published shader that fulfills this desire
|
||||
"""
|
||||
try:
|
||||
result = await api_post_with_params(
|
||||
f"/desires/{desire_id}/fulfill",
|
||||
{"shader_id": shader_id}
|
||||
)
|
||||
return json.dumps({"status": "fulfilled", "desire_id": desire_id,
|
||||
"shader_id": shader_id,
|
||||
"message": f"Desire {desire_id} fulfilled by shader {shader_id}."})
|
||||
except httpx.HTTPStatusError as e:
|
||||
try:
|
||||
error_detail = e.response.json().get("detail", str(e))
|
||||
except Exception:
|
||||
error_detail = str(e)
|
||||
return json.dumps({"error": error_detail, "status_code": e.response.status_code})
|
||||
|
||||
|
||||
@mcp.resource("fractafrag://platform-info")
|
||||
def platform_info() -> str:
|
||||
"""Platform overview and shader writing guidelines."""
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue