Source code for circle_bundles.viz.thumb_grids

# circle_bundles/viz/thumb_grids.py
from __future__ import annotations

from typing import Optional, Sequence, Callable, Any, Tuple, Union

import numpy as np
import matplotlib.pyplot as plt

from .image_utils import render_to_rgba


[docs] def show_data_vis( data: Sequence[Any], vis_func: Callable[[Any], Union[np.ndarray, plt.Figure]], *, angles: Optional[Sequence[float]] = None, max_samples: int = 100, n_cols: int = 10, seed: Optional[int] = None, save_path: Optional[str] = None, label_func: Optional[Union[Callable[[Any], Any], Sequence[Any]]] = None, label_position: str = "below", sampling_method: str = "angle", font_size: int = 15, transparent_border: bool = True, white_thresh: int = 250, wspace: float = 0.25, hspace: float = 0.10, figsize_per_cell: float = 2.0, dpi: int = 150, pad_frac: float = 0.1, show: bool = True, ) -> Tuple[plt.Figure, np.ndarray]: """ Show a thumbnail grid of rendered data. IMPORTANT: - `data` is treated as a Sequence. We do NOT coerce it to np.asarray(data), because that can break for object-like payloads (meshes, images with varying shape, etc.) """ n_total = len(data) if n_total == 0: raise ValueError("show_data_vis received empty data.") rng = np.random.default_rng(seed) n_take = int(min(max_samples, n_total)) n_cols = max(1, int(n_cols)) # ---- choose indices ---- if sampling_method == "angle" and angles is not None: ang = np.asarray(list(angles), dtype=float).reshape(-1) if ang.shape[0] != n_total: raise ValueError(f"angles must have length {n_total}, got {ang.shape[0]}.") angle_min, angle_max = float(np.min(ang)), float(np.max(ang)) centers = np.linspace(angle_min, angle_max, n_take) used: set[int] = set() selected: list[int] = [] for c in centers: idx = int(np.argmin(np.abs(ang - c))) if idx not in used: selected.append(idx) used.add(idx) if len(selected) < n_take: remaining = np.array([i for i in range(n_total) if i not in used], dtype=int) if remaining.size > 0: filler = rng.choice(remaining, size=(n_take - len(selected)), replace=False) selected.extend([int(x) for x in filler]) selected = sorted(selected, key=lambda i: float(ang[i])) elif sampling_method in ("random", None) or (sampling_method == "angle" and angles is None): selected = [int(x) for x in rng.choice(n_total, size=n_take, replace=False)] elif sampling_method == "first": selected = list(range(n_take)) else: raise ValueError(f"Unknown sampling_method={sampling_method!r}") # ---- labels ---- labels_sel = None if isinstance(label_func, (list, tuple, np.ndarray)): if len(label_func) != n_total: raise ValueError("If label_func is a sequence, it must have length len(data).") labels_sel = [label_func[i] for i in selected] # ---- layout via GridSpec ---- n = len(selected) n_rows = int(np.ceil(n / n_cols)) fig_w = float(figsize_per_cell) * n_cols fig_h = float(figsize_per_cell) * n_rows fig = plt.figure(figsize=(fig_w, fig_h), dpi=int(dpi)) gs = fig.add_gridspec( n_rows, n_cols, wspace=float(wspace), hspace=float(hspace), left=0.02, right=0.98, bottom=0.02, top=0.98, ) axes: list[plt.Axes] = [] for cell in range(n_rows * n_cols): r, c = divmod(cell, n_cols) ax = fig.add_subplot(gs[r, c]) ax.axis("off") axes.append(ax) if cell >= n: continue idx = selected[cell] rendered = vis_func(data[idx]) img_u8 = render_to_rgba( rendered, transparent_border=bool(transparent_border), trim=True, white_thresh=int(white_thresh), ) ax.imshow(img_u8, interpolation="nearest") if pad_frac and float(pad_frac) > 0: pad = float(pad_frac) ax.set_xlim(-pad * img_u8.shape[1], img_u8.shape[1] * (1 + pad)) ax.set_ylim(img_u8.shape[0] * (1 + pad), -pad * img_u8.shape[0]) label = None if callable(label_func): label = label_func(data[idx]) elif labels_sel is not None: label = labels_sel[cell] if label is not None: y = -0.10 if label_position == "below" else 1.05 ax.text( 0.5, y, str(label), transform=ax.transAxes, ha="center", va="top" if label_position == "below" else "bottom", fontsize=int(font_size), ) if save_path is not None: fig.savefig(save_path, bbox_inches="tight", dpi=300) if show: plt.show() return fig, np.array(axes, dtype=object)