test: Implemented vtracer_trace() function that converts grayscale or c…

- engine/pipeline/vectorize.py
- engine/tests/test_vectorize.py

GSD-Task: S01/T04
This commit is contained in:
jlightner 2026-03-26 04:18:31 +00:00
parent 136a9417f9
commit b33e883a6b
2 changed files with 246 additions and 2 deletions

View file

@ -1,7 +1,9 @@
"""Vectorization pipeline — converts preprocessed binary images to SVG."""
import cv2
import numpy as np
import potrace
import vtracer
def potrace_trace(
@ -43,6 +45,70 @@ def potrace_trace(
return _path_to_svg(path, w, h)
def vtracer_trace(
img: np.ndarray,
colormode: str = "binary",
hierarchical: str = "stacked",
filter_speckle: int = 4,
color_precision: int = 6,
layer_difference: int = 16,
corner_threshold: int = 60,
length_threshold: float = 4.0,
splice_threshold: int = 45,
mode: str = "spline",
path_precision: int | None = None,
max_iterations: int = 10,
) -> str:
"""Trace an image using VTracer and return an SVG string.
Unlike potrace_trace, this accepts both grayscale and color images.
Internally encodes the image as PNG and passes it to VTracer's Rust backend.
Args:
img: 2D (grayscale) or 3D (BGR/BGRA) numpy array.
colormode: 'color' or 'binary'.
hierarchical: 'stacked' or 'cutout'.
filter_speckle: Remove patches smaller than this area (in px).
color_precision: Number of significant bits for color quantization (1-8).
layer_difference: Delta threshold for color layer grouping.
corner_threshold: Angle (degrees) below which a point is a corner.
length_threshold: Minimum segment length before simplification.
splice_threshold: Angle (degrees) for splicing splines.
mode: 'spline', 'polygon', or 'none' curve fitting strategy.
path_precision: Decimal precision for path coordinates.
max_iterations: Max curve-fitting iterations.
Returns:
SVG string (includes XML declaration and generator comment).
"""
if img.ndim not in (2, 3):
raise ValueError(f"Expected 2D or 3D image, got {img.ndim}D (shape {img.shape})")
# Encode as PNG — VTracer accepts raw image bytes via convert_raw_image_to_svg.
ok, buf = cv2.imencode(".png", img)
if not ok:
raise RuntimeError("Failed to encode image as PNG for VTracer")
# Build kwargs, omitting None values so VTracer uses its defaults.
kwargs: dict = dict(
img_format="png",
colormode=colormode,
hierarchical=hierarchical,
filter_speckle=filter_speckle,
color_precision=color_precision,
layer_difference=layer_difference,
corner_threshold=corner_threshold,
length_threshold=length_threshold,
splice_threshold=splice_threshold,
mode=mode,
max_iterations=max_iterations,
)
if path_precision is not None:
kwargs["path_precision"] = path_precision
return vtracer.convert_raw_image_to_svg(buf.tobytes(), **kwargs)
def _path_to_svg(path, width: int, height: int) -> str:
"""Convert a potrace Path object to an SVG string."""
parts = []

View file

@ -1,11 +1,11 @@
"""Tests for the Potrace vectorization module."""
"""Tests for the vectorization module (Potrace + VTracer)."""
import xml.etree.ElementTree as ET
import numpy as np
import pytest
from pipeline.vectorize import potrace_trace
from pipeline.vectorize import potrace_trace, vtracer_trace
# ---------------------------------------------------------------------------
@ -212,3 +212,181 @@ class TestPotracePreprocessingIntegration:
ns = {"svg": "http://www.w3.org/2000/svg"}
d = root.find("svg:path", ns).get("d", "")
assert "M" in d
# ===========================================================================
# VTracer tests
# ===========================================================================
def _make_color_image(size: int = 100) -> np.ndarray:
"""Create a BGR image with a colored rectangle."""
img = np.zeros((size, size, 3), dtype=np.uint8)
img[20:80, 20:80] = [0, 0, 255] # Red rectangle in BGR
return img
class TestVtracerBasicOutput:
def test_returns_string(self):
svg = vtracer_trace(_make_square())
assert isinstance(svg, str)
def test_svg_is_well_formed_xml(self):
svg = vtracer_trace(_make_square())
root = _parse_svg(svg)
assert root.tag == "{http://www.w3.org/2000/svg}svg"
def test_svg_has_correct_dimensions(self):
svg = vtracer_trace(_make_square(size=200))
root = _parse_svg(svg)
assert root.get("width") == "200"
assert root.get("height") == "200"
def test_svg_contains_path_element(self):
svg = vtracer_trace(_make_square())
root = _parse_svg(svg)
ns = {"svg": "http://www.w3.org/2000/svg"}
paths = root.findall("svg:path", ns)
assert len(paths) >= 1
def test_path_has_d_attribute(self):
svg = vtracer_trace(_make_square())
root = _parse_svg(svg)
ns = {"svg": "http://www.w3.org/2000/svg"}
path_el = root.find("svg:path", ns)
d = path_el.get("d", "")
assert len(d) > 0
class TestVtracerInputTypes:
def test_accepts_grayscale_2d(self):
"""VTracer should accept 2D grayscale images."""
img = _make_square()
svg = vtracer_trace(img)
root = _parse_svg(svg)
assert root.tag == "{http://www.w3.org/2000/svg}svg"
def test_accepts_bgr_3d(self):
"""VTracer should accept 3-channel BGR images."""
img = _make_color_image()
svg = vtracer_trace(img)
root = _parse_svg(svg)
assert root.tag == "{http://www.w3.org/2000/svg}svg"
def test_accepts_bgra_4_channel(self):
"""VTracer should accept 4-channel BGRA images."""
import cv2
bgr = _make_color_image()
bgra = cv2.cvtColor(bgr, cv2.COLOR_BGR2BGRA)
svg = vtracer_trace(bgra)
root = _parse_svg(svg)
assert root.tag == "{http://www.w3.org/2000/svg}svg"
def test_rejects_4d_input(self):
with pytest.raises(ValueError, match="2D or 3D"):
vtracer_trace(np.zeros((10, 10, 3, 2), dtype=np.uint8))
def test_rectangular_image(self):
img = np.zeros((60, 120), dtype=np.uint8)
img[10:50, 10:110] = 255
svg = vtracer_trace(img)
root = _parse_svg(svg)
assert root.get("width") == "120"
assert root.get("height") == "60"
class TestVtracerParams:
def test_color_mode(self):
"""Color mode should produce SVG with color fill attributes."""
img = _make_color_image()
svg = vtracer_trace(img, colormode="color")
root = _parse_svg(svg)
ns = {"svg": "http://www.w3.org/2000/svg"}
paths = root.findall("svg:path", ns)
assert len(paths) >= 1
def test_filter_speckle_removes_noise(self):
"""High filter_speckle should remove tiny features."""
img = np.zeros((100, 100), dtype=np.uint8)
img[50, 50] = 255 # single pixel
svg_strict = vtracer_trace(img, filter_speckle=100)
svg_loose = vtracer_trace(img, filter_speckle=0)
# Strict filtering should produce fewer or no path elements
root_strict = _parse_svg(svg_strict)
root_loose = _parse_svg(svg_loose)
ns = {"svg": "http://www.w3.org/2000/svg"}
paths_strict = root_strict.findall("svg:path", ns)
paths_loose = root_loose.findall("svg:path", ns)
# Either fewer paths or shorter path data with strict filtering
d_strict = "".join(p.get("d", "") for p in paths_strict)
d_loose = "".join(p.get("d", "") for p in paths_loose)
assert len(d_strict) <= len(d_loose)
def test_polygon_mode(self):
"""Polygon mode should produce L commands instead of curves."""
svg = vtracer_trace(_make_square(), mode="polygon")
root = _parse_svg(svg)
ns = {"svg": "http://www.w3.org/2000/svg"}
path_el = root.find("svg:path", ns)
d = path_el.get("d", "")
assert "L" in d or "l" in d
def test_spline_mode_default(self):
"""Default spline mode should produce C (curve) commands for curved shapes."""
svg = vtracer_trace(_make_circle(), mode="spline")
root = _parse_svg(svg)
ns = {"svg": "http://www.w3.org/2000/svg"}
path_el = root.find("svg:path", ns)
d = path_el.get("d", "")
assert "C" in d or "c" in d
class TestVtracerEdgeCases:
def test_all_black_image(self):
img = np.zeros((50, 50), dtype=np.uint8)
svg = vtracer_trace(img)
root = _parse_svg(svg)
assert root.tag == "{http://www.w3.org/2000/svg}svg"
def test_all_white_image(self):
img = np.ones((50, 50), dtype=np.uint8) * 255
svg = vtracer_trace(img)
root = _parse_svg(svg)
assert root.tag == "{http://www.w3.org/2000/svg}svg"
class TestVtracerPreprocessingIntegration:
def test_accepts_thresholded_image(self):
"""Output of preprocessing pipeline is valid input for vtracer_trace."""
from pipeline.preprocessing import threshold, to_grayscale
import cv2
img = np.zeros((100, 100, 3), dtype=np.uint8)
cv2.rectangle(img, (20, 20), (80, 80), (255, 255, 255), -1)
gray = to_grayscale(img)
binary = threshold(gray)
svg = vtracer_trace(binary)
root = _parse_svg(svg)
assert root.tag == "{http://www.w3.org/2000/svg}svg"
ns = {"svg": "http://www.w3.org/2000/svg"}
paths = root.findall("svg:path", ns)
assert len(paths) >= 1
class TestVtracerVsPotraceComparison:
def test_both_produce_valid_svg_from_same_input(self):
"""Both backends should produce valid SVG from the same binary image."""
img = _make_square()
svg_potrace = potrace_trace(img)
svg_vtracer = vtracer_trace(img)
root_p = _parse_svg(svg_potrace)
root_v = _parse_svg(svg_vtracer)
assert root_p.tag == "{http://www.w3.org/2000/svg}svg"
assert root_v.tag == "{http://www.w3.org/2000/svg}svg"
def test_outputs_differ(self):
"""Potrace and VTracer should produce different SVG for the same input."""
img = _make_square()
svg_potrace = potrace_trace(img)
svg_vtracer = vtracer_trace(img)
# They use different algorithms so output should differ
assert svg_potrace != svg_vtracer