test: Implemented vtracer_trace() function that converts grayscale or c…
- engine/pipeline/vectorize.py - engine/tests/test_vectorize.py GSD-Task: S01/T04
This commit is contained in:
parent
136a9417f9
commit
b33e883a6b
2 changed files with 246 additions and 2 deletions
|
|
@ -1,7 +1,9 @@
|
||||||
"""Vectorization pipeline — converts preprocessed binary images to SVG."""
|
"""Vectorization pipeline — converts preprocessed binary images to SVG."""
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import potrace
|
import potrace
|
||||||
|
import vtracer
|
||||||
|
|
||||||
|
|
||||||
def potrace_trace(
|
def potrace_trace(
|
||||||
|
|
@ -43,6 +45,70 @@ def potrace_trace(
|
||||||
return _path_to_svg(path, w, h)
|
return _path_to_svg(path, w, h)
|
||||||
|
|
||||||
|
|
||||||
|
def vtracer_trace(
|
||||||
|
img: np.ndarray,
|
||||||
|
colormode: str = "binary",
|
||||||
|
hierarchical: str = "stacked",
|
||||||
|
filter_speckle: int = 4,
|
||||||
|
color_precision: int = 6,
|
||||||
|
layer_difference: int = 16,
|
||||||
|
corner_threshold: int = 60,
|
||||||
|
length_threshold: float = 4.0,
|
||||||
|
splice_threshold: int = 45,
|
||||||
|
mode: str = "spline",
|
||||||
|
path_precision: int | None = None,
|
||||||
|
max_iterations: int = 10,
|
||||||
|
) -> str:
|
||||||
|
"""Trace an image using VTracer and return an SVG string.
|
||||||
|
|
||||||
|
Unlike potrace_trace, this accepts both grayscale and color images.
|
||||||
|
Internally encodes the image as PNG and passes it to VTracer's Rust backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: 2D (grayscale) or 3D (BGR/BGRA) numpy array.
|
||||||
|
colormode: 'color' or 'binary'.
|
||||||
|
hierarchical: 'stacked' or 'cutout'.
|
||||||
|
filter_speckle: Remove patches smaller than this area (in px).
|
||||||
|
color_precision: Number of significant bits for color quantization (1-8).
|
||||||
|
layer_difference: Delta threshold for color layer grouping.
|
||||||
|
corner_threshold: Angle (degrees) below which a point is a corner.
|
||||||
|
length_threshold: Minimum segment length before simplification.
|
||||||
|
splice_threshold: Angle (degrees) for splicing splines.
|
||||||
|
mode: 'spline', 'polygon', or 'none' — curve fitting strategy.
|
||||||
|
path_precision: Decimal precision for path coordinates.
|
||||||
|
max_iterations: Max curve-fitting iterations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SVG string (includes XML declaration and generator comment).
|
||||||
|
"""
|
||||||
|
if img.ndim not in (2, 3):
|
||||||
|
raise ValueError(f"Expected 2D or 3D image, got {img.ndim}D (shape {img.shape})")
|
||||||
|
|
||||||
|
# Encode as PNG — VTracer accepts raw image bytes via convert_raw_image_to_svg.
|
||||||
|
ok, buf = cv2.imencode(".png", img)
|
||||||
|
if not ok:
|
||||||
|
raise RuntimeError("Failed to encode image as PNG for VTracer")
|
||||||
|
|
||||||
|
# Build kwargs, omitting None values so VTracer uses its defaults.
|
||||||
|
kwargs: dict = dict(
|
||||||
|
img_format="png",
|
||||||
|
colormode=colormode,
|
||||||
|
hierarchical=hierarchical,
|
||||||
|
filter_speckle=filter_speckle,
|
||||||
|
color_precision=color_precision,
|
||||||
|
layer_difference=layer_difference,
|
||||||
|
corner_threshold=corner_threshold,
|
||||||
|
length_threshold=length_threshold,
|
||||||
|
splice_threshold=splice_threshold,
|
||||||
|
mode=mode,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
)
|
||||||
|
if path_precision is not None:
|
||||||
|
kwargs["path_precision"] = path_precision
|
||||||
|
|
||||||
|
return vtracer.convert_raw_image_to_svg(buf.tobytes(), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _path_to_svg(path, width: int, height: int) -> str:
|
def _path_to_svg(path, width: int, height: int) -> str:
|
||||||
"""Convert a potrace Path object to an SVG string."""
|
"""Convert a potrace Path object to an SVG string."""
|
||||||
parts = []
|
parts = []
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
"""Tests for the Potrace vectorization module."""
|
"""Tests for the vectorization module (Potrace + VTracer)."""
|
||||||
|
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from pipeline.vectorize import potrace_trace
|
from pipeline.vectorize import potrace_trace, vtracer_trace
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -212,3 +212,181 @@ class TestPotracePreprocessingIntegration:
|
||||||
ns = {"svg": "http://www.w3.org/2000/svg"}
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
||||||
d = root.find("svg:path", ns).get("d", "")
|
d = root.find("svg:path", ns).get("d", "")
|
||||||
assert "M" in d
|
assert "M" in d
|
||||||
|
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# VTracer tests
|
||||||
|
# ===========================================================================
|
||||||
|
|
||||||
|
def _make_color_image(size: int = 100) -> np.ndarray:
|
||||||
|
"""Create a BGR image with a colored rectangle."""
|
||||||
|
img = np.zeros((size, size, 3), dtype=np.uint8)
|
||||||
|
img[20:80, 20:80] = [0, 0, 255] # Red rectangle in BGR
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class TestVtracerBasicOutput:
|
||||||
|
def test_returns_string(self):
|
||||||
|
svg = vtracer_trace(_make_square())
|
||||||
|
assert isinstance(svg, str)
|
||||||
|
|
||||||
|
def test_svg_is_well_formed_xml(self):
|
||||||
|
svg = vtracer_trace(_make_square())
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
||||||
|
|
||||||
|
def test_svg_has_correct_dimensions(self):
|
||||||
|
svg = vtracer_trace(_make_square(size=200))
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
assert root.get("width") == "200"
|
||||||
|
assert root.get("height") == "200"
|
||||||
|
|
||||||
|
def test_svg_contains_path_element(self):
|
||||||
|
svg = vtracer_trace(_make_square())
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
||||||
|
paths = root.findall("svg:path", ns)
|
||||||
|
assert len(paths) >= 1
|
||||||
|
|
||||||
|
def test_path_has_d_attribute(self):
|
||||||
|
svg = vtracer_trace(_make_square())
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
||||||
|
path_el = root.find("svg:path", ns)
|
||||||
|
d = path_el.get("d", "")
|
||||||
|
assert len(d) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestVtracerInputTypes:
|
||||||
|
def test_accepts_grayscale_2d(self):
|
||||||
|
"""VTracer should accept 2D grayscale images."""
|
||||||
|
img = _make_square()
|
||||||
|
svg = vtracer_trace(img)
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
||||||
|
|
||||||
|
def test_accepts_bgr_3d(self):
|
||||||
|
"""VTracer should accept 3-channel BGR images."""
|
||||||
|
img = _make_color_image()
|
||||||
|
svg = vtracer_trace(img)
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
||||||
|
|
||||||
|
def test_accepts_bgra_4_channel(self):
|
||||||
|
"""VTracer should accept 4-channel BGRA images."""
|
||||||
|
import cv2
|
||||||
|
bgr = _make_color_image()
|
||||||
|
bgra = cv2.cvtColor(bgr, cv2.COLOR_BGR2BGRA)
|
||||||
|
svg = vtracer_trace(bgra)
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
||||||
|
|
||||||
|
def test_rejects_4d_input(self):
|
||||||
|
with pytest.raises(ValueError, match="2D or 3D"):
|
||||||
|
vtracer_trace(np.zeros((10, 10, 3, 2), dtype=np.uint8))
|
||||||
|
|
||||||
|
def test_rectangular_image(self):
|
||||||
|
img = np.zeros((60, 120), dtype=np.uint8)
|
||||||
|
img[10:50, 10:110] = 255
|
||||||
|
svg = vtracer_trace(img)
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
assert root.get("width") == "120"
|
||||||
|
assert root.get("height") == "60"
|
||||||
|
|
||||||
|
|
||||||
|
class TestVtracerParams:
|
||||||
|
def test_color_mode(self):
|
||||||
|
"""Color mode should produce SVG with color fill attributes."""
|
||||||
|
img = _make_color_image()
|
||||||
|
svg = vtracer_trace(img, colormode="color")
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
||||||
|
paths = root.findall("svg:path", ns)
|
||||||
|
assert len(paths) >= 1
|
||||||
|
|
||||||
|
def test_filter_speckle_removes_noise(self):
|
||||||
|
"""High filter_speckle should remove tiny features."""
|
||||||
|
img = np.zeros((100, 100), dtype=np.uint8)
|
||||||
|
img[50, 50] = 255 # single pixel
|
||||||
|
svg_strict = vtracer_trace(img, filter_speckle=100)
|
||||||
|
svg_loose = vtracer_trace(img, filter_speckle=0)
|
||||||
|
# Strict filtering should produce fewer or no path elements
|
||||||
|
root_strict = _parse_svg(svg_strict)
|
||||||
|
root_loose = _parse_svg(svg_loose)
|
||||||
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
||||||
|
paths_strict = root_strict.findall("svg:path", ns)
|
||||||
|
paths_loose = root_loose.findall("svg:path", ns)
|
||||||
|
# Either fewer paths or shorter path data with strict filtering
|
||||||
|
d_strict = "".join(p.get("d", "") for p in paths_strict)
|
||||||
|
d_loose = "".join(p.get("d", "") for p in paths_loose)
|
||||||
|
assert len(d_strict) <= len(d_loose)
|
||||||
|
|
||||||
|
def test_polygon_mode(self):
|
||||||
|
"""Polygon mode should produce L commands instead of curves."""
|
||||||
|
svg = vtracer_trace(_make_square(), mode="polygon")
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
||||||
|
path_el = root.find("svg:path", ns)
|
||||||
|
d = path_el.get("d", "")
|
||||||
|
assert "L" in d or "l" in d
|
||||||
|
|
||||||
|
def test_spline_mode_default(self):
|
||||||
|
"""Default spline mode should produce C (curve) commands for curved shapes."""
|
||||||
|
svg = vtracer_trace(_make_circle(), mode="spline")
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
||||||
|
path_el = root.find("svg:path", ns)
|
||||||
|
d = path_el.get("d", "")
|
||||||
|
assert "C" in d or "c" in d
|
||||||
|
|
||||||
|
|
||||||
|
class TestVtracerEdgeCases:
|
||||||
|
def test_all_black_image(self):
|
||||||
|
img = np.zeros((50, 50), dtype=np.uint8)
|
||||||
|
svg = vtracer_trace(img)
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
||||||
|
|
||||||
|
def test_all_white_image(self):
|
||||||
|
img = np.ones((50, 50), dtype=np.uint8) * 255
|
||||||
|
svg = vtracer_trace(img)
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
||||||
|
|
||||||
|
|
||||||
|
class TestVtracerPreprocessingIntegration:
|
||||||
|
def test_accepts_thresholded_image(self):
|
||||||
|
"""Output of preprocessing pipeline is valid input for vtracer_trace."""
|
||||||
|
from pipeline.preprocessing import threshold, to_grayscale
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
cv2.rectangle(img, (20, 20), (80, 80), (255, 255, 255), -1)
|
||||||
|
gray = to_grayscale(img)
|
||||||
|
binary = threshold(gray)
|
||||||
|
svg = vtracer_trace(binary)
|
||||||
|
root = _parse_svg(svg)
|
||||||
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
||||||
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
||||||
|
paths = root.findall("svg:path", ns)
|
||||||
|
assert len(paths) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestVtracerVsPotraceComparison:
|
||||||
|
def test_both_produce_valid_svg_from_same_input(self):
|
||||||
|
"""Both backends should produce valid SVG from the same binary image."""
|
||||||
|
img = _make_square()
|
||||||
|
svg_potrace = potrace_trace(img)
|
||||||
|
svg_vtracer = vtracer_trace(img)
|
||||||
|
root_p = _parse_svg(svg_potrace)
|
||||||
|
root_v = _parse_svg(svg_vtracer)
|
||||||
|
assert root_p.tag == "{http://www.w3.org/2000/svg}svg"
|
||||||
|
assert root_v.tag == "{http://www.w3.org/2000/svg}svg"
|
||||||
|
|
||||||
|
def test_outputs_differ(self):
|
||||||
|
"""Potrace and VTracer should produce different SVG for the same input."""
|
||||||
|
img = _make_square()
|
||||||
|
svg_potrace = potrace_trace(img)
|
||||||
|
svg_vtracer = vtracer_trace(img)
|
||||||
|
# They use different algorithms so output should differ
|
||||||
|
assert svg_potrace != svg_vtracer
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue