# synthetic/mesh_vis.py
from __future__ import annotations
from collections import defaultdict
from typing import Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
from matplotlib.figure import Figure
from mpl_toolkits.mplot3d import proj3d
from mpl_toolkits.mplot3d.art3d import Line3DCollection, Poly3DCollection
from matplotlib.colors import LinearSegmentedColormap
# NOTE: we keep using your existing canonical renderer from circle_bundles.
# If you ever want synthetic to be standalone, we can move fig_to_rgba into synthetic/viz_utils.py.
from circle_bundles.viz.image_utils import fig_to_rgba
__all__ = [
# densities
"make_density_visualizer",
# mesh helpers
"FaceGroup",
"expand_face_groups",
# mesh visualizers
"make_tri_prism_visualizer",
"make_star_pyramid_visualizer",
# animation / export
"fig_to_rgb_array",
"make_rotating_mesh_clip",
]
# ============================
# Densities
# ============================
[docs]
def make_density_visualizer(
*,
grid_size: int = 32,
axis: str = "x",
cmap: str = "inferno",
normalize: bool = False,
figsize: Tuple[float, float] = (3.0, 3.0),
dpi: int = 150,
) -> Callable[[np.ndarray], Figure]:
"""
Returns a visualization function for densities on a grid_size^3 voxel grid.
Parameters
----------
grid_size : int
axis : {'x','y','z'}
Axis to sum over.
cmap : str
normalize : bool
If True, normalize projection to max=1 for display.
figsize, dpi : Figure sizing
Returns
-------
vis_func(density) -> matplotlib Figure
"""
grid_size = int(grid_size)
axis_map = {"x": 0, "y": 1, "z": 2}
if axis not in axis_map:
raise ValueError("axis must be one of: 'x', 'y', 'z'")
ax_idx = axis_map[axis]
def vis_func(density: np.ndarray) -> Figure:
arr = np.asarray(density, dtype=float)
if arr.ndim == 1:
if arr.size != grid_size**3:
raise ValueError(f"Density size mismatch: got {arr.size}, expected {grid_size**3}.")
vol = arr.reshape((grid_size, grid_size, grid_size))
elif arr.shape == (grid_size, grid_size, grid_size):
vol = arr
else:
raise ValueError(
f"Density must be flat length {grid_size**3} or shape {(grid_size,)*3}. Got {arr.shape}."
)
proj = vol.sum(axis=ax_idx)
if normalize:
m = float(np.max(proj))
if m > 0:
proj = proj / m
fig, ax = plt.subplots(figsize=figsize, dpi=dpi, facecolor="none")
ax.set_axis_off()
# transpose so x/y look natural in imshow; origin lower for consistent orientation
ax.imshow(proj.T, origin="lower", cmap=cmap)
return fig
return vis_func
# ============================
# Mesh utilities + visualizers
# ============================
FaceGroup = Union[
Sequence[int], # explicit indices, e.g. [0,5,9]
Tuple[int, int], # range (start, end_exclusive)
]
def expand_face_groups(face_groups: Sequence[FaceGroup]) -> List[List[int]]:
"""Expand face_groups into explicit lists of face indices."""
out: List[List[int]] = []
for g in face_groups:
if (
isinstance(g, tuple)
and len(g) == 2
and isinstance(g[0], (int, np.integer))
and isinstance(g[1], (int, np.integer))
):
start, end = int(g[0]), int(g[1])
if end <= start:
raise ValueError(f"Invalid face range (start,end_excl) = {(start, end)}")
out.append(list(range(start, end)))
else:
out.append([int(x) for x in g])
return out
[docs]
def make_tri_prism_visualizer(
mesh,
face_groups: Optional[Sequence[FaceGroup]] = None,
*,
face_colors_list: Optional[Sequence[str]] = None,
alpha: float = 1.0,
show_edges: bool = True,
edge_color: str = "black",
edge_width: float = 2.5,
elev: float = 0.0,
azim: float = 0.0,
figsize: Tuple[float, float] = (4.0, 4.0),
dpi: int = 150,
depth_sort: bool = True,
) -> Callable[[np.ndarray], Figure]:
"""
Visualize a triangular prism-style mesh with custom face group coloring.
Parameters
----------
mesh : trimesh.Trimesh-like
Must have .vertices and .faces.
face_groups : groups of face indices (explicit lists or (start,end_excl) ranges)
depth_sort : bool
If True, manually sorts triangles back-to-front using projected depth
(helps with alpha blending / occlusion in Matplotlib).
Returns
-------
vis_func(flat_mesh) -> Figure
flat_mesh is expected to be (n_vertices*3,) giving vertex positions.
"""
if face_colors_list is None:
face_colors_list = [
'#FFB4A2', # Rectangular Side 2 – coral blush
'#FAE3B4', # Soft buttery yellow
'#A8DADC', # Rectangular Side 1 – seafoam
'#E3C8F2', # Triangle Face 1 (top)
'#9BB1FF' # Rectangular Side 3 – pastel periwinkle
]
n_vertices = int(np.asarray(mesh.vertices).shape[0])
faces = np.asarray(mesh.faces, dtype=int)
n_faces = int(faces.shape[0])
if face_groups is None:
# default only when topology matches your canonical tri-prism triangulation
if n_faces == 8:
face_groups = [(0, 1), (1, 2), (2, 4), (4, 6), (6, 8)]
else:
raise ValueError(
"face_groups=None only supported for the canonical tri-prism mesh "
"(expected 8 triangle faces). Please pass face_groups explicitly."
)
groups = expand_face_groups(face_groups)
# color per triangle face index
face_color_map: dict[int, Tuple[float, float, float, float]] = {}
for gi, grp in enumerate(groups):
base = tuple(to_rgba(face_colors_list[gi % len(face_colors_list)], alpha))
for f_idx in grp:
face_color_map[int(f_idx)] = base
# Precompute boundary edges once (topology only)
edge_count: dict[Tuple[int, int], int] = defaultdict(int)
for f in faces:
for t in range(3):
e = tuple(sorted((int(f[t]), int(f[(t + 1) % 3]))))
edge_count[e] += 1
boundary_edges = [e for e, c in edge_count.items() if c == 1]
def vis_func(flat_mesh: np.ndarray) -> Figure:
verts = np.asarray(flat_mesh, dtype=float).reshape((n_vertices, 3))
tris = verts[faces] # (n_faces, 3, 3)
fig = plt.figure(figsize=figsize, dpi=dpi, facecolor="none")
ax = fig.add_subplot(111, projection="3d", facecolor="none")
ax.set_axis_off()
# Set view early (projection depends on it)
ax.view_init(elev=float(elev), azim=float(azim))
facecolors = np.array(
[face_color_map.get(i, (0.7, 0.7, 0.7, alpha)) for i in range(len(faces))],
dtype=float,
)
if depth_sort:
M = ax.get_proj()
X = tris[:, :, 0].ravel()
Y = tris[:, :, 1].ravel()
Z = tris[:, :, 2].ravel()
_x2, _y2, z2 = proj3d.proj_transform(X, Y, Z, M)
z2 = np.asarray(z2).reshape(len(tris), 3)
tri_depth = z2.mean(axis=1)
order = np.argsort(tri_depth) # back-to-front
tris = tris[order]
facecolors = facecolors[order]
poly = Poly3DCollection(tris, facecolors=facecolors, edgecolor="none")
poly.set_zsort("average")
ax.add_collection3d(poly)
if show_edges and boundary_edges:
segments = [(verts[i], verts[j]) for i, j in boundary_edges]
lc = Line3DCollection(segments, colors=edge_color, linewidths=float(edge_width))
ax.add_collection3d(lc)
# Equal-ish scaling
max_range = float(np.ptp(verts, axis=0).max() + 1e-12)
mid = verts.mean(axis=0)
lims = [(float(m - max_range / 2), float(m + max_range / 2)) for m in mid]
ax.set_xlim(*lims[0])
ax.set_ylim(*lims[1])
ax.set_zlim(*lims[2])
ax.set_box_aspect([1, 1, 1])
return fig
return vis_func
def _default_pastel_gradient(n: int) -> LinearSegmentedColormap:
"""
Pastel gradient that exactly matches the 5-color palette when n=5
and smoothly interpolates for other n.
"""
base_colors = ['#FFB4A2', '#FAE3B4', '#A8DADC', '#E3C8F2', '#9BB1FF']
return LinearSegmentedColormap.from_list(
f'pastel_sunset_{n}',
base_colors,
N=max(n, len(base_colors)),
)
[docs]
def make_star_pyramid_visualizer(
mesh,
*,
base_color: str = "#94A3B8",
edge_color: str = "gray",
alpha: float = 1.0,
colormap: str | LinearSegmentedColormap | None = None,
figsize: Tuple[float, float] = (4.0, 4.0),
dpi: int = 150,
elev: float = 0.0,
azim: float = 0.0,
) -> Callable[[np.ndarray], Figure]:
"""
Visualizer for a star pyramid mesh with a smooth gradient on side faces.
FIX: side-face ordering is computed ONCE from the template mesh vertices,
so colors stay attached to the same faces under rotation.
"""
faces = np.asarray(mesh.faces, dtype=int)
verts0 = np.asarray(mesh.vertices, dtype=float)
n_vertices = int(verts0.shape[0])
apex_index = n_vertices - 1
# ----------------------------
# Precompute stable side-face ordering from template mesh
# ----------------------------
side_idx = np.array([i for i, f in enumerate(faces) if apex_index in f], dtype=int)
if side_idx.size > 0:
mids = []
for fi in side_idx:
f = faces[int(fi)]
base_verts = [v for v in f if int(v) != apex_index]
p = 0.5 * (verts0[int(base_verts[0])] + verts0[int(base_verts[1])])
# Keep your original convention (yz-plane) BUT evaluated on verts0,
# so it becomes a fixed ordering.
mids.append(np.arctan2(p[2], p[1]))
order = np.argsort(np.asarray(mids))
side_idx_sorted_template = side_idx[order]
else:
side_idx_sorted_template = side_idx
# Choose colormap once (length depends on number of side faces)
if isinstance(colormap, LinearSegmentedColormap):
cmap = colormap
elif isinstance(colormap, str):
cmap = plt.get_cmap(colormap)
else:
cmap = _default_pastel_gradient(len(side_idx_sorted_template))
# Precompute the face color list (again: fixed per face index)
face_colors_template: List[object] = [base_color] * len(faces)
if side_idx_sorted_template.size > 0:
vals = np.linspace(0.0, 1.0, side_idx_sorted_template.size, endpoint=True)
for t, fi in enumerate(side_idx_sorted_template):
face_colors_template[int(fi)] = cmap(float(vals[t]))
def vis_func(flat_mesh: np.ndarray) -> Figure:
verts = np.asarray(flat_mesh, dtype=float).reshape((n_vertices, 3))
tris = verts[faces]
fig = plt.figure(figsize=figsize, dpi=dpi, facecolor="none")
ax = fig.add_subplot(111, projection="3d", facecolor="none")
ax.set_axis_off()
poly = Poly3DCollection(
tris,
facecolors=face_colors_template, # <-- fixed mapping
edgecolor=edge_color,
alpha=float(alpha),
)
ax.add_collection3d(poly)
max_range = float(np.ptp(verts, axis=0).max() + 1e-12)
mid = verts.mean(axis=0)
lims = [(float(m - max_range / 2), float(m + max_range / 2)) for m in mid]
ax.set_xlim(*lims[0])
ax.set_ylim(*lims[1])
ax.set_zlim(*lims[2])
ax.set_box_aspect([1, 1, 1])
ax.view_init(elev=float(elev), azim=float(azim))
return fig
return vis_func
# ============================
# Figure -> array + rotating clip
# ============================
def fig_to_rgb_array(fig: Figure) -> np.ndarray:
"""
Backend-safe conversion of a Matplotlib figure to an RGB uint8 image.
Uses circle_bundles.viz.image_utils.fig_to_rgba under the hood.
"""
rgba = fig_to_rgba(fig) # (H,W,4) uint8
return rgba[..., :3].copy() # (H,W,3) uint8
def make_rotating_mesh_clip(
flat_mesh: np.ndarray,
vis_func: Callable[[np.ndarray], Figure],
out_path: str = "rotation.gif",
*,
n_frames: int = 120,
azim_start: float = 0.0,
azim_end: float = 360.0,
elev: float = 20.0,
close_figs: bool = True,
fps: int = 24,
) -> List[np.ndarray]:
"""
Create a rotating 3D clip by changing the camera view each frame.
Returns
-------
frames : list[np.ndarray]
List of RGB frames (H,W,3) uint8.
"""
flat_mesh = np.asarray(flat_mesh)
frames: List[np.ndarray] = []
azims = np.linspace(float(azim_start), float(azim_end), int(n_frames), endpoint=False)
for a in azims:
fig = vis_func(flat_mesh)
ax3d = None
for ax in fig.axes:
if hasattr(ax, "view_init"):
ax3d = ax
break
if ax3d is None:
raise ValueError("vis_func figure did not contain a 3D axis (no view_init found).")
ax3d.view_init(elev=float(elev), azim=float(a))
frames.append(fig_to_rgb_array(fig))
if close_figs:
plt.close(fig)
if out_path is not None:
lower = out_path.lower()
try:
import imageio.v2 as imageio
except Exception as e: # pragma: no cover
raise ImportError("Writing gifs/mp4 requires imageio (`pip install imageio`).") from e
if lower.endswith(".gif"):
imageio.mimsave(out_path, frames, duration=1.0 / float(fps))
elif lower.endswith(".mp4"):
imageio.mimsave(out_path, frames, fps=int(fps))
else:
raise ValueError("out_path must end with .gif or .mp4")
return frames