Source code for circle_bundles.viz.bundle_dash

# circle_bundles/viz/bundle_dash.py
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Tuple

import io
import socket
import numpy as np

__all__ = [
    "BundleVizInputs",
    "find_free_port",
    "prepare_bundle_viz_inputs",
    "prepare_bundle_viz_inputs_from_bundle",
    "make_bundle_app",
    "run_bundle_app",
    "show_bundle_vis",
    "save_bundle_snapshot",
]


# ----------------------------
# Data container
# ----------------------------

@dataclass
class BundleVizInputs:
    base_points: np.ndarray                 # (m, d_base)
    data: np.ndarray                        # (m, d_data)
    dist_mat: np.ndarray                    # (m, m)
    colors: Optional[np.ndarray] = None     # (m,)
    densities: Optional[np.ndarray] = None  # (m,)

    # Total-space / fiber landmarks: list of (m,) bool masks over *downsampled* points
    data_landmark_masks: Optional[List[np.ndarray]] = None

    # Base landmarks:
    # - base_landmark_masks: list of (m,) bool masks over *downsampled* points (subset of base_points)
    # - base_landmark_points: list of (L_i, d_base) arrays (extra points not necessarily in dataset)
    base_landmark_masks: Optional[List[np.ndarray]] = None
    base_landmark_points: Optional[List[np.ndarray]] = None

    sample_inds: Optional[np.ndarray] = None           # (m,) indices into original


# ----------------------------
# Helpers
# ----------------------------

def find_free_port() -> int:
    """Pick an available local port (best-effort)."""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("", 0))
        return int(s.getsockname()[1])


def _embed_base_points_pca(base_points: np.ndarray) -> Tuple[np.ndarray, np.ndarray, Any]:
    """
    Return:
      emb: (n,3) embedding (pads with zeros if base dim < 3)
      explained: cumulative explained variance ratio, length <=3
      pca: fitted PCA object (so we can transform landmark points consistently)

    Notes
    -----
    - Imports scikit-learn lazily (only when this is called).
    """
    try:
        from sklearn.decomposition import PCA
    except ImportError as e:
        raise ImportError(
            "bundle_dash PCA embedding requires scikit-learn. Install with `pip install scikit-learn`."
        ) from e

    base_points = np.asarray(base_points)
    if base_points.ndim != 2:
        raise ValueError("base_points must be 2D (n_points, dim).")
    d = int(base_points.shape[1])
    if d <= 0:
        raise ValueError("base_points has zero columns.")

    pca = PCA(n_components=min(3, d))
    emb = pca.fit_transform(base_points.astype(float))
    explained = np.cumsum(pca.explained_variance_ratio_)[: min(3, d)]

    if emb.shape[1] < 3:
        emb = np.pad(emb, ((0, 0), (0, 3 - emb.shape[1])), mode="constant")
    return emb, explained, pca


def _normalize_to_unit_interval(vals: np.ndarray) -> np.ndarray:
    vals = np.asarray(vals, dtype=float).reshape(-1)
    vmin = float(np.min(vals))
    vmax = float(np.max(vals))
    if vmax <= vmin:
        return np.zeros_like(vals, dtype=float)
    return (vals - vmin) / (vmax - vmin)


def _normalize_bool_masks_any(inds_or_masks: Any, n: int, *, name: str) -> Optional[List[np.ndarray]]:
    """
    Accept:
      - None
      - ndarray (n,) bool-ish
      - ndarray (k,n) or (n,k) bool-ish
      - list/tuple of arrays (n,)
    Return: list of bool masks [(n,), ...] or None
    """
    if inds_or_masks is None:
        return None

    if isinstance(inds_or_masks, np.ndarray):
        arr = np.asarray(inds_or_masks)
        if arr.ndim == 1:
            if arr.shape[0] != n:
                raise ValueError(f"{name} length mismatch: expected {n}, got {arr.shape[0]}")
            return [arr.astype(bool)]
        if arr.ndim == 2:
            # allow (n,k) or (k,n); canonicalize to (k,n)
            if arr.shape[0] == n and arr.shape[1] != n:
                arr = arr.T
            if arr.shape[1] != n:
                raise ValueError(f"2D {name} must have one axis length {n}; got {arr.shape}")
            return [arr[i].astype(bool) for i in range(arr.shape[0])]
        raise ValueError(f"{name} ndarray must be 1D or 2D.")

    masks: List[np.ndarray] = []
    for m in inds_or_masks:
        mm = np.asarray(m).astype(bool)
        if mm.shape != (n,):
            raise ValueError(f"Each {name} mask must be shape ({n},), got {mm.shape}")
        masks.append(mm)
    return masks


def _subset_masks(masks: Optional[List[np.ndarray]], sample_inds: np.ndarray) -> Optional[List[np.ndarray]]:
    if masks is None:
        return None
    return [np.asarray(m, bool)[sample_inds] for m in masks]


def _normalize_base_landmarks(
    landmarks: Any,
    *,
    base_points: np.ndarray,
    name: str = "landmarks",
) -> Tuple[Optional[List[np.ndarray]], Optional[List[np.ndarray]]]:
    """
    Base-landmarks input normalization.

    Accepts:
      - None
      - points: ndarray (L, d_base)
      - mask:   ndarray (n,) bool
      - inds:   ndarray (L,) int
      - list/tuple of any mixture of the above (each element is one "group")

    Returns:
      (base_landmark_masks_full, base_landmark_points_full)
        - base_landmark_masks_full: list of (n,) bool masks selecting points in base_points
        - base_landmark_points_full: list of (L_i, d_base) arrays (extra points)
    """
    if landmarks is None:
        return None, None

    bp = np.asarray(base_points)
    if bp.ndim != 2:
        raise ValueError("base_points must be 2D.")
    n, d_base = int(bp.shape[0]), int(bp.shape[1])

    def _one(obj: Any) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
        if obj is None:
            return None, None

        arr = np.asarray(obj)

        # bool mask
        if arr.dtype == bool and arr.ndim == 1:
            if arr.shape[0] != n:
                raise ValueError(f"{name} bool mask length mismatch: expected {n}, got {arr.shape[0]}")
            return arr.astype(bool), None

        # index array
        if np.issubdtype(arr.dtype, np.integer) and arr.ndim == 1:
            idx = arr.astype(int).reshape(-1)
            if idx.size == 0:
                return np.zeros((n,), dtype=bool), None
            if np.any(idx < 0) or np.any(idx >= n):
                raise ValueError(f"{name} index out of bounds for n={n}.")
            mask = np.zeros((n,), dtype=bool)
            mask[idx] = True
            return mask, None

        # points array
        if arr.ndim == 2 and int(arr.shape[1]) == d_base:
            pts = np.asarray(arr, dtype=float)
            return None, pts

        raise ValueError(
            f"{name} entries must be bool mask (n,), int indices (L,), or points (L,d_base={d_base}). "
            f"Got array with shape {arr.shape} dtype {arr.dtype}."
        )

    masks: List[np.ndarray] = []
    pts_groups: List[np.ndarray] = []

    if isinstance(landmarks, (list, tuple)):
        for g in landmarks:
            m, p = _one(g)
            if m is not None:
                masks.append(m.astype(bool))
            if p is not None:
                pts_groups.append(np.asarray(p, dtype=float))
    else:
        m, p = _one(landmarks)
        if m is not None:
            masks.append(m.astype(bool))
        if p is not None:
            pts_groups.append(np.asarray(p, dtype=float))

    return (masks if masks else None), (pts_groups if pts_groups else None)


def _parse_click_index(clickData: Any) -> Optional[int]:
    """
    Dash clickData typically looks like:
      {"points":[{"pointIndex": i, ...}]}  (sometimes pointNumber)
    """
    if not clickData or not isinstance(clickData, dict):
        return None
    pts = clickData.get("points", None)
    if not pts or not isinstance(pts, list):
        return None
    pt0 = pts[0] if pts else None
    if not isinstance(pt0, dict):
        return None
    idx = pt0.get("pointIndex", pt0.get("pointNumber", None))
    if idx is None:
        return None
    try:
        return int(idx)
    except Exception:
        return None


def _call_get_dist_mat(
    get_dist_mat: Optional[Callable[..., np.ndarray]],
    bp: np.ndarray,
    base_metric: Any,
) -> np.ndarray:
    """
    Priority:
      1) If base_metric has .pairwise, use base_metric.pairwise(bp) directly.
      2) Else if get_dist_mat provided, try get_dist_mat(bp, metric=base_metric) then fallback get_dist_mat(bp)
      3) Else fallback to circle_bundles.metrics.get_dist_mat(bp, metric=base_metric)
    """
    # 1) Metric object fast path
    if base_metric is not None and hasattr(base_metric, "pairwise"):
        return np.asarray(base_metric.pairwise(bp))

    # 2) User-provided callable
    if get_dist_mat is not None:
        try:
            return np.asarray(get_dist_mat(bp, metric=base_metric))
        except TypeError:
            return np.asarray(get_dist_mat(bp))

    # 3) Library fallback
    from ..metrics import get_dist_mat as _get_dist_mat
    return np.asarray(_get_dist_mat(bp, metric=base_metric))


def _fig_to_png_bytes(fig, *, scale: int = 1) -> bytes:
    """
    Convert a plotly figure to PNG bytes (requires kaleido).
    - scale=1 is much faster than scale=2 for 3D.
    """
    try:
        import plotly.io as pio
        # Avoid any attempt to fetch mathjax in headless chrome (can stall on some setups)
        pio.defaults.mathjax = None
        return fig.to_image(format="png", engine="kaleido", scale=int(scale))
    except Exception as e:
        msg = str(e)
        if "not compatible with this version of Kaleido" in msg or "compatible with this version of Kaleido" in msg:
            raise RuntimeError(
                "Plotly/Kaleido version mismatch.\n"
                "Fix by either:\n"
                "  pip install -U 'plotly>=6.1.1'   (recommended)\n"
                "or\n"
                "  pip install -U 'kaleido==0.2.1'  (for Plotly 5.x)\n"
            ) from e
        if "requires the kaleido package" in msg.lower():
            raise RuntimeError("Static export requires kaleido. Install with `pip install -U kaleido`.") from e
        raise


def _combine_pngs_to_pdf_bytes(
    png_left: bytes,
    png_right: bytes,
    *,
    gap: int = 24,   # pixels between panels
    pad: int = 0,    # outer padding
) -> bytes:
    """
    Combine two PNGs side-by-side into ONE tight PDF (no extra whitespace).
    Requires pillow: pip install pillow
    """
    try:
        from PIL import Image
    except ImportError as e:
        raise ImportError("Combining snapshots requires pillow. Install with `pip install pillow`.") from e

    L = Image.open(io.BytesIO(png_left)).convert("RGB")
    R = Image.open(io.BytesIO(png_right)).convert("RGB")

    W = L.width + int(gap) + R.width + 2 * int(pad)
    H = max(L.height, R.height) + 2 * int(pad)
    out = Image.new("RGB", (W, H), (255, 255, 255))

    yL = int(pad) + (H - 2 * int(pad) - L.height) // 2
    yR = int(pad) + (H - 2 * int(pad) - R.height) // 2
    out.paste(L, (int(pad), yL))
    out.paste(R, (int(pad) + L.width + int(gap), yR))

    buf = io.BytesIO()
    out.save(buf, format="PDF")  # tight to the composed image bounds
    return buf.getvalue()


# ----------------------------
# Prep (pure)
# ----------------------------

def prepare_bundle_viz_inputs(
    *,
    base_points: np.ndarray,
    data: np.ndarray,
    get_dist_mat: Optional[Callable[..., np.ndarray]] = None,
    full_dist_mat: Optional[np.ndarray] = None,
    base_metric: Any = None,
    same_metric: bool = False,
    max_samples: int = 10_000,
    colors: Optional[np.ndarray] = None,
    densities: Optional[np.ndarray] = None,
    data_landmark_inds: Any = None,
    landmarks: Any = None,
    rng: Optional[np.random.Generator] = None,
) -> BundleVizInputs:
    """
    Produce a downsampled view of base_points/data plus a distance matrix.

    Landmarks
    ---------
    - data_landmark_inds: masks over points (n,) or list of such, used to highlight points in the *fiber* PCA panel.
    - landmarks: base-landmarks to highlight in the *base* panel.

    If full_dist_mat is provided and same_metric=True, we will:
      - use it directly if no downsampling happened
      - otherwise subset it via dist_mat = full_dist_mat[np.ix_(sample_inds, sample_inds)]
    """
    base_points = np.asarray(base_points)
    data = np.asarray(data)

    if base_points.ndim == 1:
        base_points = base_points.reshape(-1, 1)

    if base_points.ndim != 2:
        raise ValueError("base_points must be 2D.")
    n = int(base_points.shape[0])
    if data.ndim != 2 or data.shape[0] != n:
        raise ValueError(f"data and base_points must align: data {data.shape} vs base {base_points.shape}")

    if colors is not None:
        colors = np.asarray(colors).reshape(-1)
        if colors.shape[0] != n:
            raise ValueError("colors must have length n.")
    if densities is not None:
        densities = np.asarray(densities).reshape(-1)
        if densities.shape[0] != n:
            raise ValueError("densities must have length n.")

    if rng is None:
        rng = np.random.default_rng()

    # total-space landmark masks (over full n)
    data_landmark_masks_full = _normalize_bool_masks_any(data_landmark_inds, n, name="data_landmark_inds")

    # base landmarks (over full n masks and/or extra points)
    base_landmark_masks_full, base_landmark_points_full = _normalize_base_landmarks(
        landmarks, base_points=base_points, name="landmarks"
    )

    if n > int(max_samples):
        sample_inds = rng.choice(n, size=int(max_samples), replace=False)
        sample_inds.sort()  # stable order helps toggling / reproducibility
    else:
        sample_inds = np.arange(n, dtype=int)

    bp = base_points[sample_inds]
    X = data[sample_inds]
    c = colors[sample_inds] if colors is not None else None
    d = densities[sample_inds] if densities is not None else None

    data_lm = _subset_masks(data_landmark_masks_full, sample_inds)
    base_lm_masks = _subset_masks(base_landmark_masks_full, sample_inds)

    m = int(bp.shape[0])

    # Use / subset full dist matrix when available
    if full_dist_mat is not None and bool(same_metric):
        full_dist_mat = np.asarray(full_dist_mat)
        if full_dist_mat.shape != (n, n):
            raise ValueError(f"full_dist_mat must be (n,n) with n={n}. Got {full_dist_mat.shape}.")
        dist_mat = full_dist_mat[np.ix_(sample_inds, sample_inds)]
    else:
        dist_mat = _call_get_dist_mat(get_dist_mat, bp, base_metric)

    dist_mat = np.asarray(dist_mat)
    if dist_mat.shape != (m, m):
        raise ValueError(f"dist_mat must be (m,m). Got {dist_mat.shape} for m={m}")

    # extra base landmark points (not necessarily in the dataset)
    base_lm_pts: Optional[List[np.ndarray]] = None
    if base_landmark_points_full is not None:
        d_base = int(base_points.shape[1])
        base_lm_pts = []
        for pts in base_landmark_points_full:
            pts = np.asarray(pts, dtype=float)
            if pts.ndim != 2 or pts.shape[1] != d_base:
                raise ValueError(f"Each landmarks points group must be (L_i, {d_base}), got {pts.shape}")
            base_lm_pts.append(pts)

    return BundleVizInputs(
        base_points=bp,
        data=X,
        dist_mat=dist_mat,
        colors=c,
        densities=d,
        data_landmark_masks=data_lm,
        base_landmark_masks=base_lm_masks,
        base_landmark_points=base_lm_pts,
        sample_inds=sample_inds,
    )


def prepare_bundle_viz_inputs_from_bundle(
    bundle,
    *,
    get_dist_mat: Callable[..., np.ndarray],
    max_samples: int = 10_000,
    base_metric: Any = None,
    colors: Optional[np.ndarray] = None,
    densities: Optional[np.ndarray] = None,
    data_landmark_inds: Any = None,
    landmarks: Any = None,
    rng: Optional[np.random.Generator] = None,
) -> BundleVizInputs:
    """
    BundleResult-aware prep:
      - uses bundle.cover.base_points
      - uses bundle.data
      - uses bundle.cover.full_dist_mat if present
      - uses bundle.cover.metric as default base_metric
    """
    cover = bundle.cover
    base_points = getattr(cover, "base_points", None)
    if base_points is None:
        raise AttributeError("bundle.cover.base_points is missing (needed for show_bundle).")

    cover_metric = getattr(cover, "metric", None)
    if base_metric is None:
        base_metric = cover_metric
    if base_metric is None:
        from ..metrics import EuclideanMetric
        base_metric = EuclideanMetric()

    same_metric = (base_metric is cover_metric)
    full_dist_mat = getattr(cover, "full_dist_mat", None)

    return prepare_bundle_viz_inputs(
        base_points=np.asarray(base_points),
        data=np.asarray(bundle.data),
        get_dist_mat=get_dist_mat,
        full_dist_mat=full_dist_mat,
        base_metric=base_metric,
        same_metric=same_metric,
        max_samples=int(max_samples),
        colors=colors,
        densities=densities,
        data_landmark_inds=data_landmark_inds,
        landmarks=landmarks,
        rng=rng,
    )


# ----------------------------
# Pure figure construction (reused by Dash + snapshot saving)
# ----------------------------

def _make_figures(
    *,
    base_embedded: np.ndarray,                        # (n,3)
    explained_variance: np.ndarray,                   # (<=3,)
    base_landmark_masks: Optional[List[np.ndarray]],  # list of (n,) bool masks (subset of base points)
    base_landmarks_embedded: Optional[List[np.ndarray]],  # list of (L_i,3) extra points
    data: np.ndarray,                                # (n, d_data)
    dist_mat: np.ndarray,                            # (n,n)
    colors: Optional[np.ndarray],
    normalized_colors: Optional[np.ndarray],
    densities: Optional[np.ndarray],
    data_landmark_masks: Optional[List[np.ndarray]],  # list of (n,) bool masks (for fiber panel)
    selected_index: Optional[int],
    r: float,
    density_threshold: Optional[float] = None,
) -> Tuple["go.Figure", "go.Figure", str, str]:
    """
    Returns (fig_base, fig_data, label, variance_text).

    Notes
    -----
    - Imports plotly + sklearn lazily (only when called).
    """
    try:
        import plotly.graph_objects as go
    except ImportError as e:
        raise ImportError(
            "bundle_dash figure construction requires plotly. Install with `pip install plotly`."
        ) from e

    try:
        from sklearn.decomposition import PCA
    except ImportError as e:
        raise ImportError(
            "bundle_dash fiber PCA requires scikit-learn. Install with `pip install scikit-learn`."
        ) from e

    n = int(base_embedded.shape[0])

    # Marker sizes (priority: red > landmark > blue > gray)
    size_gray = 2
    size_blue = 3
    size_landmark = 5
    size_red = 7

    # --- base plot ---
    fig_base = go.Figure()
    fig_base.add_trace(
        go.Scatter3d(
            x=base_embedded[:, 0], y=base_embedded[:, 1], z=base_embedded[:, 2],
            mode="markers",
            marker=dict(size=size_gray, color="lightgray", opacity=0.5),
            hoverinfo="none",
            name="Base Points",
        )
    )

    # --- data plot ---
    fig_data = go.Figure()
    variance_text = f"PCA Variance (Base): {np.round(explained_variance, 3)}"
    label = "Selected Point: (none)"

    if selected_index is not None and 0 <= selected_index < n:
        label = f"Selected Point ({selected_index})"

        nearby_indices = np.where(dist_mat[selected_index] < float(r))[0]

        if densities is not None and density_threshold is not None:
            keep = np.asarray(densities[nearby_indices] > float(density_threshold), dtype=bool)
            filtered = nearby_indices[keep]
        else:
            filtered = nearby_indices

        # Base plot: neighbors first (blue), then landmarks, then selected (red) last.
        if filtered.size:
            fig_base.add_trace(
                go.Scatter3d(
                    x=base_embedded[filtered, 0],
                    y=base_embedded[filtered, 1],
                    z=base_embedded[filtered, 2],
                    mode="markers",
                    marker=dict(size=size_blue, color="blue", opacity=0.8),
                    name="Neighbors",
                    hoverinfo="none",
                )
            )

        # Base plot: dataset landmarks (masks over base points)
        if base_landmark_masks is not None:
            lm_colors = ["orange", "green", "purple", "cyan", "magenta", "yellow", "black"]
            for i, mask in enumerate(base_landmark_masks):
                mask = np.asarray(mask, bool)
                idx = np.where(mask)[0]
                if idx.size:
                    fig_base.add_trace(
                        go.Scatter3d(
                            x=base_embedded[idx, 0],
                            y=base_embedded[idx, 1],
                            z=base_embedded[idx, 2],
                            mode="markers",
                            marker=dict(size=size_landmark, color=lm_colors[i % len(lm_colors)], opacity=0.95),
                            name=f"Landmarks {i+1}",
                            hoverinfo="none",
                        )
                    )

        # Base plot: extra landmark points (not necessarily in dataset)
        if base_landmarks_embedded is not None:
            lm_colors = ["orange", "green", "purple", "cyan", "magenta", "yellow", "black"]
            for i, L3 in enumerate(base_landmarks_embedded):
                L3 = np.asarray(L3, dtype=float)
                if L3.size == 0:
                    continue
                fig_base.add_trace(
                    go.Scatter3d(
                        x=L3[:, 0], y=L3[:, 1], z=L3[:, 2],
                        mode="markers",
                        marker=dict(size=size_landmark, color=lm_colors[i % len(lm_colors)], opacity=0.95),
                        name=f"Landmarks (extra) {i+1}",
                        hoverinfo="none",
                    )
                )

        # Base plot: selected point last (red)
        fig_base.add_trace(
            go.Scatter3d(
                x=[base_embedded[selected_index, 0]],
                y=[base_embedded[selected_index, 1]],
                z=[base_embedded[selected_index, 2]],
                mode="markers",
                marker=dict(size=size_red, color="red", opacity=1.0),
                name="Selected",
                hoverinfo="none",
            )
        )

        # Fiber PCA
        nearby_data = data[filtered] if filtered.size else np.zeros((0, data.shape[1]), dtype=float)

        if nearby_data.shape[0] >= 2:
            pca_fiber = PCA(n_components=min(3, nearby_data.shape[1]))
            fiber_pca = pca_fiber.fit_transform(nearby_data)
            if fiber_pca.shape[1] < 3:
                fiber_pca = np.pad(fiber_pca, ((0, 0), (0, 3 - fiber_pca.shape[1])), mode="constant")

            fiber_var = np.cumsum(pca_fiber.explained_variance_ratio_)
            variance_text = f"PCA Variance (Fiber): {np.round(fiber_var, 3)}"

            if normalized_colors is not None and colors is not None:
                cvals = normalized_colors[filtered]
                orig = colors[filtered]
                nonzero = np.asarray(orig != 0, dtype=bool)

                if np.any(nonzero):
                    fig_data.add_trace(
                        go.Scatter3d(
                            x=fiber_pca[nonzero, 0], y=fiber_pca[nonzero, 1], z=fiber_pca[nonzero, 2],
                            mode="markers",
                            marker=dict(size=3, opacity=0.6, color=cvals[nonzero], colorscale="hsv", cmin=0, cmax=1),
                            name="Fiber (colored)",
                            hoverinfo="none",
                        )
                    )
                if np.any(~nonzero):
                    fig_data.add_trace(
                        go.Scatter3d(
                            x=fiber_pca[~nonzero, 0], y=fiber_pca[~nonzero, 1], z=fiber_pca[~nonzero, 2],
                            mode="markers",
                            marker=dict(size=3, opacity=0.5, color="gray"),
                            name="Fiber (zero)",
                            hoverinfo="none",
                        )
                    )
            else:
                fig_data.add_trace(
                    go.Scatter3d(
                        x=fiber_pca[:, 0], y=fiber_pca[:, 1], z=fiber_pca[:, 2],
                        mode="markers",
                        marker=dict(size=3, opacity=0.6, color="blue"),
                        name="Fiber",
                        hoverinfo="none",
                    )
                )

            # Total-space landmarks in the fiber panel (masks are over downsampled indexing)
            if data_landmark_masks is not None:
                lm_colors = ["orange", "green", "purple", "cyan", "magenta", "yellow", "black"]
                for i, mask in enumerate(data_landmark_masks):
                    mask = np.asarray(mask, bool)
                    local = np.where(mask[filtered])[0]
                    if local.size:
                        fig_data.add_trace(
                            go.Scatter3d(
                                x=fiber_pca[local, 0], y=fiber_pca[local, 1], z=fiber_pca[local, 2],
                                mode="markers",
                                marker=dict(size=4, color=lm_colors[i % len(lm_colors)], opacity=0.9),
                                name=f"Data Landmarks {i+1}",
                                hoverinfo="none",
                            )
                        )
        elif nearby_data.shape[0] == 1:
            fig_data.add_trace(
                go.Scatter3d(
                    x=[0.0], y=[0.0], z=[0.0],
                    mode="markers",
                    marker=dict(size=5, opacity=0.9, color="blue"),
                    name="Fiber (1 point)",
                    hoverinfo="none",
                )
            )
            variance_text = "PCA Variance (Fiber): (neighborhood has 1 point)"
        else:
            variance_text = "PCA Variance (Fiber): (empty neighborhood)"

    else:
        # Even with no selection, still show base landmarks (so you can orient yourself)
        if base_landmark_masks is not None:
            lm_colors = ["orange", "green", "purple", "cyan", "magenta", "yellow", "black"]
            for i, mask in enumerate(base_landmark_masks):
                mask = np.asarray(mask, bool)
                idx = np.where(mask)[0]
                if idx.size:
                    fig_base.add_trace(
                        go.Scatter3d(
                            x=base_embedded[idx, 0],
                            y=base_embedded[idx, 1],
                            z=base_embedded[idx, 2],
                            mode="markers",
                            marker=dict(size=size_landmark, color=lm_colors[i % len(lm_colors)], opacity=0.95),
                            name=f"Landmarks {i+1}",
                            hoverinfo="none",
                        )
                    )

        if base_landmarks_embedded is not None:
            lm_colors = ["orange", "green", "purple", "cyan", "magenta", "yellow", "black"]
            for i, L3 in enumerate(base_landmarks_embedded):
                L3 = np.asarray(L3, dtype=float)
                if L3.size == 0:
                    continue
                fig_base.add_trace(
                    go.Scatter3d(
                        x=L3[:, 0], y=L3[:, 1], z=L3[:, 2],
                        mode="markers",
                        marker=dict(size=size_landmark, color=lm_colors[i % len(lm_colors)], opacity=0.95),
                        name=f"Landmarks (extra) {i+1}",
                        hoverinfo="none",
                    )
                )

    # Show the 3D axes box/grid/backdrop, but hide tick labels (numbers).
    _axis_style = dict(
        showbackground=True,
        showgrid=True,
        zeroline=False,
        showticklabels=False,
        title="",  # no axis label text
    )
    _scene_style = dict(
        xaxis=_axis_style,
        yaxis=_axis_style,
        zaxis=_axis_style,
    )

    fig_base.update_layout(
        title="Base Points",
        scene=_scene_style,
        margin=dict(l=0, r=0, t=30, b=0),
        showlegend=False,
        uirevision="bundle-viewer",  # preserve camera between updates
        width=650,
        height=450,
    )
    fig_data.update_layout(
        title="Fiber Data",
        scene=_scene_style,
        margin=dict(l=0, r=0, t=30, b=0),
        showlegend=False,
        uirevision="bundle-viewer",  # preserve camera between updates
        width=650,
        height=450,
    )

    return fig_base, fig_data, label, variance_text


# ----------------------------
# Dash app (thin wrapper)
# ----------------------------

def make_bundle_app(
    viz: BundleVizInputs,
    *,
    initial_r: float = 0.1,
    r_max: float = 2.0,
):
    """
    Build the Dash app.

    Notes
    -----
    - Imports dash lazily (only when this is called).
    - Plotly/sklearn are also lazy via helper functions called here.
    """
    try:
        import dash
        from dash import dcc, html
        from dash.dependencies import Input, Output, State
    except ImportError as e:
        raise ImportError(
            "make_bundle_app requires dash. Install with `pip install dash`."
        ) from e

    base_points = np.asarray(viz.base_points)
    data = np.asarray(viz.data)
    dist_mat = np.asarray(viz.dist_mat)
    colors = viz.colors
    densities = viz.densities

    data_landmark_masks = viz.data_landmark_masks
    base_landmark_masks = viz.base_landmark_masks
    base_landmark_points = viz.base_landmark_points

    n = int(base_points.shape[0])
    if data.shape[0] != n or dist_mat.shape != (n, n):
        raise ValueError("viz inputs misaligned.")

    base_embedded, explained_variance, pca_base = _embed_base_points_pca(base_points)
    normalized_colors = _normalize_to_unit_interval(colors) if colors is not None else None

    base_landmarks_embedded: Optional[List[np.ndarray]] = None
    if base_landmark_points is not None:
        base_landmarks_embedded = []
        for g in base_landmark_points:
            g = np.asarray(g, dtype=float)
            L = pca_base.transform(g)
            if L.shape[1] < 3:
                L = np.pad(L, ((0, 0), (0, 3 - L.shape[1])), mode="constant")
            base_landmarks_embedded.append(L)

    app = dash.Dash(__name__)

    layout_children = [
        # persistent selection + download
        dcc.Store(id="selected-index-store", data=None),
        dcc.Download(id="download-pdf"),
        html.Div(
            [
                html.Button(
                    "Save snapshot (PDF)",
                    id="save-snapshot-btn",
                    n_clicks=0,
                    style={"fontSize": 14, "padding": "8px 14px"},
                ),
            ],
            style={"textAlign": "center", "marginTop": "8px", "marginBottom": "8px"},
        ),
        html.Div(
            [
                html.Div(
                    [
                        dcc.Graph(
                            id="base-plot",
                            style={"width": "100%", "height": "400px", "margin-bottom": "25px"},
                            config={"displayModeBar": True},
                        ),
                    ],
                    style={"width": "50%", "display": "inline-block", "verticalAlign": "top"},
                ),
                html.Div(
                    [
                        dcc.Graph(
                            id="data-plot",
                            style={"width": "100%", "height": "400px", "margin-bottom": "25px"},
                            config={"displayModeBar": True},
                        ),
                    ],
                    style={"width": "50%", "display": "inline-block", "verticalAlign": "top"},
                ),
            ],
            style={"display": "flex", "width": "100%", "justify-content": "center"},
        ),
        html.Div(
            [
                dcc.Slider(
                    id="radius-slider",
                    min=0.01,
                    max=float(r_max),
                    step=0.01,
                    value=float(initial_r),
                    marks={0.01: "0.01", round(r_max / 2, 2): str(round(r_max / 2, 2)), r_max: str(r_max)},
                    tooltip={"placement": "bottom", "always_visible": True},
                    updatemode="drag",
                )
            ],
            style={"width": "80%", "margin": "auto", "margin-top": "20px"},
        ),
    ]

    if densities is not None:
        dmin = float(np.min(densities))
        dmax = float(np.max(densities))
        layout_children.append(
            html.Div(
                [
                    dcc.Slider(
                        id="density-slider",
                        min=dmin,
                        max=dmax,
                        step=0.01,
                        value=dmin,
                        marks={round(dmin, 2): str(round(dmin, 2)), round(dmax, 2): str(round(dmax, 2))},
                        tooltip={"placement": "bottom", "always_visible": True},
                        updatemode="drag",
                    ),
                ],
                style={"width": "80%", "margin": "auto", "margin-top": "20px"},
            )
        )

    app.layout = html.Div(layout_children, style={"margin": "auto", "maxWidth": "95vw"})

    # Persist selection
    @app.callback(
        Output("selected-index-store", "data"),
        Input("base-plot", "clickData"),
        State("selected-index-store", "data"),
    )
    def _remember_selected_index(clickData, current):
        idx = _parse_click_index(clickData)
        return current if idx is None else int(idx)

    # Update figures (use stored selection)
    if densities is not None:
        @app.callback(
            [Output("base-plot", "figure"), Output("data-plot", "figure")],
            [Input("selected-index-store", "data"), Input("radius-slider", "value"), Input("density-slider", "value")],
        )
        def update_figures(selected_index_store, r, density_threshold):
            selected_index = int(selected_index_store) if selected_index_store is not None else None
            fig_base, fig_data, _, _ = _make_figures(
                base_embedded=base_embedded,
                explained_variance=explained_variance,
                base_landmark_masks=base_landmark_masks,
                base_landmarks_embedded=base_landmarks_embedded,
                data=data,
                dist_mat=dist_mat,
                colors=colors,
                normalized_colors=normalized_colors,
                densities=densities,
                data_landmark_masks=data_landmark_masks,
                selected_index=selected_index,
                r=float(r),
                density_threshold=None if density_threshold is None else float(density_threshold),
            )
            return fig_base, fig_data

        # Download combined PDF snapshot (side-by-side)
        @app.callback(
            Output("download-pdf", "data"),
            Input("save-snapshot-btn", "n_clicks"),
            State("selected-index-store", "data"),
            State("radius-slider", "value"),
            State("density-slider", "value"),
            prevent_initial_call=True,
        )
        def _download_snapshot_pdf(n_clicks, selected_index_store, r, density_threshold):
            if selected_index_store is None:
                return dash.no_update

            selected_index = int(selected_index_store)
            fig_base, fig_data, _, _ = _make_figures(
                base_embedded=base_embedded,
                explained_variance=explained_variance,
                base_landmark_masks=base_landmark_masks,
                base_landmarks_embedded=base_landmarks_embedded,
                data=data,
                dist_mat=dist_mat,
                colors=colors,
                normalized_colors=normalized_colors,
                densities=densities,
                data_landmark_masks=data_landmark_masks,
                selected_index=selected_index,
                r=float(r),
                density_threshold=None if density_threshold is None else float(density_threshold),
            )

            png_left = _fig_to_png_bytes(fig_base, scale=1)
            png_right = _fig_to_png_bytes(fig_data, scale=1)
            pdf_bytes = _combine_pngs_to_pdf_bytes(png_left, png_right, gap=24, pad=0)

            fname = f"bundle_snapshot_i{selected_index}_r{float(r):.3f}_d{float(density_threshold):.3f}.pdf"
            return dcc.send_bytes(pdf_bytes, fname)

    else:
        @app.callback(
            [Output("base-plot", "figure"), Output("data-plot", "figure")],
            [Input("selected-index-store", "data"), Input("radius-slider", "value")],
        )
        def update_figures(selected_index_store, r):
            selected_index = int(selected_index_store) if selected_index_store is not None else None
            fig_base, fig_data, _, _ = _make_figures(
                base_embedded=base_embedded,
                explained_variance=explained_variance,
                base_landmark_masks=base_landmark_masks,
                base_landmarks_embedded=base_landmarks_embedded,
                data=data,
                dist_mat=dist_mat,
                colors=colors,
                normalized_colors=normalized_colors,
                densities=None,
                data_landmark_masks=data_landmark_masks,
                selected_index=selected_index,
                r=float(r),
                density_threshold=None,
            )
            return fig_base, fig_data

        @app.callback(
            Output("download-pdf", "data"),
            Input("save-snapshot-btn", "n_clicks"),
            State("selected-index-store", "data"),
            State("radius-slider", "value"),
            prevent_initial_call=True,
        )
        def _download_snapshot_pdf(n_clicks, selected_index_store, r):
            if selected_index_store is None:
                return dash.no_update

            selected_index = int(selected_index_store)
            fig_base, fig_data, _, _ = _make_figures(
                base_embedded=base_embedded,
                explained_variance=explained_variance,
                base_landmark_masks=base_landmark_masks,
                base_landmarks_embedded=base_landmarks_embedded,
                data=data,
                dist_mat=dist_mat,
                colors=colors,
                normalized_colors=normalized_colors,
                densities=None,
                data_landmark_masks=data_landmark_masks,
                selected_index=selected_index,
                r=float(r),
                density_threshold=None,
            )

            png_left = _fig_to_png_bytes(fig_base, scale=1)
            png_right = _fig_to_png_bytes(fig_data, scale=1)
            pdf_bytes = _combine_pngs_to_pdf_bytes(png_left, png_right, gap=24, pad=0)

            fname = f"bundle_snapshot_i{selected_index}_r{float(r):.3f}.pdf"
            return dcc.send_bytes(pdf_bytes, fname)

    return app


def run_bundle_app(app, *, port: Optional[int] = None, debug: bool = False):
    if port is None:
        port = find_free_port()
    url = f"http://127.0.0.1:{int(port)}/"
    print(f"Bundle viewer running at: {url}")
    app.run(debug=bool(debug), use_reloader=False, port=int(port))


# ----------------------------
# Public general entrypoint
# ----------------------------

[docs] def show_bundle_vis( *, base_points: np.ndarray, data: np.ndarray, get_dist_mat: Optional[Callable[..., np.ndarray]] = None, full_dist_mat: Optional[np.ndarray] = None, base_metric: Any = None, same_metric: bool = False, initial_r: float = 0.1, r_max: float = 2.0, colors: Optional[np.ndarray] = None, densities: Optional[np.ndarray] = None, data_landmark_inds: Any = None, landmarks: Any = None, max_samples: int = 10_000, rng: Optional[np.random.Generator] = None, port: Optional[int] = None, debug: bool = False, ): """ General interactive viewer for (data, base_points) where base_points live in some metric space. - Neighborhoods come from dist_mat computed on base_points (via get_dist_mat/base_metric). - "Fiber Data" shows PCA of data restricted to the selected neighborhood. Snapshots --------- The save button downloads a combined PDF (side-by-side) containing ONLY the two Plotly figures (no sliders, no UI). """ try: from dash import dcc # noqa: F401 except Exception: pass viz = prepare_bundle_viz_inputs( base_points=np.asarray(base_points), data=np.asarray(data), get_dist_mat=get_dist_mat, full_dist_mat=full_dist_mat, base_metric=base_metric, same_metric=bool(same_metric), max_samples=int(max_samples), colors=colors, densities=densities, data_landmark_inds=data_landmark_inds, landmarks=landmarks, rng=rng, ) app = make_bundle_app(viz, initial_r=float(initial_r), r_max=float(r_max)) run_bundle_app(app, port=port, debug=debug) return app
# ---------------------------- # Snapshot saving (offline, no Dash required) # ---------------------------- def save_bundle_snapshot( viz: BundleVizInputs, *, selected_index: int, r: float, density_threshold: Optional[float] = None, base_html: Optional[str] = None, data_html: Optional[str] = None, base_image: Optional[str] = None, data_image: Optional[str] = None, ) -> Tuple["go.Figure", "go.Figure"]: """ Create the two figures for a given (selected_index, r, density_threshold) and optionally save. Notes on saving: - HTML always works: fig.write_html("file.html") - Static images require 'kaleido': pip install -U kaleido """ try: import plotly.graph_objects as go # noqa: F401 except ImportError as e: raise ImportError( "save_bundle_snapshot requires plotly. Install with `pip install plotly`." ) from e base_points = np.asarray(viz.base_points) data = np.asarray(viz.data) dist_mat = np.asarray(viz.dist_mat) colors = viz.colors densities = viz.densities data_landmark_masks = viz.data_landmark_masks base_landmark_masks = viz.base_landmark_masks base_landmark_points = viz.base_landmark_points base_embedded, explained_variance, pca_base = _embed_base_points_pca(base_points) normalized_colors = _normalize_to_unit_interval(colors) if colors is not None else None base_landmarks_embedded: Optional[List[np.ndarray]] = None if base_landmark_points is not None: base_landmarks_embedded = [] for g in base_landmark_points: g = np.asarray(g, dtype=float) L = pca_base.transform(g) if L.shape[1] < 3: L = np.pad(L, ((0, 0), (0, 3 - L.shape[1])), mode="constant") base_landmarks_embedded.append(L) fig_base, fig_data, _, _ = _make_figures( base_embedded=base_embedded, explained_variance=explained_variance, base_landmark_masks=base_landmark_masks, base_landmarks_embedded=base_landmarks_embedded, data=data, dist_mat=dist_mat, colors=colors, normalized_colors=normalized_colors, densities=densities, data_landmark_masks=data_landmark_masks, selected_index=int(selected_index), r=float(r), density_threshold=None if density_threshold is None else float(density_threshold), ) if base_html is not None: fig_base.write_html(base_html) if data_html is not None: fig_data.write_html(data_html) if base_image is not None: fig_base.write_image(base_image) if data_image is not None: fig_data.write_image(data_image) return fig_base, fig_data