Source code for circle_bundles.reduction.frame_reduction

# circle_bundles/frame_reduction.py
"""
O(2)-equivariant frame reduction for Stiefel frames Y ∈ V(2, D).

This module provides optional dimensionality reduction for collections of 2-frames
(`(D,2)` matrices with orthonormal columns), used in the bundle pipeline when the
ambient feature dimension D is large.

Design goals
------------
- Keep bundle/gluing logic out of this file (this is a preprocessing utility).
- Preserve the right O(2)-action: Y ↦ YQ for Q ∈ O(2) (i.e., equivariant reduction).
- Provide two reducers:
    * subspace_pca : always available, fast, deterministic baseline
    * psc          : optional dependency (HarlinLee/PSC), higher-quality but heavier

Conventions
-----------
- Frames are represented as real arrays of shape (D,2).
- A "projection" Π(Y) refers to projecting Y into a learned d-dimensional subspace:
      Π(Y) = B B^T Y   or   Π(Y) = α α^T Y
  with B/α having orthonormal columns in R^D.
- After projecting, we re-orthonormalize columns via the polar factor to return to V(2,d).

Where it is used
----------------
This is intended to be called either:
- "pre_classifying": reduce raw Stiefel frames before building the classifying map
  (paper-faithful), or
- "post_stiefel": reduce after a Stiefel projection step (legacy / experimental).

Optional dependency
-------------------
The PSC reducer imports PSC lazily inside PSC functions so that PSC is not required
for the rest of the package.
"""


from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple, Literal

import numpy as np

from ..utils.status_utils import _status, _status_clear  # shared status helpers

from ..trivializations.gauge_canon import GaugeCanonConfig, compute_samplewise_gauge_from_o2_cocycle, apply_gauge_to_frames, GaugeCanonReport  # noqa


ReduceMethod = Literal["none", "subspace_pca", "psc"]


ReduceStage = Literal["pre_classifying", "post_stiefel"]


__all__ = [
    "FrameReducerConfig",
    "FrameReductionReport",
]


def canonicalize_frames_before_reduction(
    *,
    Phi_true: np.ndarray,
    Omega_true: dict,
    U: np.ndarray,
    pou: np.ndarray,
    edges: Optional[Sequence[Tuple[int,int]]] = None,
    cfg: Optional[GaugeCanonConfig] = None,
) -> Tuple[np.ndarray, GaugeCanonReport, np.ndarray, dict]:
    """
    Convenience: compute gauge from Omega_true, apply it to Phi_true.

    Returns:
      Phi_star: (n_sets,n_samples,D,2)
      report: GaugeCanonReport
      gauge: (n_sets,n_samples,2,2)
      Omega_star: dict edges -> (n_samples,2,2)
    """
    if cfg is None:
        cfg = GaugeCanonConfig(enabled=True)

    gauge, Omega_star, rep = compute_samplewise_gauge_from_o2_cocycle(
        Omega_true, U=U, pou=pou, edges=edges, cfg=cfg
    )
    Phi_star = apply_gauge_to_frames(Phi_true, gauge, U=U)
    return Phi_star, rep, gauge, Omega_star


# ============================================================
# Reports / config
# ============================================================

[docs] @dataclass class FrameReducerConfig: """ Configuration for O(2)-equivariant dimensionality reduction of Stiefel frames. Attributes ---------- method : ReduceMethod - "none" : no reduction - "subspace_pca": always-available, O(2)-equivariant baseline - "psc" : PSC-based reducer (optional dependency) stage : ReduceStage - "pre_classifying": reduce the raw Stiefel point cloud before building the classifying map - "post_stiefel" : reduce after a Stiefel projection step (legacy / experimental) d : int Target ambient dimension after reduction (must satisfy 2 <= d <= D at runtime). A value of 0 is treated as "unset" by callers (i.e., no reduction unless set). max_frames : Optional[int] Optional subsampling cap for fitting the reducer (speed control). rng_seed : int RNG seed used for subsampling. psc_verbosity : int Verbosity forwarded to PSC's optimization routine (if method="psc"). """ method: ReduceMethod = "none" stage: ReduceStage = "pre_classifying" d: int = 0 max_frames: Optional[int] = None rng_seed: int = 0 psc_verbosity: int = 0
@dataclass class FrameReductionReport: """ Summary statistics for a fitted frame reduction map Π : V(2,D) -> V(2,d). Attributes ---------- method : ReduceMethod Reduction method used. D_in : int Original ambient dimension D. d_out : int Reduced ambient dimension d. sup_proj_err : float Supremum (over processed frames) of the ambient projection error: sup ||Y - Π(Y)||_F mean_proj_err : float Mean (over processed frames) of the ambient projection error: mean ||Y - Π(Y)||_F """ method: ReduceMethod D_in: int d_out: int sup_proj_err: float # sup ||Y - Π(Y)||_F (ε_red) mean_proj_err: float # mean ||Y - Π(Y)||_F (\bar{ε}_red) def to_text(self, *, decimals: int = 3) -> str: """ Render a human-readable multi-line summary (for notebooks/logging). Parameters ---------- decimals : int Number of decimal places used for numeric fields. Returns ------- str A formatted text block describing the fitted reduction and its errors. """ r = int(decimals) return ( "\n" + "=" * 12 + " Frame Reduction " + "=" * 12 + "\n\n" + f"Method: {self.method}\n" + f"Ambient dim: D={self.D_in} → d={self.d_out}\n" + f"Sup projection error (||Y - Π(Y)||_F): {self.sup_proj_err:.{r}f}\n" + f"Mean projection error (||Y - Π(Y)||_F): {self.mean_proj_err:.{r}f}\n" + "\n" + "=" * 40 + "\n" ) # ============================================================ # Progress helpers # ============================================================ def _maybe_status(progress: bool, msg: str) -> None: if progress: _status(msg) def _maybe_clear(progress: bool) -> None: if progress: _status_clear() def _auto_every(n: int, *, target_updates: int = 25, min_every: int = 200) -> int: """ Choose an update frequency that doesn't spam clear_output in notebooks. """ n = int(n) if n <= 0: return 1 every = max(int(min_every), int(np.ceil(n / max(1, int(target_updates))))) return max(1, every) # ============================================================ # Helpers # ============================================================ def _polar_orthonormalize_2cols(Yd: np.ndarray, *, eps: float = 1e-12) -> np.ndarray: """Nearest Stiefel frame to Yd (d,2) via polar factor.""" Yd = np.asarray(Yd, dtype=float) if Yd.ndim != 2 or Yd.shape[1] != 2: raise ValueError("Yd must have shape (d,2).") AtA = Yd.T @ Yd eigvals, eigvecs = np.linalg.eigh(AtA) eigvals = np.clip(eigvals, eps, None) inv_sqrt = eigvecs @ np.diag(1.0 / np.sqrt(eigvals)) @ eigvecs.T return Yd @ inv_sqrt def _collect_frames_as_list( Phi_true: np.ndarray, U: Optional[np.ndarray] = None, *, max_frames: Optional[int] = None, rng_seed: int = 0, ) -> List[np.ndarray]: """ Collect frames as a Python list of (D,2) arrays. Accepts either: - Phi_true with shape (n_sets,n_samples,D,2) plus U (n_sets,n_samples), OR - Phi_true with shape (m,D,2) (packed dataset), in which case U is ignored. """ Phi_true = np.asarray(Phi_true, dtype=float) # Case A: packed dataset (m,D,2) if Phi_true.ndim == 3: m, D, two = Phi_true.shape if two != 2: raise ValueError("Packed frames must have shape (m,D,2).") frames = [Phi_true[i] for i in range(m)] # Case B: grid of frames (n_sets,n_samples,D,2) elif Phi_true.ndim == 4: if U is None: raise ValueError("U must be provided when Phi_true has shape (n_sets,n_samples,D,2).") U = np.asarray(U, dtype=bool) n_sets, n_samples, D, two = Phi_true.shape if two != 2: raise ValueError("Expected frames in V(2,D): last dim must be 2.") if U.shape != (n_sets, n_samples): raise ValueError("U shape mismatch with Phi_true.") Js, Ss = np.where(U) frames = [Phi_true[int(j), int(s)] for (j, s) in zip(Js, Ss)] else: raise ValueError(f"Phi_true must have ndim 3 or 4. Got shape {Phi_true.shape}.") if max_frames is not None and len(frames) > int(max_frames): rng = np.random.default_rng(int(rng_seed)) idx = rng.choice(len(frames), size=int(max_frames), replace=False) frames = [frames[int(i)] for i in idx] return frames # ============================================================ # Subspace PCA (equivariant baseline) # ============================================================ def subspace_pca_fit(frames: Sequence[np.ndarray], *, d: int) -> Tuple[np.ndarray, np.ndarray]: """ Fit O(2)-equivariant subspace PCA on frames in V(2,D). Mean projector M = avg(Y Y^T) is invariant under Y -> YQ for Q∈O(2). Take top-d eigenspace of M. Returns: B: (D,d) orthonormal basis evals: (d,) top eigenvalues """ if len(frames) == 0: raise ValueError("Need at least one frame to fit a reducer.") Y0 = np.asarray(frames[0], dtype=float) if Y0.ndim != 2 or Y0.shape[1] != 2: raise ValueError("Each frame must have shape (D,2).") D = int(Y0.shape[0]) if not (2 <= d <= D): raise ValueError(f"Need 2 <= d <= D. Got d={d}, D={D}.") M = np.zeros((D, D), dtype=float) for Y in frames: Y = np.asarray(Y, dtype=float) if Y.shape != (D, 2): raise ValueError("All frames must share the same ambient dimension D and have shape (D,2).") M += Y @ Y.T M /= float(len(frames)) vals, vecs = np.linalg.eigh(0.5 * (M + M.T)) order = np.argsort(vals)[::-1] vals = vals[order] vecs = vecs[:, order] B = vecs[:, :d] evals = vals[:d] return B, evals def subspace_pca_transform_frame(B: np.ndarray, Y: np.ndarray) -> np.ndarray: """Y (D,2) -> (d,2): B^T Y, then polar re-orthonormalize.""" B = np.asarray(B, dtype=float) Y = np.asarray(Y, dtype=float) if Y.ndim != 2 or Y.shape[1] != 2: raise ValueError("Y must have shape (D,2).") if B.ndim != 2: raise ValueError("B must be (D,d).") if B.shape[0] != Y.shape[0]: raise ValueError("B and Y have incompatible ambient dimensions.") return _polar_orthonormalize_2cols(B.T @ Y) def reduce_frames_subspace_pca( *, Phi_true: np.ndarray, U: np.ndarray, d: int, max_frames: Optional[int] = None, rng_seed: int = 0, progress: bool = False, ) -> Tuple[np.ndarray, FrameReductionReport, np.ndarray]: """ Reduce all frames Phi_true[j,s] (D,2) -> Phi_red[j,s] (d,2) using subspace PCA. Returns: Phi_red: (n_sets, n_samples, d, 2) report B: (D,d) """ Phi_true = np.asarray(Phi_true, dtype=float) U = np.asarray(U, dtype=bool) if Phi_true.ndim != 4: raise ValueError(f"Expected Phi_true with shape (n_sets,n_samples,D,2). Got {Phi_true.shape}.") n_sets, n_samples, D, two = Phi_true.shape if two != 2: raise ValueError("Phi_true last dim must be 2.") if U.shape != (n_sets, n_samples): raise ValueError("U shape mismatch with Phi_true.") if not (2 <= int(d) <= int(D)): raise ValueError(f"Need 2 <= d <= D. Got d={d}, D={D}.") _maybe_status(progress, f"[frame reduction | subspace_pca] collecting frames (max_frames={max_frames})...") frames_fit = _collect_frames_as_list(Phi_true, U, max_frames=max_frames, rng_seed=rng_seed) if len(frames_fit) == 0: _maybe_clear(progress) raise ValueError("No frames were available to reduce (U has no True entries?).") _maybe_status(progress, f"[frame reduction | subspace_pca] fitting basis B (D={D} → d={d}) on {len(frames_fit)} frames...") B, _ = subspace_pca_fit(frames_fit, d=int(d)) Phi_red = np.zeros((n_sets, n_samples, int(d), 2), dtype=float) recon_errs: List[float] = [] total = int(np.count_nonzero(U)) every = _auto_every(total, target_updates=25, min_every=200) done = 0 _maybe_status(progress, f"[frame reduction | subspace_pca] transforming frames: 0/{total}") for j in range(n_sets): for s in range(n_samples): if not U[j, s]: continue Y = Phi_true[j, s] Yproj = B @ (B.T @ Y) # Π(Y) in ambient space recon_errs.append(float(np.linalg.norm(Y - Yproj, ord="fro"))) Phi_red[j, s] = subspace_pca_transform_frame(B, Y) done += 1 if progress and (done % every == 0 or done == total): _status(f"[frame reduction | subspace_pca] transforming frames: {done}/{total}") errs = np.asarray(recon_errs, dtype=float) sup_err = float(np.max(errs)) if errs.size else 0.0 mean_err = float(np.mean(errs)) if errs.size else 0.0 rep = FrameReductionReport( method="subspace_pca", D_in=int(D), d_out=int(d), sup_proj_err=sup_err, mean_proj_err=mean_err, ) _maybe_clear(progress) return Phi_red, rep, B def reduction_curve_subspace_pca( *, Phi_true: np.ndarray, U: np.ndarray, dims: Sequence[int], max_frames: Optional[int] = None, rng_seed: int = 0, progress: bool = False, ) -> Tuple[np.ndarray, np.ndarray]: """ Compute a subspace PCA reconstruction-error curve across multiple target dimensions. For each d in `dims`, this fits a subspace basis B_d (using the same sampled frames) and reports the mean squared reconstruction error: mean ||Y - B_d B_d^T Y||_F^2. Returns ------- dims_arr : np.ndarray Target dimensions as an int array. mean_sq_err : np.ndarray Mean squared reconstruction errors corresponding to dims_arr. """ Phi_true = np.asarray(Phi_true, dtype=float) U = np.asarray(U, dtype=bool) _maybe_status(progress, f"[frame reduction | subspace_pca] collecting frames for curve (max_frames={max_frames})...") frames_fit = _collect_frames_as_list(Phi_true, U, max_frames=max_frames, rng_seed=rng_seed) if len(frames_fit) == 0: _maybe_clear(progress) raise ValueError("No frames available for curve computation.") D = int(frames_fit[0].shape[0]) dims_arr = np.asarray(list(dims), dtype=int) if np.any(dims_arr < 2) or np.any(dims_arr > D): _maybe_clear(progress) raise ValueError(f"All dims must satisfy 2 <= d <= D={D}.") mean_sq_err = np.zeros_like(dims_arr, dtype=float) every = _auto_every(len(dims_arr), target_updates=25, min_every=1) _maybe_status(progress, f"[frame reduction | subspace_pca] curve: 0/{len(dims_arr)}") for t, dd in enumerate(dims_arr): B, _ = subspace_pca_fit(frames_fit, d=int(dd)) se = [] for Y in frames_fit: Yproj = B @ (B.T @ Y) se.append(float(np.linalg.norm(Y - Yproj, ord="fro") ** 2)) mean_sq_err[t] = float(np.mean(se)) if len(se) else 0.0 if progress and ((t + 1) % every == 0 or (t + 1) == len(dims_arr)): _status(f"[frame reduction | subspace_pca] curve: {t+1}/{len(dims_arr)}") _maybe_clear(progress) return dims_arr, mean_sq_err # ============================================================ # PSC (optional dependency) # ============================================================ def reduce_frames_psc( *, Phi_true: np.ndarray, U: np.ndarray, d: int, max_frames: Optional[int] = None, rng_seed: int = 0, verbosity: int = 0, progress: bool = False, ) -> Tuple[np.ndarray, FrameReductionReport, np.ndarray]: """ Reduce frames using the PSC package (HarlinLee/PSC). Returns: Phi_red: (n_sets,n_samples,d,2) report alpha: (D,d) """ Phi_true = np.asarray(Phi_true, dtype=float) U = np.asarray(U, dtype=bool) if Phi_true.ndim != 4: raise ValueError(f"Expected Phi_true with shape (n_sets,n_samples,D,2). Got {Phi_true.shape}.") n_sets, n_samples, D, two = Phi_true.shape if two != 2: raise ValueError("PSC reducer expects k=2 frames (last dim must be 2).") if U.shape != (n_sets, n_samples): raise ValueError("U shape mismatch with Phi_true.") if not (2 <= int(d) <= int(D)): raise ValueError(f"Need 2 <= d <= D. Got d={d}, D={D}.") _maybe_status(progress, f"[frame reduction | psc] collecting frames (max_frames={max_frames})...") frames_fit = _collect_frames_as_list(Phi_true, U, max_frames=max_frames, rng_seed=rng_seed) if len(frames_fit) == 0: _maybe_clear(progress) raise ValueError("No frames available to fit PSC reducer (U has no True entries?).") ys = np.stack(frames_fit, axis=0) # (s, D, 2) try: from PSC.projections import PCA as PSC_PCA, manopt_alpha # type: ignore except Exception as e: _maybe_clear(progress) raise ImportError( "PSC package not importable. Install it so that " "`from PSC.projections import PCA, manopt_alpha` works." ) from e _maybe_status(progress, f"[frame reduction | psc] PSC_PCA init (D={D} → d={d}) on {len(frames_fit)} frames...") alpha_init = PSC_PCA(ys, int(d)) # (D,d) _maybe_status(progress, f"[frame reduction | psc] manopt_alpha (verbosity={int(verbosity)})...") alpha = manopt_alpha(ys, alpha_init, verbosity=int(verbosity)) alpha = np.asarray(alpha, dtype=float) if alpha.shape != (D, int(d)): _maybe_clear(progress) raise ValueError(f"PSC returned alpha with shape {alpha.shape}, expected {(D,int(d))}.") Phi_red = np.zeros((n_sets, n_samples, int(d), 2), dtype=float) recon_errs: List[float] = [] total = int(np.count_nonzero(U)) every = _auto_every(total, target_updates=25, min_every=200) done = 0 _maybe_status(progress, f"[frame reduction | psc] transforming frames: 0/{total}") for j in range(n_sets): for s in range(n_samples): if not U[j, s]: continue Y = Phi_true[j, s] # (D,2) Yproj = alpha @ (alpha.T @ Y) # Π(Y) in ambient space recon_errs.append(float(np.linalg.norm(Y - Yproj, ord="fro"))) Phi_red[j, s] = _polar_orthonormalize_2cols(alpha.T @ Y) done += 1 if progress and (done % every == 0 or done == total): _status(f"[frame reduction | psc] transforming frames: {done}/{total}") errs = np.asarray(recon_errs, dtype=float) sup_err = float(np.max(errs)) if errs.size else 0.0 mean_err = float(np.mean(errs)) if errs.size else 0.0 rep = FrameReductionReport( method="psc", D_in=int(D), d_out=int(d), sup_proj_err=sup_err, mean_proj_err=mean_err, ) _maybe_clear(progress) return Phi_red, rep, alpha def reduction_curve_psc( *, Phi_true: np.ndarray, U: np.ndarray, dims: Sequence[int], max_frames: Optional[int] = None, rng_seed: int = 0, psc_verbosity: int = 0, use_manopt: bool = True, plot: bool = True, progress: bool = False, ) -> Tuple[np.ndarray, np.ndarray]: """ Compute a PSC reconstruction-error curve across multiple target dimensions. For each d in `dims`, this fits α_d via PSC (optionally refined with manopt) and reports: mean ||Y - α_d α_d^T Y||_F^2. Parameters ---------- use_manopt : bool If True, refine the PCA initializer with PSC's manifold optimization routine. plot : bool If True, display a simple matplotlib curve in the current environment. Returns ------- dims_arr : np.ndarray Target dimensions as an int array. mean_sq_err : np.ndarray Mean squared reconstruction errors corresponding to dims_arr. Raises ------ ImportError If PSC is not importable. """ Phi_true = np.asarray(Phi_true, dtype=float) U = np.asarray(U, dtype=bool) _maybe_status(progress, f"[frame reduction | psc] collecting frames for curve (max_frames={max_frames})...") frames_fit = _collect_frames_as_list(Phi_true, U, max_frames=max_frames, rng_seed=rng_seed) if len(frames_fit) == 0: _maybe_clear(progress) raise ValueError("No frames available for PSC curve computation.") D = int(frames_fit[0].shape[0]) dims_arr = np.asarray(list(dims), dtype=int) if np.any(dims_arr < 2) or np.any(dims_arr > D): _maybe_clear(progress) raise ValueError(f"All dims must satisfy 2 <= d <= D={D}.") ys = np.stack(frames_fit, axis=0) # (s, D, 2) try: from PSC.projections import PCA as PSC_PCA, manopt_alpha # type: ignore except Exception as e: _maybe_clear(progress) raise ImportError( "PSC package not importable. Install it so that " "`from PSC.projections import PCA, manopt_alpha` works." ) from e mean_sq_err = np.zeros_like(dims_arr, dtype=float) every = _auto_every(len(dims_arr), target_updates=25, min_every=1) _maybe_status(progress, f"[frame reduction | psc] curve: 0/{len(dims_arr)}") for t, dd in enumerate(dims_arr): alpha = PSC_PCA(ys, int(dd)) if use_manopt: alpha = manopt_alpha(ys, alpha, verbosity=int(psc_verbosity)) alpha = np.asarray(alpha, dtype=float) se = [] for Y in frames_fit: Yproj = alpha @ (alpha.T @ Y) se.append(float(np.linalg.norm(Y - Yproj, ord="fro") ** 2)) mean_sq_err[t] = float(np.mean(se)) if len(se) else 0.0 if progress and ((t + 1) % every == 0 or (t + 1) == len(dims_arr)): _status(f"[frame reduction | psc] curve: {t+1}/{len(dims_arr)}") if plot: import matplotlib.pyplot as plt plt.figure() plt.plot(dims_arr, mean_sq_err, marker="o") plt.xlabel("Target dimension d") plt.ylabel("Mean squared projection error") plt.title("PSC projection error curve") plt.show() _maybe_clear(progress) return dims_arr, mean_sq_err