375 lines
12 KiB
Python
375 lines
12 KiB
Python
"""Tests for the post-processing pipeline (RDP, island detection, open path repair)."""
|
|
|
|
import math
|
|
|
|
import pytest
|
|
|
|
from pipeline.postprocess import (
|
|
PostProcessResult,
|
|
close_path,
|
|
detect_island,
|
|
is_closed,
|
|
node_count,
|
|
parse_svg_path,
|
|
postprocess_svg,
|
|
rdp_simplify,
|
|
signed_area,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_svg(d: str, width: int = 100, height: int = 100) -> str:
|
|
"""Build a minimal SVG string with the given path data."""
|
|
return (
|
|
f'<svg xmlns="http://www.w3.org/2000/svg" '
|
|
f'width="{width}" height="{height}" '
|
|
f'viewBox="0 0 {width} {height}">'
|
|
f'<path d="{d}" fill="black" fill-rule="evenodd" stroke="none"/>'
|
|
f"</svg>"
|
|
)
|
|
|
|
|
|
# A simple closed square: 0,0 → 100,0 → 100,100 → 0,100 → close
|
|
SQUARE_D = "M 0,0 L 100,0 L 100,100 L 0,100 Z"
|
|
|
|
# A triangle
|
|
TRIANGLE_D = "M 50,0 L 100,100 L 0,100 Z"
|
|
|
|
# An open path (no Z, endpoints differ)
|
|
OPEN_D = "M 0,0 L 50,50 L 100,0"
|
|
|
|
# Clockwise square (island/hole) — opposite winding from SQUARE_D
|
|
CW_SQUARE_D = "M 0,0 L 0,100 L 100,100 L 100,0 Z"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SVG path parsing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestParseSvgPath:
|
|
def test_simple_move_and_lines(self):
|
|
subpaths = parse_svg_path("M 0,0 L 10,0 L 10,10 Z")
|
|
assert len(subpaths) == 1
|
|
assert subpaths[0][0] == (0.0, 0.0)
|
|
assert subpaths[0][1] == (10.0, 0.0)
|
|
assert subpaths[0][2] == (10.0, 10.0)
|
|
|
|
def test_multiple_subpaths(self):
|
|
subpaths = parse_svg_path("M 0,0 L 10,10 Z M 20,20 L 30,30 Z")
|
|
assert len(subpaths) == 2
|
|
|
|
def test_cubic_bezier(self):
|
|
subpaths = parse_svg_path("M 0,0 C 10,20 30,40 50,60 Z")
|
|
assert len(subpaths) == 1
|
|
coords = subpaths[0]
|
|
assert len(coords) >= 2
|
|
# Endpoint (50, 60) should be present; last point is (0,0) from Z close
|
|
assert (50.0, 60.0) in coords
|
|
assert coords[-1] == (0.0, 0.0) # Z closes back to start
|
|
|
|
def test_relative_lineto(self):
|
|
subpaths = parse_svg_path("M 10,10 l 5,0 l 0,5 Z")
|
|
assert len(subpaths) == 1
|
|
assert subpaths[0][0] == (10.0, 10.0)
|
|
assert subpaths[0][1] == (15.0, 10.0)
|
|
assert subpaths[0][2] == (15.0, 15.0)
|
|
|
|
def test_horizontal_vertical(self):
|
|
subpaths = parse_svg_path("M 0,0 H 10 V 10 Z")
|
|
assert len(subpaths) == 1
|
|
assert (10.0, 0.0) in subpaths[0]
|
|
assert (10.0, 10.0) in subpaths[0]
|
|
|
|
def test_empty_path(self):
|
|
subpaths = parse_svg_path("")
|
|
assert subpaths == []
|
|
|
|
def test_move_only(self):
|
|
subpaths = parse_svg_path("M 5,5")
|
|
assert len(subpaths) == 1
|
|
assert subpaths[0] == [(5.0, 5.0)]
|
|
|
|
def test_quadratic_bezier(self):
|
|
subpaths = parse_svg_path("M 0,0 Q 50,100 100,0 Z")
|
|
assert len(subpaths) == 1
|
|
coords = subpaths[0]
|
|
assert (100.0, 0.0) in coords
|
|
assert coords[-1] == (0.0, 0.0) # Z closes back to start
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RDP simplification
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestRdpSimplify:
|
|
def test_collinear_points_reduced(self):
|
|
"""Points along a straight line should be reduced to just endpoints."""
|
|
coords = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
|
|
result = rdp_simplify(coords, epsilon=0.1)
|
|
assert len(result) == 2
|
|
assert result[0] == (0, 0)
|
|
assert result[-1] == (4, 4)
|
|
|
|
def test_preserves_corners(self):
|
|
"""A right angle should be preserved even with simplification."""
|
|
coords = [(0, 0), (10, 0), (10, 10)]
|
|
result = rdp_simplify(coords, epsilon=0.5)
|
|
assert len(result) == 3
|
|
|
|
def test_epsilon_zero_preserves_all(self):
|
|
"""Epsilon=0 should keep all points."""
|
|
coords = [(0, 0), (5, 1), (10, 0)]
|
|
result = rdp_simplify(coords, epsilon=0.0)
|
|
assert len(result) == 3
|
|
|
|
def test_high_epsilon_aggressive(self):
|
|
"""High epsilon should aggressively simplify."""
|
|
coords = [(0, 0), (5, 0.5), (10, 0), (15, 0.3), (20, 0)]
|
|
result = rdp_simplify(coords, epsilon=10.0)
|
|
assert len(result) == 2
|
|
|
|
def test_two_points_unchanged(self):
|
|
coords = [(0, 0), (10, 10)]
|
|
result = rdp_simplify(coords, epsilon=1.0)
|
|
assert result == [(0, 0), (10, 10)]
|
|
|
|
def test_single_point_unchanged(self):
|
|
coords = [(5, 5)]
|
|
result = rdp_simplify(coords, epsilon=1.0)
|
|
assert result == [(5, 5)]
|
|
|
|
def test_empty_input(self):
|
|
result = rdp_simplify([], epsilon=1.0)
|
|
assert result == []
|
|
|
|
def test_reduces_node_count(self):
|
|
"""A complex path should have fewer nodes after simplification."""
|
|
# Approximate a circle with many points
|
|
n = 100
|
|
coords = [
|
|
(50 + 40 * math.cos(2 * math.pi * i / n),
|
|
50 + 40 * math.sin(2 * math.pi * i / n))
|
|
for i in range(n)
|
|
]
|
|
result = rdp_simplify(coords, epsilon=2.0)
|
|
assert len(result) < len(coords)
|
|
assert len(result) >= 3 # must retain at least a polygon
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Signed area / winding detection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSignedArea:
|
|
def test_ccw_square_positive(self):
|
|
"""Counter-clockwise square should have positive area."""
|
|
coords = [(0, 0), (100, 0), (100, 100), (0, 100)]
|
|
assert signed_area(coords) > 0
|
|
|
|
def test_cw_square_negative(self):
|
|
"""Clockwise square should have negative area."""
|
|
coords = [(0, 0), (0, 100), (100, 100), (100, 0)]
|
|
assert signed_area(coords) < 0
|
|
|
|
def test_area_magnitude(self):
|
|
"""Area of a 10x10 square should be 100."""
|
|
coords = [(0, 0), (10, 0), (10, 10), (0, 10)]
|
|
assert abs(signed_area(coords)) == pytest.approx(100.0)
|
|
|
|
def test_degenerate_line(self):
|
|
"""Two points have zero area."""
|
|
assert signed_area([(0, 0), (10, 10)]) == 0.0
|
|
|
|
def test_single_point(self):
|
|
assert signed_area([(0, 0)]) == 0.0
|
|
|
|
def test_empty(self):
|
|
assert signed_area([]) == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Island detection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestDetectIsland:
|
|
def test_ccw_is_not_island(self):
|
|
coords = [(0, 0), (100, 0), (100, 100), (0, 100)]
|
|
assert detect_island(coords) is False
|
|
|
|
def test_cw_is_island(self):
|
|
coords = [(0, 0), (0, 100), (100, 100), (100, 0)]
|
|
assert detect_island(coords) is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Open path detection + repair
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestIsClosed:
|
|
def test_closed_path(self):
|
|
coords = [(0, 0), (10, 0), (10, 10), (0, 0)]
|
|
assert is_closed(coords) is True
|
|
|
|
def test_open_path(self):
|
|
coords = [(0, 0), (10, 0), (10, 10)]
|
|
assert is_closed(coords) is False
|
|
|
|
def test_nearly_closed(self):
|
|
"""Path within tolerance should count as closed."""
|
|
coords = [(0, 0), (10, 0), (10, 10), (0.5, 0.3)]
|
|
assert is_closed(coords, tolerance=1.0) is True
|
|
|
|
def test_single_point(self):
|
|
assert is_closed([(0, 0)]) is False
|
|
|
|
def test_empty(self):
|
|
assert is_closed([]) is False
|
|
|
|
|
|
class TestClosePath:
|
|
def test_closes_open_path(self):
|
|
coords = [(0, 0), (10, 0), (10, 10)]
|
|
result = close_path(coords)
|
|
assert result[-1] == result[0]
|
|
assert len(result) == 4
|
|
|
|
def test_already_closed(self):
|
|
coords = [(0, 0), (10, 0), (10, 10), (0, 0)]
|
|
result = close_path(coords)
|
|
assert len(result) == 4 # no duplicate added
|
|
|
|
def test_empty(self):
|
|
assert close_path([]) == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Node counting
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestNodeCount:
|
|
def test_counts_nodes(self):
|
|
assert node_count([(0, 0), (1, 1), (2, 2)]) == 3
|
|
|
|
def test_empty(self):
|
|
assert node_count([]) == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Full pipeline integration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestPostprocessSvg:
|
|
def test_returns_result_object(self):
|
|
svg = _make_svg(SQUARE_D)
|
|
result = postprocess_svg(svg)
|
|
assert isinstance(result, PostProcessResult)
|
|
|
|
def test_path_count(self):
|
|
svg = _make_svg(SQUARE_D)
|
|
result = postprocess_svg(svg)
|
|
assert len(result.paths) >= 1
|
|
|
|
def test_node_count_reduction(self):
|
|
"""Simplification should reduce or maintain node count."""
|
|
svg = _make_svg(SQUARE_D)
|
|
result = postprocess_svg(svg, epsilon=0.5)
|
|
for path in result.paths:
|
|
assert path.node_count <= path.original_node_count
|
|
|
|
def test_total_nodes_tracked(self):
|
|
svg = _make_svg(SQUARE_D)
|
|
result = postprocess_svg(svg)
|
|
assert result.total_nodes == sum(p.node_count for p in result.paths)
|
|
|
|
def test_closed_path_detected(self):
|
|
svg = _make_svg(SQUARE_D)
|
|
result = postprocess_svg(svg)
|
|
# Square with Z should be detected as closed
|
|
assert any(p.is_closed for p in result.paths)
|
|
|
|
def test_open_path_detected(self):
|
|
svg = _make_svg(OPEN_D)
|
|
result = postprocess_svg(svg)
|
|
assert result.open_path_count >= 1
|
|
|
|
def test_auto_close(self):
|
|
svg = _make_svg(OPEN_D)
|
|
result = postprocess_svg(svg, auto_close=True)
|
|
# After auto-close, no open paths should remain
|
|
assert result.open_path_count == 0
|
|
|
|
def test_island_detection(self):
|
|
# Combine an outer CCW path with an inner CW path
|
|
combined_d = f"{SQUARE_D} {CW_SQUARE_D}"
|
|
svg = _make_svg(combined_d)
|
|
result = postprocess_svg(svg)
|
|
assert result.island_count >= 1
|
|
|
|
def test_output_svg_is_well_formed(self):
|
|
svg = _make_svg(SQUARE_D)
|
|
result = postprocess_svg(svg)
|
|
import xml.etree.ElementTree as ET
|
|
root = ET.fromstring(result.svg)
|
|
assert root.tag == "{http://www.w3.org/2000/svg}svg"
|
|
|
|
def test_output_svg_has_path(self):
|
|
svg = _make_svg(SQUARE_D)
|
|
result = postprocess_svg(svg)
|
|
import xml.etree.ElementTree as ET
|
|
root = ET.fromstring(result.svg)
|
|
ns = {"svg": "http://www.w3.org/2000/svg"}
|
|
paths = root.findall("svg:path", ns)
|
|
assert len(paths) >= 1
|
|
|
|
def test_epsilon_affects_simplification(self):
|
|
"""Higher epsilon should produce fewer or equal nodes."""
|
|
# Build a complex path
|
|
n = 50
|
|
points = " ".join(
|
|
f"L {50 + 40 * math.cos(2 * math.pi * i / n):.3f},"
|
|
f"{50 + 40 * math.sin(2 * math.pi * i / n):.3f}"
|
|
for i in range(1, n)
|
|
)
|
|
x0 = 50 + 40 * math.cos(0)
|
|
y0 = 50 + 40 * math.sin(0)
|
|
d = f"M {x0:.3f},{y0:.3f} {points} Z"
|
|
svg = _make_svg(d)
|
|
|
|
result_low = postprocess_svg(svg, epsilon=0.1)
|
|
result_high = postprocess_svg(svg, epsilon=10.0)
|
|
assert result_high.total_nodes <= result_low.total_nodes
|
|
|
|
|
|
class TestPostprocessWithVectorizerOutput:
|
|
"""Integration test — feed real vectorizer SVG through post-processing."""
|
|
|
|
def test_potrace_output(self):
|
|
"""Post-process real Potrace output."""
|
|
import numpy as np
|
|
from pipeline.vectorize import potrace_trace
|
|
|
|
img = np.zeros((100, 100), dtype=np.uint8)
|
|
img[20:80, 20:80] = 255
|
|
svg = potrace_trace(img)
|
|
|
|
result = postprocess_svg(svg, epsilon=1.0)
|
|
assert isinstance(result, PostProcessResult)
|
|
assert len(result.paths) >= 1
|
|
assert result.total_nodes > 0
|
|
|
|
def test_vtracer_output(self):
|
|
"""Post-process real VTracer output."""
|
|
import numpy as np
|
|
from pipeline.vectorize import vtracer_trace
|
|
|
|
img = np.zeros((100, 100), dtype=np.uint8)
|
|
img[20:80, 20:80] = 255
|
|
svg = vtracer_trace(img)
|
|
|
|
result = postprocess_svg(svg, epsilon=1.0)
|
|
assert isinstance(result, PostProcessResult)
|
|
assert len(result.paths) >= 1
|
|
assert result.total_nodes > 0
|