"""Tests for the Potrace vectorization module.""" import xml.etree.ElementTree as ET import numpy as np import pytest from pipeline.vectorize import potrace_trace # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_square(size: int = 100, x0: int = 20, x1: int = 80) -> np.ndarray: """Create a binary image with a filled white square on black background.""" img = np.zeros((size, size), dtype=np.uint8) img[x0:x1, x0:x1] = 255 return img def _make_circle(size: int = 100, radius: int = 30) -> np.ndarray: """Create a binary image with a filled white circle.""" img = np.zeros((size, size), dtype=np.uint8) cy, cx = size // 2, size // 2 Y, X = np.ogrid[:size, :size] mask = (X - cx) ** 2 + (Y - cy) ** 2 < radius ** 2 img[mask] = 255 return img def _parse_svg(svg_str: str) -> ET.Element: """Parse SVG string and return root element.""" return ET.fromstring(svg_str) # --------------------------------------------------------------------------- # Basic output tests # --------------------------------------------------------------------------- class TestPotraceBasicOutput: def test_returns_string(self): svg = potrace_trace(_make_square()) assert isinstance(svg, str) def test_svg_is_well_formed_xml(self): svg = potrace_trace(_make_square()) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" def test_svg_has_correct_dimensions(self): svg = potrace_trace(_make_square(size=200)) root = _parse_svg(svg) assert root.get("width") == "200" assert root.get("height") == "200" def test_svg_has_viewbox(self): svg = potrace_trace(_make_square(size=150)) root = _parse_svg(svg) assert root.get("viewBox") == "0 0 150 150" def test_svg_contains_path_element(self): svg = potrace_trace(_make_square()) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} paths = root.findall("svg:path", ns) assert len(paths) == 1 def test_path_d_attribute_nonempty(self): svg = potrace_trace(_make_square()) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} path_el = root.find("svg:path", ns) d = path_el.get("d", "") assert len(d) > 0 assert "M" in d # --------------------------------------------------------------------------- # Tracing shapes # --------------------------------------------------------------------------- class TestPotraceShapes: def test_square_produces_corner_segments(self): """A sharp square with alphamax=0 should produce L commands (corners).""" svg = potrace_trace(_make_square(), alphamax=0.0) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} d = root.find("svg:path", ns).get("d", "") assert "L" in d def test_circle_produces_curve_segments(self): """A circle should produce C (cubic bezier) commands.""" svg = potrace_trace(_make_circle()) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} d = root.find("svg:path", ns).get("d", "") assert "C" in d def test_path_closes_with_z(self): svg = potrace_trace(_make_square()) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} d = root.find("svg:path", ns).get("d", "") assert d.rstrip().endswith("Z") # --------------------------------------------------------------------------- # Parameter tuning # --------------------------------------------------------------------------- class TestPotraceParams: def test_turdsize_removes_small_features(self): """High turdsize should remove a tiny speck, producing empty path data.""" img = np.zeros((100, 100), dtype=np.uint8) img[50, 50] = 255 # single pixel speck svg = potrace_trace(img, turdsize=10) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} d = root.find("svg:path", ns).get("d", "") # A single pixel with turdsize=10 should be suppressed → empty or near-empty path assert "M" not in d or d.strip() == "" def test_opticurve_off_vs_on(self): """Disabling opticurve should produce different (typically more verbose) output.""" img = _make_circle() svg_on = potrace_trace(img, opticurve=True) svg_off = potrace_trace(img, opticurve=False) # They should differ (optimization reduces segments) assert svg_on != svg_off def test_alphamax_polygon_mode(self): """alphamax=0 forces polygon mode — output should contain L but no C commands.""" svg = potrace_trace(_make_square(), alphamax=0.0) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} d = root.find("svg:path", ns).get("d", "") assert "L" in d assert "C" not in d def test_default_params_produce_valid_svg(self): """Default parameters should produce valid SVG for a standard test image.""" svg = potrace_trace(_make_square()) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" # --------------------------------------------------------------------------- # Edge cases & errors # --------------------------------------------------------------------------- class TestPotraceEdgeCases: def test_all_black_image(self): """An all-black (all-zero) image should produce valid SVG with empty path.""" img = np.zeros((50, 50), dtype=np.uint8) svg = potrace_trace(img) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" def test_all_white_image(self): """An all-white image should trace the entire frame.""" img = np.ones((50, 50), dtype=np.uint8) * 255 svg = potrace_trace(img) root = _parse_svg(svg) ns = {"svg": "http://www.w3.org/2000/svg"} path_el = root.find("svg:path", ns) d = path_el.get("d", "") assert "M" in d def test_rejects_3d_input(self): with pytest.raises(ValueError, match="2D"): potrace_trace(np.zeros((50, 50, 3), dtype=np.uint8)) def test_rectangular_image(self): """Non-square images should work and set correct dimensions.""" img = np.zeros((80, 120), dtype=np.uint8) img[10:70, 10:110] = 255 svg = potrace_trace(img) root = _parse_svg(svg) assert root.get("width") == "120" assert root.get("height") == "80" def test_uint8_and_uint32_both_work(self): """Potrace should accept common numpy dtypes.""" base = _make_square() svg8 = potrace_trace(base.astype(np.uint8)) svg32 = potrace_trace(base.astype(np.uint32)) # Both should be valid SVG with the same structure root8 = _parse_svg(svg8) root32 = _parse_svg(svg32) assert root8.tag == root32.tag # --------------------------------------------------------------------------- # Integration with preprocessing output # --------------------------------------------------------------------------- class TestPotracePreprocessingIntegration: def test_accepts_thresholded_image(self): """Output of preprocessing threshold() is a valid input for potrace_trace.""" from pipeline.preprocessing import threshold, to_grayscale import cv2 # Simulate a preprocessed image img = np.zeros((100, 100, 3), dtype=np.uint8) cv2.rectangle(img, (20, 20), (80, 80), (255, 255, 255), -1) gray = to_grayscale(img) binary = threshold(gray) svg = potrace_trace(binary) root = _parse_svg(svg) assert root.tag == "{http://www.w3.org/2000/svg}svg" ns = {"svg": "http://www.w3.org/2000/svg"} d = root.find("svg:path", ns).get("d", "") assert "M" in d