diff --git a/engine/pipeline/vectorize.py b/engine/pipeline/vectorize.py index 4158718..04c3c40 100644 --- a/engine/pipeline/vectorize.py +++ b/engine/pipeline/vectorize.py @@ -1,7 +1,9 @@ """Vectorization pipeline — converts preprocessed binary images to SVG.""" +import cv2 import numpy as np import potrace +import vtracer def potrace_trace( @@ -43,6 +45,70 @@ def potrace_trace( return _path_to_svg(path, w, h) +def vtracer_trace( + img: np.ndarray, + colormode: str = "binary", + hierarchical: str = "stacked", + filter_speckle: int = 4, + color_precision: int = 6, + layer_difference: int = 16, + corner_threshold: int = 60, + length_threshold: float = 4.0, + splice_threshold: int = 45, + mode: str = "spline", + path_precision: int | None = None, + max_iterations: int = 10, +) -> str: + """Trace an image using VTracer and return an SVG string. + + Unlike potrace_trace, this accepts both grayscale and color images. + Internally encodes the image as PNG and passes it to VTracer's Rust backend. + + Args: + img: 2D (grayscale) or 3D (BGR/BGRA) numpy array. + colormode: 'color' or 'binary'. + hierarchical: 'stacked' or 'cutout'. + filter_speckle: Remove patches smaller than this area (in px). + color_precision: Number of significant bits for color quantization (1-8). + layer_difference: Delta threshold for color layer grouping. + corner_threshold: Angle (degrees) below which a point is a corner. + length_threshold: Minimum segment length before simplification. + splice_threshold: Angle (degrees) for splicing splines. + mode: 'spline', 'polygon', or 'none' — curve fitting strategy. + path_precision: Decimal precision for path coordinates. + max_iterations: Max curve-fitting iterations. + + Returns: + SVG string (includes XML declaration and generator comment). + """ + if img.ndim not in (2, 3): + raise ValueError(f"Expected 2D or 3D image, got {img.ndim}D (shape {img.shape})") + + # Encode as PNG — VTracer accepts raw image bytes via convert_raw_image_to_svg. + ok, buf = cv2.imencode(".png", img) + if not ok: + raise RuntimeError("Failed to encode image as PNG for VTracer") + + # Build kwargs, omitting None values so VTracer uses its defaults. + kwargs: dict = dict( + img_format="png", + colormode=colormode, + hierarchical=hierarchical, + filter_speckle=filter_speckle, + color_precision=color_precision, + layer_difference=layer_difference, + corner_threshold=corner_threshold, + length_threshold=length_threshold, + splice_threshold=splice_threshold, + mode=mode, + max_iterations=max_iterations, + ) + if path_precision is not None: + kwargs["path_precision"] = path_precision + + return vtracer.convert_raw_image_to_svg(buf.tobytes(), **kwargs) + + def _path_to_svg(path, width: int, height: int) -> str: """Convert a potrace Path object to an SVG string.""" parts = [] diff --git a/engine/tests/test_vectorize.py b/engine/tests/test_vectorize.py index f55cc71..a8994f0 100644 --- a/engine/tests/test_vectorize.py +++ b/engine/tests/test_vectorize.py @@ -1,11 +1,11 @@ -"""Tests for the Potrace vectorization module.""" +"""Tests for the vectorization module (Potrace + VTracer).""" import xml.etree.ElementTree as ET import numpy as np import pytest -from pipeline.vectorize import potrace_trace +from pipeline.vectorize import potrace_trace, vtracer_trace # --------------------------------------------------------------------------- @@ -212,3 +212,181 @@ class TestPotracePreprocessingIntegration: ns = {"svg": "http://www.w3.org/2000/svg"} d = root.find("svg:path", ns).get("d", "") assert "M" in d + + +# =========================================================================== +# VTracer tests +# =========================================================================== + +def _make_color_image(size: int = 100) -> np.ndarray: + """Create a BGR image with a colored rectangle.""" + img = np.zeros((size, size, 3), dtype=np.uint8) + img[20:80, 20:80] = [0, 0, 255] # Red rectangle in BGR + return img + + +class TestVtracerBasicOutput: + def test_returns_string(self): + svg = vtracer_trace(_make_square()) + assert isinstance(svg, str) + + def test_svg_is_well_formed_xml(self): + svg = vtracer_trace(_make_square()) + root = _parse_svg(svg) + assert root.tag == "{http://www.w3.org/2000/svg}svg" + + def test_svg_has_correct_dimensions(self): + svg = vtracer_trace(_make_square(size=200)) + root = _parse_svg(svg) + assert root.get("width") == "200" + assert root.get("height") == "200" + + def test_svg_contains_path_element(self): + svg = vtracer_trace(_make_square()) + root = _parse_svg(svg) + ns = {"svg": "http://www.w3.org/2000/svg"} + paths = root.findall("svg:path", ns) + assert len(paths) >= 1 + + def test_path_has_d_attribute(self): + svg = vtracer_trace(_make_square()) + root = _parse_svg(svg) + ns = {"svg": "http://www.w3.org/2000/svg"} + path_el = root.find("svg:path", ns) + d = path_el.get("d", "") + assert len(d) > 0 + + +class TestVtracerInputTypes: + def test_accepts_grayscale_2d(self): + """VTracer should accept 2D grayscale images.""" + img = _make_square() + svg = vtracer_trace(img) + root = _parse_svg(svg) + assert root.tag == "{http://www.w3.org/2000/svg}svg" + + def test_accepts_bgr_3d(self): + """VTracer should accept 3-channel BGR images.""" + img = _make_color_image() + svg = vtracer_trace(img) + root = _parse_svg(svg) + assert root.tag == "{http://www.w3.org/2000/svg}svg" + + def test_accepts_bgra_4_channel(self): + """VTracer should accept 4-channel BGRA images.""" + import cv2 + bgr = _make_color_image() + bgra = cv2.cvtColor(bgr, cv2.COLOR_BGR2BGRA) + svg = vtracer_trace(bgra) + root = _parse_svg(svg) + assert root.tag == "{http://www.w3.org/2000/svg}svg" + + def test_rejects_4d_input(self): + with pytest.raises(ValueError, match="2D or 3D"): + vtracer_trace(np.zeros((10, 10, 3, 2), dtype=np.uint8)) + + def test_rectangular_image(self): + img = np.zeros((60, 120), dtype=np.uint8) + img[10:50, 10:110] = 255 + svg = vtracer_trace(img) + root = _parse_svg(svg) + assert root.get("width") == "120" + assert root.get("height") == "60" + + +class TestVtracerParams: + def test_color_mode(self): + """Color mode should produce SVG with color fill attributes.""" + img = _make_color_image() + svg = vtracer_trace(img, colormode="color") + root = _parse_svg(svg) + ns = {"svg": "http://www.w3.org/2000/svg"} + paths = root.findall("svg:path", ns) + assert len(paths) >= 1 + + def test_filter_speckle_removes_noise(self): + """High filter_speckle should remove tiny features.""" + img = np.zeros((100, 100), dtype=np.uint8) + img[50, 50] = 255 # single pixel + svg_strict = vtracer_trace(img, filter_speckle=100) + svg_loose = vtracer_trace(img, filter_speckle=0) + # Strict filtering should produce fewer or no path elements + root_strict = _parse_svg(svg_strict) + root_loose = _parse_svg(svg_loose) + ns = {"svg": "http://www.w3.org/2000/svg"} + paths_strict = root_strict.findall("svg:path", ns) + paths_loose = root_loose.findall("svg:path", ns) + # Either fewer paths or shorter path data with strict filtering + d_strict = "".join(p.get("d", "") for p in paths_strict) + d_loose = "".join(p.get("d", "") for p in paths_loose) + assert len(d_strict) <= len(d_loose) + + def test_polygon_mode(self): + """Polygon mode should produce L commands instead of curves.""" + svg = vtracer_trace(_make_square(), mode="polygon") + root = _parse_svg(svg) + ns = {"svg": "http://www.w3.org/2000/svg"} + path_el = root.find("svg:path", ns) + d = path_el.get("d", "") + assert "L" in d or "l" in d + + def test_spline_mode_default(self): + """Default spline mode should produce C (curve) commands for curved shapes.""" + svg = vtracer_trace(_make_circle(), mode="spline") + root = _parse_svg(svg) + ns = {"svg": "http://www.w3.org/2000/svg"} + path_el = root.find("svg:path", ns) + d = path_el.get("d", "") + assert "C" in d or "c" in d + + +class TestVtracerEdgeCases: + def test_all_black_image(self): + img = np.zeros((50, 50), dtype=np.uint8) + svg = vtracer_trace(img) + root = _parse_svg(svg) + assert root.tag == "{http://www.w3.org/2000/svg}svg" + + def test_all_white_image(self): + img = np.ones((50, 50), dtype=np.uint8) * 255 + svg = vtracer_trace(img) + root = _parse_svg(svg) + assert root.tag == "{http://www.w3.org/2000/svg}svg" + + +class TestVtracerPreprocessingIntegration: + def test_accepts_thresholded_image(self): + """Output of preprocessing pipeline is valid input for vtracer_trace.""" + from pipeline.preprocessing import threshold, to_grayscale + import cv2 + + img = np.zeros((100, 100, 3), dtype=np.uint8) + cv2.rectangle(img, (20, 20), (80, 80), (255, 255, 255), -1) + gray = to_grayscale(img) + binary = threshold(gray) + svg = vtracer_trace(binary) + root = _parse_svg(svg) + assert root.tag == "{http://www.w3.org/2000/svg}svg" + ns = {"svg": "http://www.w3.org/2000/svg"} + paths = root.findall("svg:path", ns) + assert len(paths) >= 1 + + +class TestVtracerVsPotraceComparison: + def test_both_produce_valid_svg_from_same_input(self): + """Both backends should produce valid SVG from the same binary image.""" + img = _make_square() + svg_potrace = potrace_trace(img) + svg_vtracer = vtracer_trace(img) + root_p = _parse_svg(svg_potrace) + root_v = _parse_svg(svg_vtracer) + assert root_p.tag == "{http://www.w3.org/2000/svg}svg" + assert root_v.tag == "{http://www.w3.org/2000/svg}svg" + + def test_outputs_differ(self): + """Potrace and VTracer should produce different SVG for the same input.""" + img = _make_square() + svg_potrace = potrace_trace(img) + svg_vtracer = vtracer_trace(img) + # They use different algorithms so output should differ + assert svg_potrace != svg_vtracer