Source code for circle_bundles.viz.fiber_vis

from __future__ import annotations

from typing import Callable, Optional, Sequence

import numpy as np

from .image_utils import render_to_rgba

__all__ = ["fiber_vis"]


[docs] def fiber_vis( data: np.ndarray, vis_func: Optional[Callable] = None, *, vis_data: Optional[np.ndarray] = None, selected_indices: Optional[Sequence[int]] = None, max_images: int = 12, zoom: float = 0.2, figsize=(10, 8), dpi: int = 150, save_path: Optional[str] = None, random_state: Optional[int] = None, ax=None, clear_ax: bool = True, show: bool = True, scatter_alpha: float = 0.15, scatter_s: float = 10.0, ): """ Visualize up to `max_images` items from `data` by embedding them to 3D (PCA). - If `vis_func` is provided: overlay thumbnails rendered by `vis_func`. - If `vis_func` is None: just show the embedded points. Parameters ---------- data : (N, d) array vis_func : optional callable(datum) -> (Figure | ndarray image) Something render_to_rgba can handle after vis_func returns. If None, no thumbnails are rendered; points only. vis_data : optional data source for vis_func (same length N) If provided, thumbnails are rendered from vis_data[idx] while embedding uses data[idx]. selected_indices : optional explicit indices to visualize max_images : cap on number of items/points shown (and thumbnails if vis_func given) random_state : RNG seed used when selected_indices is None Subplot usage ------------- If `ax` is provided, it must be a 3D axis (projection='3d'). In that case, this function draws into `ax` and will not create a new figure. """ import matplotlib.pyplot as plt from matplotlib.offsetbox import OffsetImage, AnnotationBbox from mpl_toolkits.mplot3d.proj3d import proj_transform data = np.asarray(data) if data.ndim != 2: raise ValueError(f"data must be 2D (N,d). Got shape {data.shape}.") N = int(data.shape[0]) if N == 0: raise ValueError("Empty data array.") if vis_data is not None: vis_data = np.asarray(vis_data) if vis_data.shape[0] != N: raise ValueError( f"vis_data must have same length as data. Got {vis_data.shape[0]} vs {N}." ) max_images = int(max_images) if max_images <= 0: raise ValueError("max_images must be positive.") # ---- choose indices ---- if selected_indices is None: rng = np.random.default_rng(random_state) k = min(max_images, N) selected_indices = np.sort( rng.choice(np.arange(N, dtype=int), size=k, replace=False) ).tolist() else: selected_indices = [int(i) for i in selected_indices][: min(max_images, len(selected_indices))] selected_indices = [i for i in selected_indices if 0 <= i < N] if len(selected_indices) == 0: raise ValueError("selected_indices produced no valid indices in range.") selected_indices = list(selected_indices) selected_data = data[np.asarray(selected_indices, dtype=int)] # ---- embed (PCA) ---- # Lazy import so viz submodule doesn't require sklearn at import time. from sklearn.decomposition import PCA if selected_data.shape[0] == 1: embedded = np.zeros((1, 3), dtype=float) else: n_comp = 3 if selected_data.shape[1] >= 3 else min(3, selected_data.shape[1]) pca = PCA(n_components=n_comp) emb = pca.fit_transform(selected_data) if emb.shape[1] < 3: embedded = np.pad(emb, ((0, 0), (0, 3 - emb.shape[1])), mode="constant") else: embedded = emb[:, :3] # ---- figure / axes ---- created_fig = False if ax is None: fig = plt.figure(figsize=figsize, dpi=int(dpi)) ax = fig.add_subplot(111, projection="3d") created_fig = True else: fig = ax.figure ax_is_3d = getattr(ax, "name", "") == "3d" if not ax_is_3d: raise ValueError("Provided ax is not a 3D axis. Create it with projection='3d'.") if clear_ax: ax.cla() # ---- always show points ---- ax.scatter( embedded[:, 0], embedded[:, 1], embedded[:, 2], alpha=float(scatter_alpha), s=float(scatter_s), ) # If no vis_func supplied, we're done (aside from aesthetics/save/show) if vis_func is not None: # Ensure transforms/projection are ready fig.canvas.draw() # ---- overlays ---- for i, (x, y, z) in enumerate(embedded): idx = int(selected_indices[i]) try: datum = vis_data[idx] if vis_data is not None else selected_data[i] rendered = vis_func(datum) img = render_to_rgba(rendered, transparent_border=True, trim=True) x2, y2, _ = proj_transform(float(x), float(y), float(z), ax.get_proj()) ab = AnnotationBbox( OffsetImage(img, zoom=float(zoom)), (x2, y2), xycoords="data", frameon=False, ) ax.add_artist(ab) except Exception as e: print(f"Error rendering image at index {idx}: {type(e).__name__}: {e}") # ---- aesthetics ---- ax.grid(True) try: ax.set_box_aspect([1, 1, 1]) except Exception: pass # Only do layout/show if we created the figure if created_fig: plt.tight_layout() if save_path is not None: fig.savefig(save_path, dpi=300, bbox_inches="tight") if show: plt.show() else: # In subplot mode, still honor save_path (saves the full figure) if save_path is not None: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig, ax