Source code for circle_bundles.viz.angles

from __future__ import annotations

from typing import Dict, List, Optional, Sequence, Tuple

import numpy as np


__all__ = [
    "set_pi_ticks",
    "fit_o2_on_circle",
    "align_angles_to",
    "compare_angle_pairs",
    "compare_trivs",
]


# ----------------------------
# Tick helpers
# ----------------------------

def set_pi_ticks(ax, fontsize: int = 12) -> None:
    ticks = [0, np.pi / 2, np.pi, 3 * np.pi / 2, 2 * np.pi]
    labels = [r"$0$", r"$\frac{\pi}{2}$", r"$\pi$", r"$\frac{3\pi}{2}$", r"$2\pi$"]
    ax.set_xticks(ticks)
    ax.set_xticklabels(labels, fontsize=fontsize)
    ax.set_yticks(ticks)
    ax.set_yticklabels(labels, fontsize=fontsize)


def _set_angle_ticks(ax, *, which: str, ticks: List[float], labels: List[str], fontsize: int) -> None:
    if which == "x":
        ax.set_xticks(ticks)
        ax.set_xticklabels(labels, fontsize=fontsize)
    elif which == "y":
        ax.set_yticks(ticks)
        ax.set_yticklabels(labels, fontsize=fontsize)
    else:
        raise ValueError("which must be 'x' or 'y'.")


def _pi_ticks_0_to_2pi() -> Tuple[List[float], List[str]]:
    ticks = [0, np.pi / 2, np.pi, 3 * np.pi / 2, 2 * np.pi]
    labels = [r"$0$", r"$\frac{\pi}{2}$", r"$\pi$", r"$\frac{3\pi}{2}$", r"$2\pi$"]
    return ticks, labels


def _pi_ticks_0_to_pi() -> Tuple[List[float], List[str]]:
    ticks = [0, np.pi / 2, np.pi]
    labels = [r"$0$", r"$\frac{\pi}{2}$", r"$\pi$"]
    return ticks, labels


# ----------------------------
# O(2) alignment on S^1
# ----------------------------

def fit_o2_on_circle(angles_ref: np.ndarray, angles_mov: np.ndarray) -> Tuple[np.ndarray, float, float]:
    """
    Fit A in O(2) so that (cos,sin)(angles_mov) @ A ≈ (cos,sin)(angles_ref).

    Returns:
      A : (2,2) orthogonal matrix
      mean_err : mean Euclidean error in R^2 after alignment
      rms_err  : RMS Euclidean error in R^2 after alignment
    """
    a1 = np.asarray(angles_ref).reshape(-1)
    a2 = np.asarray(angles_mov).reshape(-1)
    if a1.shape != a2.shape:
        raise ValueError("angles_ref and angles_mov must have same shape.")
    if a1.size == 0:
        raise ValueError("Empty angle arrays.")

    X = np.c_[np.cos(a2), np.sin(a2)]  # moving
    Y = np.c_[np.cos(a1), np.sin(a1)]  # reference

    M = X.T @ Y
    U, _, Vt = np.linalg.svd(M)
    A = U @ Vt  # in O(2)

    R = X @ A
    diff = R - Y
    per_pt = np.linalg.norm(diff, axis=1)
    mean_err = float(np.mean(per_pt))
    rms_err = float(np.sqrt(np.mean(per_pt**2)))
    return A, mean_err, rms_err


def align_angles_to(angles_ref: np.ndarray, angles_mov: np.ndarray) -> np.ndarray:
    """Return angles_mov after optimal O(2) alignment to angles_ref (mod 2π)."""
    A, _, _ = fit_o2_on_circle(angles_ref, angles_mov)
    X = np.c_[np.cos(angles_mov), np.sin(angles_mov)]
    X2 = X @ A
    return np.mod(np.arctan2(X2[:, 1], X2[:, 0]), 2 * np.pi)


# ----------------------------
# Plot helpers
# ----------------------------

def compare_angle_pairs(
    angle_arrays: List[np.ndarray],
    pairs: List[Tuple[int, int]],
    *,
    labels: Optional[List[str]] = None,
    align: bool = False,
    s: float = 1.0,
    fontsize: int = 14,
    ncols: int | str = "auto",
    titles: Optional[List[str]] = None,
    titlesize: int = 14,
    show_metrics: bool = True,
    metric: str = "mean",  # "mean" or "rms"
    x_range: Tuple[float, float] = (0.0, 2 * np.pi),
    y_range: Tuple[float, float] = (0.0, 2 * np.pi),
    x_ticks: Tuple[List[float], List[str]] | None | str = "auto",
    y_ticks: Tuple[List[float], List[str]] | None | str = "auto",
):
    """
    Scatter plots angle_arrays[i] vs angle_arrays[j] for each (i,j) in pairs.

    If align=True, aligns the *second* array in each pair to the first using O(2).

    Tick control:
      - "auto" (default): choose nice pi-ticks for [0,π] or [0,2π] if matched.
      - None: leave matplotlib defaults.
      - (ticks, labels): explicit.
    """
    # Lazy import so importing viz module doesn't require a GUI backend.
    import matplotlib.pyplot as plt

    if labels is None:
        labels = [f"angle[{i}]" for i in range(len(angle_arrays))]

    if ncols == "auto":
        nrows = int(np.ceil(np.sqrt(len(pairs)))) if len(pairs) > 0 else 1
        ncols_i = int(np.ceil(len(pairs) / nrows)) if len(pairs) > 0 else 1
    else:
        ncols_i = int(ncols)
        nrows = int(np.ceil(len(pairs) / ncols_i)) if len(pairs) > 0 else 1

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols_i, figsize=(5 * ncols_i, 5 * nrows))
    axes = np.array(axes).reshape(-1)

    def resolve_ticks(ticks_setting, rng):
        if ticks_setting is None:
            return None
        if ticks_setting != "auto":
            return ticks_setting  # explicit (ticks, labels)

        lo, hi = float(rng[0]), float(rng[1])
        if np.isclose(lo, 0.0) and np.isclose(hi, 2 * np.pi):
            return _pi_ticks_0_to_2pi()
        if np.isclose(lo, 0.0) and np.isclose(hi, np.pi):
            return _pi_ticks_0_to_pi()
        return None

    xt = resolve_ticks(x_ticks, x_range)
    yt = resolve_ticks(y_ticks, y_range)

    for t, (i, j) in enumerate(pairs):
        a = np.asarray(angle_arrays[i]).reshape(-1)
        b = np.asarray(angle_arrays[j]).reshape(-1)

        if a.shape != b.shape:
            raise ValueError(f"Pair {t}: shapes differ: {a.shape} vs {b.shape}")

        _, mean_err, rms_err = fit_o2_on_circle(a, b)
        b_plot = align_angles_to(a, b) if align else b

        ax = axes[t]
        ax.scatter(a, b_plot, s=float(s), alpha=0.7)

        ax.set_xlim(*x_range)
        ax.set_ylim(*y_range)

        if xt is not None:
            ticks, ticklabels = xt
            _set_angle_ticks(ax, which="x", ticks=list(ticks), labels=list(ticklabels), fontsize=int(fontsize))
        if yt is not None:
            ticks, ticklabels = yt
            _set_angle_ticks(ax, which="y", ticks=list(ticks), labels=list(ticklabels), fontsize=int(fontsize))

        ax.set_xlabel(labels[i], fontsize=int(fontsize))
        ax.set_ylabel(labels[j] + (" (aligned)" if align else ""), fontsize=int(fontsize))

        if titles is not None:
            title = titles[t]
        else:
            if show_metrics:
                val = mean_err if metric == "mean" else rms_err
                title = f"{metric} err (circle): {val:.2g}"
            else:
                title = ""
        ax.set_title(title, fontsize=int(titlesize))

        ax.grid(True, linestyle="--", alpha=0.4)
        ax.set_aspect("auto")
        ax.set_box_aspect(1)

    for k in range(len(pairs), len(axes)):
        axes[k].axis("off")

    plt.tight_layout()
    return fig


def _select_edges_by_error(
    U: np.ndarray,
    f: np.ndarray,
    edges: List[Tuple[int, int]],
    *,
    metric: str = "mean",
    max_pairs: int = 25,
) -> Tuple[List[Tuple[int, int]], Dict[Tuple[int, int], float], Dict[Tuple[int, int], int]]:
    """
    Decide which edges to plot.

    Returns:
      selected_edges
      err_by_edge: edge -> error
      overlap_by_edge: edge -> |Uj ∩ Uk|
    """
    if metric not in ("mean", "rms"):
        raise ValueError("metric must be 'mean' or 'rms'.")

    err_by_edge: Dict[Tuple[int, int], float] = {}
    ov_by_edge: Dict[Tuple[int, int], int] = {}

    for (j, k) in edges:
        j, k = int(j), int(k)
        mask = U[j] & U[k]
        ov = int(mask.sum())
        if ov == 0:
            continue

        a = f[j, mask]
        b = f[k, mask]
        _, mean_err, rms_err = fit_o2_on_circle(a, b)
        err = float(mean_err if metric == "mean" else rms_err)

        e = (j, k)
        err_by_edge[e] = err
        ov_by_edge[e] = ov

    candidates = list(err_by_edge.keys())
    if not candidates:
        return [], err_by_edge, ov_by_edge

    if len(candidates) <= int(max_pairs):
        selected = [e for e in edges if e in err_by_edge]
        return selected, err_by_edge, ov_by_edge

    candidates_sorted = sorted(candidates, key=lambda e: err_by_edge[e])
    best = candidates_sorted[0]
    worst = candidates_sorted[-1]
    median = candidates_sorted[len(candidates_sorted) // 2]

    selected: List[Tuple[int, int]] = []
    for e in (worst, median, best):  # worst→median→best
        if e not in selected:
            selected.append(e)

    return selected, err_by_edge, ov_by_edge

def compare_trivs_from_U(
    U: np.ndarray,
    f: np.ndarray,
    *,
    edges: Optional[List[Tuple[int, int]]] = None,
    ncols: int | str = "auto",
    title_size: int = 14,
    align: bool = False,
    s: float = 1.0,
    save_path: Optional[str] = None,
    show: bool = True,
    max_pairs: int = 25,
    metric: str = "mean",
    return_selected: bool = False,
):
    """
    Compare local trivializations on overlaps for each nerve edge (j,k), using cover-free inputs.

    New behavior:
      - If number of nonempty overlaps <= max_pairs: plot all.
      - Otherwise: plot WORST / MEDIAN / BEST (by chosen metric).
    """
    import matplotlib.pyplot as plt

    U = np.asarray(U, dtype=bool)
    if U.ndim != 2:
        raise ValueError("U must be 2D (n_sets, n_samples).")

    n_sets, n_samples = U.shape
    f = np.asarray(f)
    if f.shape != (n_sets, n_samples):
        raise ValueError(f"f must have shape {(n_sets, n_samples)}, got {f.shape}")

    # If edges not provided: infer all nonempty overlaps (same policy as cover.nerve_edges())
    if edges is None:
        edges = []
        for j in range(n_sets):
            Uj = U[j]
            for k in range(j + 1, n_sets):
                if np.any(Uj & U[k]):
                    edges.append((j, k))

    # normalize / cast
    edges = [(int(a), int(b)) for (a, b) in edges if int(a) != int(b)]
    edges = [(min(a, b), max(a, b)) for (a, b) in edges]

    selected_edges, err_by_edge, ov_by_edge = _select_edges_by_error(
        U,
        f,
        edges,
        metric=metric,
        max_pairs=int(max_pairs),
    )

    if not selected_edges:
        raise ValueError("No nonempty overlaps found on the provided edges.")

    # were we subsampled?
    n_nonempty = len([e for e in edges if (int(e[0]), int(e[1])) in err_by_edge])
    subsampled = (len(selected_edges) < n_nonempty)

    tag_by_edge: Dict[Tuple[int, int], str] = {}
    if subsampled and len(selected_edges) >= 2:
        sorted_edges = sorted(err_by_edge.keys(), key=lambda e: err_by_edge[e])
        best = sorted_edges[0]
        worst = sorted_edges[-1]
        median = sorted_edges[len(sorted_edges) // 2]
        tag_by_edge[best] = "BEST"
        tag_by_edge[worst] = "WORST"
        tag_by_edge[median] = "MEDIAN"
        ncols = 3

    angle_arrays: List[np.ndarray] = []
    labels: List[str] = []
    titles: List[str] = []
    pairs: List[Tuple[int, int]] = []

    for (j, k) in selected_edges:
        j, k = int(j), int(k)
        mask = U[j] & U[k]
        if not np.any(mask):
            continue

        a = f[j, mask]
        b = f[k, mask]

        angle_arrays.append(a)
        angle_arrays.append(b)
        labels.append(fr"$f_{{{j}}}$")
        labels.append(fr"$f_{{{k}}}$")

        ov = int(ov_by_edge[(j, k)])
        err = float(err_by_edge[(j, k)])

        tag = tag_by_edge.get((j, k), "")
        tag_str = f" [{tag}]" if tag else ""
        titles.append(
            tag_str
            + fr"  ($|U_{{{j}}}\cap U_{{{k}}}|={ov}$, {metric} err={err:.2g})"
        )

        pairs.append((2 * len(pairs), 2 * len(pairs) + 1))

    fig = compare_angle_pairs(
        angle_arrays,
        pairs,
        labels=labels,
        align=align,
        s=s,
        ncols=ncols,
        titles=titles,
        titlesize=title_size,
        metric=metric,
        show_metrics=False,
    )

    if save_path is not None:
        fig.savefig(save_path, bbox_inches="tight")

    if show:
        plt.show()
    else:
        plt.close(fig)

    if return_selected:
        return fig, selected_edges, err_by_edge

    return fig


[docs] def compare_trivs( cover, f: np.ndarray, *, edges: Optional[List[Tuple[int, int]]] = None, ncols: int | str = "auto", title_size: int = 14, align: bool = False, s: float = 1.0, save_path: Optional[str] = None, show: bool = True, max_pairs: int = 25, metric: str = "mean", return_selected: bool = False, ): """ Backwards-compatible wrapper: accepts `cover` and calls compare_trivs_from_U(U=cover.U,...). """ U = np.asarray(cover.U, dtype=bool) if edges is None: # preserve old behavior if cover provides nerve_edges() if hasattr(cover, "nerve_edges") and callable(getattr(cover, "nerve_edges")): edges = list(cover.nerve_edges()) else: edges = None # let compare_trivs_from_U infer return compare_trivs_from_U( U=U, f=f, edges=edges, ncols=ncols, title_size=title_size, align=align, s=s, save_path=save_path, show=show, max_pairs=max_pairs, metric=metric, return_selected=return_selected, )