Source code for circle_bundles.optical_flow.patch_viz

# optical_flow/patch_viz.py
from __future__ import annotations

from typing import Callable, Literal, Tuple

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpl_patches

PatchKind = Literal["intensity", "flow"]

__all__ = [
    "PatchKind",
    "make_patch_visualizer",
]


def _infer_patch_kind_and_n(length: int) -> Tuple[PatchKind, int]:
    """Infer patch kind and n from flat vector length."""
    length = int(length)

    n = int(np.sqrt(length))
    if n * n == length:
        return "intensity", n

    n2 = int(np.sqrt(length / 2))
    if 2 * n2 * n2 == length:
        return "flow", n2

    raise ValueError("Invalid patch length: must be n^2 or 2*n^2 for some integer n.")


[docs] def make_patch_visualizer( *, # figure formatting fig_size: float = 3.0, dpi: int = 120, # intensity settings intensity_cmap: str = "gray", intensity_normalize: Literal["minmax", "none"] = "minmax", show_intensity_grid: bool = True, intensity_grid_lw: float = 0.8, # flow settings flipud: bool = True, flow_normalize: Literal["maxmag", "none"] = "maxmag", quiver_width: float = 0.02, headwidth: float = 5, headlength: float = 5, grid_lw: float = 7, border_lw_mult: float = 2.0, ) -> Callable[[np.ndarray], plt.Figure]: """ Canonical patch visualizer. - intensity (n^2): uses imshow, optional grid overlay - flow (2*n^2): quiver arrows on grid, optional vertical flip for display convention Returns ------- vis_func(patch_vector) -> matplotlib Figure """ def make_fig_ax(n: int): fig, ax = plt.subplots(figsize=(fig_size, fig_size), dpi=dpi) fig.patch.set_alpha(0.0) # transparent around axes ax.set_facecolor("white") # solid inside ax.set_xlim(0, n) ax.set_ylim(0, n) ax.set_aspect("equal") fig.subplots_adjust(left=0, right=1, bottom=0, top=1) ax.set_xticks([]) ax.set_yticks([]) for spine in ax.spines.values(): spine.set_visible(False) return fig, ax def vis_func(patch_vector: np.ndarray) -> plt.Figure: v = np.asarray(patch_vector, dtype=float).ravel() kind, n = _infer_patch_kind_and_n(v.size) fig, ax = make_fig_ax(n) if kind == "intensity": img = v.reshape((n, n), order="F") if intensity_normalize == "minmax": a = float(np.min(img)) b = float(np.max(img)) img = (img - a) / (b - a) if b > a else np.zeros_like(img) ax.imshow( img[::-1, :], # row 0 appears at bottom cmap=intensity_cmap, origin="lower", extent=(0, n, 0, n), interpolation="nearest", ) if show_intensity_grid: ax.set_xticks(np.arange(n + 1)) ax.set_yticks(np.arange(n + 1)) ax.grid(True, which="both", linewidth=float(intensity_grid_lw), color="black") ax.tick_params(axis="both", which="both", length=0) ax.set_xticklabels([]) ax.set_yticklabels([]) rect = mpl_patches.Rectangle( (0, 0), n, n, linewidth=float(border_lw_mult) * float(intensity_grid_lw), edgecolor="black", facecolor="none", ) ax.add_patch(rect) else: flow = v.reshape((n, n, 2), order="F") if flipud: flow = flow[::-1, :, :] if flow_normalize == "maxmag": mags = np.linalg.norm(flow, axis=2) m = float(np.max(mags)) if m > 0: flow = flow / m x, y = np.meshgrid(np.arange(n) + 0.5, np.arange(n) + 0.5) x = x - flow[:, :, 0] / 2.0 y = y - flow[:, :, 1] / 2.0 ax.quiver( x, y, flow[:, :, 0], flow[:, :, 1], scale=n, width=float(quiver_width), headwidth=float(headwidth), headlength=float(headlength), ) ax.set_xticks(np.arange(n + 1)) ax.set_yticks(np.arange(n + 1)) ax.grid(True, which="both", linewidth=float(grid_lw), color="black") ax.tick_params(axis="both", which="both", length=0) ax.set_xticklabels([]) ax.set_yticklabels([]) rect = mpl_patches.Rectangle( (0, 0), n, n, linewidth=float(border_lw_mult) * float(grid_lw), edgecolor="black", facecolor="none", ) ax.add_patch(rect) return fig return vis_func