from __future__ import annotations
from typing import Callable, Optional, Sequence, Tuple, Union, Iterable
import numpy as np
import matplotlib.pyplot as plt
from .image_utils import render_to_rgba
__all__ = ["lattice_vis",
"scatter_lattice_vis"]
[docs]
def lattice_vis(
data: Sequence,
coords: np.ndarray,
vis_func: Callable[[object], Union[np.ndarray, plt.Figure]],
*,
per_row: int = 7,
per_col: int = 7,
padding: float = 0.05,
figsize: float | Tuple[float, float] = 10,
thumb_px: int = 200,
dpi: int = 200,
save_path: Optional[str] = None,
transparent_border: bool = True,
white_thresh: int = 250,
ax=None,
clear_ax: bool = True,
):
"""
Visualize a dataset by placing rendered thumbnails at 2D coordinates, using a
lattice-based nearest-neighbor selection to pick representative examples.
The input coordinates are first affine-rescaled to the unit square [0, 1]^2.
A regular lattice of target points (``per_row`` × ``per_col``) is created
inside the unit square (inset from the border by ``padding``). For each
lattice target, the nearest *unused* datum is selected (greedy, without
replacement). The selected thumbnails are then placed at their true rescaled
positions, but mapped into a "safe center region" so that each thumbnail
remains fully visible and is not clipped by the axes boundary.
Thumbnails are drawn by creating inset axes positioned in *figure-fraction*
coordinates corresponding to the provided axis' rectangle. This allows
consistent pixel-sized thumbnails (``thumb_px``) regardless of axis data
limits.
Parameters
----------
data
Sequence of N objects to visualize (one per coordinate).
coords
Array of shape (N, 2) giving 2D coordinates for each datum.
vis_func
Callable mapping a single datum to either an image array or a Matplotlib
Figure. The output is converted to an RGBA image via ``render_to_rgba``.
per_row, per_col
Number of lattice targets in the x- and y-directions used for selecting
representative points.
padding
Inset margin (in [0, 0.49]) used when constructing lattice targets in
the rescaled coordinate system.
figsize
Figure size in inches if ``ax is None``. If a single number is given,
a square figure is created.
thumb_px
Desired thumbnail size in pixels (square). This is enforced in figure
pixel space; if too large relative to the figure/axis size, an error is
raised.
dpi
Dots-per-inch used when creating a new figure (``ax is None``) and when
saving to disk.
save_path
Optional path to save the resulting figure.
transparent_border
Passed to ``render_to_rgba``. If True, attempts to make background
whitespace transparent when trimming.
white_thresh
Passed to ``render_to_rgba``. Pixel intensity threshold for detecting
near-white background when trimming/making transparent.
ax
Optional Matplotlib axis to draw into. If provided, thumbnails are placed
within this axis' rectangle in figure coordinates.
clear_ax
If True, clears ``ax`` before drawing and turns the axis off.
Returns
-------
fig : matplotlib.figure.Figure
The figure containing the visualization.
ax : matplotlib.axes.Axes
The base axis used as the placement region (turned off).
Raises
------
ValueError
If ``coords`` is not shape (N, 2), if ``len(data) != N`` or ``N == 0``,
if ``per_row``/``per_col`` are non-positive, or if ``thumb_px`` is too
large to fit within the figure or provided axis region.
Notes
-----
*Selection vs placement:* lattice targets are used only to choose which data
points to show; thumbnails are placed at the selected points' actual rescaled
coordinates (subject to the safe-center remapping that prevents clipping).
The selection procedure is greedy without replacement and may not produce a
globally optimal assignment between lattice targets and points.
"""
coords = np.asarray(coords, dtype=float)
if coords.ndim != 2 or coords.shape[1] != 2:
raise ValueError(f"coords must be an (N,2) array. Got {coords.shape}.")
N = int(coords.shape[0])
if len(data) != N:
raise ValueError(f"data length must match coords rows. Got len(data)={len(data)} vs N={N}.")
if N == 0:
raise ValueError("Empty coords/data.")
per_row = int(per_row)
per_col = int(per_col)
if per_row <= 0 or per_col <= 0:
raise ValueError("per_row and per_col must be positive.")
# Normalize coords to [0,1]^2 (for selection & placement)
min_vals = coords.min(axis=0)
max_vals = coords.max(axis=0)
denom = (max_vals - min_vals)
denom = np.where(np.abs(denom) < 1e-12, 1.0, denom) # avoid division by ~0
scaled_coords = (coords - min_vals) / denom
# Build lattice targets (used only for selection)
pad = float(np.clip(padding, 0.0, 0.49))
lin_x = np.linspace(pad, 1 - pad, per_row)
lin_y = np.linspace(pad, 1 - pad, per_col)
grid_x, grid_y = np.meshgrid(lin_x, lin_y, indexing="xy")
lattice_pts = np.column_stack([grid_x.ravel(), grid_y.ravel()])
# Pick nearest data point to each lattice target, without reuse
selected_indices: list[int] = []
used: set[int] = set()
for lp in lattice_pts:
d = np.linalg.norm(scaled_coords - lp[None, :], axis=1)
for idx in np.argsort(d):
idx = int(idx)
if idx not in used:
selected_indices.append(idx)
used.add(idx)
break
selected_coords = scaled_coords[selected_indices]
# --- Figure / axis handling ---
created_fig = False
if ax is None:
if isinstance(figsize, (int, float)):
figsize = (float(figsize), float(figsize))
fig = plt.figure(figsize=figsize, dpi=int(dpi))
ax = fig.add_subplot(111)
created_fig = True
else:
fig = ax.figure
if clear_ax:
ax.cla()
ax.axis("off")
# We need up-to-date positions and pixel sizes
fig.canvas.draw()
renderer = fig.canvas.get_renderer()
# Axis bounding box in figure-fraction coordinates
ax_bbox_fig = ax.get_position() # Bbox in [0,1] figure fraction
ax_left, ax_bottom, ax_w, ax_h = (
float(ax_bbox_fig.x0),
float(ax_bbox_fig.y0),
float(ax_bbox_fig.width),
float(ax_bbox_fig.height),
)
# Compute figure pixel dimensions from the actual figure
fig_w_px = float(fig.bbox.width)
fig_h_px = float(fig.bbox.height)
# Convert desired thumbnail pixel size to figure fractions,
# then to fractions of the axis rectangle.
width_fig_frac = float(thumb_px) / fig_w_px
height_fig_frac = float(thumb_px) / fig_h_px
if width_fig_frac >= 1 or height_fig_frac >= 1:
raise ValueError(
"thumb_px too large relative to figure pixel size (thumbnail doesn't fit). "
f"width_fig_frac={width_fig_frac:.3f}, height_fig_frac={height_fig_frac:.3f}."
)
width_ax_frac = width_fig_frac / ax_w
height_ax_frac = height_fig_frac / ax_h
if width_ax_frac >= 1 or height_ax_frac >= 1:
raise ValueError(
"thumb_px too large relative to the provided axis size. "
f"width_ax_frac={width_ax_frac:.3f}, height_ax_frac={height_ax_frac:.3f}."
)
# Safe center region inside the axis (in axis-fraction coordinates)
x0, x1 = width_ax_frac / 2, 1 - width_ax_frac / 2
y0, y1 = height_ax_frac / 2, 1 - height_ax_frac / 2
# Helper: convert an (u,v) in axis-fraction coordinates to figure fraction
def _axfrac_to_figfrac(u: float, v: float) -> tuple[float, float]:
return (ax_left + u * ax_w, ax_bottom + v * ax_h)
# Place thumbnails (as inset axes in figure fraction coordinates)
for idx, (cx, cy) in zip(selected_indices, selected_coords):
u = x0 + float(cx) * (x1 - x0) # axis-fraction x
v = y0 + float(cy) * (y1 - y0) # axis-fraction y
left_fig, bottom_fig = _axfrac_to_figfrac(u, v)
left_fig -= width_fig_frac / 2
bottom_fig -= height_fig_frac / 2
ax_in = fig.add_axes([left_fig, bottom_fig, width_fig_frac, height_fig_frac])
rendered = vis_func(data[idx])
img = render_to_rgba(
rendered,
transparent_border=bool(transparent_border),
trim=True,
white_thresh=int(white_thresh),
)
ax_in.imshow(img, interpolation="nearest")
ax_in.set_facecolor("none")
ax_in.axis("off")
if save_path is not None:
# keep the whole figure; bbox_inches tight is usually fine, but can clip
# inset axes depending on backend. If you see clipping, remove bbox_inches.
fig.savefig(save_path, dpi=int(dpi), bbox_inches="tight")
return fig, ax
[docs]
def scatter_lattice_vis(
data: Sequence,
coords: np.ndarray,
vis_func: Callable[[object], Union[np.ndarray, plt.Figure]],
*,
# selection
selected_indices: Optional[Sequence[int]] = None,
per_row: int = 7,
per_col: int = 7,
padding: float = 0.05,
# appearance
point_size: float = 3.0,
point_alpha: float = 0.1,
highlight_size: float = 50.0,
highlight_lw: float = 1.5,
thumb_px: int = 30,
thumb_offset_px: int = 5, # gap between dot and thumbnail (pixels)
leader_line: bool = False,
leader_lw: float = 0.8,
# figure
figsize: float | Tuple[float, float] = 10,
dpi: int = 200,
save_path: Optional[str] = None,
transparent_border: bool = True,
white_thresh: int = 250,
ax=None,
clear_ax: bool = True,
):
"""
Scatter all points, highlight selected points, and draw each selected point's
thumbnail just to the left of its dot.
Selection:
- If selected_indices is provided, use it.
- Otherwise, select representatives via the same lattice-nearest-neighbor
scheme used by lattice_vis (per_row x per_col, no reuse).
Placement:
- All dots are plotted at their scaled coordinates in [0,1]^2.
- Each selected thumbnail is placed in *pixel space* using Matplotlib's
offsetbox machinery, so the thumbnail is consistently sized and sits
'thumb_offset_px' pixels left of the dot.
Returns
-------
fig, ax, selected_indices
"""
coords = np.asarray(coords, dtype=float)
if coords.ndim != 2 or coords.shape[1] != 2:
raise ValueError(f"coords must be an (N,2) array. Got {coords.shape}.")
N = int(coords.shape[0])
if len(data) != N:
raise ValueError(f"data length must match coords rows. Got len(data)={len(data)} vs N={N}.")
if N == 0:
raise ValueError("Empty coords/data.")
# Normalize coords to [0,1]^2 for plotting (and selection)
min_vals = coords.min(axis=0)
max_vals = coords.max(axis=0)
denom = (max_vals - min_vals)
denom = np.where(np.abs(denom) < 1e-12, 1.0, denom)
scaled = (coords - min_vals) / denom
# --- choose indices ---
if selected_indices is None:
per_row = int(per_row)
per_col = int(per_col)
if per_row <= 0 or per_col <= 0:
raise ValueError("per_row and per_col must be positive.")
pad = float(np.clip(padding, 0.0, 0.49))
lin_x = np.linspace(pad, 1 - pad, per_row)
lin_y = np.linspace(pad, 1 - pad, per_col)
gx, gy = np.meshgrid(lin_x, lin_y, indexing="xy")
lattice_pts = np.column_stack([gx.ravel(), gy.ravel()])
chosen: list[int] = []
used: set[int] = set()
for lp in lattice_pts:
d = np.linalg.norm(scaled - lp[None, :], axis=1)
for idx in np.argsort(d):
idx = int(idx)
if idx not in used:
chosen.append(idx)
used.add(idx)
break
selected_indices = chosen
else:
selected_indices = [int(i) for i in selected_indices]
# --- figure/axis ---
created_fig = False
if ax is None:
if isinstance(figsize, (int, float)):
figsize = (float(figsize), float(figsize))
fig, ax = plt.subplots(figsize=figsize, dpi=int(dpi))
created_fig = True
else:
fig = ax.figure
if clear_ax:
ax.cla()
ax.set_aspect("equal", adjustable="box")
ax.set_xlim(-0.02, 1.02)
ax.set_ylim(-0.02, 1.02)
# plot all points
ax.scatter(scaled[:, 0], scaled[:, 1], s=float(point_size), alpha=float(point_alpha))
# highlight selected points with red rings
sel = np.asarray(selected_indices, dtype=int)
ax.scatter(
scaled[sel, 0],
scaled[sel, 1],
s=float(highlight_size),
facecolors="none",
edgecolors="red",
linewidths=float(highlight_lw),
zorder=5,
)
# --- place thumbnails left of each selected point ---
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
# Convert pixel thumbnail size to zoom based on image array shape.
# We'll render to RGBA and then set zoom to hit approx thumb_px.
for idx in selected_indices:
rendered = vis_func(data[idx])
img = render_to_rgba(
rendered,
transparent_border=bool(transparent_border),
trim=True,
white_thresh=int(white_thresh),
)
# zoom so that max dimension ~ thumb_px in display pixels
h, w = img.shape[0], img.shape[1]
max_hw = max(h, w)
zoom = float(thumb_px) / float(max_hw)
oi = OffsetImage(img, zoom=zoom)
# place the image at an offset in pixels relative to the data point
# left by (thumb_px/2 + thumb_offset_px), no vertical offset
xy = (float(scaled[idx, 0]), float(scaled[idx, 1]))
dx = -(thumb_px / 2 + float(thumb_offset_px))
dy = 0.0
ab = AnnotationBbox(
oi,
xy=xy,
xybox=(dx, dy),
xycoords="data",
boxcoords="offset points",
frameon=False,
pad=0.0,
arrowprops=(
dict(arrowstyle="-", lw=float(leader_lw), alpha=0.8, color="black")
if leader_line
else None
),
zorder=10,
)
ax.add_artist(ab)
#ax.set_xlabel("Base Angle")
#ax.set_ylabel("Fiber Angle")
ax.set_xticks([])
ax.set_yticks([])
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
ax.set_frame_on(False)
if save_path is not None:
fig.savefig(save_path, dpi=int(dpi), bbox_inches="tight")
if created_fig:
plt.show()
return fig, ax, selected_indices