From 136a9417f97abb6162e66fbbb3a244b1342c4e87 Mon Sep 17 00:00:00 2001 From: jlightner Date: Thu, 26 Mar 2026 04:15:01 +0000 Subject: [PATCH] =?UTF-8?q?test:=20Implemented=20potrace=5Ftrace()=20funct?= =?UTF-8?q?ion=20that=20converts=20preprocessed=20b=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - engine/pipeline/vectorize.py - engine/tests/test_vectorize.py GSD-Task: S01/T03 --- engine/pipeline/vectorize.py | 74 ++++++++++++ engine/tests/test_vectorize.py | 214 +++++++++++++++++++++++++++++++++ 2 files changed, 288 insertions(+) create mode 100644 engine/pipeline/vectorize.py create mode 100644 engine/tests/test_vectorize.py diff --git a/engine/pipeline/vectorize.py b/engine/pipeline/vectorize.py new file mode 100644 index 0000000..4158718 --- /dev/null +++ b/engine/pipeline/vectorize.py @@ -0,0 +1,74 @@ +"""Vectorization pipeline — converts preprocessed binary images to SVG.""" + +import numpy as np +import potrace + + +def potrace_trace( + binary_img: np.ndarray, + turdsize: int = 2, + alphamax: float = 1.0, + opticurve: bool = True, + opttolerance: float = 0.2, +) -> str: + """Trace a binary image using Potrace and return an SVG string. + + Args: + binary_img: 2D numpy array — nonzero pixels are foreground. + turdsize: Despeckle threshold; curves with enclosed area below this are removed. + alphamax: Corner detection threshold (0.0 = polygon, 1.3333 = no corners). + opticurve: Whether to optimize curves by reducing Bezier segments. + opttolerance: Tolerance for curve optimization. + + Returns: + Well-formed SVG string. + """ + if binary_img.ndim != 2: + raise ValueError(f"Expected 2D binary image, got shape {binary_img.shape}") + + h, w = binary_img.shape + + # Potrace interprets nonzero pixels as foreground. + # Convert to uint32 — pypotrace needs values that fit in a C int. + data = (binary_img > 0).astype(np.uint32) + + bmp = potrace.Bitmap(data) + path = bmp.trace( + turdsize=turdsize, + alphamax=alphamax, + opticurve=int(opticurve), + opttolerance=opttolerance, + ) + + return _path_to_svg(path, w, h) + + +def _path_to_svg(path, width: int, height: int) -> str: + """Convert a potrace Path object to an SVG string.""" + parts = [] + for curve in path: + sx, sy = curve.start_point + parts.append(f"M {sx:.3f},{sy:.3f}") + for segment in curve.segments: + if segment.is_corner: + cx, cy = segment.c + ex, ey = segment.end_point + parts.append(f"L {cx:.3f},{cy:.3f} L {ex:.3f},{ey:.3f}") + else: + c1x, c1y = segment.c1 + c2x, c2y = segment.c2 + ex, ey = segment.end_point + parts.append( + f"C {c1x:.3f},{c1y:.3f} {c2x:.3f},{c2y:.3f} {ex:.3f},{ey:.3f}" + ) + parts.append("Z") + + d = " ".join(parts) + + return ( + f'' + f'' + f"" + ) diff --git a/engine/tests/test_vectorize.py b/engine/tests/test_vectorize.py new file mode 100644 index 0000000..f55cc71 --- /dev/null +++ b/engine/tests/test_vectorize.py @@ -0,0 +1,214 @@ +"""Tests for the Potrace vectorization module.""" + +import xml.etree.ElementTree as ET + +import numpy as np +import pytest + +from pipeline.vectorize import potrace_trace + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_square(size: int = 100, x0: int = 20, x1: int = 80) -> np.ndarray: + """Create a binary image with a filled white square on black background.""" + img = np.zeros((size, size), dtype=np.uint8) + img[x0:x1, x0:x1] = 255 + return img + + +def _make_circle(size: int = 100, radius: int = 30) -> np.ndarray: + """Create a binary image with a filled white circle.""" + img = np.zeros((size, size), dtype=np.uint8) + cy, cx = size // 2, size // 2 + Y, X = np.ogrid[:size, :size] + mask = (X - cx) ** 2 + (Y - cy) ** 2 < radius ** 2 + img[mask] = 255 + return img + + +def _parse_svg(svg_str: str) -> ET.Element: + """Parse SVG string and return root element.""" + return ET.fromstring(svg_str) + + +# --------------------------------------------------------------------------- +# Basic output tests +# --------------------------------------------------------------------------- + +class TestPotraceBasicOutput: + def test_returns_string(self): + svg = potrace_trace(_make_square()) + assert isinstance(svg, str) + + def test_svg_is_well_formed_xml(self): + svg = potrace_trace(_make_square()) + root = _parse_svg(svg) + assert root.tag == "{http://www.w3.org/2000/svg}svg" + + def test_svg_has_correct_dimensions(self): + svg = potrace_trace(_make_square(size=200)) + root = _parse_svg(svg) + assert root.get("width") == "200" + assert root.get("height") == "200" + + def test_svg_has_viewbox(self): + svg = potrace_trace(_make_square(size=150)) + root = _parse_svg(svg) + assert root.get("viewBox") == "0 0 150 150" + + def test_svg_contains_path_element(self): + svg = potrace_trace(_make_square()) + root = _parse_svg(svg) + ns = {"svg": "http://www.w3.org/2000/svg"} + paths = root.findall("svg:path", ns) + assert len(paths) == 1 + + def test_path_d_attribute_nonempty(self): + svg = potrace_trace(_make_square()) + root = _parse_svg(svg) + ns = {"svg": "http://www.w3.org/2000/svg"} + path_el = root.find("svg:path", ns) + d = path_el.get("d", "") + assert len(d) > 0 + assert "M" in d + + +# --------------------------------------------------------------------------- +# Tracing shapes +# --------------------------------------------------------------------------- + +class TestPotraceShapes: + def test_square_produces_corner_segments(self): + """A sharp square with alphamax=0 should produce L commands (corners).""" + svg = potrace_trace(_make_square(), alphamax=0.0) + root = _parse_svg(svg) + ns = {"svg": "http://www.w3.org/2000/svg"} + d = root.find("svg:path", ns).get("d", "") + assert "L" in d + + def test_circle_produces_curve_segments(self): + """A circle should produce C (cubic bezier) commands.""" + svg = potrace_trace(_make_circle()) + root = _parse_svg(svg) + ns = {"svg": "http://www.w3.org/2000/svg"} + d = root.find("svg:path", ns).get("d", "") + assert "C" in d + + def test_path_closes_with_z(self): + svg = potrace_trace(_make_square()) + root = _parse_svg(svg) + ns = {"svg": "http://www.w3.org/2000/svg"} + d = root.find("svg:path", ns).get("d", "") + assert d.rstrip().endswith("Z") + + +# --------------------------------------------------------------------------- +# Parameter tuning +# --------------------------------------------------------------------------- + +class TestPotraceParams: + def test_turdsize_removes_small_features(self): + """High turdsize should remove a tiny speck, producing empty path data.""" + img = np.zeros((100, 100), dtype=np.uint8) + img[50, 50] = 255 # single pixel speck + svg = potrace_trace(img, turdsize=10) + root = _parse_svg(svg) + ns = {"svg": "http://www.w3.org/2000/svg"} + d = root.find("svg:path", ns).get("d", "") + # A single pixel with turdsize=10 should be suppressed → empty or near-empty path + assert "M" not in d or d.strip() == "" + + def test_opticurve_off_vs_on(self): + """Disabling opticurve should produce different (typically more verbose) output.""" + img = _make_circle() + svg_on = potrace_trace(img, opticurve=True) + svg_off = potrace_trace(img, opticurve=False) + # They should differ (optimization reduces segments) + assert svg_on != svg_off + + def test_alphamax_polygon_mode(self): + """alphamax=0 forces polygon mode — output should contain L but no C commands.""" + svg = potrace_trace(_make_square(), alphamax=0.0) + root = _parse_svg(svg) + ns = {"svg": "http://www.w3.org/2000/svg"} + d = root.find("svg:path", ns).get("d", "") + assert "L" in d + assert "C" not in d + + def test_default_params_produce_valid_svg(self): + """Default parameters should produce valid SVG for a standard test image.""" + svg = potrace_trace(_make_square()) + root = _parse_svg(svg) + assert root.tag == "{http://www.w3.org/2000/svg}svg" + + +# --------------------------------------------------------------------------- +# Edge cases & errors +# --------------------------------------------------------------------------- + +class TestPotraceEdgeCases: + def test_all_black_image(self): + """An all-black (all-zero) image should produce valid SVG with empty path.""" + img = np.zeros((50, 50), dtype=np.uint8) + svg = potrace_trace(img) + root = _parse_svg(svg) + assert root.tag == "{http://www.w3.org/2000/svg}svg" + + def test_all_white_image(self): + """An all-white image should trace the entire frame.""" + img = np.ones((50, 50), dtype=np.uint8) * 255 + svg = potrace_trace(img) + root = _parse_svg(svg) + ns = {"svg": "http://www.w3.org/2000/svg"} + path_el = root.find("svg:path", ns) + d = path_el.get("d", "") + assert "M" in d + + def test_rejects_3d_input(self): + with pytest.raises(ValueError, match="2D"): + potrace_trace(np.zeros((50, 50, 3), dtype=np.uint8)) + + def test_rectangular_image(self): + """Non-square images should work and set correct dimensions.""" + img = np.zeros((80, 120), dtype=np.uint8) + img[10:70, 10:110] = 255 + svg = potrace_trace(img) + root = _parse_svg(svg) + assert root.get("width") == "120" + assert root.get("height") == "80" + + def test_uint8_and_uint32_both_work(self): + """Potrace should accept common numpy dtypes.""" + base = _make_square() + svg8 = potrace_trace(base.astype(np.uint8)) + svg32 = potrace_trace(base.astype(np.uint32)) + # Both should be valid SVG with the same structure + root8 = _parse_svg(svg8) + root32 = _parse_svg(svg32) + assert root8.tag == root32.tag + + +# --------------------------------------------------------------------------- +# Integration with preprocessing output +# --------------------------------------------------------------------------- + +class TestPotracePreprocessingIntegration: + def test_accepts_thresholded_image(self): + """Output of preprocessing threshold() is a valid input for potrace_trace.""" + from pipeline.preprocessing import threshold, to_grayscale + import cv2 + + # Simulate a preprocessed image + img = np.zeros((100, 100, 3), dtype=np.uint8) + cv2.rectangle(img, (20, 20), (80, 80), (255, 255, 255), -1) + gray = to_grayscale(img) + binary = threshold(gray) + svg = potrace_trace(binary) + root = _parse_svg(svg) + assert root.tag == "{http://www.w3.org/2000/svg}svg" + ns = {"svg": "http://www.w3.org/2000/svg"} + d = root.find("svg:path", ns).get("d", "") + assert "M" in d