"""Tests for the preset system — loader, GET /engine/presets, and preset-driven trace.""" import json import cv2 import numpy as np import pytest from fastapi.testclient import TestClient from main import app from presets.loader import all_presets, get_preset, preset_names, reload, resolve_params client = TestClient(app) def _make_test_png(width: int = 100, height: int = 100) -> bytes: """Create a simple test PNG with a white rectangle on black background.""" img = np.zeros((height, width, 3), dtype=np.uint8) cv2.rectangle(img, (20, 20), (80, 80), (255, 255, 255), -1) ok, buf = cv2.imencode(".png", img) assert ok return buf.tobytes() @pytest.fixture def test_png() -> bytes: return _make_test_png() # ----------------------------------------------------------------------- # Preset Loader Unit Tests # ----------------------------------------------------------------------- class TestPresetLoader: """Tests for the preset loading and resolution logic.""" def test_all_presets_loads_five(self): reload() presets = all_presets() assert len(presets) == 5 assert set(presets.keys()) == {"sign", "patch", "stencil", "detailed", "custom"} def test_preset_names_sorted(self): reload() names = preset_names() assert names == sorted(names) assert "sign" in names assert "custom" in names def test_get_preset_returns_config(self): reload() sign = get_preset("sign") assert sign is not None assert sign["name"] == "sign" assert "preprocessing" in sign assert "vectorization" in sign assert "postprocessing" in sign def test_get_preset_unknown_returns_none(self): reload() assert get_preset("nonexistent") is None def test_sign_preset_has_aggressive_simplification(self): reload() sign = get_preset("sign") assert sign["postprocessing"]["epsilon"] > 1.0 assert sign["preprocessing"]["morph_kernel_size"] >= 5 def test_detailed_preset_has_low_simplification(self): reload() detailed = get_preset("detailed") assert detailed["postprocessing"]["epsilon"] < 1.0 assert detailed["vectorization"]["potrace"]["turdsize"] <= 2 def test_stencil_preset_has_manual_threshold(self): reload() stencil = get_preset("stencil") assert stencil["preprocessing"]["threshold_manual"] is not None def test_custom_preset_has_empty_params(self): reload() custom = get_preset("custom") assert custom["preprocessing"] == {} assert custom["vectorization"]["potrace"] == {} def test_each_preset_has_description(self): reload() for name, config in all_presets().items(): assert "description" in config, f"Preset {name} missing description" assert len(config["description"]) > 0 # ----------------------------------------------------------------------- # Preset Resolution Tests # ----------------------------------------------------------------------- class TestPresetResolution: """Tests for resolve_params merging logic.""" def test_resolve_uses_preset_defaults(self): reload() resolved = resolve_params("sign") assert resolved["vectorization_mode"] == "potrace" assert resolved["postprocessing"]["epsilon"] == 2.5 assert resolved["preprocessing"]["morph_kernel_size"] == 5 def test_resolve_user_override_epsilon(self): reload() resolved = resolve_params("sign", {"epsilon": 0.5}) assert resolved["postprocessing"]["epsilon"] == 0.5 def test_resolve_user_override_mode(self): reload() resolved = resolve_params("sign", {"mode": "vtracer"}) assert resolved["vectorization_mode"] == "vtracer" def test_resolve_user_override_vectorizer_param(self): reload() resolved = resolve_params("sign", {"turdsize": 99}) assert resolved["vectorizer_params"]["turdsize"] == 99 def test_resolve_custom_preset_falls_through(self): reload() resolved = resolve_params("custom", {"epsilon": 3.0, "turdsize": 5}) assert resolved["postprocessing"]["epsilon"] == 3.0 def test_resolve_unknown_preset_uses_user_params(self): resolved = resolve_params("nonexistent", {"mode": "vtracer", "epsilon": 2.0}) assert resolved["vectorization_mode"] == "vtracer" assert resolved["postprocessing"]["epsilon"] == 2.0 # ----------------------------------------------------------------------- # GET /engine/presets Endpoint Tests # ----------------------------------------------------------------------- class TestPresetsEndpoint: """Tests for the GET /engine/presets endpoint.""" def test_get_presets_returns_all(self): resp = client.get("/engine/presets") assert resp.status_code == 200 body = resp.json() assert "presets" in body presets = body["presets"] assert len(presets) == 5 assert "sign" in presets assert "patch" in presets assert "stencil" in presets assert "detailed" in presets assert "custom" in presets def test_preset_structure(self): resp = client.get("/engine/presets") body = resp.json() for name, config in body["presets"].items(): assert "name" in config assert "description" in config assert "preprocessing" in config assert "vectorization" in config assert "postprocessing" in config # ----------------------------------------------------------------------- # Preset-Driven Trace Endpoint Tests # ----------------------------------------------------------------------- class TestPresetTrace: """Tests for /engine/trace with preset selection.""" def test_trace_with_sign_preset(self, test_png): resp = client.post( "/engine/trace", files={"file": ("test.png", test_png, "image/png")}, data={"preset": "sign"}, ) assert resp.status_code == 200 body = resp.json() assert body["format"] == "svg" assert "