feat(M001): Desire Economy

Completed slices:
- S01: Desire Embedding & Clustering
- S02: Fulfillment Flow & Frontend

Branch: milestone/M001
This commit is contained in:
John Lightner 2026-03-25 02:22:50 -05:00
parent a5f0c0e093
commit 5936ab167e
19 changed files with 2612 additions and 19 deletions

28
Makefile Normal file
View file

@ -0,0 +1,28 @@
# Fractafrag — Docker Compose monorepo
# Common development commands
.PHONY: up down build logs test api-shell worker-shell db-shell
up:
docker compose up -d
down:
docker compose down
build:
docker compose build
logs:
docker compose logs -f
test:
docker compose exec api python -m pytest tests/ -v
api-shell:
docker compose exec api bash
worker-shell:
docker compose exec worker bash
db-shell:
docker compose exec postgres psql -U fracta -d fractafrag

6
services/api/=0.20.0 Normal file
View file

@ -0,0 +1,6 @@
Defaulting to user installation because normal site-packages is not writeable
Collecting aiosqlite
Using cached aiosqlite-0.22.1-py3-none-any.whl.metadata (4.3 kB)
Using cached aiosqlite-0.22.1-py3-none-any.whl (17 kB)
Installing collected packages: aiosqlite
Successfully installed aiosqlite-0.22.1

View file

@ -3,10 +3,10 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy import select, text
from app.database import get_db
from app.models import User, Desire
from app.models import User, Desire, Shader
from app.schemas import DesireCreate, DesirePublic
from app.middleware.auth import get_current_user, require_tier
@ -29,7 +29,24 @@ async def list_desires(
query = query.order_by(Desire.heat_score.desc()).limit(limit).offset(offset)
result = await db.execute(query)
return result.scalars().all()
desires = list(result.scalars().all())
# Batch-annotate cluster_count to avoid N+1 queries
desire_ids = [d.id for d in desires]
if desire_ids:
cluster_query = text("""
SELECT dc1.desire_id, COUNT(dc2.desire_id) as cluster_count
FROM desire_clusters dc1
JOIN desire_clusters dc2 ON dc1.cluster_id = dc2.cluster_id
WHERE dc1.desire_id = ANY(:desire_ids)
GROUP BY dc1.desire_id
""")
cluster_result = await db.execute(cluster_query, {"desire_ids": desire_ids})
cluster_counts = {row[0]: row[1] for row in cluster_result}
for d in desires:
d.cluster_count = cluster_counts.get(d.id, 0)
return desires
@router.get("/{desire_id}", response_model=DesirePublic)
@ -38,6 +55,18 @@ async def get_desire(desire_id: UUID, db: AsyncSession = Depends(get_db)):
desire = result.scalar_one_or_none()
if not desire:
raise HTTPException(status_code=404, detail="Desire not found")
# Annotate cluster_count for single desire
cluster_query = text("""
SELECT COUNT(dc2.desire_id) as cluster_count
FROM desire_clusters dc1
JOIN desire_clusters dc2 ON dc1.cluster_id = dc2.cluster_id
WHERE dc1.desire_id = :desire_id
""")
cluster_result = await db.execute(cluster_query, {"desire_id": desire_id})
row = cluster_result.first()
desire.cluster_count = row[0] if row else 0
return desire
@ -55,9 +84,9 @@ async def create_desire(
db.add(desire)
await db.flush()
# TODO: Embed prompt text (Track G)
# TODO: Check similarity clustering (Track G)
# TODO: Enqueue process_desire worker job (Track G)
# Fire-and-forget: enqueue embedding + clustering worker task
from app.worker import process_desire
process_desire.delay(str(desire.id))
return desire
@ -76,6 +105,13 @@ async def fulfill_desire(
if desire.status != "open":
raise HTTPException(status_code=400, detail="Desire is not open")
# Validate shader exists and is published
shader = (await db.execute(select(Shader).where(Shader.id == shader_id))).scalar_one_or_none()
if not shader:
raise HTTPException(status_code=404, detail="Shader not found")
if shader.status != "published":
raise HTTPException(status_code=400, detail="Shader must be published to fulfill a desire")
from datetime import datetime, timezone
desire.status = "fulfilled"
desire.fulfilled_by_shader = shader_id

View file

@ -184,6 +184,7 @@ class DesirePublic(BaseModel):
tip_amount_cents: int
status: str
heat_score: float
cluster_count: int = 0
fulfilled_by_shader: Optional[UUID]
fulfilled_at: Optional[datetime]
created_at: datetime

View file

@ -0,0 +1,406 @@
"""Clustering service — pgvector cosine similarity search and heat calculation.
Groups desires into clusters based on prompt embedding similarity.
Uses pgvector's <=> cosine distance operator for nearest-neighbor search.
Heat scores scale linearly with cluster size (more demand more visible).
Provides both async (for FastAPI) and sync (for Celery worker) variants.
"""
import logging
import uuid as uuid_mod
from uuid import UUID
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.models.models import DesireCluster
logger = logging.getLogger(__name__)
async def find_nearest_cluster(
embedding: list[float],
db: AsyncSession,
threshold: float = 0.82,
) -> tuple[UUID | None, float]:
"""Find the nearest existing desire cluster for an embedding vector.
Uses pgvector cosine distance (<=> operator). A threshold of 0.82 means
cosine_similarity >= 0.82, i.e., cosine_distance <= 0.18.
Returns:
(cluster_id, similarity) if a match is found within threshold,
(None, 0.0) if no match exists.
"""
max_distance = 1.0 - threshold
# Raw SQL for pgvector cosine distance — SQLAlchemy ORM doesn't natively
# support the <=> operator without extra configuration.
query = text("""
SELECT d.id AS desire_id,
(d.prompt_embedding <=> :embedding) AS distance
FROM desires d
WHERE d.prompt_embedding IS NOT NULL
AND (d.prompt_embedding <=> :embedding) <= :max_distance
ORDER BY distance ASC
LIMIT 1
""")
result = await db.execute(
query,
{"embedding": str(embedding), "max_distance": max_distance},
)
row = result.first()
if row is None:
logger.debug("No nearby cluster found (threshold=%.2f)", threshold)
return (None, 0.0)
# Found a nearby desire — look up its cluster membership
matched_desire_id = row.desire_id
similarity = 1.0 - row.distance
cluster_query = text("""
SELECT cluster_id FROM desire_clusters
WHERE desire_id = :desire_id
LIMIT 1
""")
cluster_result = await db.execute(
cluster_query, {"desire_id": matched_desire_id}
)
cluster_row = cluster_result.first()
if cluster_row is None:
# Nearby desire exists but isn't in a cluster — shouldn't normally
# happen but handle gracefully by treating as no match.
logger.warning(
"Desire %s is nearby (sim=%.3f) but has no cluster assignment",
matched_desire_id,
similarity,
)
return (None, 0.0)
logger.info(
"Found nearby cluster %s via desire %s (similarity=%.3f)",
cluster_row.cluster_id,
matched_desire_id,
similarity,
)
return (cluster_row.cluster_id, similarity)
async def create_cluster(desire_id: UUID, db: AsyncSession) -> UUID:
"""Create a new single-member cluster for a desire.
Returns the new cluster_id.
"""
cluster_id = uuid_mod.uuid4()
entry = DesireCluster(
cluster_id=cluster_id,
desire_id=desire_id,
similarity=1.0,
)
db.add(entry)
await db.flush()
logger.info("Created new cluster %s for desire %s", cluster_id, desire_id)
return cluster_id
async def add_to_cluster(
cluster_id: UUID,
desire_id: UUID,
similarity: float,
db: AsyncSession,
) -> None:
"""Add a desire to an existing cluster.
Uses INSERT ... ON CONFLICT DO NOTHING to safely handle re-processing
(idempotent won't duplicate if the desire is already in the cluster).
"""
insert_query = text("""
INSERT INTO desire_clusters (cluster_id, desire_id, similarity)
VALUES (:cluster_id, :desire_id, :similarity)
ON CONFLICT (cluster_id, desire_id) DO NOTHING
""")
await db.execute(insert_query, {
"cluster_id": cluster_id,
"desire_id": desire_id,
"similarity": similarity,
})
await db.flush()
logger.info(
"Added desire %s to cluster %s (similarity=%.3f)",
desire_id,
cluster_id,
similarity,
)
async def recalculate_heat(cluster_id: UUID, db: AsyncSession) -> float:
"""Recalculate heat scores for all desires in a cluster.
Heat = cluster_size (linear scaling). A 3-member cluster means each
desire in it gets heat_score = 3.0.
Returns the new heat score.
"""
# Count members in this cluster
count_query = text("""
SELECT COUNT(*) AS cnt FROM desire_clusters
WHERE cluster_id = :cluster_id
""")
result = await db.execute(count_query, {"cluster_id": cluster_id})
cluster_size = result.scalar_one()
heat_score = float(cluster_size)
# Update all desires in the cluster
update_query = text("""
UPDATE desires SET heat_score = :heat
WHERE id IN (
SELECT desire_id FROM desire_clusters
WHERE cluster_id = :cluster_id
)
""")
await db.execute(update_query, {
"heat": heat_score,
"cluster_id": cluster_id,
})
await db.flush()
logger.info(
"Recalculated heat for cluster %s: size=%d, heat_score=%.1f",
cluster_id,
cluster_size,
heat_score,
)
return heat_score
async def cluster_desire(
desire_id: UUID,
embedding: list[float],
db: AsyncSession,
) -> dict:
"""Orchestrate clustering for a single desire.
Flow: find_nearest_cluster add_to_cluster + recalculate_heat (if match)
or create_cluster (if no match).
Returns an observability dict:
{
"cluster_id": UUID,
"is_new": bool,
"heat_score": float,
}
"""
cluster_id, similarity = await find_nearest_cluster(embedding, db)
if cluster_id is not None:
# Join existing cluster
await add_to_cluster(cluster_id, desire_id, similarity, db)
heat_score = await recalculate_heat(cluster_id, db)
logger.info(
"Desire %s joined cluster %s (similarity=%.3f, heat=%.1f)",
desire_id,
cluster_id,
similarity,
heat_score,
)
return {
"cluster_id": cluster_id,
"is_new": False,
"heat_score": heat_score,
}
else:
# Create new single-member cluster
cluster_id = await create_cluster(desire_id, db)
logger.info(
"Desire %s started new cluster %s (heat=1.0)",
desire_id,
cluster_id,
)
return {
"cluster_id": cluster_id,
"is_new": True,
"heat_score": 1.0,
}
# ── Synchronous variants (for Celery worker context) ─────────────────────
def find_nearest_cluster_sync(
embedding: list[float],
session: Session,
threshold: float = 0.82,
) -> tuple[UUID | None, float]:
"""Sync variant of find_nearest_cluster for Celery worker context."""
max_distance = 1.0 - threshold
query = text("""
SELECT d.id AS desire_id,
(d.prompt_embedding <=> :embedding) AS distance
FROM desires d
WHERE d.prompt_embedding IS NOT NULL
AND (d.prompt_embedding <=> :embedding) <= :max_distance
ORDER BY distance ASC
LIMIT 1
""")
result = session.execute(
query,
{"embedding": str(embedding), "max_distance": max_distance},
)
row = result.first()
if row is None:
logger.debug("No nearby cluster found (threshold=%.2f)", threshold)
return (None, 0.0)
matched_desire_id = row.desire_id
similarity = 1.0 - row.distance
cluster_query = text("""
SELECT cluster_id FROM desire_clusters
WHERE desire_id = :desire_id
LIMIT 1
""")
cluster_result = session.execute(
cluster_query, {"desire_id": matched_desire_id}
)
cluster_row = cluster_result.first()
if cluster_row is None:
logger.warning(
"Desire %s is nearby (sim=%.3f) but has no cluster assignment",
matched_desire_id,
similarity,
)
return (None, 0.0)
logger.info(
"Found nearby cluster %s via desire %s (similarity=%.3f)",
cluster_row.cluster_id,
matched_desire_id,
similarity,
)
return (cluster_row.cluster_id, similarity)
def create_cluster_sync(desire_id: UUID, session: Session) -> UUID:
"""Sync variant of create_cluster for Celery worker context."""
cluster_id = uuid_mod.uuid4()
entry = DesireCluster(
cluster_id=cluster_id,
desire_id=desire_id,
similarity=1.0,
)
session.add(entry)
session.flush()
logger.info("Created new cluster %s for desire %s", cluster_id, desire_id)
return cluster_id
def add_to_cluster_sync(
cluster_id: UUID,
desire_id: UUID,
similarity: float,
session: Session,
) -> None:
"""Sync variant of add_to_cluster for Celery worker context."""
insert_query = text("""
INSERT INTO desire_clusters (cluster_id, desire_id, similarity)
VALUES (:cluster_id, :desire_id, :similarity)
ON CONFLICT (cluster_id, desire_id) DO NOTHING
""")
session.execute(insert_query, {
"cluster_id": cluster_id,
"desire_id": desire_id,
"similarity": similarity,
})
session.flush()
logger.info(
"Added desire %s to cluster %s (similarity=%.3f)",
desire_id,
cluster_id,
similarity,
)
def recalculate_heat_sync(cluster_id: UUID, session: Session) -> float:
"""Sync variant of recalculate_heat for Celery worker context."""
count_query = text("""
SELECT COUNT(*) AS cnt FROM desire_clusters
WHERE cluster_id = :cluster_id
""")
result = session.execute(count_query, {"cluster_id": cluster_id})
cluster_size = result.scalar_one()
heat_score = float(cluster_size)
update_query = text("""
UPDATE desires SET heat_score = :heat
WHERE id IN (
SELECT desire_id FROM desire_clusters
WHERE cluster_id = :cluster_id
)
""")
session.execute(update_query, {
"heat": heat_score,
"cluster_id": cluster_id,
})
session.flush()
logger.info(
"Recalculated heat for cluster %s: size=%d, heat_score=%.1f",
cluster_id,
cluster_size,
heat_score,
)
return heat_score
def cluster_desire_sync(
desire_id: UUID,
embedding: list[float],
session: Session,
) -> dict:
"""Sync orchestrator for clustering a single desire (Celery worker context).
Same flow as async cluster_desire but uses synchronous Session.
Returns:
{"cluster_id": UUID, "is_new": bool, "heat_score": float}
"""
cluster_id, similarity = find_nearest_cluster_sync(embedding, session)
if cluster_id is not None:
add_to_cluster_sync(cluster_id, desire_id, similarity, session)
heat_score = recalculate_heat_sync(cluster_id, session)
logger.info(
"Desire %s joined cluster %s (similarity=%.3f, heat=%.1f)",
desire_id,
cluster_id,
similarity,
heat_score,
)
return {
"cluster_id": cluster_id,
"is_new": False,
"heat_score": heat_score,
}
else:
cluster_id = create_cluster_sync(desire_id, session)
logger.info(
"Desire %s started new cluster %s (heat=1.0)",
desire_id,
cluster_id,
)
return {
"cluster_id": cluster_id,
"is_new": True,
"heat_score": 1.0,
}

View file

@ -0,0 +1,291 @@
"""Text embedding service using TF-IDF + TruncatedSVD.
Converts desire prompt text into 512-dimensional dense vectors suitable
for pgvector cosine similarity search. Pre-seeded with shader/visual-art
domain vocabulary so the model produces meaningful vectors even from a
single short text input.
"""
import logging
import time
import numpy as np
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer
logger = logging.getLogger(__name__)
# Shader / visual-art domain seed corpus.
# Gives the TF-IDF model a vocabulary foundation so it can produce
# meaningful vectors from short creative text descriptions.
# Target: 500+ unique TF-IDF features (unigrams + bigrams) to support
# near-512 SVD components without heavy padding.
_SEED_CORPUS: list[str] = [
"particle system fluid simulation dynamics motion",
"raymarching signed distance field sdf shapes volumes",
"procedural noise fractal pattern generation recursive",
"color palette gradient blend interpolation smooth",
"audio reactive frequency spectrum visualization beat",
"ragdoll physics dark moody atmosphere heavy",
"kaleidoscope symmetry rotation mirror reflection",
"voronoi cellular texture organic growth biological",
"bloom glow post processing effect luminance",
"retro pixel art scanlines crt monitor vintage",
"geometry morphing vertex displacement deformation mesh",
"wave propagation ripple interference oscillation",
"fire smoke volumetric rendering density fog",
"crystal refraction caustics light transparency",
"terrain heightmap erosion landscape mountain valley",
"shader glitch distortion databend corruption artifact",
"feedback loop recursive transformation iteration",
"physics collision rigid body dynamics impulse",
"abstract minimal geometric composition shape",
"aurora borealis atmospheric optical phenomenon sky",
"underwater caustics god rays depth ocean",
"cyberpunk neon wireframe grid futuristic urban",
"organic growth branching lsystem tree root",
"mandelbrot julia set fractal zoom iteration",
"cloth simulation soft body drape fabric textile",
"dissolve transition threshold mask binary cutoff",
"chromatic aberration lens distortion optical warp",
"shadow mapping ambient occlusion darkness depth",
"motion blur temporal accumulation streak trail",
"boids flocking swarm emergence collective behavior",
"reaction diffusion turing pattern spots stripes",
"perlin simplex worley noise texture procedural",
"voxel rendering isometric cube block pixel",
"gravity orbital celestial mechanics planet orbit",
"psychedelic trippy color shift hue rotation",
"spiral fibonacci golden ratio mathematical curve",
"explosion debris shatter fragment destruction impact",
"rain snow weather precipitation droplet splash",
"electric lightning bolt plasma energy discharge",
"tunnel infinite corridor perspective vanishing point",
"metaball blob isosurface marching cubes implicit",
"starfield galaxy nebula cosmic space stellar",
"shadow puppet silhouette outline contour edge",
"mosaic tessellation tile pattern hexagonal grid",
"hologram iridescent spectrum rainbow interference",
"ink watercolor paint brush stroke artistic",
"sand dune desert wind erosion granular",
"ice frost frozen crystal snowflake cold",
"magma lava volcanic molten heat flow",
"laser beam scanning projection line vector",
"DNA helix molecular biology strand protein",
"circuit board electronic trace signal digital",
"camouflage pattern dithering halftone screening dots",
"waveform synthesizer oscillator modulation frequency",
"topographic contour map elevation isoline level",
"origami fold paper crease geometric angular",
"stained glass window colorful segmented panels",
"smoke ring vortex toroidal turbulence curl",
"pendulum harmonic oscillation swing periodic cycle",
"cloud formation cumulus atmospheric convection wispy",
"ripple pond surface tension concentric circular",
"decay rust corrosion entropy degradation aging",
"fiber optic strand luminous filament glow",
"prism dispersion spectral separation wavelength band",
"radar sonar ping pulse echo scanning",
"compass rose navigation cardinal directional symbol",
"clock mechanism gear cog rotation mechanical",
"barcode matrix encoding data stripe identification",
"fingerprint unique biometric whorl ridge pattern",
"maze labyrinth path algorithm recursive backtrack",
"chess checkerboard alternating square pattern grid",
"domino cascade chain sequential trigger falling",
"balloon inflation expansion pressure sphere elastic",
"ribbon flowing fabric curve spline bezier",
"confetti celebration scatter random distribution joyful",
"ember spark ignition tiny particle hot",
"bubble foam soap iridescent sphere surface tension",
"whirlpool maelstrom spinning vortex fluid drain",
"mirage shimmer heat haze atmospheric refraction",
"echo reverberation delay repetition fading diminish",
"pulse heartbeat rhythmic expanding ring concentric",
"weave interlocking thread textile warp weft",
"honeycomb hexagonal efficient packing natural structure",
"coral reef branching organic marine growth colony",
"mushroom spore cap stem fungal network mycelium",
"neuron synapse network brain signal impulse dendrite",
"constellation star connect dot line celestial chart",
"seismograph earthquake wave amplitude vibration tremor",
"aurora curtain charged particle magnetic solar wind",
"tidal wave surge ocean force gravitational pull",
"sandstorm particle erosion wind desert visibility",
"volcanic eruption ash plume pyroclastic flow magma",
]
class EmbeddingService:
"""Produces 512-dim L2-normalized vectors from text using TF-IDF + SVD."""
def __init__(self) -> None:
self._vectorizer = TfidfVectorizer(
ngram_range=(1, 2),
max_features=10_000,
stop_words="english",
)
self._svd = TruncatedSVD(n_components=512, random_state=42)
self._corpus: list[str] = list(_SEED_CORPUS)
self._fitted: bool = False
def _fit_if_needed(self) -> None:
"""Fit the vectorizer and SVD on the seed corpus if not yet fitted."""
if self._fitted:
return
tfidf_matrix = self._vectorizer.fit_transform(self._corpus)
# SVD n_components must be < number of features in the TF-IDF matrix.
# If the corpus is too small, reduce SVD components temporarily.
n_features = tfidf_matrix.shape[1]
if n_features < self._svd.n_components:
logger.warning(
"TF-IDF produced only %d features, reducing SVD components "
"from %d to %d",
n_features,
self._svd.n_components,
n_features - 1,
)
self._svd = TruncatedSVD(
n_components=n_features - 1, random_state=42
)
self._svd.fit(tfidf_matrix)
self._fitted = True
def embed_text(self, text: str) -> list[float]:
"""Embed a single text into a 512-dim L2-normalized vector.
Args:
text: Input text to embed.
Returns:
List of 512 floats (L2-normalized).
Raises:
ValueError: If text is empty or whitespace-only.
"""
if not text or not text.strip():
raise ValueError("Cannot embed empty or whitespace-only text")
start = time.perf_counter()
self._fit_if_needed()
tfidf = self._vectorizer.transform([text])
svd_vec = self._svd.transform(tfidf)[0]
# L2 normalize — handle zero vectors from OOV text
norm = np.linalg.norm(svd_vec)
if norm > 1e-10:
svd_vec = svd_vec / norm
else:
# Text produced no recognized vocabulary — generate a
# deterministic low-magnitude vector from the text hash
# so the vector is non-zero but won't cluster with anything.
logger.warning(
"Text produced zero TF-IDF vector (no recognized vocabulary): "
"'%s'",
text[:80],
)
rng = np.random.RandomState(hash(text) % (2**31))
svd_vec = rng.randn(len(svd_vec))
svd_vec = svd_vec / np.linalg.norm(svd_vec)
# Pad to 512 if SVD produced fewer components
if len(svd_vec) < 512:
padded = np.zeros(512)
padded[: len(svd_vec)] = svd_vec
# Re-normalize after padding
pad_norm = np.linalg.norm(padded)
if pad_norm > 0:
padded = padded / pad_norm
svd_vec = padded
elapsed_ms = (time.perf_counter() - start) * 1000
logger.info(
"Embedded text (%d chars) → %d-dim vector in %.1fms",
len(text),
len(svd_vec),
elapsed_ms,
)
return svd_vec.tolist()
def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""Embed multiple texts at once.
Args:
texts: List of input texts.
Returns:
List of 512-dim L2-normalized float lists.
Raises:
ValueError: If any text is empty or whitespace-only.
"""
for i, text in enumerate(texts):
if not text or not text.strip():
raise ValueError(
f"Cannot embed empty or whitespace-only text at index {i}"
)
start = time.perf_counter()
self._fit_if_needed()
tfidf = self._vectorizer.transform(texts)
svd_vecs = self._svd.transform(tfidf)
results: list[list[float]] = []
for idx, vec in enumerate(svd_vecs):
norm = np.linalg.norm(vec)
if norm > 1e-10:
vec = vec / norm
else:
logger.warning(
"Text at index %d produced zero TF-IDF vector: '%s'",
idx,
texts[idx][:80],
)
rng = np.random.RandomState(hash(texts[idx]) % (2**31))
vec = rng.randn(len(vec))
vec = vec / np.linalg.norm(vec)
if len(vec) < 512:
padded = np.zeros(512)
padded[: len(vec)] = vec
pad_norm = np.linalg.norm(padded)
if pad_norm > 0:
padded = padded / pad_norm
vec = padded
results.append(vec.tolist())
elapsed_ms = (time.perf_counter() - start) * 1000
logger.info(
"Batch-embedded %d texts → %d-dim vectors in %.1fms",
len(texts),
512,
elapsed_ms,
)
return results
# Module-level singleton
embedding_service = EmbeddingService()
def embed_text(text: str) -> list[float]:
"""Embed a single text into a 512-dim normalized vector.
Convenience wrapper around the module-level EmbeddingService singleton.
"""
return embedding_service.embed_text(text)
def embed_batch(texts: list[str]) -> list[list[float]]:
"""Embed multiple texts into 512-dim normalized vectors.
Convenience wrapper around the module-level EmbeddingService singleton.
"""
return embedding_service.embed_batch(texts)

View file

@ -1,8 +1,15 @@
"""Fractafrag — Celery worker configuration."""
import logging
import time
import uuid as uuid_mod
from celery import Celery
import os
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
redis_url = os.environ.get("REDIS_URL", "redis://redis:6379/0")
celery_app = Celery(
@ -28,6 +35,19 @@ celery_app.conf.update(
celery_app.autodiscover_tasks(["app.worker"])
# ── Sync DB session factory for worker tasks ──────────────
def _get_sync_session_factory():
"""Lazy-init sync session factory using settings.database_url_sync."""
from app.config import get_settings
settings = get_settings()
engine = create_engine(settings.database_url_sync, pool_pre_ping=True)
return sessionmaker(bind=engine)
logger = logging.getLogger(__name__)
# ── Task Definitions ──────────────────────────────────────
@celery_app.task(name="render_shader", bind=True, max_retries=2)
@ -48,11 +68,77 @@ def embed_shader(self, shader_id: str):
pass
@celery_app.task(name="process_desire", bind=True)
@celery_app.task(name="process_desire", bind=True, max_retries=3)
def process_desire(self, desire_id: str):
"""Process a new desire: embed, cluster, optionally auto-fulfill. (Track G)"""
# TODO: Implement in Track G
pass
"""Process a new desire: embed text, store embedding, cluster, update heat.
Flow:
1. Load desire from DB by id
2. Embed prompt_text 512-dim vector
3. Store embedding on desire row
4. Run sync clustering (find nearest or create new cluster)
5. Commit all changes
On transient DB errors, retries up to 3 times with 30s backoff.
On success, logs desire_id, cluster_id, heat_score, and elapsed_ms.
On failure, desire keeps prompt_embedding=NULL and heat_score=1.0.
"""
start = time.perf_counter()
desire_uuid = uuid_mod.UUID(desire_id)
SessionFactory = _get_sync_session_factory()
session = SessionFactory()
try:
from app.models.models import Desire
from app.services.embedding import embed_text
from app.services.clustering import cluster_desire_sync
# 1. Load desire
desire = session.get(Desire, desire_uuid)
if desire is None:
logger.warning(
"process_desire: desire %s not found, skipping", desire_id
)
return
# 2. Embed prompt text
embedding = embed_text(desire.prompt_text)
# 3. Store embedding on desire
desire.prompt_embedding = embedding
session.flush()
# 4. Cluster
cluster_result = cluster_desire_sync(desire.id, embedding, session)
# 5. Commit
session.commit()
elapsed_ms = (time.perf_counter() - start) * 1000
logger.info(
"process_desire completed: desire_id=%s cluster_id=%s "
"is_new=%s heat_score=%.1f elapsed_ms=%.1f",
desire_id,
cluster_result["cluster_id"],
cluster_result["is_new"],
cluster_result["heat_score"],
elapsed_ms,
)
except Exception as exc:
session.rollback()
elapsed_ms = (time.perf_counter() - start) * 1000
logger.error(
"process_desire failed: desire_id=%s error=%s elapsed_ms=%.1f",
desire_id,
str(exc),
elapsed_ms,
)
raise self.retry(exc=exc, countdown=30)
finally:
session.close()
@celery_app.task(name="ai_generate", bind=True, max_retries=3)

View file

@ -24,6 +24,7 @@ dependencies = [
"python-multipart>=0.0.12",
"stripe>=11.0.0",
"numpy>=2.0.0",
"scikit-learn>=1.4",
]
[project.optional-dependencies]
@ -32,4 +33,5 @@ dev = [
"pytest-asyncio>=0.24.0",
"httpx>=0.28.0",
"ruff>=0.8.0",
"aiosqlite>=0.20.0",
]

View file

@ -0,0 +1 @@
"""Tests package for fractafrag-api."""

View file

@ -0,0 +1,190 @@
"""Pytest configuration and shared fixtures for fractafrag-api tests.
Integration test infrastructure:
- Async SQLite in-memory database (via aiosqlite)
- FastAPI test client with dependency overrides
- Auth dependency overrides (mock pro-tier user)
- Celery worker mock (process_desire.delay no-op)
Environment variables are set BEFORE any app.* imports to ensure
get_settings() picks up test values (database.py calls get_settings()
at module scope with @lru_cache).
"""
import os
import sys
import uuid
from pathlib import Path
from unittest.mock import MagicMock, patch
# ── 1. sys.path setup ─────────────────────────────────────
_api_root = str(Path(__file__).resolve().parent.parent)
if _api_root not in sys.path:
sys.path.insert(0, _api_root)
# ── 2. Set env vars BEFORE any app.* imports ──────────────
# We do NOT override DATABASE_URL — the module-level engine in database.py
# uses pool_size/max_overflow which are PostgreSQL-specific. The default
# PostgreSQL URL creates an engine that never actually connects (no queries
# hit it). Our integration tests override get_db with a test SQLite session.
# We only set dummy values for env vars that cause validation failures.
os.environ.setdefault("JWT_SECRET", "test-secret")
os.environ.setdefault("REDIS_URL", "redis://localhost:6379/0")
os.environ.setdefault("BYOK_MASTER_KEY", "test-master-key-0123456789abcdef")
# ── 3. Now safe to import app modules ─────────────────────
import pytest # noqa: E402
import pytest_asyncio # noqa: E402
from httpx import ASGITransport, AsyncClient # noqa: E402
from sqlalchemy import event, text # noqa: E402
from sqlalchemy.ext.asyncio import ( # noqa: E402
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.ext.compiler import compiles # noqa: E402
from pgvector.sqlalchemy import Vector # noqa: E402
from sqlalchemy.dialects.postgresql import UUID as PG_UUID, JSONB, ARRAY # noqa: E402
from app.database import Base, get_db # noqa: E402
from app.main import app # noqa: E402
from app.middleware.auth import get_current_user, require_tier # noqa: E402
from app.models.models import User # noqa: E402
# ── 4. SQLite type compilation overrides ──────────────────
# pgvector Vector, PostgreSQL UUID, JSONB, and ARRAY don't exist in SQLite.
# Register custom compilation rules so create_all() works.
@compiles(Vector, "sqlite")
def _compile_vector_sqlite(type_, compiler, **kw):
"""Render pgvector Vector as TEXT in SQLite."""
return "TEXT"
# Override PostgreSQL UUID to TEXT for SQLite
@compiles(PG_UUID, "sqlite")
def _compile_pg_uuid_sqlite(type_, compiler, **kw):
"""Render PostgreSQL UUID as TEXT in SQLite (standard UUID is fine, dialect-specific isn't)."""
return "TEXT"
# Override JSONB to TEXT for SQLite
@compiles(JSONB, "sqlite")
def _compile_jsonb_sqlite(type_, compiler, **kw):
"""Render JSONB as TEXT in SQLite."""
return "TEXT"
# Override ARRAY to TEXT for SQLite
@compiles(ARRAY, "sqlite")
def _compile_array_sqlite(type_, compiler, **kw):
"""Render PostgreSQL ARRAY as TEXT in SQLite."""
return "TEXT"
# Register Python uuid.UUID as a SQLite adapter so raw text() queries
# can bind UUID parameters without "type 'UUID' is not supported" errors.
# IMPORTANT: Use .hex (no hyphens) to match SQLAlchemy's UUID storage format in SQLite.
# Also register list adapter so ARRAY columns (compiled as TEXT in SQLite)
# can bind Python lists without "type 'list' is not supported" errors.
import json as _json # noqa: E402
import sqlite3 # noqa: E402
sqlite3.register_adapter(uuid.UUID, lambda u: u.hex)
sqlite3.register_adapter(list, lambda lst: _json.dumps(lst))
# ── 5. Test database engine and session fixtures ──────────
# Shared test user ID — consistent across all integration tests
TEST_USER_ID = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
@pytest_asyncio.fixture(scope="session")
async def db_engine():
"""Create an async SQLite engine and all tables. Session-scoped."""
engine = create_async_engine(
"sqlite+aiosqlite://",
echo=False,
# SQLite doesn't support pool_size/max_overflow
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest_asyncio.fixture
async def db_session(db_engine):
"""Yield a fresh AsyncSession per test. Rolls back after each test for isolation."""
session_factory = async_sessionmaker(
db_engine, class_=AsyncSession, expire_on_commit=False
)
async with session_factory() as session:
# Start a nested transaction so we can roll back after the test
async with session.begin():
yield session
# Rollback ensures test isolation — no committed state leaks between tests
await session.rollback()
# ── 6. Mock user fixture ─────────────────────────────────
@pytest.fixture
def test_user():
"""Return a mock User object for auth dependency overrides."""
user = MagicMock(spec=User)
user.id = TEST_USER_ID
user.username = "testuser"
user.email = "testuser@test.com"
user.role = "user"
user.subscription_tier = "pro"
user.is_system = False
user.trust_tier = "standard"
return user
# ── 7. FastAPI test client fixture ────────────────────────
@pytest_asyncio.fixture
async def client(db_session, test_user):
"""Async HTTP client wired to the FastAPI app with dependency overrides.
Overrides:
- get_db yields the test db_session
- get_current_user returns test_user (pro tier)
- require_tier returns test_user unconditionally (bypasses tier check)
- process_desire.delay no-op (prevents Celery/Redis connection)
"""
# Override get_db to yield test session
async def _override_get_db():
yield db_session
# Override get_current_user to return mock user
async def _override_get_current_user():
return test_user
app.dependency_overrides[get_db] = _override_get_db
app.dependency_overrides[get_current_user] = _override_get_current_user
# require_tier is a factory that returns inner functions depending on
# get_current_user. Since we override get_current_user to return a pro-tier
# user, the tier check inside require_tier will pass naturally.
# We still need to mock process_desire to prevent Celery/Redis connection.
with patch("app.worker.process_desire") as mock_task:
# process_desire.delay() should be a no-op
mock_task.delay = MagicMock(return_value=None)
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
# Clean up dependency overrides after test
app.dependency_overrides.clear()

View file

@ -0,0 +1,366 @@
"""Unit tests for the clustering service.
Tests use mocked async DB sessions to isolate clustering logic from
pgvector and database concerns. Synthetic 512-dim vectors verify the
service's orchestration, heat calculation, and threshold behavior.
"""
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.models.models import DesireCluster
from app.services.clustering import (
add_to_cluster,
cluster_desire,
create_cluster,
find_nearest_cluster,
recalculate_heat,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_embedding(dim: int = 512) -> list[float]:
"""Create a synthetic embedding vector for testing."""
import numpy as np
rng = np.random.default_rng(42)
vec = rng.standard_normal(dim)
vec = vec / np.linalg.norm(vec)
return vec.tolist()
def _mock_result_row(**kwargs):
"""Create a mock DB result row with named attributes."""
row = MagicMock()
for key, value in kwargs.items():
setattr(row, key, value)
return row
# ---------------------------------------------------------------------------
# Tests: cluster_desire orchestration
# ---------------------------------------------------------------------------
class TestClusterDesireOrchestration:
"""Test the main cluster_desire orchestrator with mocked sub-functions."""
@pytest.mark.asyncio
@patch("app.services.clustering.find_nearest_cluster", new_callable=AsyncMock)
@patch("app.services.clustering.create_cluster", new_callable=AsyncMock)
async def test_new_desire_creates_own_cluster(
self, mock_create, mock_find
) -> None:
"""When no nearby cluster exists, create a new one."""
new_cluster_id = uuid.uuid4()
desire_id = uuid.uuid4()
embedding = _make_embedding()
mock_find.return_value = (None, 0.0)
mock_create.return_value = new_cluster_id
db = AsyncMock()
result = await cluster_desire(desire_id, embedding, db)
mock_find.assert_awaited_once_with(embedding, db)
mock_create.assert_awaited_once_with(desire_id, db)
assert result["is_new"] is True
assert result["cluster_id"] == new_cluster_id
assert result["heat_score"] == 1.0
@pytest.mark.asyncio
@patch("app.services.clustering.find_nearest_cluster", new_callable=AsyncMock)
@patch("app.services.clustering.add_to_cluster", new_callable=AsyncMock)
@patch("app.services.clustering.recalculate_heat", new_callable=AsyncMock)
async def test_similar_desire_joins_existing_cluster(
self, mock_recalc, mock_add, mock_find
) -> None:
"""When a nearby cluster is found, join it and recalculate heat."""
existing_cluster_id = uuid.uuid4()
desire_id = uuid.uuid4()
embedding = _make_embedding()
similarity = 0.92
mock_find.return_value = (existing_cluster_id, similarity)
mock_recalc.return_value = 3.0
db = AsyncMock()
result = await cluster_desire(desire_id, embedding, db)
mock_find.assert_awaited_once_with(embedding, db)
mock_add.assert_awaited_once_with(
existing_cluster_id, desire_id, similarity, db
)
mock_recalc.assert_awaited_once_with(existing_cluster_id, db)
assert result["is_new"] is False
assert result["cluster_id"] == existing_cluster_id
assert result["heat_score"] == 3.0
@pytest.mark.asyncio
@patch("app.services.clustering.find_nearest_cluster", new_callable=AsyncMock)
@patch("app.services.clustering.create_cluster", new_callable=AsyncMock)
async def test_cluster_desire_returns_observability_dict(
self, mock_create, mock_find
) -> None:
"""Returned dict always has cluster_id, is_new, heat_score."""
cluster_id = uuid.uuid4()
mock_find.return_value = (None, 0.0)
mock_create.return_value = cluster_id
db = AsyncMock()
result = await cluster_desire(uuid.uuid4(), _make_embedding(), db)
assert "cluster_id" in result
assert "is_new" in result
assert "heat_score" in result
assert isinstance(result["cluster_id"], uuid.UUID)
assert isinstance(result["is_new"], bool)
assert isinstance(result["heat_score"], float)
# ---------------------------------------------------------------------------
# Tests: recalculate_heat
# ---------------------------------------------------------------------------
class TestRecalculateHeat:
"""Test heat score recalculation with mocked DB results."""
@pytest.mark.asyncio
async def test_heat_scales_with_cluster_size(self) -> None:
"""Heat score should equal cluster size (linear scaling)."""
cluster_id = uuid.uuid4()
db = AsyncMock()
# First call: COUNT(*) returns 3
count_result = MagicMock()
count_result.scalar_one.return_value = 3
# Second call: UPDATE (no return value needed)
update_result = MagicMock()
db.execute = AsyncMock(side_effect=[count_result, update_result])
heat = await recalculate_heat(cluster_id, db)
assert heat == 3.0
assert db.execute.await_count == 2
assert db.flush.await_count >= 1
@pytest.mark.asyncio
async def test_heat_for_single_member_cluster(self) -> None:
"""A single-member cluster should have heat_score = 1.0."""
cluster_id = uuid.uuid4()
db = AsyncMock()
count_result = MagicMock()
count_result.scalar_one.return_value = 1
update_result = MagicMock()
db.execute = AsyncMock(side_effect=[count_result, update_result])
heat = await recalculate_heat(cluster_id, db)
assert heat == 1.0
@pytest.mark.asyncio
async def test_heat_for_large_cluster(self) -> None:
"""Heat scales to large cluster sizes."""
cluster_id = uuid.uuid4()
db = AsyncMock()
count_result = MagicMock()
count_result.scalar_one.return_value = 15
update_result = MagicMock()
db.execute = AsyncMock(side_effect=[count_result, update_result])
heat = await recalculate_heat(cluster_id, db)
assert heat == 15.0
# ---------------------------------------------------------------------------
# Tests: find_nearest_cluster
# ---------------------------------------------------------------------------
class TestFindNearestCluster:
"""Test pgvector distance query with mocked DB results."""
@pytest.mark.asyncio
async def test_empty_db_returns_none(self) -> None:
"""No desires with embeddings → no cluster match."""
db = AsyncMock()
# Query returns no rows
empty_result = MagicMock()
empty_result.first.return_value = None
db.execute = AsyncMock(return_value=empty_result)
cluster_id, similarity = await find_nearest_cluster(
_make_embedding(), db
)
assert cluster_id is None
assert similarity == 0.0
@pytest.mark.asyncio
async def test_match_found_with_cluster(self) -> None:
"""A desire within threshold that has a cluster → returns cluster."""
desire_id = uuid.uuid4()
cluster_id = uuid.uuid4()
db = AsyncMock()
# First query: find nearest desire (distance = 0.08 → similarity = 0.92)
desire_row = _mock_result_row(desire_id=desire_id, distance=0.08)
desire_result = MagicMock()
desire_result.first.return_value = desire_row
# Second query: cluster lookup
cluster_row = _mock_result_row(cluster_id=cluster_id)
cluster_result = MagicMock()
cluster_result.first.return_value = cluster_row
db.execute = AsyncMock(side_effect=[desire_result, cluster_result])
found_id, sim = await find_nearest_cluster(_make_embedding(), db)
assert found_id == cluster_id
assert abs(sim - 0.92) < 1e-6
@pytest.mark.asyncio
async def test_match_found_without_cluster(self) -> None:
"""A nearby desire that has no cluster entry → returns None."""
desire_id = uuid.uuid4()
db = AsyncMock()
# First query: find nearest desire
desire_row = _mock_result_row(desire_id=desire_id, distance=0.10)
desire_result = MagicMock()
desire_result.first.return_value = desire_row
# Second query: cluster lookup returns nothing
cluster_result = MagicMock()
cluster_result.first.return_value = None
db.execute = AsyncMock(side_effect=[desire_result, cluster_result])
found_id, sim = await find_nearest_cluster(_make_embedding(), db)
assert found_id is None
assert sim == 0.0
@pytest.mark.asyncio
async def test_threshold_boundary_at_0_82(self) -> None:
"""Threshold of 0.82 means max distance of 0.18.
A desire at exactly distance=0.18 (similarity=0.82) should be
returned by the SQL query (distance <= 0.18).
"""
desire_id = uuid.uuid4()
cluster_id = uuid.uuid4()
db = AsyncMock()
# Exactly at boundary: distance = 0.18 → similarity = 0.82
desire_row = _mock_result_row(desire_id=desire_id, distance=0.18)
desire_result = MagicMock()
desire_result.first.return_value = desire_row
cluster_row = _mock_result_row(cluster_id=cluster_id)
cluster_result = MagicMock()
cluster_result.first.return_value = cluster_row
db.execute = AsyncMock(side_effect=[desire_result, cluster_result])
found_id, sim = await find_nearest_cluster(
_make_embedding(), db, threshold=0.82
)
assert found_id == cluster_id
assert abs(sim - 0.82) < 1e-6
@pytest.mark.asyncio
async def test_below_threshold_returns_none(self) -> None:
"""A desire beyond the distance threshold is not returned by SQL.
With threshold=0.82 (max_distance=0.18), a desire at distance=0.19
(similarity=0.81) would be filtered out by the WHERE clause.
The mock simulates this by returning no rows.
"""
db = AsyncMock()
# SQL filters it out → no rows
empty_result = MagicMock()
empty_result.first.return_value = None
db.execute = AsyncMock(return_value=empty_result)
found_id, sim = await find_nearest_cluster(
_make_embedding(), db, threshold=0.82
)
assert found_id is None
assert sim == 0.0
# ---------------------------------------------------------------------------
# Tests: create_cluster
# ---------------------------------------------------------------------------
class TestCreateCluster:
"""Test cluster creation."""
@pytest.mark.asyncio
async def test_create_cluster_returns_uuid(self) -> None:
"""New cluster gets a valid UUID."""
db = AsyncMock()
db.add = MagicMock() # Session.add() is synchronous
desire_id = uuid.uuid4()
cluster_id = await create_cluster(desire_id, db)
assert isinstance(cluster_id, uuid.UUID)
db.add.assert_called_once()
db.flush.assert_awaited_once()
@pytest.mark.asyncio
async def test_create_cluster_adds_desire_cluster_row(self) -> None:
"""The DesireCluster row has similarity=1.0 (self-reference)."""
db = AsyncMock()
db.add = MagicMock() # Session.add() is synchronous
desire_id = uuid.uuid4()
cluster_id = await create_cluster(desire_id, db)
added_obj = db.add.call_args[0][0]
assert isinstance(added_obj, DesireCluster)
assert added_obj.cluster_id == cluster_id
assert added_obj.desire_id == desire_id
assert added_obj.similarity == 1.0
# ---------------------------------------------------------------------------
# Tests: add_to_cluster
# ---------------------------------------------------------------------------
class TestAddToCluster:
"""Test adding a desire to an existing cluster."""
@pytest.mark.asyncio
async def test_add_to_cluster_executes_insert(self) -> None:
"""Insert is executed and flushed."""
db = AsyncMock()
cluster_id = uuid.uuid4()
desire_id = uuid.uuid4()
await add_to_cluster(cluster_id, desire_id, 0.91, db)
db.execute.assert_awaited_once()
db.flush.assert_awaited()
# Verify the parameters passed to execute
call_kwargs = db.execute.call_args[0][1]
assert call_kwargs["cluster_id"] == cluster_id
assert call_kwargs["desire_id"] == desire_id
assert call_kwargs["similarity"] == 0.91

View file

@ -0,0 +1,250 @@
"""Pipeline integration tests — embed → cluster → heat.
Proves the full desire processing pipeline works end-to-end by:
1. Verifying similar texts produce embeddings with cosine similarity above
the clustering threshold (0.82)
2. Verifying dissimilar texts stay below the clustering threshold
3. Validating heat calculation logic for clustered desires
4. Checking that the router and worker are wired correctly (static assertions)
"""
import uuid
from pathlib import Path
from unittest.mock import MagicMock
import numpy as np
import pytest
from app.services.embedding import embed_text
def cosine_sim(a: list[float], b: list[float]) -> float:
"""Cosine similarity between two L2-normalized vectors (= dot product)."""
return float(np.dot(a, b))
# ---------------------------------------------------------------------------
# Embedding pipeline: similar texts cluster, dissimilar texts don't
# ---------------------------------------------------------------------------
class TestSimilarDesiresClustering:
"""Verify that similar desire texts produce clusterable embeddings."""
SIMILAR_TEXTS = [
"ragdoll physics dark moody slow motion",
"dark physics ragdoll slow motion moody",
"slow motion ragdoll dark physics moody",
]
def test_similar_desires_produce_clusterable_embeddings(self) -> None:
"""All pairwise cosine similarities among similar texts exceed 0.82."""
embeddings = [embed_text(t) for t in self.SIMILAR_TEXTS]
for i in range(len(embeddings)):
for j in range(i + 1, len(embeddings)):
sim = cosine_sim(embeddings[i], embeddings[j])
assert sim > 0.82, (
f"Texts [{i}] and [{j}] should cluster (sim > 0.82), "
f"got {sim:.4f}:\n"
f" [{i}] '{self.SIMILAR_TEXTS[i]}'\n"
f" [{j}] '{self.SIMILAR_TEXTS[j]}'"
)
def test_dissimilar_desire_does_not_cluster(self) -> None:
"""A dissimilar text has cosine similarity < 0.82 with all similar texts."""
dissimilar = embed_text("bright colorful kaleidoscope flowers rainbow")
similar_embeddings = [embed_text(t) for t in self.SIMILAR_TEXTS]
for i, emb in enumerate(similar_embeddings):
sim = cosine_sim(dissimilar, emb)
assert sim < 0.82, (
f"Dissimilar text should NOT cluster with text [{i}] "
f"(sim < 0.82), got {sim:.4f}:\n"
f" dissimilar: 'bright colorful kaleidoscope flowers rainbow'\n"
f" similar[{i}]: '{self.SIMILAR_TEXTS[i]}'"
)
# ---------------------------------------------------------------------------
# Heat calculation logic
# ---------------------------------------------------------------------------
class TestPipelineHeatCalculation:
"""Verify heat score calculation matches cluster size."""
def test_pipeline_heat_calculation_logic(self) -> None:
"""A cluster of 3 desires should produce heat_score = 3.0 for each member.
This tests the recalculate_heat_sync logic by simulating its
DB interaction pattern with mocks.
"""
from app.services.clustering import recalculate_heat_sync
cluster_id = uuid.uuid4()
session = MagicMock()
# Mock COUNT(*) returning 3 members
count_result = MagicMock()
count_result.scalar_one.return_value = 3
# Mock UPDATE (no meaningful return)
update_result = MagicMock()
session.execute = MagicMock(side_effect=[count_result, update_result])
heat = recalculate_heat_sync(cluster_id, session)
assert heat == 3.0
assert session.execute.call_count == 2
assert session.flush.call_count >= 1
def test_single_member_cluster_has_heat_1(self) -> None:
"""A new single-member cluster should have heat_score = 1.0."""
from app.services.clustering import recalculate_heat_sync
cluster_id = uuid.uuid4()
session = MagicMock()
count_result = MagicMock()
count_result.scalar_one.return_value = 1
update_result = MagicMock()
session.execute = MagicMock(side_effect=[count_result, update_result])
heat = recalculate_heat_sync(cluster_id, session)
assert heat == 1.0
# ---------------------------------------------------------------------------
# Sync clustering orchestrator
# ---------------------------------------------------------------------------
class TestSyncClusteringOrchestrator:
"""Test cluster_desire_sync orchestration with mocked sub-functions."""
def test_new_desire_creates_cluster(self) -> None:
"""When no nearby cluster exists, creates a new one."""
from unittest.mock import patch
from app.services.clustering import cluster_desire_sync
desire_id = uuid.uuid4()
embedding = embed_text("ragdoll physics dark moody slow")
new_cluster_id = uuid.uuid4()
session = MagicMock()
with patch("app.services.clustering.find_nearest_cluster_sync") as mock_find, \
patch("app.services.clustering.create_cluster_sync") as mock_create:
mock_find.return_value = (None, 0.0)
mock_create.return_value = new_cluster_id
result = cluster_desire_sync(desire_id, embedding, session)
assert result["is_new"] is True
assert result["cluster_id"] == new_cluster_id
assert result["heat_score"] == 1.0
mock_find.assert_called_once_with(embedding, session)
mock_create.assert_called_once_with(desire_id, session)
def test_similar_desire_joins_existing_cluster(self) -> None:
"""When a nearby cluster is found, joins it and recalculates heat."""
from unittest.mock import patch
from app.services.clustering import cluster_desire_sync
desire_id = uuid.uuid4()
embedding = embed_text("ragdoll physics dark moody slow")
existing_cluster_id = uuid.uuid4()
session = MagicMock()
with patch("app.services.clustering.find_nearest_cluster_sync") as mock_find, \
patch("app.services.clustering.add_to_cluster_sync") as mock_add, \
patch("app.services.clustering.recalculate_heat_sync") as mock_recalc:
mock_find.return_value = (existing_cluster_id, 0.91)
mock_recalc.return_value = 3.0
result = cluster_desire_sync(desire_id, embedding, session)
assert result["is_new"] is False
assert result["cluster_id"] == existing_cluster_id
assert result["heat_score"] == 3.0
mock_add.assert_called_once_with(
existing_cluster_id, desire_id, 0.91, session
)
# ---------------------------------------------------------------------------
# Wiring checks: router + worker are connected
# ---------------------------------------------------------------------------
class TestWiring:
"""Static assertions that the router and worker are properly wired."""
def test_router_has_worker_enqueue(self) -> None:
"""desires.py contains process_desire.delay — fire-and-forget enqueue."""
desires_path = (
Path(__file__).resolve().parent.parent
/ "app"
/ "routers"
/ "desires.py"
)
source = desires_path.read_text()
assert "process_desire.delay" in source, (
"Router should call process_desire.delay() to enqueue worker task"
)
def test_worker_task_is_implemented(self) -> None:
"""process_desire task body is not just 'pass' — has real implementation.
Reads the worker source file directly to avoid importing celery
(which may not be installed in the test environment).
"""
worker_path = (
Path(__file__).resolve().parent.parent
/ "app"
/ "worker"
/ "__init__.py"
)
source = worker_path.read_text()
# Should contain key implementation markers
assert "embed_text" in source, (
"Worker should call embed_text to embed desire prompt"
)
assert "cluster_desire_sync" in source, (
"Worker should call cluster_desire_sync to cluster the desire"
)
assert "session.commit" in source, (
"Worker should commit the DB transaction"
)
def test_worker_has_structured_logging(self) -> None:
"""process_desire task includes structured logging of key fields."""
worker_path = (
Path(__file__).resolve().parent.parent
/ "app"
/ "worker"
/ "__init__.py"
)
source = worker_path.read_text()
assert "desire_id" in source, "Should log desire_id"
assert "cluster_id" in source, "Should log cluster_id"
assert "heat_score" in source, "Should log heat_score"
assert "elapsed_ms" in source, "Should log elapsed_ms"
def test_worker_has_error_handling_with_retry(self) -> None:
"""process_desire catches exceptions and retries."""
worker_path = (
Path(__file__).resolve().parent.parent
/ "app"
/ "worker"
/ "__init__.py"
)
source = worker_path.read_text()
assert "self.retry" in source, (
"Worker should use self.retry for transient error handling"
)
assert "session.rollback" in source, (
"Worker should rollback on error before retrying"
)

View file

@ -0,0 +1,137 @@
"""Unit tests for the text embedding service.
Validates that TF-IDF + TruncatedSVD produces 512-dim L2-normalized vectors
with meaningful cosine similarity for shader/visual-art domain text.
"""
import numpy as np
import pytest
from app.services.embedding import EmbeddingService, embed_text
def cosine_sim(a: list[float], b: list[float]) -> float:
"""Compute cosine similarity between two vectors.
Since our vectors are already L2-normalized, this is just the dot product.
"""
return float(np.dot(a, b))
class TestEmbedDimension:
"""Verify output vector dimensions."""
def test_embed_produces_512_dim_vector(self) -> None:
result = embed_text("particle system fluid simulation")
assert len(result) == 512, f"Expected 512 dims, got {len(result)}"
def test_embed_returns_list_of_floats(self) -> None:
result = embed_text("fractal noise pattern")
assert isinstance(result, list)
assert all(isinstance(x, float) for x in result)
class TestNormalization:
"""Verify L2 normalization of output vectors."""
def test_embed_vectors_are_normalized(self) -> None:
result = embed_text("raymarching distance field shapes")
norm = np.linalg.norm(result)
assert abs(norm - 1.0) < 1e-6, f"Expected norm ≈ 1.0, got {norm}"
def test_various_inputs_all_normalized(self) -> None:
texts = [
"short",
"a much longer description of a complex visual effect with many words",
"ragdoll physics dark moody atmosphere simulation",
]
for text in texts:
result = embed_text(text)
norm = np.linalg.norm(result)
assert abs(norm - 1.0) < 1e-6, (
f"Norm for '{text}' = {norm}, expected ≈ 1.0"
)
class TestSimilarity:
"""Verify semantic similarity properties of the embeddings."""
def test_similar_texts_have_high_cosine_similarity(self) -> None:
a = embed_text("ragdoll physics dark and slow")
b = embed_text("dark physics simulation ragdoll")
sim = cosine_sim(a, b)
assert sim > 0.8, (
f"Similar texts should have >0.8 cosine sim, got {sim:.4f}"
)
def test_dissimilar_texts_have_low_cosine_similarity(self) -> None:
a = embed_text("ragdoll physics dark")
b = embed_text("bright colorful kaleidoscope flowers")
sim = cosine_sim(a, b)
assert sim < 0.5, (
f"Dissimilar texts should have <0.5 cosine sim, got {sim:.4f}"
)
def test_identical_texts_have_perfect_similarity(self) -> None:
text = "procedural noise fractal generation"
a = embed_text(text)
b = embed_text(text)
sim = cosine_sim(a, b)
assert sim > 0.999, (
f"Identical texts should have ~1.0 cosine sim, got {sim:.4f}"
)
class TestBatch:
"""Verify batch embedding matches individual embeddings."""
def test_embed_batch_matches_individual(self) -> None:
texts = [
"particle system fluid",
"ragdoll physics dark moody",
"kaleidoscope symmetry rotation",
]
# Fresh service to ensure deterministic results
service = EmbeddingService()
individual = [service.embed_text(t) for t in texts]
# Reset and do batch
service2 = EmbeddingService()
batched = service2.embed_batch(texts)
assert len(batched) == len(individual)
for i, (ind, bat) in enumerate(zip(individual, batched)):
sim = cosine_sim(ind, bat)
assert sim > 0.999, (
f"Batch result {i} doesn't match individual: sim={sim:.6f}"
)
def test_batch_dimensions(self) -> None:
texts = ["fire smoke volumetric", "crystal refraction light"]
results = EmbeddingService().embed_batch(texts)
assert len(results) == 2
for vec in results:
assert len(vec) == 512
class TestErrorHandling:
"""Verify clear error messages on invalid input."""
def test_empty_string_raises_valueerror(self) -> None:
with pytest.raises(ValueError, match="empty or whitespace"):
embed_text("")
def test_whitespace_only_raises_valueerror(self) -> None:
with pytest.raises(ValueError, match="empty or whitespace"):
embed_text(" \n\t ")
def test_batch_with_empty_string_raises_valueerror(self) -> None:
service = EmbeddingService()
with pytest.raises(ValueError, match="empty or whitespace"):
service.embed_batch(["valid text", ""])
def test_batch_with_whitespace_raises_valueerror(self) -> None:
service = EmbeddingService()
with pytest.raises(ValueError, match="empty or whitespace"):
service.embed_batch([" ", "valid text"])

View file

@ -0,0 +1,294 @@
"""Tests for desire fulfillment endpoint and cluster_count annotation.
Covers:
- fulfill_desire endpoint: happy path, not-found, not-open, shader validation
(tested via source assertions since FastAPI isn't in the test environment)
- cluster_count annotation: batch query pattern, single desire query
- Schema field: cluster_count exists in DesirePublic
Approach: Per K005, router functions can't be imported without FastAPI installed.
We verify correctness through:
1. Source-level structure assertions (endpoint wiring, imports, validation logic)
2. Isolated logic unit tests (annotation loop, status transitions)
3. Schema field verification via Pydantic model introspection
"""
import uuid
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import MagicMock
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _router_source() -> str:
"""Read the desires router source code."""
return (
Path(__file__).resolve().parent.parent
/ "app"
/ "routers"
/ "desires.py"
).read_text(encoding="utf-8")
def _schema_source() -> str:
"""Read the schemas source code."""
return (
Path(__file__).resolve().parent.parent
/ "app"
/ "schemas"
/ "schemas.py"
).read_text(encoding="utf-8")
def _make_mock_desire(
*,
desire_id=None,
status="open",
heat_score=1.0,
):
"""Create a mock object simulating a Desire ORM instance."""
d = MagicMock()
d.id = desire_id or uuid.uuid4()
d.status = status
d.heat_score = heat_score
d.cluster_count = 0 # default before annotation
return d
# ---------------------------------------------------------------------------
# fulfill_desire — happy path structure
# ---------------------------------------------------------------------------
class TestFulfillHappyPath:
"""Verify the fulfill endpoint's happy-path logic via source analysis."""
def test_fulfill_sets_status_to_fulfilled(self):
"""The endpoint sets desire.status = 'fulfilled' on success."""
source = _router_source()
assert 'desire.status = "fulfilled"' in source
def test_fulfill_sets_fulfilled_by_shader(self):
"""The endpoint sets desire.fulfilled_by_shader = shader_id."""
source = _router_source()
assert "desire.fulfilled_by_shader = shader_id" in source
def test_fulfill_sets_fulfilled_at_timestamp(self):
"""The endpoint sets desire.fulfilled_at to current UTC time."""
source = _router_source()
assert "desire.fulfilled_at" in source
assert "datetime.now(timezone.utc)" in source
def test_fulfill_returns_status_response(self):
"""The endpoint returns a dict with status, desire_id, shader_id."""
source = _router_source()
assert '"status": "fulfilled"' in source
assert '"desire_id"' in source
assert '"shader_id"' in source
# ---------------------------------------------------------------------------
# fulfill_desire — error paths
# ---------------------------------------------------------------------------
class TestFulfillDesireNotFound:
"""404 when desire doesn't exist."""
def test_desire_not_found_raises_404(self):
source = _router_source()
# After desire lookup, checks scalar_one_or_none result
assert "Desire not found" in source
class TestFulfillDesireNotOpen:
"""400 when desire is not in 'open' status."""
def test_desire_not_open_check_exists(self):
source = _router_source()
assert 'desire.status != "open"' in source
def test_desire_not_open_error_message(self):
source = _router_source()
assert "Desire is not open" in source
class TestFulfillShaderNotFound:
"""404 when shader_id doesn't match any shader."""
def test_shader_lookup_exists(self):
source = _router_source()
assert "select(Shader).where(Shader.id == shader_id)" in source
def test_shader_not_found_raises_404(self):
source = _router_source()
assert "Shader not found" in source
class TestFulfillShaderNotPublished:
"""400 when shader status is not 'published'."""
def test_shader_status_validation(self):
source = _router_source()
assert 'shader.status != "published"' in source
def test_shader_not_published_error_message(self):
source = _router_source()
assert "Shader must be published to fulfill a desire" in source
# ---------------------------------------------------------------------------
# cluster_count annotation — logic unit tests
# ---------------------------------------------------------------------------
class TestClusterCountAnnotation:
"""Verify cluster_count annotation logic patterns."""
def test_list_desires_has_batch_cluster_query(self):
"""list_desires uses a batch query with ANY(:desire_ids)."""
source = _router_source()
assert "ANY(:desire_ids)" in source
assert "desire_clusters dc1" in source
assert "desire_clusters dc2" in source
def test_list_desires_avoids_n_plus_1(self):
"""Cluster counts are fetched in a single batch, not per-desire."""
source = _router_source()
# The pattern: build dict from batch query, then loop to annotate
assert "cluster_counts = {" in source
assert "cluster_counts.get(d.id, 0)" in source
def test_list_desires_skips_cluster_query_when_empty(self):
"""When no desires are returned, cluster query is skipped."""
source = _router_source()
assert "if desire_ids:" in source
def test_get_desire_annotates_single_cluster_count(self):
"""get_desire runs a cluster count query for the single desire."""
source = _router_source()
# Should have a cluster query scoped to a single desire_id
assert "WHERE dc1.desire_id = :desire_id" in source
def test_annotation_loop_sets_default_zero(self):
"""Desires without cluster entries default to cluster_count = 0."""
source = _router_source()
assert "cluster_counts.get(d.id, 0)" in source
def test_annotation_loop_logic(self):
"""Unit test: the annotation loop correctly maps cluster counts to desires."""
# Simulate the annotation loop from list_desires
d1 = _make_mock_desire()
d2 = _make_mock_desire()
d3 = _make_mock_desire()
desires = [d1, d2, d3]
# Simulate cluster query result: d1 has 3 in cluster, d3 has 2
cluster_counts = {d1.id: 3, d3.id: 2}
# This is the exact logic from the router
for d in desires:
d.cluster_count = cluster_counts.get(d.id, 0)
assert d1.cluster_count == 3
assert d2.cluster_count == 0 # not in any cluster
assert d3.cluster_count == 2
def test_get_desire_cluster_count_fallback(self):
"""get_desire sets cluster_count=0 when no cluster row exists."""
source = _router_source()
# The router checks `row[0] if row else 0`
assert "row[0] if row else 0" in source
# ---------------------------------------------------------------------------
# Schema field verification
# ---------------------------------------------------------------------------
class TestDesirePublicSchema:
"""Verify DesirePublic schema has the cluster_count field."""
def test_cluster_count_field_in_schema_source(self):
"""DesirePublic schema source contains cluster_count field."""
source = _schema_source()
assert "cluster_count" in source
def test_cluster_count_default_zero(self):
"""cluster_count defaults to 0 in the schema."""
source = _schema_source()
assert "cluster_count: int = 0" in source
def test_schema_from_attributes_enabled(self):
"""DesirePublic uses from_attributes=True for ORM compatibility."""
source = _schema_source()
# Find the DesirePublic class section
desire_public_idx = source.index("class DesirePublic")
desire_public_section = source[desire_public_idx:desire_public_idx + 200]
assert "from_attributes=True" in desire_public_section
def test_cluster_count_pydantic_model(self):
"""DesirePublic schema has cluster_count as an int field with default 0."""
source = _schema_source()
# Find the DesirePublic class and verify cluster_count is between
# heat_score and fulfilled_by_shader (correct field ordering)
desire_idx = source.index("class DesirePublic")
desire_section = source[desire_idx:desire_idx + 500]
heat_pos = desire_section.index("heat_score")
cluster_pos = desire_section.index("cluster_count")
fulfilled_pos = desire_section.index("fulfilled_by_shader")
assert heat_pos < cluster_pos < fulfilled_pos, (
"cluster_count should be between heat_score and fulfilled_by_shader"
)
# ---------------------------------------------------------------------------
# Wiring assertions
# ---------------------------------------------------------------------------
class TestFulfillmentWiring:
"""Structural assertions that the router is properly wired."""
def test_router_imports_shader_model(self):
"""desires.py imports Shader for shader validation."""
source = _router_source()
assert "Shader" in source.split("\n")[8] # near top imports
def test_router_imports_text_from_sqlalchemy(self):
"""desires.py imports text from sqlalchemy for raw SQL."""
source = _router_source()
assert "from sqlalchemy import" in source
assert "text" in source
def test_fulfill_endpoint_requires_auth(self):
"""fulfill_desire uses get_current_user dependency."""
source = _router_source()
# Find the fulfill_desire function
fulfill_idx = source.index("async def fulfill_desire")
fulfill_section = source[fulfill_idx:fulfill_idx + 500]
assert "get_current_user" in fulfill_section
def test_fulfill_endpoint_takes_shader_id_param(self):
"""fulfill_desire accepts shader_id as a query parameter."""
source = _router_source()
fulfill_idx = source.index("async def fulfill_desire")
fulfill_section = source[fulfill_idx:fulfill_idx + 300]
assert "shader_id" in fulfill_section
def test_list_desires_returns_desire_public(self):
"""list_desires endpoint uses DesirePublic response model."""
source = _router_source()
assert "response_model=list[DesirePublic]" in source
def test_get_desire_returns_desire_public(self):
"""get_desire endpoint uses DesirePublic response model."""
source = _router_source()
# Find the get_desire endpoint specifically
lines = source.split("\n")
for i, line in enumerate(lines):
if "async def get_desire" in line:
# Check the decorator line above
decorator_line = lines[i - 1]
assert "response_model=DesirePublic" in decorator_line
break

View file

@ -0,0 +1,412 @@
"""Integration tests — end-to-end acceptance scenarios through FastAPI.
Uses async SQLite test database, real FastAPI endpoint handlers,
and dependency overrides for auth and Celery worker.
Test classes:
TestInfrastructureSmoke proves test infra works (T01)
TestClusteringScenario clustering + heat elevation via API (T02)
TestFulfillmentScenario desire fulfillment lifecycle (T02)
TestMCPFieldPassthrough MCP tool field passthrough (T02, source-level)
"""
import inspect
import json
import uuid
from pathlib import Path
import pytest
from httpx import AsyncClient
from sqlalchemy import select, update
# ── Smoke Test: proves integration infrastructure works ───
class TestInfrastructureSmoke:
"""Verify that the integration test infrastructure (DB, client, auth, Celery mock) works."""
@pytest.mark.asyncio
async def test_create_and_read_desire(self, client: AsyncClient):
"""POST a desire, then GET it back — proves DB, serialization, auth override, and Celery mock."""
# Create a desire
response = await client.post(
"/api/v1/desires",
json={"prompt_text": "glowing neon wireframe city"},
)
assert response.status_code == 201, f"Expected 201, got {response.status_code}: {response.text}"
data = response.json()
desire_id = data["id"]
assert data["prompt_text"] == "glowing neon wireframe city"
assert data["status"] == "open"
# Read it back
response = await client.get(f"/api/v1/desires/{desire_id}")
assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}"
data = response.json()
assert data["id"] == desire_id
assert data["prompt_text"] == "glowing neon wireframe city"
assert data["heat_score"] == 1.0
assert data["cluster_count"] == 0
# ── Clustering Scenario ──────────────────────────────────────
class TestClusteringScenario:
"""Prove that clustered desires have elevated heat and cluster_count via the API.
Strategy: POST desires through the API, then directly insert DesireCluster
rows and update heat_score in the test DB (simulating what the Celery worker
pipeline does). Verify via GET /api/v1/desires/{id} that the API returns
correct heat_score and cluster_count.
Note: list_desires uses PostgreSQL ANY(:desire_ids) which doesn't work in
SQLite, so we verify via individual GET requests instead.
"""
@pytest.mark.asyncio
async def test_similar_desires_cluster_and_elevate_heat(
self, client: AsyncClient, db_session
):
"""Create 3 desires, cluster them, elevate heat, verify API returns correct data."""
from app.models.models import Desire, DesireCluster
# Create 3 desires via API
desire_ids = []
prompts = [
"neon fractal explosion in deep space",
"colorful fractal burst cosmic background",
"glowing fractal nova against dark stars",
]
for prompt in prompts:
resp = await client.post(
"/api/v1/desires", json={"prompt_text": prompt}
)
assert resp.status_code == 201, f"Create failed: {resp.text}"
desire_ids.append(resp.json()["id"])
# Simulate clustering: insert DesireCluster rows linking all 3 to one cluster
cluster_id = uuid.uuid4()
for did in desire_ids:
dc = DesireCluster(
cluster_id=cluster_id,
desire_id=uuid.UUID(did),
similarity=0.88,
)
db_session.add(dc)
# Simulate heat recalculation: update heat_score on all 3 desires
for did in desire_ids:
await db_session.execute(
update(Desire)
.where(Desire.id == uuid.UUID(did))
.values(heat_score=3.0)
)
await db_session.flush()
# Verify each desire via GET shows correct heat_score and cluster_count
for did in desire_ids:
resp = await client.get(f"/api/v1/desires/{did}")
assert resp.status_code == 200, f"GET {did} failed: {resp.text}"
data = resp.json()
assert data["heat_score"] >= 3.0, (
f"Desire {did} heat_score={data['heat_score']}, expected >= 3.0"
)
assert data["cluster_count"] >= 3, (
f"Desire {did} cluster_count={data['cluster_count']}, expected >= 3"
)
@pytest.mark.asyncio
async def test_lone_desire_has_default_heat(self, client: AsyncClient):
"""A single desire without clustering has heat_score=1.0 and cluster_count=0."""
resp = await client.post(
"/api/v1/desires",
json={"prompt_text": "unique standalone art concept"},
)
assert resp.status_code == 201
desire_id = resp.json()["id"]
resp = await client.get(f"/api/v1/desires/{desire_id}")
assert resp.status_code == 200
data = resp.json()
assert data["heat_score"] == 1.0, f"Expected heat_score=1.0, got {data['heat_score']}"
assert data["cluster_count"] == 0, f"Expected cluster_count=0, got {data['cluster_count']}"
@pytest.mark.asyncio
async def test_desires_sorted_by_heat_descending(
self, client: AsyncClient, db_session
):
"""When fetching desires, high-heat desires appear before low-heat ones.
Uses individual GET since list_desires relies on PostgreSQL ANY().
Verifies the ordering guarantee via direct heat_score comparison.
"""
from app.models.models import Desire, DesireCluster
# Create a "hot" desire and cluster it
hot_resp = await client.post(
"/api/v1/desires",
json={"prompt_text": "blazing hot fractal vortex"},
)
assert hot_resp.status_code == 201
hot_id = hot_resp.json()["id"]
# Simulate clustering for hot desire
cluster_id = uuid.uuid4()
dc = DesireCluster(
cluster_id=cluster_id,
desire_id=uuid.UUID(hot_id),
similarity=0.90,
)
db_session.add(dc)
await db_session.execute(
update(Desire)
.where(Desire.id == uuid.UUID(hot_id))
.values(heat_score=5.0)
)
await db_session.flush()
# Create a "cold" desire (no clustering)
cold_resp = await client.post(
"/api/v1/desires",
json={"prompt_text": "calm minimal zen garden"},
)
assert cold_resp.status_code == 201
cold_id = cold_resp.json()["id"]
# Verify hot desire has higher heat than cold
hot_data = (await client.get(f"/api/v1/desires/{hot_id}")).json()
cold_data = (await client.get(f"/api/v1/desires/{cold_id}")).json()
assert hot_data["heat_score"] > cold_data["heat_score"], (
f"Hot ({hot_data['heat_score']}) should be > Cold ({cold_data['heat_score']})"
)
# ── Fulfillment Scenario ─────────────────────────────────────
class TestFulfillmentScenario:
"""Prove desire fulfillment transitions status and links to a shader."""
@pytest.mark.asyncio
async def test_fulfill_desire_transitions_status(
self, client: AsyncClient, db_session
):
"""Create desire, insert published shader, fulfill, verify status transition."""
from app.models.models import Shader
# Create desire
resp = await client.post(
"/api/v1/desires",
json={"prompt_text": "ethereal particle waterfall"},
)
assert resp.status_code == 201
desire_id = resp.json()["id"]
# Insert a published shader directly in test DB
shader_id = uuid.uuid4()
shader = Shader(
id=shader_id,
title="Particle Waterfall",
glsl_code="void mainImage(out vec4 c, in vec2 f) { c = vec4(0); }",
status="published",
author_id=None,
)
db_session.add(shader)
await db_session.flush()
# Fulfill the desire
resp = await client.post(
f"/api/v1/desires/{desire_id}/fulfill",
params={"shader_id": str(shader_id)},
)
assert resp.status_code == 200, f"Fulfill failed: {resp.text}"
data = resp.json()
assert data["status"] == "fulfilled"
assert data["desire_id"] == desire_id
assert data["shader_id"] == str(shader_id)
# Verify read-back shows fulfilled status and linked shader
resp = await client.get(f"/api/v1/desires/{desire_id}")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "fulfilled"
assert data["fulfilled_by_shader"] == str(shader_id)
@pytest.mark.asyncio
async def test_fulfill_requires_published_shader(
self, client: AsyncClient, db_session
):
"""Fulfilling with a draft shader returns 400."""
from app.models.models import Shader
# Create desire
resp = await client.post(
"/api/v1/desires",
json={"prompt_text": "glitch art mosaic pattern"},
)
assert resp.status_code == 201
desire_id = resp.json()["id"]
# Insert a draft shader
shader_id = uuid.uuid4()
shader = Shader(
id=shader_id,
title="Draft Mosaic",
glsl_code="void mainImage(out vec4 c, in vec2 f) { c = vec4(1); }",
status="draft",
author_id=None,
)
db_session.add(shader)
await db_session.flush()
# Attempt fulfill — should fail
resp = await client.post(
f"/api/v1/desires/{desire_id}/fulfill",
params={"shader_id": str(shader_id)},
)
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}: {resp.text}"
assert "published" in resp.json()["detail"].lower()
@pytest.mark.asyncio
async def test_fulfill_already_fulfilled_returns_400(
self, client: AsyncClient, db_session
):
"""Fulfilling an already-fulfilled desire returns 400."""
from app.models.models import Shader
# Create desire
resp = await client.post(
"/api/v1/desires",
json={"prompt_text": "recursive mirror tunnel"},
)
assert resp.status_code == 201
desire_id = resp.json()["id"]
# Insert published shader
shader_id = uuid.uuid4()
shader = Shader(
id=shader_id,
title="Mirror Tunnel",
glsl_code="void mainImage(out vec4 c, in vec2 f) { c = vec4(0.5); }",
status="published",
author_id=None,
)
db_session.add(shader)
await db_session.flush()
# First fulfill — should succeed
resp = await client.post(
f"/api/v1/desires/{desire_id}/fulfill",
params={"shader_id": str(shader_id)},
)
assert resp.status_code == 200
# Second fulfill — should fail
resp = await client.post(
f"/api/v1/desires/{desire_id}/fulfill",
params={"shader_id": str(shader_id)},
)
assert resp.status_code == 400, f"Expected 400, got {resp.status_code}: {resp.text}"
assert "not open" in resp.json()["detail"].lower()
# ── MCP Field Passthrough (source-level) ─────────────────────
class TestMCPFieldPassthrough:
"""Verify MCP server tools pass through all required fields via source inspection.
The MCP server runs as a separate process and can't be tested through
FastAPI TestClient. These tests verify the source code structure to ensure
field passthrough is correct.
"""
@classmethod
def _read_mcp_server_source(cls) -> str:
"""Read the MCP server source file."""
# From services/api/tests/ → up 3 to services/ → mcp/server.py
mcp_path = Path(__file__).resolve().parent.parent.parent / "mcp" / "server.py"
assert mcp_path.exists(), f"MCP server.py not found at {mcp_path}"
return mcp_path.read_text()
def test_get_desire_queue_includes_cluster_fields(self):
"""get_desire_queue maps cluster_count, heat_score, style_hints, fulfilled_by_shader."""
source = self._read_mcp_server_source()
# Verify get_desire_queue function exists
assert "async def get_desire_queue" in source, "get_desire_queue function not found"
# Extract the function body (from def to next @mcp or end)
fn_start = source.index("async def get_desire_queue")
# Find next top-level decorator or end of file
next_decorator = source.find("\n@mcp.", fn_start + 1)
if next_decorator == -1:
fn_body = source[fn_start:]
else:
fn_body = source[fn_start:next_decorator]
required_fields = ["cluster_count", "heat_score", "style_hints", "fulfilled_by_shader"]
for field in required_fields:
assert field in fn_body, (
f"get_desire_queue missing field '{field}' in response mapping"
)
def test_fulfill_desire_tool_exists(self):
"""fulfill_desire function exists and uses api_post_with_params."""
source = self._read_mcp_server_source()
assert "async def fulfill_desire" in source, "fulfill_desire function not found"
# Extract function body
fn_start = source.index("async def fulfill_desire")
next_decorator = source.find("\n@mcp.", fn_start + 1)
if next_decorator == -1:
fn_body = source[fn_start:]
else:
fn_body = source[fn_start:next_decorator]
assert "api_post_with_params" in fn_body, (
"fulfill_desire should call api_post_with_params"
)
def test_fulfill_desire_returns_structured_response(self):
"""fulfill_desire returns JSON with status, desire_id, shader_id."""
source = self._read_mcp_server_source()
fn_start = source.index("async def fulfill_desire")
next_decorator = source.find("\n@mcp.", fn_start + 1)
if next_decorator == -1:
fn_body = source[fn_start:]
else:
fn_body = source[fn_start:next_decorator]
# Check the success-path return contains the required fields
required_keys = ['"status"', '"desire_id"', '"shader_id"']
for key in required_keys:
assert key in fn_body, (
f"fulfill_desire response missing key {key}"
)
def test_submit_shader_accepts_fulfills_desire_id(self):
"""submit_shader accepts fulfills_desire_id parameter and passes it to the API."""
source = self._read_mcp_server_source()
assert "async def submit_shader" in source, "submit_shader function not found"
fn_start = source.index("async def submit_shader")
next_decorator = source.find("\n@mcp.", fn_start + 1)
if next_decorator == -1:
fn_body = source[fn_start:]
else:
fn_body = source[fn_start:next_decorator]
# Verify parameter exists in function signature
assert "fulfills_desire_id" in fn_body, (
"submit_shader should accept fulfills_desire_id parameter"
)
# Verify it's passed to the payload
assert 'payload["fulfills_desire_id"]' in fn_body or \
'"fulfills_desire_id"' in fn_body, (
"submit_shader should include fulfills_desire_id in the API payload"
)

View file

@ -46,6 +46,11 @@ export default function Bounties() {
<span className="flex items-center gap-1">
🔥 Heat: {desire.heat_score.toFixed(1)}
</span>
{desire.cluster_count > 1 && (
<span className="text-purple-400">
👥 {desire.cluster_count} similar
</span>
)}
{desire.tip_amount_cents > 0 && (
<span className="text-green-400">
💰 ${(desire.tip_amount_cents / 100).toFixed(2)} tip

View file

@ -49,6 +49,11 @@ export default function BountyDetail() {
<h1 className="text-xl font-bold">{desire.prompt_text}</h1>
<div className="flex items-center gap-3 mt-3 text-sm text-gray-500">
<span>🔥 Heat: {desire.heat_score.toFixed(1)}</span>
{desire.cluster_count > 1 && (
<span className="text-purple-400">
👥 {desire.cluster_count} similar
</span>
)}
{desire.tip_amount_cents > 0 && (
<span className="text-green-400">
💰 ${(desire.tip_amount_cents / 100).toFixed(2)} bounty
@ -77,7 +82,7 @@ export default function BountyDetail() {
{desire.status === 'open' && (
<div className="mt-6 pt-4 border-t border-surface-3">
<Link to="/editor" className="btn-primary">
<Link to={`/editor?fulfill=${desire.id}`} className="btn-primary">
Fulfill this Desire
</Link>
<p className="text-xs text-gray-500 mt-2">

View file

@ -9,7 +9,7 @@
*/
import { useState, useEffect, useCallback, useRef } from 'react';
import { useParams, useNavigate } from 'react-router-dom';
import { useParams, useNavigate, useSearchParams } from 'react-router-dom';
import { useQuery } from '@tanstack/react-query';
import api from '@/lib/api';
import { useAuthStore } from '@/stores/auth';
@ -38,8 +38,13 @@ void mainImage(out vec4 fragColor, in vec2 fragCoord) {
export default function Editor() {
const { id } = useParams<{ id: string }>();
const navigate = useNavigate();
const [searchParams] = useSearchParams();
const { isAuthenticated, user } = useAuthStore();
// Fulfillment context — read once from URL, persist in ref so it survives navigation
const fulfillId = searchParams.get('fulfill');
const fulfillDesireId = useRef(fulfillId);
const [code, setCode] = useState(DEFAULT_SHADER);
const [liveCode, setLiveCode] = useState(DEFAULT_SHADER);
const [title, setTitle] = useState('Untitled Shader');
@ -70,6 +75,16 @@ export default function Editor() {
enabled: !!id,
});
// Fetch desire context when fulfilling
const { data: fulfillDesire } = useQuery({
queryKey: ['desire', fulfillDesireId.current],
queryFn: async () => {
const { data } = await api.get(`/desires/${fulfillDesireId.current}`);
return data;
},
enabled: !!fulfillDesireId.current,
});
useEffect(() => {
if (existingShader) {
setCode(existingShader.glsl_code);
@ -160,7 +175,10 @@ export default function Editor() {
}
} else {
// Create new shader
const { data } = await api.post('/shaders', payload);
const { data } = await api.post('/shaders', {
...payload,
fulfills_desire_id: fulfillDesireId.current || undefined,
});
if (publishStatus === 'published') {
navigate(`/shader/${data.id}`);
} else {
@ -278,6 +296,17 @@ export default function Editor() {
</div>
)}
{/* Desire fulfillment context banner */}
{fulfillDesire && (
<div className="px-4 py-3 bg-amber-600/10 border-b border-amber-600/20 flex items-center gap-3">
<span className="text-amber-400 text-sm font-medium">🎯 Fulfilling desire:</span>
<span className="text-gray-300 text-sm flex-1">{fulfillDesire.prompt_text}</span>
{fulfillDesire.style_hints && (
<span className="text-xs text-gray-500">Style hints available</span>
)}
</div>
)}
{/* Split pane: editor + drag handle + preview */}
<div ref={containerRef} className="flex-1 flex min-h-0">
{/* Code editor */}

View file

@ -41,6 +41,14 @@ async def api_post(path: str, data: dict):
return resp.json()
async def api_post_with_params(path: str, params: dict):
"""POST with query parameters (not JSON body). Used for endpoints like fulfill."""
async with httpx.AsyncClient(base_url=API_BASE, timeout=15.0) as client:
resp = await client.post(f"/api/v1{path}", params=params, headers=INTERNAL_AUTH)
resp.raise_for_status()
return resp.json()
async def api_put(path: str, data: dict):
async with httpx.AsyncClient(base_url=API_BASE, timeout=15.0) as client:
resp = await client.put(f"/api/v1{path}", json=data, headers=INTERNAL_AUTH)
@ -121,7 +129,8 @@ async def get_shader_version_code(shader_id: str, version_number: int) -> str:
@mcp.tool()
async def submit_shader(title: str, glsl_code: str, description: str = "", tags: str = "",
shader_type: str = "2d", status: str = "published") -> str:
shader_type: str = "2d", status: str = "published",
fulfills_desire_id: str = "") -> str:
"""Submit a new GLSL shader to Fractafrag.
Shader format: void mainImage(out vec4 fragColor, in vec2 fragCoord)
@ -134,11 +143,15 @@ async def submit_shader(title: str, glsl_code: str, description: str = "", tags:
tags: Comma-separated tags (e.g. "fractal,noise,colorful")
shader_type: 2d, 3d, or audio-reactive
status: "published" to go live, "draft" to save privately
fulfills_desire_id: Optional UUID of a desire this shader fulfills
"""
tag_list = [t.strip() for t in tags.split(",") if t.strip()] if tags else []
result = await api_post("/shaders", {"title": title, "glsl_code": glsl_code,
"description": description, "tags": tag_list,
"shader_type": shader_type, "status": status})
payload = {"title": title, "glsl_code": glsl_code,
"description": description, "tags": tag_list,
"shader_type": shader_type, "status": status}
if fulfills_desire_id:
payload["fulfills_desire_id"] = fulfills_desire_id
result = await api_post("/shaders", payload)
return json.dumps({"id": result["id"], "title": result["title"],
"status": result.get("status"), "current_version": result.get("current_version", 1),
"message": f"Shader '{result['title']}' created.", "url": f"/shader/{result['id']}"})
@ -210,7 +223,11 @@ async def get_similar_shaders(shader_id: str, limit: int = 10) -> str:
@mcp.tool()
async def get_desire_queue(min_heat: float = 0, limit: int = 10) -> str:
"""Get open shader desires/bounties. These are community requests.
"""Get open shader desires/bounties with cluster context and style hints.
Returns community requests ranked by heat. Use cluster_count to identify
high-demand desires (many similar requests). Use style_hints to understand
the visual direction requested.
Args:
min_heat: Minimum heat score (higher = more demand)
@ -220,11 +237,42 @@ async def get_desire_queue(min_heat: float = 0, limit: int = 10) -> str:
return json.dumps({"count": len(desires),
"desires": [{"id": d["id"], "prompt_text": d["prompt_text"],
"heat_score": d.get("heat_score", 0),
"cluster_count": d.get("cluster_count", 0),
"style_hints": d.get("style_hints"),
"tip_amount_cents": d.get("tip_amount_cents", 0),
"status": d.get("status")}
"status": d.get("status"),
"fulfilled_by_shader": d.get("fulfilled_by_shader")}
for d in desires]})
@mcp.tool()
async def fulfill_desire(desire_id: str, shader_id: str) -> str:
"""Mark a desire as fulfilled by linking it to a published shader.
The shader must be published. The desire must be open.
Use get_desire_queue to find open desires, then submit_shader or
use an existing shader ID to fulfill one.
Args:
desire_id: UUID of the desire to fulfill
shader_id: UUID of the published shader that fulfills this desire
"""
try:
result = await api_post_with_params(
f"/desires/{desire_id}/fulfill",
{"shader_id": shader_id}
)
return json.dumps({"status": "fulfilled", "desire_id": desire_id,
"shader_id": shader_id,
"message": f"Desire {desire_id} fulfilled by shader {shader_id}."})
except httpx.HTTPStatusError as e:
try:
error_detail = e.response.json().get("detail", str(e))
except Exception:
error_detail = str(e)
return json.dumps({"error": error_detail, "status_code": e.response.status_code})
@mcp.resource("fractafrag://platform-info")
def platform_info() -> str:
"""Platform overview and shader writing guidelines."""