Source code for circle_bundles.metrics

# circle_bundles/metrics.py
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Callable, Optional, Protocol, Union

import numpy as np

try:
    from scipy.spatial.distance import cdist as _cdist  # optional
except Exception:  # pragma: no cover
    _cdist = None


# ============================================================
# Vectorized metric objects
# ============================================================

class Metric(Protocol):
    """
    Vectorized metric interface.

    A ``Metric`` represents a distance function that can be evaluated *in batch*
    on two collections of points.

    Implementations must provide:
    - a human-readable ``name`` attribute, and
    - a :meth:`pairwise` method returning a full distance matrix.

    This interface is intentionally minimal so metrics can be passed seamlessly
    to covers, bundle construction, and visualization utilities.

    Notes
    -----
    - The library assumes (but does not enforce) that distances satisfy the
      metric axioms.
    - All metrics operate on NumPy arrays and return NumPy arrays.

    See Also
    --------
    as_metric :
        Utility for converting scalar distance functions into vectorized metrics.
    """
    name: str

    def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray:
        """
        Compute pairwise distances between two point clouds.

        Parameters
        ----------
        X :
            Array of shape ``(n, d)`` or ``(n,)`` representing ``n`` points.
        Y :
            Optional array of shape ``(m, d)`` or ``(m,)``.
            If omitted, distances are computed between rows of ``X``.

        Returns
        -------
        D :
            Distance matrix of shape ``(n, m)``, where
            ``D[i, j] = d(X[i], Y[j])``.
        """
        ...


[docs] @dataclass(frozen=True) class EuclideanMetric: """ Standard Euclidean metric on :math:`\\mathbb{R}^d`. This metric computes ordinary Euclidean distances between vectors using the ℓ² norm. It is the default metric used throughout the library whenever no other metric is specified. Attributes ---------- name : Display name for the metric (default ``"euclidean"``). """ name: str = "euclidean"
[docs] def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: """ Compute Euclidean pairwise distances. Parameters ---------- X : Array of shape ``(n, d)`` or ``(n,)`` representing ``n`` points. Y : Optional array of shape ``(m, d)`` or ``(m,)``. If omitted, uses ``X``. Returns ------- D : Euclidean distance matrix of shape ``(n, m)``. """ X = np.asarray(X) Y = X if Y is None else np.asarray(Y) if X.ndim == 1: X = X.reshape(-1, 1) if Y.ndim == 1: Y = Y.reshape(-1, 1) return np.linalg.norm(X[:, None, :] - Y[None, :, :], axis=-1)
[docs] @dataclass(frozen=True) class S1AngleMetric: r""" Geodesic distance on the circle :math:`\mathbb{S}^1` using angles. Points are represented by angles (in radians). The distance between two angles is the shorter arc length between them on the circle: .. math:: d(\theta_1, \theta_2) = \min\left(|\theta_2 - \theta_1|, 2\pi - |\theta_2 - \theta_1|\right). This metric is appropriate when the base space is a circle and data are naturally parameterized by angles. Attributes ---------- name : Metric identifier (default ``"S1_angle"``). base_name : Short name of the base space for plots/UI. base_name_latex : LaTeX symbol used in summaries and tables. """ name: str = "S1_angle" base_name: str = "S^1" base_name_latex: str = r"\mathbb{S}^1"
[docs] def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: """ Compute circular geodesic distances between angle arrays. Parameters ---------- X : Array of angles (any real values), shape ``(n,)``. Y : Optional array of angles, shape ``(m,)``. If omitted, uses ``X``. Returns ------- D : Distance matrix of shape ``(n, m)``. """ t1 = np.asarray(X).reshape(-1)[:, None] t2 = t1.T if Y is None else np.asarray(Y).reshape(-1)[None, :] d = np.abs(t2 - t1) return np.minimum(d, 2 * np.pi - d)
[docs] @dataclass(frozen=True) class RP1AngleMetric: r""" Geodesic distance on the real projective line :math:`\mathbb{RP}^1` using angular coordinates. The space :math:`\mathbb{RP}^1` can be viewed as a circle with antipodal points identified. Angles are therefore taken modulo :math:`\pi`. The distance between two angles is .. math:: d(\theta_1, \theta_2) = \min\left(|\Delta|, \pi - |\Delta|\right), \quad \Delta = (\theta_2 - \theta_1) \bmod \pi. This metric is commonly used when the base variable represents *unoriented directions*. Attributes ---------- name : Metric identifier (default ``"RP1_angle"``). base_name : Short name of the base space for plots/UI. base_name_latex : LaTeX symbol used in summaries and tables. """ name: str = "RP1_angle" base_name: str = "RP^1" base_name_latex: str = r"\mathbb{RP}^1"
[docs] def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: """ Compute projective geodesic distances between angle arrays. Parameters ---------- X : Array of angles (any real values), shape ``(n,)``. Y : Optional array of angles, shape ``(m,)``. If omitted, uses ``X``. Returns ------- D : Distance matrix of shape ``(n, m)``. """ t1 = np.mod(np.asarray(X, dtype=float).reshape(-1), np.pi)[:, None] t2 = t1.T if Y is None else np.mod( np.asarray(Y, dtype=float).reshape(-1), np.pi )[None, :] d = np.abs(t2 - t1) return np.minimum(d, np.pi - d)
[docs] @dataclass(frozen=True) class S1UnitVectorMetric: r""" Geodesic distance on :math:`\mathbb{S}^1` using unit vectors in :math:`\mathbb{R}^2`. Points are represented as (approximately) unit vectors ``p, q ∈ R^2`` lying on the unit circle. The geodesic distance is the angle between the vectors: .. math:: d(p, q) = \arccos(\langle p, q \rangle), where the dot product is clamped to ``[-1, 1]`` for numerical stability. Use this metric when your base points are stored as 2D unit vectors (e.g. ``(cos θ, sin θ)``) rather than angles. Attributes ---------- name : Metric identifier (default ``"S1_unitvec"``). base_name : Short name of the base space for plots/UI. base_name_latex : LaTeX symbol used in summaries and tables. Notes ----- - This metric assumes inputs are unit vectors. If your vectors are not normalized, you should normalize them before calling :meth:`pairwise`, or use a different metric. - Values are in radians in the range ``[0, π]``. """ name: str = "S1_unitvec" base_name: str = "S^1" base_name_latex: str = r"\mathbb{S}^1"
[docs] def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: """ Compute geodesic distances on :math:`\mathbb{S}^1` between unit-vector samples. Parameters ---------- X : Array of shape ``(n, 2)`` containing unit vectors in :math:`\mathbb{R}^2`. Y : Optional array of shape ``(m, 2)``. If omitted, uses ``X``. Returns ------- D : Distance matrix of shape ``(n, m)`` with entries in radians. """ X = np.asarray(X, dtype=float) Y = X if Y is None else np.asarray(Y, dtype=float) dots = np.clip(X @ Y.T, -1.0, 1.0) return np.arccos(dots)
[docs] @dataclass(frozen=True) class RP1UnitVectorMetric: r""" Geodesic distance on :math:`\mathbb{RP}^1` using unit vectors in :math:`\mathbb{R}^2`. Points are represented by unit vectors in :math:`\mathbb{R}^2`, but with the antipodal identification: .. math:: p \sim -p. Therefore the distance between classes ``[p]`` and ``[q]`` is: .. math:: d([p],[q]) = \arccos(|\langle p, q \rangle|). This is the correct metric when the base variable represents *unoriented directions* (i.e. an axis rather than an arrow). Attributes ---------- name : Metric identifier (default ``"RP1_unitvec"``). base_name : Short name of the base space for plots/UI. base_name_latex : LaTeX symbol used in summaries and tables. Notes ----- - This implementation is robust to small deviations from unit norm: it normalizes input rows internally. - Distances are in radians in the range ``[0, π/2]`` for true unit inputs (because of antipodal identification). """ name: str = "RP1_unitvec" base_name: str = "RP^1" base_name_latex: str = r"\mathbb{RP}^1"
[docs] def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: """ Compute projective geodesic distances between unit-vector samples. Parameters ---------- X : Array of shape ``(n, 2)`` representing vectors in :math:`\mathbb{R}^2`. Rows should be (approximately) unit length. Y : Optional array of shape ``(m, 2)``. If omitted, uses ``X``. Returns ------- D : Distance matrix of shape ``(n, m)`` with entries in radians. Raises ------ ValueError If ``X`` or ``Y`` does not have shape ``(*, 2)`` after reshaping 1D inputs. """ X = np.asarray(X, dtype=float) Y = X if Y is None else np.asarray(Y, dtype=float) # allow (n,) -> (n,1) but RP1UnitVectorMetric really expects (n,2) if X.ndim == 1: X = X.reshape(-1, 1) if Y.ndim == 1: Y = Y.reshape(-1, 1) if X.shape[1] != 2 or Y.shape[1] != 2: raise ValueError(f"RP1UnitVectorMetric expects (n,2) arrays. Got X={X.shape}, Y={Y.shape}.") # normalize rows (robust if user passes slightly non-unit vectors) Xn = X / np.maximum(np.linalg.norm(X, axis=1, keepdims=True), 1e-12) Yn = Y / np.maximum(np.linalg.norm(Y, axis=1, keepdims=True), 1e-12) dots = np.clip(Xn @ Yn.T, -1.0, 1.0) return np.arccos(np.abs(dots))
[docs] @dataclass(frozen=True) class RP2UnitVectorMetric: r""" Metric on :math:`\mathbb{RP}^2` using antipodal unit vectors in :math:`\mathbb{R}^3`. Points in :math:`\mathbb{RP}^2` can be represented by unit vectors ``p ∈ S^2 ⊂ R^3`` with the antipodal identification ``p ~ -p``. This implementation uses the *chordal* quotient distance induced from Euclidean distance in :math:`\mathbb{R}^3`: .. math:: d([p],[q]) = \min(\|p - q\|,\; \|p + q\|). This is a common practical choice for embedding-based computations and is fast to compute in batch. Attributes ---------- name : Metric identifier (default ``"RP2_unitvec"``). base_name : Short name of the base space for plots/UI. base_name_latex : LaTeX symbol used in summaries and tables. Notes ----- - This is a *chordal* metric, not the intrinsic geodesic metric on :math:`\mathbb{RP}^2`. For most cover-building and neighborhood computations, chordal behavior is appropriate. - If you need the intrinsic projective geodesic distance, you’d implement a different class (e.g. based on ``arccos(|<p,q>|)``). """ name: str = "RP2_unitvec" base_name: str = "RP^2" base_name_latex: str = r"\mathbb{RP}^2"
[docs] def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: """ Compute chordal quotient distances on :math:`\mathbb{RP}^2`. Parameters ---------- X : Array of shape ``(n, 3)`` representing vectors in :math:`\mathbb{R}^3`. Rows are ideally unit vectors on :math:`S^2`. Y : Optional array of shape ``(m, 3)``. If omitted, uses ``X``. Returns ------- D : Distance matrix of shape ``(n, m)``. """ X = np.asarray(X, dtype=float) Y = X if Y is None else np.asarray(Y, dtype=float) Dpos = np.linalg.norm(X[:, None, :] - Y[None, :, :], axis=-1) Dneg = np.linalg.norm(X[:, None, :] + Y[None, :, :], axis=-1) return np.minimum(Dpos, Dneg)
def _t2_flat_pairwise_angles(X: np.ndarray, Y: np.ndarray) -> np.ndarray: r""" Pairwise flat torus distance on :math:`\mathbb{T}^2 = (\mathbb{R}/2\pi\mathbb{Z})^2` using angle coordinates. This helper computes the standard *flat* (product) metric on the 2-torus by applying circular distance in each angular coordinate and then taking the Euclidean norm: .. math:: d((\theta_1,\theta_2),(\phi_1,\phi_2)) = \sqrt{ d_{S^1}(\theta_1,\phi_1)^2 + d_{S^1}(\theta_2,\phi_2)^2 }, where :math:`d_{S^1}(a,b) = \min(|a-b|, 2\pi - |a-b|)`. Parameters ---------- X : Array of shape ``(n, 2)`` containing angles in radians, interpreted modulo ``2π``. Y : Array of shape ``(m, 2)`` containing angles in radians. Returns ------- D : Distance matrix of shape ``(n, m)``. Notes ----- - This is the *flat* torus metric, not a quotient metric. - Used internally by several quotient constructions (Klein bottle, diagonal Z₂ quotients). """ X = np.asarray(X, dtype=float) Y = np.asarray(Y, dtype=float) diff = np.abs(X[:, None, :] - Y[None, :, :]) torus_diff = np.minimum(diff, 2.0 * np.pi - diff) return np.linalg.norm(torus_diff, axis=-1)
[docs] @dataclass(frozen=True) class T2FlatMetric: r""" Flat metric on the 2-torus :math:`\mathbb{T}^2`. Points are represented as angle pairs ``(θ₁, θ₂)`` in radians, interpreted modulo ``2π`` in each coordinate. The distance is computed using the product of circular distances in each factor: .. math:: d(x,y) = \sqrt{ d_{S^1}(x_1,y_1)^2 + d_{S^1}(x_2,y_2)^2 }. This metric is appropriate when: - your base space is a genuine torus (no quotient identifications), and - coordinates are stored explicitly as angles. Attributes ---------- name : Metric identifier (default ``"T2_flat"``). base_name : Short name of the base space for plots/UI. base_name_latex : LaTeX symbol used in summaries and tables. Notes ----- - Input angles may lie outside ``[0, 2π)``; wrapping is handled implicitly. - This metric is frequently used as the *upstairs* metric before taking Z₂ quotients (e.g. Klein bottle, diagonal quotients). """ name: str = "T2_flat" base_name: str = "T^2" base_name_latex: str = r"\mathbb{T}^2"
[docs] def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: """ Compute flat torus distances between angle-coordinate samples. Parameters ---------- X : Array of shape ``(n, 2)`` containing angles in radians. Y : Optional array of shape ``(m, 2)``. If omitted, uses ``X``. Returns ------- D : Distance matrix of shape ``(n, m)``. Raises ------ ValueError If ``X`` or ``Y`` does not have shape ``(*, 2)``. """ X = np.asarray(X, dtype=float) Y0 = X if Y is None else np.asarray(Y, dtype=float) if X.ndim != 2 or X.shape[1] != 2: raise ValueError(f"X must be (n,2) angles. Got {X.shape}.") if Y0.ndim != 2 or Y0.shape[1] != 2: raise ValueError(f"Y must be (m,2) angles. Got {Y0.shape}.") return _t2_flat_pairwise_angles(X, Y0)
# ============================================================ # Product metrics on concatenated (base | fiber) vectors # ============================================================ @dataclass(frozen=True) class ProductMetricConcat: """ Product metric on concatenated vectors Z = [base | fiber]. Distance: d(z,z')^2 = (base_weight * dB(base, base'))^2 + (fiber_weight * ||fiber-fiber'||)^2 Notes ----- - Uses the *cover/base metric* dB for the base block (can be torus, RP1, etc.). - Uses Euclidean distance for the fiber block. - Fully vectorized, returns an (n,m) matrix, satisfies Metric protocol. """ base_metric: Metric base_dim: int base_weight: float = 1.0 fiber_weight: float = 1.0 name: str = "product_concat" def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: X = np.asarray(X, dtype=float) Y0 = X if Y is None else np.asarray(Y, dtype=float) if X.ndim != 2 or Y0.ndim != 2: raise ValueError("ProductMetricConcat expects 2D arrays.") if X.shape[1] < self.base_dim or Y0.shape[1] < self.base_dim: raise ValueError( f"base_dim={self.base_dim} exceeds feature dim: X={X.shape}, Y={Y0.shape}" ) XB, XF = X[:, : self.base_dim], X[:, self.base_dim :] YB, YF = Y0[:, : self.base_dim], Y0[:, self.base_dim :] DB = self.base_metric.pairwise(XB, YB) # (n,m) DF = np.linalg.norm(XF[:, None, :] - YF[None, :, :], axis=-1) # (n,m) bw = float(self.base_weight) fw = float(self.fiber_weight) return np.sqrt((bw * DB) ** 2 + (fw * DF) ** 2) # ============================================================ # Flat Z2 quotient metrics on angle-coordinates (base-first!) # ============================================================ @dataclass(frozen=True) class KleinBottleFlatMetric: r""" Flat Klein bottle metric as a :math:`\mathbb{Z}_2` quotient of the flat torus. We represent points by angle coordinates ``(b, f)`` in radians (interpreted mod ``2π``), using the **base-first convention**: - ``b`` = base angle - ``f`` = fiber angle The Klein bottle arises as the quotient of :math:`\mathbb{T}^2` by the action .. math:: g(b,f) = (b+\pi,\,-f). The induced quotient distance is .. math:: d_{\mathrm{KB}}([x],[y]) = \min\{ d_{\mathbb{T}^2}(x,y),\ d_{\mathbb{T}^2}(x,g(y)) \}, where :math:`d_{\mathbb{T}^2}` is the flat torus metric (coordinatewise circular distance, then Euclidean norm). Attributes ---------- name : Metric identifier (default ``"KB_flat"``). Notes ----- - This is an *intrinsic quotient metric* computed via minimizing over the two representatives ``y`` and ``g(y)``. - Inputs may be any real angles; wrapping into ``[0,2π)`` is handled internally. - This is the natural metric to use when your total-space coordinates already live in angle form (base,fiber) and you want the Klein identification. """ name: str = "KB_flat" def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: """ Compute pairwise Klein bottle distances on angle-coordinate data. Parameters ---------- X : Array of shape ``(n, 2)`` containing angles ``(b, f)`` in radians. Y : Optional array of shape ``(m, 2)``. If omitted, uses ``X``. Returns ------- D : Distance matrix of shape ``(n, m)`` where ``D[i,j]`` is the quotient distance between ``X[i]`` and ``Y[j]``. Raises ------ ValueError If ``X`` or ``Y`` is not a 2D array of shape ``(*, 2)``. """ X = np.asarray(X, dtype=float) Y0 = X if Y is None else np.asarray(Y, dtype=float) if X.ndim != 2 or X.shape[1] != 2: raise ValueError(f"X must be (n,2) angles (base,fiber). Got {X.shape}.") if Y0.ndim != 2 or Y0.shape[1] != 2: raise ValueError(f"Y must be (m,2) angles (base,fiber). Got {Y0.shape}.") # g(b,f) = (b + pi, -f) modulo 2pi Y1 = Y0.copy() Y1[:, 0] = np.mod(Y1[:, 0] + np.pi, 2.0 * np.pi) Y1[:, 1] = np.mod(-Y1[:, 1], 2.0 * np.pi) D0 = _t2_flat_pairwise_angles(X, Y0) D1 = _t2_flat_pairwise_angles(X, Y1) return np.minimum(D0, D1) @dataclass(frozen=True) class TorusDiagFlatMetric: r""" Flat metric on the :math:`\mathbb{Z}_2` quotient of the flat torus by a diagonal π-shift. Points are angle pairs ``(b, f)`` in radians (mod ``2π``), again using the **base-first convention**. The :math:`\mathbb{Z}_2` action is .. math:: g(b,f) = (b+\pi,\ f+\pi). The induced quotient distance is .. math:: d([x],[y]) = \min\{ d_{\mathbb{T}^2}(x,y),\ d_{\mathbb{T}^2}(x,g(y)) \}. Topologically, this quotient is homeomorphic to :math:`\mathbb{RP}^1 \times \mathbb{S}^1` (i.e. a trivial circle bundle over :math:`\mathbb{RP}^1`), but this class only encodes the metric structure via the quotient construction. Attributes ---------- name : Metric identifier (default ``"T2_diag_flat"``). Notes ----- - Symmetric in the two coordinates as a space, but we keep the (base,fiber) naming for consistency with the rest of the library. - Inputs may be any real angles; wrapping is handled internally. """ name: str = "T2_diag_flat" def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: """ Compute pairwise distances for the diagonal π-shift quotient. Parameters ---------- X : Array of shape ``(n, 2)`` containing angles ``(b, f)`` in radians. Y : Optional array of shape ``(m, 2)``. If omitted, uses ``X``. Returns ------- D : Distance matrix of shape ``(n, m)`` giving quotient distances. Raises ------ ValueError If ``X`` or ``Y`` is not a 2D array of shape ``(*, 2)``. """ X = np.asarray(X, dtype=float) Y0 = X if Y is None else np.asarray(Y, dtype=float) if X.ndim != 2 or X.shape[1] != 2: raise ValueError(f"X must be (n,2) angles (base,fiber). Got {X.shape}.") if Y0.ndim != 2 or Y0.shape[1] != 2: raise ValueError(f"Y must be (m,2) angles (base,fiber). Got {Y0.shape}.") # g(b,f) = (b + pi, f + pi) modulo 2pi Y1 = Y0.copy() Y1[:, 0] = np.mod(Y1[:, 0] + np.pi, 2.0 * np.pi) Y1[:, 1] = np.mod(Y1[:, 1] + np.pi, 2.0 * np.pi) D0 = _t2_flat_pairwise_angles(X, Y0) D1 = _t2_flat_pairwise_angles(X, Y1) return np.minimum(D0, D1) def T2_Z2QuotientFlatMetric(kind: str = "klein") -> Metric: r""" Factory for :math:`\mathbb{Z}_2` quotient metrics on angle-coordinate data ``(base, fiber)`` in ``[0,2π)²``. This returns a metric object implementing the quotient distance induced by the flat torus metric upstairs. Parameters ---------- kind : Which :math:`\mathbb{Z}_2` action to use: - ``"klein"`` (aliases: ``"kb"``, ``"klein_bottle"``): :math:`(b,f) \sim (b+\pi,\,-f)` giving the Klein bottle. - ``"diag"`` (aliases: ``"diagonal"``, ``"diag_pi"``, ``"diagonal_pi"``): :math:`(b,f) \sim (b+\pi,\ f+\pi)` giving the diagonal π-shift quotient. Returns ------- metric : A metric object with a vectorized ``pairwise(X, Y=None)`` method, suitable for passing anywhere the library expects a :class:`Metric`. Raises ------ ValueError If ``kind`` is not recognized. Examples -------- >>> M = T2_Z2QuotientFlatMetric("klein") >>> D = M.pairwise(X) # X has shape (n,2) of (base,fiber) angles """ kind = str(kind).lower().strip() if kind in {"klein", "kb", "klein_bottle"}: return KleinBottleFlatMetric() if kind in {"diag", "diagonal", "diag_pi", "diagonal_pi"}: return TorusDiagFlatMetric() raise ValueError(f"Unknown kind={kind!r}. Expected 'klein' or 'diag'.") # ============================================================ # Converting scalar metrics -> vectorized metrics # ============================================================ @dataclass(frozen=True) class SciPyCdistMetric: """Fallback wrapper for a scalar metric(p,q) using scipy.spatial.distance.cdist.""" metric: Callable name: str = "scipy_cdist" def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: if _cdist is None: raise ImportError("SciPy not available: cannot use cdist fallback for custom metrics.") X = np.asarray(X) Y = X if Y is None else np.asarray(Y) if X.ndim == 1: X = X.reshape(-1, 1) if Y.ndim == 1: Y = Y.reshape(-1, 1) return _cdist(X, Y, metric=self.metric) def as_metric(metric: Union["Metric", Callable, None]) -> "Metric": """Convert either a Metric object or a callable(p,q) into a Metric object.""" if metric is None: return EuclideanMetric() if hasattr(metric, "pairwise"): return metric # type: ignore[return-value] return SciPyCdistMetric(metric=metric, name=getattr(metric, "__name__", "custom_metric")) # ============================================================ # Euclidean Z2 quotient metrics on embedded data # ============================================================ ActionFn = Callable[[np.ndarray], np.ndarray] def _pairwise_euclidean(X: np.ndarray, Y: np.ndarray) -> np.ndarray: X = np.asarray(X, dtype=float) Y = np.asarray(Y, dtype=float) return np.linalg.norm(X[:, None, :] - Y[None, :, :], axis=-1) @dataclass(frozen=True) class Z2QuotientMetricEuclidean: """ Z2 quotient metric induced from ambient Euclidean distance in R^d. d([x],[y]) = min(||x - y||, ||x - g(y)||) """ action: ActionFn dim: int name: str = "Z2QuotientMetricEuclidean" def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: X = np.asarray(X, dtype=float) Y0 = X if Y is None else np.asarray(Y, dtype=float) if X.ndim != 2 or X.shape[1] != self.dim: raise ValueError(f"X must be (n,{self.dim}). Got {X.shape}.") if Y0.ndim != 2 or Y0.shape[1] != self.dim: raise ValueError(f"Y must be (m,{self.dim}). Got {Y0.shape}.") Y1 = np.asarray(self.action(Y0), dtype=float) if Y1.shape != Y0.shape: raise ValueError(f"action returned shape {Y1.shape}, expected {Y0.shape}.") D0 = _pairwise_euclidean(X, Y0) D1 = _pairwise_euclidean(X, Y1) return np.minimum(D0, D1) def dist(self, p: np.ndarray, q: np.ndarray) -> float: p = np.asarray(p, dtype=float).reshape(self.dim) q = np.asarray(q, dtype=float).reshape(self.dim) q2 = np.asarray(self.action(q), dtype=float).reshape(self.dim) return float(min(np.linalg.norm(p - q), np.linalg.norm(p - q2))) # ============================================================ # C^2 torus (R^4) Z2 quotient metrics with base-first convention # ============================================================ # We assume the torus is embedded as (z1,z2) in C^2 with real coords: # z1 = x1 + i x2, z2 = x3 + i x4. # # To make "first coordinate is base projection" explicit, we allow: # base_in="z1" or base_in="z2" # ============================================================ def act_klein_C2_torus_base_in_z1(Y: np.ndarray) -> np.ndarray: """ base_in="z1": base angle lives in z1, fiber angle lives in z2. Klein action in angles (b,f): (b,f) ~ (b+pi, -f) becomes (z1,z2) -> (-z1, conj(z2)). Real coords: (x1,x2,x3,x4) -> (-x1,-x2, x3,-x4) """ Y = np.asarray(Y, dtype=float) if Y.shape[-1] != 4: raise ValueError(f"Expected last dim 4 for C^2 data. Got {Y.shape}.") out = Y.copy() out[..., 0:2] *= -1.0 out[..., 3] *= -1.0 return out def act_klein_C2_torus_base_in_z2(Y: np.ndarray) -> np.ndarray: """ base_in="z2": base angle lives in z2, fiber angle lives in z1. Klein action in (b,f): (b,f) ~ (b+pi, -f) in (z1,z2) ordering means: (z_fiber, z_base) -> (conj(z_fiber), -z_base) i.e. (z1,z2) -> (conj(z1), -z2). Real coords: (x1,x2,x3,x4) -> (x1,-x2, -x3,-x4) """ Y = np.asarray(Y, dtype=float) if Y.shape[-1] != 4: raise ValueError(f"Expected last dim 4 for C^2 data. Got {Y.shape}.") out = Y.copy() out[..., 1] *= -1.0 # conj(z1) out[..., 2:4] *= -1.0 # -z2 return out def act_diag_C2_torus_base_in_z1(Y: np.ndarray) -> np.ndarray: """ Diagonal pi-shift: (b,f) ~ (b+pi, f+pi) corresponds to (z1,z2)->(-z1,-z2), independent of which factor is called base. """ Y = np.asarray(Y, dtype=float) if Y.shape[-1] != 4: raise ValueError(f"Expected last dim 4 for C^2 data. Got {Y.shape}.") return -Y def act_diag_C2_torus_base_in_z2(Y: np.ndarray) -> np.ndarray: # Same map; kept for clarity/symmetry. return act_diag_C2_torus_base_in_z1(Y) # Back-compat aliases (older code may import these names) def act_klein_C2_torus(Y: np.ndarray) -> np.ndarray: return act_klein_C2_torus_base_in_z1(Y) def act_diag_C2_torus(Y: np.ndarray) -> np.ndarray: return act_diag_C2_torus_base_in_z1(Y)
[docs] def Torus_KleinQuotientMetric_R4(*, base_in: str = "z2") -> "Metric": """ Z2 quotient metric on R^4 C^2-torus embedding that implements the Klein identification (base,fiber) ~ (base+pi, -fiber) with an explicit base-factor choice. base_in: - "z1" (default): base angle is encoded in z1 - "z2": base angle is encoded in z2 """ base_in = str(base_in).lower().strip() if base_in == "z1": act = act_klein_C2_torus_base_in_z1 nm = "T2_to_Klein_R4(base=z1)" elif base_in == "z2": act = act_klein_C2_torus_base_in_z2 nm = "T2_to_Klein_R4(base=z2)" else: raise ValueError("base_in must be 'z1' or 'z2'.") return Z2QuotientMetricEuclidean(action=act, dim=4, name=nm)
[docs] def Torus_DiagQuotientMetric_R4(*, base_in: str = "z2") -> "Metric": """ Z2 quotient metric on R^4 C^2-torus embedding for the diagonal pi-shift: (base,fiber) ~ (base+pi, fiber+pi) This is symmetric under swapping base/fiber, but we keep base_in for API consistency. """ base_in = str(base_in).lower().strip() if base_in not in {"z1", "z2"}: raise ValueError("base_in must be 'z1' or 'z2'.") act = act_diag_C2_torus_base_in_z1 nm = f"T2_to_Diag_R4(base={base_in})" return Z2QuotientMetricEuclidean(action=act, dim=4, name=nm)
def Torus_Z2QuotientMetric_R4(kind: str = "klein", *, base_in: str = "z2") -> "Metric": """ Factory for Z2 quotient metrics on a C^2 torus embedded in R^4. kind: - "klein": (b,f)~(b+pi,-f) - "diag": (b,f)~(b+pi,f+pi) base_in: - "z1" (default) or "z2" (for klein; diag is symmetric but accepted) """ kind = str(kind).lower().strip() if kind in {"klein", "kb", "klein_bottle"}: return Torus_KleinQuotientMetric_R4(base_in=base_in) if kind in {"diag", "diagonal", "diag_pi", "diagonal_pi"}: return Torus_DiagQuotientMetric_R4(base_in=base_in) raise ValueError(f"Unknown kind={kind!r}. Expected 'klein' or 'diag'.") # ============================================================ # Existing R^5 Z2 quotient (kept for compatibility) # ============================================================ @dataclass(frozen=True) class Z2QuotientMetricR5: """ Z2 quotient metric induced from ambient Euclidean distance in R^5. Points are in R^5, intended as (v, Re z, Im z) with v in R^3, z on S^1. """ action: ActionFn name: str = "Z2QuotientMetricR5" def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: X = np.asarray(X, dtype=float) Y0 = X if Y is None else np.asarray(Y, dtype=float) if X.ndim != 2 or X.shape[1] != 5: raise ValueError(f"X must be (n,5). Got {X.shape}.") if Y0.ndim != 2 or Y0.shape[1] != 5: raise ValueError(f"Y must be (m,5). Got {Y0.shape}.") Y1 = np.asarray(self.action(Y0), dtype=float) if Y1.shape != Y0.shape: raise ValueError(f"action returned shape {Y1.shape}, expected {Y0.shape}.") D0 = _pairwise_euclidean(X, Y0) D1 = _pairwise_euclidean(X, Y1) return np.minimum(D0, D1) def dist(self, p: np.ndarray, q: np.ndarray) -> float: p = np.asarray(p, dtype=float).reshape(5) q = np.asarray(q, dtype=float).reshape(5) q2 = np.asarray(self.action(q), dtype=float).reshape(5) return float(min(np.linalg.norm(p - q), np.linalg.norm(p - q2))) def act_base_only(Y: np.ndarray) -> np.ndarray: """(v, a, b) -> (-v, a, b) i.e. (v,z) -> (-v, z)""" Y = np.asarray(Y, dtype=float) out = Y.copy() out[..., :3] *= -1.0 return out def act_pi_twist(Y: np.ndarray) -> np.ndarray: """(v, a, b) -> (-v, -a, -b) i.e. (v,z) -> (-v, -z)""" Y = np.asarray(Y, dtype=float) out = Y.copy() out[..., :3] *= -1.0 out[..., 3:5] *= -1.0 return out def act_reflection_twist(Y: np.ndarray) -> np.ndarray: """(v, a, b) -> (-v, a, -b) i.e. (v,z) -> (-v, conj z)""" Y = np.asarray(Y, dtype=float) out = Y.copy() out[..., :3] *= -1.0 out[..., 4] *= -1.0 return out
[docs] def RP2_TrivialMetric() -> Z2QuotientMetricR5: """Trivial circle bundle over RP^2: (v,z)~(-v,z).""" return Z2QuotientMetricR5(action=act_base_only, name="RP2xS1")
[docs] def RP2_TwistMetric() -> Z2QuotientMetricR5: """Orientable nontrivial (monodromy -1): (v,z)~(-v,-z).""" return Z2QuotientMetricR5(action=act_pi_twist, name="RP2_Twist")
[docs] def RP2_FlipMetric() -> Z2QuotientMetricR5: """Non-orientable (reflection on fiber): (v,z)~(-v,conj z).""" return Z2QuotientMetricR5(action=act_reflection_twist, name="RP2_Klein")
# ============================================================ # Z_p quotient metric on S^3 compatible with hopf_projection # ============================================================ def _safe_normalize_rows(X: np.ndarray, eps: float = 1e-12) -> np.ndarray: X = np.asarray(X, dtype=float) nrm = np.linalg.norm(X, axis=1, keepdims=True) nrm = np.maximum(nrm, eps) return X / nrm def _quat_mul(A: np.ndarray, B: np.ndarray) -> np.ndarray: """Hamilton product, vectorized. q=(a,b,c,d) with a scalar.""" A = np.asarray(A, dtype=float) B = np.asarray(B, dtype=float) aw, ax, ay, az = A[..., 0], A[..., 1], A[..., 2], A[..., 3] bw, bx, by, bz = B[..., 0], B[..., 1], B[..., 2], B[..., 3] w = aw * bw - ax * bx - ay * by - az * bz x = aw * bx + ax * bw + ay * bz - az * by y = aw * by - ax * bz + ay * bw + az * bx z = aw * bz + ax * by - ay * bx + az * bw return np.stack([w, x, y, z], axis=-1) def _s3_geodesic_pairwise(X: np.ndarray, Y: np.ndarray, *, eps: float = 1e-12) -> np.ndarray: X = _safe_normalize_rows(X, eps=eps) Y = _safe_normalize_rows(Y, eps=eps) dots = np.clip(X @ Y.T, -1.0, 1.0) return np.arccos(dots) def _s3_chordal_pairwise(X: np.ndarray, Y: np.ndarray, *, eps: float = 1e-12) -> np.ndarray: X = _safe_normalize_rows(X, eps=eps) Y = _safe_normalize_rows(Y, eps=eps) return np.linalg.norm(X[:, None, :] - Y[None, :, :], axis=-1) def _unit_quat_from_axis_angle(v: np.ndarray, theta: float, *, eps: float = 1e-12) -> np.ndarray: v = np.asarray(v, dtype=float).reshape(3,) nv = max(np.linalg.norm(v), eps) vhat = v / nv return np.array([np.cos(theta), *(np.sin(theta) * vhat)], dtype=float) def _default_v_axis() -> np.ndarray: return np.array([1.0, 0.0, 0.0], dtype=float) @dataclass(frozen=True) class ZpHopfQuotientMetricS3: """ Quotient metric on S^3 / <g>, where g = exp(v * 2pi/p) lies in the Hopf fiber circle S^1_v. d([x],[y]) = min_{m=0..p-1} d_S3(x, y * g^m) """ p: int v_axis: np.ndarray = field(default_factory=_default_v_axis) base: str = "geodesic" # "geodesic" or "chordal" name: str = "ZpHopfQuotientMetricS3" eps: float = 1e-12 def __post_init__(self) -> None: p = int(self.p) if p <= 0: raise ValueError(f"p must be positive. Got {self.p}.") if self.base not in ("geodesic", "chordal"): raise ValueError(f"base must be 'geodesic' or 'chordal'. Got {self.base}.") v = np.asarray(self.v_axis, dtype=float).reshape(3,) nv = max(np.linalg.norm(v), self.eps) v = v / nv object.__setattr__(self, "p", p) object.__setattr__(self, "v_axis", v) def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: X = np.asarray(X, dtype=float) Y0 = X if Y is None else np.asarray(Y, dtype=float) if X.ndim != 2 or X.shape[1] != 4: raise ValueError(f"X must be (n,4) quaternions. Got {X.shape}.") if Y0.ndim != 2 or Y0.shape[1] != 4: raise ValueError(f"Y must be (m,4) quaternions. Got {Y0.shape}.") Xn = _safe_normalize_rows(X, eps=self.eps) Yn = _safe_normalize_rows(Y0, eps=self.eps) dist_fn = _s3_geodesic_pairwise if self.base == "geodesic" else _s3_chordal_pairwise theta0 = 2.0 * np.pi / float(self.p) gs = np.stack( [_unit_quat_from_axis_angle(self.v_axis, m * theta0, eps=self.eps) for m in range(self.p)], axis=0, ) Dmin = None for m in range(self.p): Ym = _quat_mul(Yn, gs[m][None, :]) Dm = dist_fn(Xn, Ym, eps=self.eps) Dmin = Dm if Dmin is None else np.minimum(Dmin, Dm) return Dmin def _pick_u_perp_v(v_axis: np.ndarray, *, eps: float = 1e-12) -> np.ndarray: v = np.asarray(v_axis, dtype=float).reshape(3,) nv = max(np.linalg.norm(v), eps) v = v / nv a = np.array([1.0, 0.0, 0.0]) if abs(np.dot(a, v)) > 0.9: a = np.array([0.0, 1.0, 0.0]) u = np.cross(v, a) nu = max(np.linalg.norm(u), eps) return u / nu def _right_mul_by_u(Q: np.ndarray, u_axis: np.ndarray) -> np.ndarray: u_axis = np.asarray(u_axis, dtype=float).reshape(3,) u_quat = np.array([0.0, u_axis[0], u_axis[1], u_axis[2]], dtype=float) return _quat_mul(Q, u_quat[None, :]) @dataclass(frozen=True) class Z2LensAntipodalQuotientMetricS3: """ Metric on the quotient of the lens space L(p,1)=S^3/<g> by the Z2-action tau(q)=q*u that covers antipodal on S^2 under pi_v(q)=q v q^{-1}. Requires p even so that tau^2 is trivial on L(p,1). """ p: int v_axis: np.ndarray = field(default_factory=_default_v_axis) base: str = "geodesic" # "geodesic" or "chordal" name: str = "Z2LensAntipodalQuotientMetricS3" eps: float = 1e-12 u_axis: Optional[np.ndarray] = None def __post_init__(self) -> None: p = int(self.p) if p <= 0: raise ValueError(f"p must be positive. Got {self.p}.") if p % 2 != 0: raise ValueError(f"antipodal quotient requires even p. Got p={p}.") if self.base not in ("geodesic", "chordal"): raise ValueError(f"base must be 'geodesic' or 'chordal'. Got {self.base}.") v = np.asarray(self.v_axis, dtype=float).reshape(3,) v = v / max(np.linalg.norm(v), self.eps) ua = None if self.u_axis is not None: ua = np.asarray(self.u_axis, dtype=float).reshape(3,) ua = ua / max(np.linalg.norm(ua), self.eps) object.__setattr__(self, "p", p) object.__setattr__(self, "v_axis", v) object.__setattr__(self, "u_axis", ua) def pairwise(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> np.ndarray: X = np.asarray(X, dtype=float) Y0 = X if Y is None else np.asarray(Y, dtype=float) if X.ndim != 2 or X.shape[1] != 4: raise ValueError(f"X must be (n,4) quaternions. Got {X.shape}.") if Y0.ndim != 2 or Y0.shape[1] != 4: raise ValueError(f"Y must be (m,4) quaternions. Got {Y0.shape}.") Xn = _safe_normalize_rows(X, eps=self.eps) Yn = _safe_normalize_rows(Y0, eps=self.eps) dist_fn = _s3_geodesic_pairwise if self.base == "geodesic" else _s3_chordal_pairwise theta0 = 2.0 * np.pi / float(self.p) gs = np.stack( [_unit_quat_from_axis_angle(self.v_axis, m * theta0, eps=self.eps) for m in range(self.p)], axis=0, ) u_axis = _pick_u_perp_v(self.v_axis, eps=self.eps) if self.u_axis is None else self.u_axis assert u_axis is not None u_axis = u_axis / max(np.linalg.norm(u_axis), self.eps) Ytau = _right_mul_by_u(Yn, u_axis) def min_over_zp(Yrep: np.ndarray) -> np.ndarray: Dmin = None for m in range(self.p): Ym = _quat_mul(Yrep, gs[m][None, :]) Dm = dist_fn(Xn, Ym, eps=self.eps) Dmin = Dm if Dmin is None else np.minimum(Dmin, Dm) return Dmin return np.minimum(min_over_zp(Yn), min_over_zp(Ytau))
[docs] def S3QuotientMetric( p: int, *, v_axis: np.ndarray = np.array([1.0, 0.0, 0.0]), antipodal: bool = False, base: str = "geodesic", u_axis: Optional[np.ndarray] = None, name: Optional[str] = None, ) -> "Metric": p = int(p) if p <= 0: raise ValueError(f"p must be positive. Got {p}.") if base not in ("geodesic", "chordal"): raise ValueError(f"base must be 'geodesic' or 'chordal'. Got {base}.") v_axis = np.asarray(v_axis, dtype=float).reshape(3,) if antipodal: if p % 2 != 0: raise ValueError(f"antipodal=True requires even p. Got p={p}.") return Z2LensAntipodalQuotientMetricS3( p=p, v_axis=v_axis, base=base, u_axis=u_axis, name=name or f"S3/(Z_{p}⋊Z2)_antipodal", ) return ZpHopfQuotientMetricS3( p=p, v_axis=v_axis, base=base, name=name or f"S3/Z_{p}", )
# ============================================================ # Old scalar metric functions (kept for compatibility) # ============================================================ def S1_dist(theta1, theta2): d = np.abs(theta2 - theta1) return np.minimum(d, 2 * np.pi - d) def RP1_dist(theta1, theta2): d = np.abs(theta2 - theta1) return np.minimum(d, np.pi - d) def S1_dist2(p, q): return np.arccos(np.clip(np.dot(p, q), -1.0, 1.0)) def RP1_dist2(p, q): ang = np.arccos(np.clip(np.dot(p, q), -1.0, 1.0)) return np.minimum(ang, np.pi - ang) def Euc_met(p, q): return np.linalg.norm(p - q) def RP2_dist(p, q): return min(np.linalg.norm(p - q), np.linalg.norm(p + q)) def T2_dist(p, q): diff = np.abs(p - q) torus_diff = np.minimum(diff, 2 * np.pi - diff) return np.linalg.norm(torus_diff) def KB_flat_dist(p, q): """ Scalar Klein bottle flat distance on angle coords p=(base,fiber), q=(base,fiber), using (b,f)~(b+pi,-f). """ p = np.asarray(p, dtype=float).reshape(2,) q = np.asarray(q, dtype=float).reshape(2,) def _t2(pv, qv): diff = np.abs(pv - qv) diff = np.minimum(diff, 2.0 * np.pi - diff) return float(np.linalg.norm(diff)) qg = np.array([(q[0] + np.pi) % (2.0 * np.pi), (-q[1]) % (2.0 * np.pi)], dtype=float) return min(_t2(p, q), _t2(p, qg)) def T2_diag_flat_dist(p, q): """ Scalar flat distance for the Z2 quotient on (base,fiber): (b,f) ~ (b+pi, f+pi) """ p = np.asarray(p, dtype=float).reshape(2,) q = np.asarray(q, dtype=float).reshape(2,) def _t2(pv, qv): diff = np.abs(pv - qv) diff = np.minimum(diff, 2.0 * np.pi - diff) return float(np.linalg.norm(diff)) qg = np.array([(q[0] + np.pi) % (2.0 * np.pi), (q[1] + np.pi) % (2.0 * np.pi)], dtype=float) return min(_t2(p, q), _t2(p, qg)) # ============================================================ # Distance matrices helper (backwards compatible) # ============================================================ def get_dist_mat(data1, data2=None, metric=Euc_met): """ Backwards-compatible distance matrix helper. metric can be: - one of the old scalar functions (Euc_met, S1_dist, RP1_dist, ...) - a new Metric object with .pairwise - an arbitrary callable(p,q) (requires SciPy for fallback) """ X = np.asarray(data1) Y = X if data2 is None else np.asarray(data2) if metric is Euc_met: M = EuclideanMetric() return M.pairwise(X, None if data2 is None else Y) if metric is S1_dist: M = S1AngleMetric() return M.pairwise(X, None if data2 is None else Y) if metric is RP1_dist: M = RP1AngleMetric() return M.pairwise(X, None if data2 is None else Y) if metric is S1_dist2: M = S1UnitVectorMetric() return M.pairwise(X, None if data2 is None else Y) if metric is RP1_dist2: M = RP1UnitVectorMetric() return M.pairwise(X, None if data2 is None else Y) if metric is RP2_dist: M = RP2UnitVectorMetric() return M.pairwise(X, None if data2 is None else Y) if metric is T2_dist: M = T2FlatMetric() return M.pairwise(X, None if data2 is None else Y) if metric is KB_flat_dist: M = KleinBottleFlatMetric() return M.pairwise(X, None if data2 is None else Y) if metric is T2_diag_flat_dist: M = TorusDiagFlatMetric() return M.pairwise(X, None if data2 is None else Y) M = as_metric(metric) return M.pairwise(X, None if data2 is None else Y)