# circle_bundles/trivializations/local_triv.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Tuple, Any
import numpy as np
import warnings
from ..utils.status_utils import _status, _status_clear
__all__ = [
"LocalTrivResult",
"DreimacCCConfig",
"compute_circular_coords_pca2",
"compute_circular_coords_dreimac",
"compute_local_triv",
]
@dataclass
class LocalTrivResult:
"""
Result container for local trivialization (circle-coordinate) computations.
This object is returned by :func:`compute_local_triv` and packages the
per-cover-set circular coordinates together with diagnostic metadata.
Conventions
-----------
- U has shape (n_sets, n_samples)
- f has shape (n_sets, n_samples)
- f[j, s] is meaningful only when U[j, s] is True
Attributes
----------
f : ndarray of shape (n_sets, n_samples)
Local circular coordinates in radians, wrapped to [0, 2π).
Values are only meaningful on samples belonging to the corresponding
cover set.
valid : ndarray of shape (n_sets,)
Boolean mask indicating which cover sets were successfully
coordinatized.
n_retries : ndarray of shape (n_sets,)
Number of retries used for each cover set (relevant for iterative
methods such as Dreimac).
n_landmarks : ndarray of shape (n_sets,)
Number of landmarks ultimately used for each cover set
(method-dependent; meaningful for Dreimac-based methods).
errors : dict[int, str]
Mapping from cover-set index to error message for any set that
failed to produce valid coordinates.
Notes
-----
This class is intended as a lightweight, inspection-friendly summary
of the local trivialization step. Most users will encounter it through
higher-level bundle construction workflows.
"""
f: np.ndarray
valid: np.ndarray
n_retries: np.ndarray
n_landmarks: np.ndarray
errors: Dict[int, str]
# ----------------------------
# CC method config(s)
# ----------------------------
[docs]
@dataclass(frozen=True)
class DreimacCCConfig:
"""
Configuration object for Dreimac-based circular coordinates.
This dataclass specifies how Dreimac's circular coordinates algorithm
should be applied within local trivialization routines such as
:func:`compute_local_triv`.
Attributes
----------
CircularCoords_cls : Any
The Dreimac circular coordinates class (e.g. ``dreimac.CircularCoords``).
landmarks_per_patch : int, default=200
Initial number of landmarks to use per patch. The algorithm may
increase this value automatically if coverage is insufficient.
prime : int, default=41
Prime number used internally by Dreimac for coefficient computations.
update_frac : float, default=0.25
Fractional increase applied to the number of landmarks when a retry
is required.
standard_range : bool, default=False
Whether to return angles in Dreimac's standard range instead of
wrapping to [0, 2π).
Notes
-----
- This configuration is passed as the ``cc`` argument to
:func:`compute_local_triv`.
- If a ``total_metric`` is supplied to :func:`compute_local_triv`,
distance matrices are passed to Dreimac with ``distance_matrix=True``.
- The dataclass is frozen to emphasize its role as an immutable
configuration object.
"""
CircularCoords_cls: Any
landmarks_per_patch: int = 200
prime: int = 41
update_frac: float = 0.25
standard_range: bool = False
# ----------------------------
# PCA2 / MDS helpers
# ----------------------------
def _pca2_project(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Return 2D PCA coordinates and PCA basis.
Returns
-------
Y : (m,2) projected coords
V : (D,2) principal directions
mu : (D,) mean
"""
X = np.asarray(X, dtype=float)
if X.ndim != 2:
raise ValueError(f"X must be 2D (m,D). Got shape {X.shape}.")
m = int(X.shape[0])
if m == 0:
raise ValueError("Empty patch: cannot run PCA.")
mu = X.mean(axis=0)
Xc = X - mu
U, S, Vt = np.linalg.svd(Xc, full_matrices=False)
V = Vt.T[:, :2] # (D,2)
Y = Xc @ V # (m,2)
return Y, V, mu
def _mds2_from_dist(D: np.ndarray, *, eps: float = 1e-12) -> np.ndarray:
"""
Classical MDS to 2D from a distance matrix.
Parameters
----------
D : (m,m) distances (assumed symmetric, 0 diagonal)
Returns
-------
Y : (m,2) Euclidean embedding
"""
D = np.asarray(D, dtype=float)
if D.ndim != 2 or D.shape[0] != D.shape[1]:
raise ValueError(f"D must be square (m,m). Got {D.shape}.")
m = int(D.shape[0])
if m == 0:
raise ValueError("Empty distance matrix: cannot run MDS.")
J = np.eye(m) - np.ones((m, m), dtype=float) / float(m)
D2 = D * D
B = -0.5 * (J @ D2 @ J)
w, V = np.linalg.eigh(B)
idx = np.argsort(w)[::-1]
w = w[idx]
V = V[:, idx]
w = np.maximum(w, 0.0)
if w.size == 0 or w[0] < eps:
return np.zeros((m, 2), dtype=float)
Y = V[:, :2] * np.sqrt(w[:2])[None, :]
return Y
def compute_circular_coords_pca2(
X: Optional[np.ndarray] = None,
*,
dist_mat: Optional[np.ndarray] = None,
eps: float = 1e-12,
) -> np.ndarray:
"""
Compute simple circular coordinates via 2D PCA (or via 2D classical MDS if dist_mat given).
Orientation stabilization is fixed to the "farthest point" convention.
Parameters
----------
X : (m,D) array, required if dist_mat is None
dist_mat : (m,m) optional precomputed distances (if provided, uses MDS->2D->atan2)
eps : small guard
Returns
-------
angles : (m,) in [0, 2pi)
"""
if dist_mat is None:
if X is None:
raise ValueError("Provide X or dist_mat.")
X = np.asarray(X, dtype=float)
if X.ndim != 2:
raise ValueError(f"X must be 2D (m,D). Got shape {X.shape}.")
Y, _, _ = _pca2_project(X)
else:
D = np.asarray(dist_mat, dtype=float)
if D.ndim != 2 or D.shape[0] != D.shape[1]:
raise ValueError(f"dist_mat must be square (m,m). Got {D.shape}.")
Y = _mds2_from_dist(D, eps=eps)
u = Y[:, 0]
v = Y[:, 1]
if np.std(u) < eps and np.std(v) < eps:
return np.zeros(len(u), dtype=float)
ang = np.arctan2(v, u)
ang = np.mod(ang, 2.0 * np.pi)
# ---- orientation stabilization: ALWAYS "farthest" ----
if len(ang) > 0:
r2 = u * u + v * v
k = int(np.argmax(r2))
ang0 = float(ang[k])
ang = np.mod(ang - ang0, 2.0 * np.pi)
return ang
def compute_circular_coords_dreimac(
X: np.ndarray,
*,
n_landmarks_init: int,
prime: int = 41,
update_frac: float = 0.25,
standard_range: bool = False,
CircularCoords_cls=None,
dist_mat: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, int, int]:
"""
Compute circular coordinates using Dreimac.
If dist_mat is provided, it must be (n_points, n_points) and we pass
distance_matrix=True to Dreimac.
"""
if CircularCoords_cls is None:
raise ValueError("CircularCoords_cls must be provided (e.g., dreimac.CircularCoords).")
X = np.asarray(X)
n_points = int(X.shape[0])
if n_points == 0:
raise ValueError("Empty patch: cannot compute circular coordinates.")
use_dist = dist_mat is not None
if use_dist:
D = np.asarray(dist_mat, dtype=float)
if D.shape != (n_points, n_points):
raise ValueError(f"dist_mat has shape {D.shape}, expected ({n_points},{n_points}).")
X_or_D = D
else:
X_or_D = X
n_landmarks = min(int(n_landmarks_init), n_points)
n_retries = 0
while True:
try:
with warnings.catch_warnings(record=True) as wlist:
warnings.simplefilter("always")
cc = CircularCoords_cls(
X_or_D,
n_landmarks,
prime=prime,
distance_matrix=use_dist,
)
angles = cc.get_coordinates(standard_range=standard_range)
for w in wlist:
msg = str(w.message).lower()
if "not covered by a landmark" in msg:
raise RuntimeError("Dreimac: not covered by a landmark")
angles = np.asarray(angles, dtype=float).reshape(-1)
if angles.shape != (n_points,):
raise ValueError(f"Dreimac returned shape {angles.shape}, expected ({n_points},).")
angles = np.mod(angles, 2.0 * np.pi)
return angles, n_retries, n_landmarks
except Exception as e:
n_retries += 1
if n_landmarks >= n_points:
raise ValueError(
f"Circular coordinates failed even with n_landmarks=n_points={n_points}. "
f"Last error: {type(e).__name__}: {e}"
) from e
n_landmarks = min(int(np.ceil((1.0 + float(update_frac)) * n_landmarks)), n_points)
def _pairwise_dist_or_none(total_metric: Optional[object], Xj: np.ndarray) -> Optional[np.ndarray]:
if total_metric is None:
return None
if not hasattr(total_metric, "pairwise"):
raise TypeError("total_metric must have a .pairwise(X, Y=None) method.")
Dj = np.asarray(total_metric.pairwise(Xj), dtype=float)
m = int(Xj.shape[0])
if Dj.shape != (m, m):
raise ValueError(f"total_metric.pairwise returned shape {Dj.shape}, expected ({m},{m}).")
return Dj
def compute_local_triv(
data: np.ndarray,
U: np.ndarray,
*,
cc: object = "pca2",
total_metric: Optional[object] = None,
min_patch_size: int = 10,
verbose: bool = True,
fail_fast: bool = True,
) -> LocalTrivResult:
"""
Compute local circle coordinates f[j, s] on each cover set U[j].
Supports two data modes:
- Point cloud: data shape (n_samples, D)
- Full distance matrix: data shape (n_samples, n_samples)
Conventions (UNCHANGED):
- U has shape (n_sets, n_samples) bool
- f has shape (n_sets, n_samples) radians
- f[j, s] meaningful only when U[j, s] is True
Circular coordinates method (cc):
- "pca2" (default): compute_circular_coords_pca2
* uses PCA2 on points, or MDS2 if a dist_mat is provided
- DreimacCCConfig(...): uses Dreimac on points or distance matrices
- callable: advanced hook; we call cc(Xj, dist_mat=Dj) if possible, else cc(Xj)
Notes
-----
- For a callable cc, returning angles in radians is expected; we wrap to [0,2pi).
- If `data` is a distance matrix, we pass per-patch submatrices as `dist_mat`.
In that case, `Xj` is a dummy placeholder (unless the callable ignores dist_mat).
"""
data = np.asarray(data)
U = np.asarray(U, dtype=bool)
if U.ndim != 2:
raise ValueError(f"U must be 2D (n_sets, n_samples). Got shape {U.shape}.")
n_sets, n_samples = U.shape
# Determine whether data is a full distance matrix
data_is_dist = (data.ndim == 2 and data.shape[0] == data.shape[1])
if data_is_dist:
if data.shape[0] != n_samples:
raise ValueError(
f"Distance matrix has n={data.shape[0]} but U has n_samples={n_samples}."
)
else:
# point cloud mode
if data.ndim != 2:
raise ValueError(
f"data must be either a point cloud (n_samples, D) or a distance matrix (n_samples, n_samples). "
f"Got shape {data.shape}."
)
if data.shape[0] != n_samples:
raise ValueError(f"data has n={data.shape[0]} samples but U has n_samples={n_samples}.")
cc_is_pca2 = (isinstance(cc, str) and str(cc).lower() == "pca2")
cc_is_dreimac = isinstance(cc, DreimacCCConfig)
cc_is_callable = callable(cc)
if not (cc_is_pca2 or cc_is_dreimac or cc_is_callable):
raise ValueError(
"cc must be 'pca2', a DreimacCCConfig(...), or a callable. "
f"Got {type(cc).__name__}: {cc!r}"
)
f = np.zeros((n_sets, n_samples), dtype=float)
valid = np.zeros(n_sets, dtype=bool)
n_retries = np.zeros(n_sets, dtype=int)
n_landmarks = np.zeros(n_sets, dtype=int)
errors: Dict[int, str] = {}
for j in range(n_sets):
if verbose:
_status(f"Coordinatizing set {j+1}/{n_sets}...")
mask = U[j]
m = int(mask.sum())
if m < int(min_patch_size):
msg = f"Patch too small for set j={j}: |U[j]|={m} < min_patch_size={min_patch_size}."
if fail_fast:
if verbose:
_status_clear()
raise ValueError(msg)
errors[j] = msg
continue
try:
# Build per-patch data
if data_is_dist:
idx = np.where(mask)[0]
Dj = np.asarray(data[np.ix_(idx, idx)], dtype=float)
if Dj.shape != (m, m):
raise ValueError(f"Patch dist submatrix has shape {Dj.shape}, expected ({m},{m}).")
# Dummy placeholder for callables that insist on Xj
Xj = np.zeros((m, 1), dtype=float)
else:
Xj = np.asarray(data[mask], dtype=float)
Dj = _pairwise_dist_or_none(total_metric, Xj)
# 1) Custom callable wins
if cc_is_callable:
try:
ang = cc(Xj, dist_mat=Dj) # type: ignore[misc]
except TypeError:
ang = cc(Xj) # type: ignore[misc]
ang = np.asarray(ang, dtype=float).reshape(-1)
if ang.shape != (m,):
raise ValueError(f"cc callable returned shape {ang.shape}, expected ({m},).")
f[j, mask] = np.mod(ang, 2.0 * np.pi)
valid[j] = True
continue
# 2) PCA2 (PCA2 on points, or MDS2 if Dj is provided)
if cc_is_pca2:
if data_is_dist:
ang = compute_circular_coords_pca2(X=None, dist_mat=Dj)
else:
ang = compute_circular_coords_pca2(Xj, dist_mat=Dj)
f[j, mask] = ang
valid[j] = True
continue
# 3) Dreimac config
assert cc_is_dreimac
ang, retries, n_lmks = compute_circular_coords_dreimac(
Xj,
dist_mat=Dj, # if Dj is not None, dreimac uses distance_matrix=True internally
n_landmarks_init=cc.landmarks_per_patch,
prime=cc.prime,
update_frac=cc.update_frac,
standard_range=cc.standard_range,
CircularCoords_cls=cc.CircularCoords_cls,
)
f[j, mask] = ang
valid[j] = True
n_retries[j] = int(retries)
n_landmarks[j] = int(n_lmks)
except Exception as e:
msg = f"Failed on set j={j} (|U[j]|={m}): {type(e).__name__}: {e}"
if fail_fast:
if verbose:
_status_clear()
raise ValueError(msg) from e
errors[j] = msg
valid[j] = False
if verbose:
_status_clear()
return LocalTrivResult(
f=f,
valid=valid,
n_retries=n_retries,
n_landmarks=n_landmarks,
errors=errors,
)