392 lines
15 KiB
Python
392 lines
15 KiB
Python
"""Tests for the vectorization module (Potrace + VTracer)."""
|
|
|
|
import xml.etree.ElementTree as ET
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from pipeline.vectorize import potrace_trace, vtracer_trace
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_square(size: int = 100, x0: int = 20, x1: int = 80) -> np.ndarray:
|
|
"""Create a binary image with a filled white square on black background."""
|
|
img = np.zeros((size, size), dtype=np.uint8)
|
|
img[x0:x1, x0:x1] = 255
|
|
return img
|
|
|
|
|
|
def _make_circle(size: int = 100, radius: int = 30) -> np.ndarray:
|
|
"""Create a binary image with a filled white circle."""
|
|
img = np.zeros((size, size), dtype=np.uint8)
|
|
cy, cx = size // 2, size // 2
|
|
Y, X = np.ogrid[:size, :size]
|
|
mask = (X - cx) ** 2 + (Y - cy) ** 2 < radius ** 2
|
|
img[mask] = 255
|
|
return img
|
|
|
|
|
|
def _parse_svg(svg_str: str) -> ET.Element:
|
|
"""Parse SVG string and return root element."""
|
|
return ET.fromstring(svg_str)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Basic output tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestPotraceBasicOutput:
|
|
def test_returns_string(self):
|
|
svg = potrace_trace(_make_square())
|
|
assert isinstance(svg, str)
|
|
|
|
def test_svg_is_well_formed_xml(self):
|
|
svg = potrace_trace(_make_square())
|
|
root = _parse_svg(svg)
|
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
|
|
|
def test_svg_has_correct_dimensions(self):
|
|
svg = potrace_trace(_make_square(size=200))
|
|
root = _parse_svg(svg)
|
|
assert root.get("width") == "200"
|
|
assert root.get("height") == "200"
|
|
|
|
def test_svg_has_viewbox(self):
|
|
svg = potrace_trace(_make_square(size=150))
|
|
root = _parse_svg(svg)
|
|
assert root.get("viewBox") == "0 0 150 150"
|
|
|
|
def test_svg_contains_path_element(self):
|
|
svg = potrace_trace(_make_square())
|
|
root = _parse_svg(svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
paths = root.findall("svg:path", ns)
|
|
assert len(paths) == 1
|
|
|
|
def test_path_d_attribute_nonempty(self):
|
|
svg = potrace_trace(_make_square())
|
|
root = _parse_svg(svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
path_el = root.find("svg:path", ns)
|
|
d = path_el.get("d", "")
|
|
assert len(d) > 0
|
|
assert "M" in d
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tracing shapes
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestPotraceShapes:
|
|
def test_square_produces_corner_segments(self):
|
|
"""A sharp square with alphamax=0 should produce L commands (corners)."""
|
|
svg = potrace_trace(_make_square(), alphamax=0.0)
|
|
root = _parse_svg(svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
d = root.find("svg:path", ns).get("d", "")
|
|
assert "L" in d
|
|
|
|
def test_circle_produces_curve_segments(self):
|
|
"""A circle should produce C (cubic bezier) commands."""
|
|
svg = potrace_trace(_make_circle())
|
|
root = _parse_svg(svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
d = root.find("svg:path", ns).get("d", "")
|
|
assert "C" in d
|
|
|
|
def test_path_closes_with_z(self):
|
|
svg = potrace_trace(_make_square())
|
|
root = _parse_svg(svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
d = root.find("svg:path", ns).get("d", "")
|
|
assert d.rstrip().endswith("Z")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Parameter tuning
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestPotraceParams:
|
|
def test_turdsize_removes_small_features(self):
|
|
"""High turdsize should remove a tiny speck, producing empty path data."""
|
|
img = np.zeros((100, 100), dtype=np.uint8)
|
|
img[50, 50] = 255 # single pixel speck
|
|
svg = potrace_trace(img, turdsize=10)
|
|
root = _parse_svg(svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
d = root.find("svg:path", ns).get("d", "")
|
|
# A single pixel with turdsize=10 should be suppressed → empty or near-empty path
|
|
assert "M" not in d or d.strip() == ""
|
|
|
|
def test_opticurve_off_vs_on(self):
|
|
"""Disabling opticurve should produce different (typically more verbose) output."""
|
|
img = _make_circle()
|
|
svg_on = potrace_trace(img, opticurve=True)
|
|
svg_off = potrace_trace(img, opticurve=False)
|
|
# They should differ (optimization reduces segments)
|
|
assert svg_on != svg_off
|
|
|
|
def test_alphamax_polygon_mode(self):
|
|
"""alphamax=0 forces polygon mode — output should contain L but no C commands."""
|
|
svg = potrace_trace(_make_square(), alphamax=0.0)
|
|
root = _parse_svg(svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
d = root.find("svg:path", ns).get("d", "")
|
|
assert "L" in d
|
|
assert "C" not in d
|
|
|
|
def test_default_params_produce_valid_svg(self):
|
|
"""Default parameters should produce valid SVG for a standard test image."""
|
|
svg = potrace_trace(_make_square())
|
|
root = _parse_svg(svg)
|
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Edge cases & errors
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestPotraceEdgeCases:
|
|
def test_all_black_image(self):
|
|
"""An all-black (all-zero) image should produce valid SVG with empty path."""
|
|
img = np.zeros((50, 50), dtype=np.uint8)
|
|
svg = potrace_trace(img)
|
|
root = _parse_svg(svg)
|
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
|
|
|
def test_all_white_image(self):
|
|
"""An all-white image should trace the entire frame."""
|
|
img = np.ones((50, 50), dtype=np.uint8) * 255
|
|
svg = potrace_trace(img)
|
|
root = _parse_svg(svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
path_el = root.find("svg:path", ns)
|
|
d = path_el.get("d", "")
|
|
assert "M" in d
|
|
|
|
def test_rejects_3d_input(self):
|
|
with pytest.raises(ValueError, match="2D"):
|
|
potrace_trace(np.zeros((50, 50, 3), dtype=np.uint8))
|
|
|
|
def test_rectangular_image(self):
|
|
"""Non-square images should work and set correct dimensions."""
|
|
img = np.zeros((80, 120), dtype=np.uint8)
|
|
img[10:70, 10:110] = 255
|
|
svg = potrace_trace(img)
|
|
root = _parse_svg(svg)
|
|
assert root.get("width") == "120"
|
|
assert root.get("height") == "80"
|
|
|
|
def test_uint8_and_uint32_both_work(self):
|
|
"""Potrace should accept common numpy dtypes."""
|
|
base = _make_square()
|
|
svg8 = potrace_trace(base.astype(np.uint8))
|
|
svg32 = potrace_trace(base.astype(np.uint32))
|
|
# Both should be valid SVG with the same structure
|
|
root8 = _parse_svg(svg8)
|
|
root32 = _parse_svg(svg32)
|
|
assert root8.tag == root32.tag
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration with preprocessing output
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestPotracePreprocessingIntegration:
|
|
def test_accepts_thresholded_image(self):
|
|
"""Output of preprocessing threshold() is a valid input for potrace_trace."""
|
|
from pipeline.preprocessing import threshold, to_grayscale
|
|
import cv2
|
|
|
|
# Simulate a preprocessed image
|
|
img = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
cv2.rectangle(img, (20, 20), (80, 80), (255, 255, 255), -1)
|
|
gray = to_grayscale(img)
|
|
binary = threshold(gray)
|
|
svg = potrace_trace(binary)
|
|
root = _parse_svg(svg)
|
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
d = root.find("svg:path", ns).get("d", "")
|
|
assert "M" in d
|
|
|
|
|
|
# ===========================================================================
|
|
# VTracer tests
|
|
# ===========================================================================
|
|
|
|
def _make_color_image(size: int = 100) -> np.ndarray:
|
|
"""Create a BGR image with a colored rectangle."""
|
|
img = np.zeros((size, size, 3), dtype=np.uint8)
|
|
img[20:80, 20:80] = [0, 0, 255] # Red rectangle in BGR
|
|
return img
|
|
|
|
|
|
class TestVtracerBasicOutput:
|
|
def test_returns_string(self):
|
|
svg = vtracer_trace(_make_square())
|
|
assert isinstance(svg, str)
|
|
|
|
def test_svg_is_well_formed_xml(self):
|
|
svg = vtracer_trace(_make_square())
|
|
root = _parse_svg(svg)
|
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
|
|
|
def test_svg_has_correct_dimensions(self):
|
|
svg = vtracer_trace(_make_square(size=200))
|
|
root = _parse_svg(svg)
|
|
assert root.get("width") == "200"
|
|
assert root.get("height") == "200"
|
|
|
|
def test_svg_contains_path_element(self):
|
|
svg = vtracer_trace(_make_square())
|
|
root = _parse_svg(svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
paths = root.findall("svg:path", ns)
|
|
assert len(paths) >= 1
|
|
|
|
def test_path_has_d_attribute(self):
|
|
svg = vtracer_trace(_make_square())
|
|
root = _parse_svg(svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
path_el = root.find("svg:path", ns)
|
|
d = path_el.get("d", "")
|
|
assert len(d) > 0
|
|
|
|
|
|
class TestVtracerInputTypes:
|
|
def test_accepts_grayscale_2d(self):
|
|
"""VTracer should accept 2D grayscale images."""
|
|
img = _make_square()
|
|
svg = vtracer_trace(img)
|
|
root = _parse_svg(svg)
|
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
|
|
|
def test_accepts_bgr_3d(self):
|
|
"""VTracer should accept 3-channel BGR images."""
|
|
img = _make_color_image()
|
|
svg = vtracer_trace(img)
|
|
root = _parse_svg(svg)
|
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
|
|
|
def test_accepts_bgra_4_channel(self):
|
|
"""VTracer should accept 4-channel BGRA images."""
|
|
import cv2
|
|
bgr = _make_color_image()
|
|
bgra = cv2.cvtColor(bgr, cv2.COLOR_BGR2BGRA)
|
|
svg = vtracer_trace(bgra)
|
|
root = _parse_svg(svg)
|
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
|
|
|
def test_rejects_4d_input(self):
|
|
with pytest.raises(ValueError, match="2D or 3D"):
|
|
vtracer_trace(np.zeros((10, 10, 3, 2), dtype=np.uint8))
|
|
|
|
def test_rectangular_image(self):
|
|
img = np.zeros((60, 120), dtype=np.uint8)
|
|
img[10:50, 10:110] = 255
|
|
svg = vtracer_trace(img)
|
|
root = _parse_svg(svg)
|
|
assert root.get("width") == "120"
|
|
assert root.get("height") == "60"
|
|
|
|
|
|
class TestVtracerParams:
|
|
def test_color_mode(self):
|
|
"""Color mode should produce SVG with color fill attributes."""
|
|
img = _make_color_image()
|
|
svg = vtracer_trace(img, colormode="color")
|
|
root = _parse_svg(svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
paths = root.findall("svg:path", ns)
|
|
assert len(paths) >= 1
|
|
|
|
def test_filter_speckle_removes_noise(self):
|
|
"""High filter_speckle should remove tiny features."""
|
|
img = np.zeros((100, 100), dtype=np.uint8)
|
|
img[50, 50] = 255 # single pixel
|
|
svg_strict = vtracer_trace(img, filter_speckle=100)
|
|
svg_loose = vtracer_trace(img, filter_speckle=0)
|
|
# Strict filtering should produce fewer or no path elements
|
|
root_strict = _parse_svg(svg_strict)
|
|
root_loose = _parse_svg(svg_loose)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
paths_strict = root_strict.findall("svg:path", ns)
|
|
paths_loose = root_loose.findall("svg:path", ns)
|
|
# Either fewer paths or shorter path data with strict filtering
|
|
d_strict = "".join(p.get("d", "") for p in paths_strict)
|
|
d_loose = "".join(p.get("d", "") for p in paths_loose)
|
|
assert len(d_strict) <= len(d_loose)
|
|
|
|
def test_polygon_mode(self):
|
|
"""Polygon mode should produce L commands instead of curves."""
|
|
svg = vtracer_trace(_make_square(), mode="polygon")
|
|
root = _parse_svg(svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
path_el = root.find("svg:path", ns)
|
|
d = path_el.get("d", "")
|
|
assert "L" in d or "l" in d
|
|
|
|
def test_spline_mode_default(self):
|
|
"""Default spline mode should produce C (curve) commands for curved shapes."""
|
|
svg = vtracer_trace(_make_circle(), mode="spline")
|
|
root = _parse_svg(svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
path_el = root.find("svg:path", ns)
|
|
d = path_el.get("d", "")
|
|
assert "C" in d or "c" in d
|
|
|
|
|
|
class TestVtracerEdgeCases:
|
|
def test_all_black_image(self):
|
|
img = np.zeros((50, 50), dtype=np.uint8)
|
|
svg = vtracer_trace(img)
|
|
root = _parse_svg(svg)
|
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
|
|
|
def test_all_white_image(self):
|
|
img = np.ones((50, 50), dtype=np.uint8) * 255
|
|
svg = vtracer_trace(img)
|
|
root = _parse_svg(svg)
|
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
|
|
|
|
|
class TestVtracerPreprocessingIntegration:
|
|
def test_accepts_thresholded_image(self):
|
|
"""Output of preprocessing pipeline is valid input for vtracer_trace."""
|
|
from pipeline.preprocessing import threshold, to_grayscale
|
|
import cv2
|
|
|
|
img = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
cv2.rectangle(img, (20, 20), (80, 80), (255, 255, 255), -1)
|
|
gray = to_grayscale(img)
|
|
binary = threshold(gray)
|
|
svg = vtracer_trace(binary)
|
|
root = _parse_svg(svg)
|
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
paths = root.findall("svg:path", ns)
|
|
assert len(paths) >= 1
|
|
|
|
|
|
class TestVtracerVsPotraceComparison:
|
|
def test_both_produce_valid_svg_from_same_input(self):
|
|
"""Both backends should produce valid SVG from the same binary image."""
|
|
img = _make_square()
|
|
svg_potrace = potrace_trace(img)
|
|
svg_vtracer = vtracer_trace(img)
|
|
root_p = _parse_svg(svg_potrace)
|
|
root_v = _parse_svg(svg_vtracer)
|
|
assert root_p.tag == "{http://www.w3.org/2000/svg}svg"
|
|
assert root_v.tag == "{http://www.w3.org/2000/svg}svg"
|
|
|
|
def test_outputs_differ(self):
|
|
"""Potrace and VTracer should produce different SVG for the same input."""
|
|
img = _make_square()
|
|
svg_potrace = potrace_trace(img)
|
|
svg_vtracer = vtracer_trace(img)
|
|
# They use different algorithms so output should differ
|
|
assert svg_potrace != svg_vtracer
|