# synthetic/s2_bundles.py
from __future__ import annotations
from typing import Optional, Tuple
import numpy as np
__all__ = [
"sample_sphere",
"hopf_projection",
"spin3_adjoint_to_so3",
"so3_to_s2_projection",
"sample_s2_trivial",
"tangent_frame_on_s2",
"sample_s2_unit_tangent",
]
# ----------------------------
# RNG + numerics helpers
# ----------------------------
def _get_rng(rng: Optional[np.random.Generator]) -> np.random.Generator:
return np.random.default_rng() if rng is None else rng
def _safe_normalize(x: np.ndarray, *, axis: int = -1, eps: float = 1e-12) -> np.ndarray:
"""Normalize x along `axis`, guarding against near-zero norms."""
x = np.asarray(x, dtype=float)
nrm = np.linalg.norm(x, axis=axis, keepdims=True)
nrm = np.maximum(nrm, float(eps))
return x / nrm
# ----------------------------
# Sphere sampling
# ----------------------------
[docs]
def sample_sphere(
n: int,
dim: int = 2,
*,
rng: Optional[np.random.Generator] = None,
) -> np.ndarray:
"""
Sample n points ~uniformly from S^{dim} ⊂ R^{dim+1} via Gaussian normalization.
Examples
--------
dim=2 -> S^2 in R^3, output shape (n,3)
dim=3 -> S^3 in R^4, output shape (n,4)
"""
n = int(n)
dim = int(dim)
if n <= 0:
raise ValueError(f"n must be positive. Got {n}.")
if dim < 0:
raise ValueError(f"dim must be >= 0. Got {dim}.")
rng = _get_rng(rng)
x = rng.normal(size=(n, dim + 1))
return _safe_normalize(x, axis=1)
# ----------------------------
# Hopf / Spin(3) / SO(3) helpers
# ----------------------------
[docs]
def hopf_projection(
data: np.ndarray,
*,
v: Optional[np.ndarray] = None,
eps: float = 1e-12,
) -> np.ndarray:
"""
Generalized Hopf projection defined by q ↦ q v q^{-1}, where v ∈ S^2 ⊂ Im(H).
Parameters
----------
data : (n,4) real or (n,2) complex
Quaternion coordinates: q = (a,b,c,d) with z1=a+ib, z2=c+id.
v : (3,) array-like, optional
Axis vector in R^3. Will be normalized. Default is e1 = (1,0,0).
Returns
-------
(n,3) array on S^2.
"""
data = np.asarray(data)
if data.ndim != 2:
raise ValueError(f"data must be 2D. Got shape {data.shape}.")
if v is None:
v = np.array([1.0, 0.0, 0.0], dtype=float)
else:
v = np.asarray(v, dtype=float).reshape(3,)
v = v / max(np.linalg.norm(v), float(eps))
# Parse quaternion components
if np.iscomplexobj(data):
if data.shape[1] != 2:
raise ValueError(f"complex data must have shape (n,2). Got {data.shape}.")
z1 = data[:, 0]
z2 = data[:, 1]
a, b = z1.real, z1.imag
c, d = z2.real, z2.imag
else:
if data.shape[1] != 4:
raise ValueError(f"real data must have shape (n,4). Got {data.shape}.")
a, b, c, d = data[:, 0], data[:, 1], data[:, 2], data[:, 3]
# Normalize q
nrm = np.sqrt(a * a + b * b + c * c + d * d)
nrm = np.maximum(nrm, float(eps))
a, b, c, d = a / nrm, b / nrm, c / nrm, d / nrm
# Unit quaternion vector rotation:
# For q = (a, u) with u=(b,c,d), rotate v by:
# v' = v + 2a(u×v) + 2(u×(u×v))
u = np.stack([b, c, d], axis=1) # (n,3)
v0 = v[None, :] # (1,3) broadcasts
uv = np.cross(u, v0)
uuv = np.cross(u, uv)
out = v0 + 2.0 * a[:, None] * uv + 2.0 * uuv
out = out / np.maximum(np.linalg.norm(out, axis=1, keepdims=True), float(eps))
return out
def spin3_adjoint_to_so3(data: np.ndarray, *, eps: float = 1e-12) -> np.ndarray:
"""
Adjoint map Spin(3) ≅ S^3 (unit quaternions) -> SO(3), vectorized.
Input formats (same as hopf_projection):
- real (n,4): [a,b,c,d] = [Re z1, Im z1, Re z2, Im z2]
- complex (n,2): [z1,z2] where z1=a+ib, z2=c+id
Returns
-------
R_flat : (n,9) array, row-major flattening of 3x3 matrices.
"""
data = np.asarray(data)
if data.ndim != 2:
raise ValueError(f"data must be D. Got shape {data.shape}.")
if np.iscomplexobj(data):
if data.shape[1] != 2:
raise ValueError(f"complex data must have shape (n,2). Got {data.shape}.")
z1 = data[:, 0]
z2 = data[:, 1]
a, b = z1.real, z1.imag
c, d = z2.real, z2.imag
else:
if data.shape[1] != 4:
raise ValueError(f"real data must have shape (n,4). Got {data.shape}.")
a, b, c, d = data[:, 0], data[:, 1], data[:, 2], data[:, 3]
nrm = np.sqrt(a * a + b * b + c * c + d * d)
nrm = np.maximum(nrm, float(eps))
a, b, c, d = a / nrm, b / nrm, c / nrm, d / nrm
aa, bb, cc, dd = a * a, b * b, c * c, d * d
ab, ac, ad = a * b, a * c, a * d
bc, bd, cd = b * c, b * d, c * d
R_flat = np.empty((a.shape[0], 9), dtype=float)
# Row-major order:
# [r00 r01 r02 r10 r11 r12 r20 r21 r22]
R_flat[:, 0] = aa + bb - cc - dd
R_flat[:, 1] = 2.0 * (bc - ad)
R_flat[:, 2] = 2.0 * (bd + ac)
R_flat[:, 3] = 2.0 * (bc + ad)
R_flat[:, 4] = aa - bb + cc - dd
R_flat[:, 5] = 2.0 * (cd - ab)
R_flat[:, 6] = 2.0 * (bd - ac)
R_flat[:, 7] = 2.0 * (cd + ab)
R_flat[:, 8] = aa - bb - cc + dd
return R_flat
def so3_to_s2_projection(
R: np.ndarray,
*,
v: Optional[np.ndarray] = None,
eps: float = 1e-12,
) -> np.ndarray:
"""
Projection SO(3) -> S^2 defined by R ↦ R v, with v ∈ S^2.
Accepts:
- (3,3)
- (n,3,3)
- (9,) row-major flatten
- (n,9) row-major flatten
Default v is e1 = (1,0,0).
"""
R = np.asarray(R, dtype=float)
if v is None:
v = np.array([1.0, 0.0, 0.0], dtype=float)
else:
v = np.asarray(v, dtype=float).reshape(3,)
v = v / max(np.linalg.norm(v), float(eps))
if R.shape == (3, 3):
return R @ v
if R.ndim == 3 and R.shape[1:] == (3, 3):
return np.einsum("nij,j->ni", R, v)
if R.shape == (9,):
return R.reshape(3, 3) @ v
if R.ndim == 2 and R.shape[1] == 9:
M = R.reshape(-1, 3, 3)
return np.einsum("nij,j->ni", M, v)
raise ValueError(f"Expected (3,3), (n,3,3), (9,), or (n,9). Got {R.shape}.")
# ----------------------------
# S^2 bundles / embeddings
# ----------------------------
[docs]
def sample_s2_trivial(
n_points: int,
*,
sigma: float = 0.0,
rng: Optional[np.random.Generator] = None,
radius_mean: float = 1.0,
radius_clip: Tuple[float, float] = (0.0, 5.0),
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Product bundle S^2 × S^1 embedded as (base ∈ R^3, fiber ∈ R^2) in R^5.
Returns
-------
data : (n_points, 5) = [base_x, base_y, base_z, fiber_u, fiber_v]
base_points : (n_points, 3) points on S^2
angles : (n_points,) fiber angles in radians
"""
n_points = int(n_points)
if n_points <= 0:
raise ValueError(f"n_points must be positive. Got {n_points}.")
rng = _get_rng(rng)
base_points = sample_sphere(n_points, dim=2, rng=rng)
angles = 2.0 * np.pi * rng.random(n_points)
radii = rng.normal(loc=float(radius_mean), scale=float(sigma), size=n_points)
radii = np.clip(radii, float(radius_clip[0]), float(radius_clip[1]))
fibers = np.column_stack([radii * np.cos(angles), radii * np.sin(angles)])
data = np.empty((n_points, 5), dtype=float)
data[:, :3] = base_points
data[:, 3:] = fibers
return data, base_points, angles
def tangent_frame_on_s2(p: np.ndarray, *, eps: float = 1e-12) -> Tuple[np.ndarray, np.ndarray]:
"""
Given p on S^2, return a positively oriented orthonormal basis (e1,e2) for T_p S^2.
Convention: e2 = p × e1.
Stable except very near poles; near poles uses a consistent fallback.
"""
p = np.asarray(p, dtype=float).reshape(3,)
p = _safe_normalize(p, axis=0)
x, y, _z = p
r_xy = np.hypot(x, y)
if r_xy <= float(eps):
e1 = np.array([1.0, 0.0, 0.0], dtype=float)
e1 = e1 - np.dot(e1, p) * p
e1 = _safe_normalize(e1, axis=0)
else:
phi = np.arctan2(y, x)
e1 = np.array([-np.sin(phi), np.cos(phi), 0.0], dtype=float)
e1 = _safe_normalize(e1, axis=0)
e2 = np.cross(p, e1)
e2 = _safe_normalize(e2, axis=0)
return e1, e2
def sample_s2_unit_tangent(
n_points: int,
*,
sigma: float = 0.0,
equator: bool = False,
rng: Optional[np.random.Generator] = None,
radius_mean: float = 1.0,
radius_clip: Tuple[float, float] = (0.0, 5.0),
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Sample the (scaled) unit tangent bundle of S^2 as (tangent ∈ R^3, base ∈ R^3) in R^6.
If equator=True, restrict base points to the equator z=0.
"""
n_points = int(n_points)
if n_points <= 0:
raise ValueError(f"n_points must be positive. Got {n_points}.")
rng = _get_rng(rng)
if equator:
theta0 = 2.0 * np.pi * rng.random(n_points)
base_points = np.column_stack([np.cos(theta0), np.sin(theta0), np.zeros(n_points)])
else:
base_points = sample_sphere(n_points, dim=2, rng=rng)
angles = 2.0 * np.pi * rng.random(n_points)
radii = rng.normal(loc=float(radius_mean), scale=float(sigma), size=n_points)
radii = np.clip(radii, float(radius_clip[0]), float(radius_clip[1]))
x = base_points[:, 0]
y = base_points[:, 1]
r_xy = np.hypot(x, y)
near = r_xy <= 1e-12
e1 = np.column_stack([-y, x, np.zeros(n_points)])
e1 = _safe_normalize(e1, axis=1)
if np.any(near):
p_near = base_points[near]
e1_near = np.tile(np.array([1.0, 0.0, 0.0]), (p_near.shape[0], 1))
e1_near = e1_near - (np.sum(e1_near * p_near, axis=1, keepdims=True) * p_near)
e1_near = _safe_normalize(e1_near, axis=1)
e1[near] = e1_near
e2 = np.cross(base_points, e1)
e2 = _safe_normalize(e2, axis=1)
ca = np.cos(angles)[:, None]
sa = np.sin(angles)[:, None]
tangent = radii[:, None] * (ca * e1 + sa * e2)
data = np.empty((n_points, 6), dtype=float)
data[:, :3] = tangent
data[:, 3:] = base_points
return data, base_points, angles