"""Tests for the vectorization module (Potrace + VTracer).""" import xml.etree.ElementTree as ET import numpy as np import pytest from pipeline.vectorize import potrace_trace, vtracer_trace # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_square(size: int = 100, x0: int = 20, x1: int = 80) -> np.ndarray: """Create a binary image with a filled white square on black background.""" img = np.zeros((size, size), dtype=np.uint8) img[x0:x1, x0:x1] = 255 return img def _make_circle(size: int = 100, radius: int = 30) -> np.ndarray: """Create a binary image with a filled white circle.""" img = np.zeros((size, size), dtype=np.uint8) cy, cx = size // 2, size // 2 Y, X = np.ogrid[:size, :size] mask = (X - cx) ** 2 + (Y - cy) ** 2 < radius ** 2 img[mask] = 255 return img def _parse_svg(svg_str: str) -> ET.Element: """Parse SVG string and return root element.""" return ET.fromstring(svg_str) # --------------------------------------------------------------------------- # Basic output tests # --------------------------------------------------------------------------- class TestPotraceBasicOutput: def test_returns_string(self): svg = potrace_trace(_make_square()) assert isinstance(svg, str) def test_svg_is_well_formed_xml(self): svg = potrace_trace(_make_square()) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" def test_svg_has_correct_dimensions(self): svg = potrace_trace(_make_square(size=200)) root = _parse_svg(svg) assert root.get("width") == "200" assert root.get("height") == "200" def test_svg_has_viewbox(self): svg = potrace_trace(_make_square(size=150)) root = _parse_svg(svg) assert root.get("viewBox") == "0 0 150 150" def test_svg_contains_path_element(self): svg = potrace_trace(_make_square()) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} paths = root.findall("svg:path", ns) assert len(paths) == 1 def test_path_d_attribute_nonempty(self): svg = potrace_trace(_make_square()) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} path_el = root.find("svg:path", ns) d = path_el.get("d", "") assert len(d) > 0 assert "M" in d # --------------------------------------------------------------------------- # Tracing shapes # --------------------------------------------------------------------------- class TestPotraceShapes: def test_square_produces_corner_segments(self): """A sharp square with alphamax=0 should produce L commands (corners).""" svg = potrace_trace(_make_square(), alphamax=0.0) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} d = root.find("svg:path", ns).get("d", "") assert "L" in d def test_circle_produces_curve_segments(self): """A circle should produce C (cubic bezier) commands.""" svg = potrace_trace(_make_circle()) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} d = root.find("svg:path", ns).get("d", "") assert "C" in d def test_path_closes_with_z(self): svg = potrace_trace(_make_square()) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} d = root.find("svg:path", ns).get("d", "") assert d.rstrip().endswith("Z") # --------------------------------------------------------------------------- # Parameter tuning # --------------------------------------------------------------------------- class TestPotraceParams: def test_turdsize_removes_small_features(self): """High turdsize should remove a tiny speck, producing empty path data.""" img = np.zeros((100, 100), dtype=np.uint8) img[50, 50] = 255 # single pixel speck svg = potrace_trace(img, turdsize=10) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} d = root.find("svg:path", ns).get("d", "") # A single pixel with turdsize=10 should be suppressed → empty or near-empty path assert "M" not in d or d.strip() == "" def test_opticurve_off_vs_on(self): """Disabling opticurve should produce different (typically more verbose) output.""" img = _make_circle() svg_on = potrace_trace(img, opticurve=True) svg_off = potrace_trace(img, opticurve=False) # They should differ (optimization reduces segments) assert svg_on != svg_off def test_alphamax_polygon_mode(self): """alphamax=0 forces polygon mode — output should contain L but no C commands.""" svg = potrace_trace(_make_square(), alphamax=0.0) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} d = root.find("svg:path", ns).get("d", "") assert "L" in d assert "C" not in d def test_default_params_produce_valid_svg(self): """Default parameters should produce valid SVG for a standard test image.""" svg = potrace_trace(_make_square()) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" # --------------------------------------------------------------------------- # Edge cases & errors # --------------------------------------------------------------------------- class TestPotraceEdgeCases: def test_all_black_image(self): """An all-black (all-zero) image should produce valid SVG with empty path.""" img = np.zeros((50, 50), dtype=np.uint8) svg = potrace_trace(img) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" def test_all_white_image(self): """An all-white image should trace the entire frame.""" img = np.ones((50, 50), dtype=np.uint8) * 255 svg = potrace_trace(img) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} path_el = root.find("svg:path", ns) d = path_el.get("d", "") assert "M" in d def test_rejects_3d_input(self): with pytest.raises(ValueError, match="2D"): potrace_trace(np.zeros((50, 50, 3), dtype=np.uint8)) def test_rectangular_image(self): """Non-square images should work and set correct dimensions.""" img = np.zeros((80, 120), dtype=np.uint8) img[10:70, 10:110] = 255 svg = potrace_trace(img) root = _parse_svg(svg) assert root.get("width") == "120" assert root.get("height") == "80" def test_uint8_and_uint32_both_work(self): """Potrace should accept common numpy dtypes.""" base = _make_square() svg8 = potrace_trace(base.astype(np.uint8)) svg32 = potrace_trace(base.astype(np.uint32)) # Both should be valid SVG with the same structure root8 = _parse_svg(svg8) root32 = _parse_svg(svg32) assert root8.tag == root32.tag # --------------------------------------------------------------------------- # Integration with preprocessing output # --------------------------------------------------------------------------- class TestPotracePreprocessingIntegration: def test_accepts_thresholded_image(self): """Output of preprocessing threshold() is a valid input for potrace_trace.""" from pipeline.preprocessing import threshold, to_grayscale import cv2 # Simulate a preprocessed image img = np.zeros((100, 100, 3), dtype=np.uint8) cv2.rectangle(img, (20, 20), (80, 80), (255, 255, 255), -1) gray = to_grayscale(img) binary = threshold(gray) svg = potrace_trace(binary) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" ns = {"svg": "http://www.w3.org/2000/svg"} d = root.find("svg:path", ns).get("d", "") assert "M" in d # =========================================================================== # VTracer tests # =========================================================================== def _make_color_image(size: int = 100) -> np.ndarray: """Create a BGR image with a colored rectangle.""" img = np.zeros((size, size, 3), dtype=np.uint8) img[20:80, 20:80] = [0, 0, 255] # Red rectangle in BGR return img class TestVtracerBasicOutput: def test_returns_string(self): svg = vtracer_trace(_make_square()) assert isinstance(svg, str) def test_svg_is_well_formed_xml(self): svg = vtracer_trace(_make_square()) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" def test_svg_has_correct_dimensions(self): svg = vtracer_trace(_make_square(size=200)) root = _parse_svg(svg) assert root.get("width") == "200" assert root.get("height") == "200" def test_svg_contains_path_element(self): svg = vtracer_trace(_make_square()) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} paths = root.findall("svg:path", ns) assert len(paths) >= 1 def test_path_has_d_attribute(self): svg = vtracer_trace(_make_square()) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} path_el = root.find("svg:path", ns) d = path_el.get("d", "") assert len(d) > 0 class TestVtracerInputTypes: def test_accepts_grayscale_2d(self): """VTracer should accept 2D grayscale images.""" img = _make_square() svg = vtracer_trace(img) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" def test_accepts_bgr_3d(self): """VTracer should accept 3-channel BGR images.""" img = _make_color_image() svg = vtracer_trace(img) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" def test_accepts_bgra_4_channel(self): """VTracer should accept 4-channel BGRA images.""" import cv2 bgr = _make_color_image() bgra = cv2.cvtColor(bgr, cv2.COLOR_BGR2BGRA) svg = vtracer_trace(bgra) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" def test_rejects_4d_input(self): with pytest.raises(ValueError, match="2D or 3D"): vtracer_trace(np.zeros((10, 10, 3, 2), dtype=np.uint8)) def test_rectangular_image(self): img = np.zeros((60, 120), dtype=np.uint8) img[10:50, 10:110] = 255 svg = vtracer_trace(img) root = _parse_svg(svg) assert root.get("width") == "120" assert root.get("height") == "60" class TestVtracerParams: def test_color_mode(self): """Color mode should produce SVG with color fill attributes.""" img = _make_color_image() svg = vtracer_trace(img, colormode="color") root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} paths = root.findall("svg:path", ns) assert len(paths) >= 1 def test_filter_speckle_removes_noise(self): """High filter_speckle should remove tiny features.""" img = np.zeros((100, 100), dtype=np.uint8) img[50, 50] = 255 # single pixel svg_strict = vtracer_trace(img, filter_speckle=100) svg_loose = vtracer_trace(img, filter_speckle=0) # Strict filtering should produce fewer or no path elements root_strict = _parse_svg(svg_strict) root_loose = _parse_svg(svg_loose) ns = {"svg": "http://www.w3.org/2000/svg"} paths_strict = root_strict.findall("svg:path", ns) paths_loose = root_loose.findall("svg:path", ns) # Either fewer paths or shorter path data with strict filtering d_strict = "".join(p.get("d", "") for p in paths_strict) d_loose = "".join(p.get("d", "") for p in paths_loose) assert len(d_strict) <= len(d_loose) def test_polygon_mode(self): """Polygon mode should produce L commands instead of curves.""" svg = vtracer_trace(_make_square(), mode="polygon") root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} path_el = root.find("svg:path", ns) d = path_el.get("d", "") assert "L" in d or "l" in d def test_spline_mode_default(self): """Default spline mode should produce C (curve) commands for curved shapes.""" svg = vtracer_trace(_make_circle(), mode="spline") root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} path_el = root.find("svg:path", ns) d = path_el.get("d", "") assert "C" in d or "c" in d class TestVtracerEdgeCases: def test_all_black_image(self): img = np.zeros((50, 50), dtype=np.uint8) svg = vtracer_trace(img) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" def test_all_white_image(self): img = np.ones((50, 50), dtype=np.uint8) * 255 svg = vtracer_trace(img) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" class TestVtracerPreprocessingIntegration: def test_accepts_thresholded_image(self): """Output of preprocessing pipeline is valid input for vtracer_trace.""" from pipeline.preprocessing import threshold, to_grayscale import cv2 img = np.zeros((100, 100, 3), dtype=np.uint8) cv2.rectangle(img, (20, 20), (80, 80), (255, 255, 255), -1) gray = to_grayscale(img) binary = threshold(gray) svg = vtracer_trace(binary) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" ns = {"svg": "http://www.w3.org/2000/svg"} paths = root.findall("svg:path", ns) assert len(paths) >= 1 class TestVtracerVsPotraceComparison: def test_both_produce_valid_svg_from_same_input(self): """Both backends should produce valid SVG from the same binary image.""" img = _make_square() svg_potrace = potrace_trace(img) svg_vtracer = vtracer_trace(img) root_p = _parse_svg(svg_potrace) root_v = _parse_svg(svg_vtracer) assert root_p.tag == "{http://www.w3.org/2000/svg}svg" assert root_v.tag == "{http://www.w3.org/2000/svg}svg" def test_outputs_differ(self): """Potrace and VTracer should produce different SVG for the same input.""" img = _make_square() svg_potrace = potrace_trace(img) svg_vtracer = vtracer_trace(img) # They use different algorithms so output should differ assert svg_potrace != svg_vtracer