Source code for circle_bundles.viz.pca_vis

# circle_bundles/viz/pca_viz.py
from __future__ import annotations

from typing import List, Optional, Sequence, Tuple, Union

import numpy as np


[docs] def show_pca( data: np.ndarray, *, colors: Optional[Sequence] = None, U: Optional[np.ndarray] = None, size: float = 2.0, titles: Optional[List[str]] = None, # --- speed controls --- max_points: int = 2000, max_components: int = 50, use_randomized: bool = False, n_iter: int = 4, # --- numerical / plot stability --- jitter_eps: float = 1e-3, tol_flat_z: float = 1e-6, random_state: int = 0, # --- appearance --- set_cmap: str = "viridis", set_cmap_range: Tuple[float, float] = (0.10, 0.90), elev = 20, azim = 35, # --- behavior --- interactive: bool = False, show: bool = True, return_figs: bool = False, # --- plotly UX --- plotly_scroll_zoom: bool = True, plotly_double_click: Union[bool, str] = "reset", ) -> Union["go.Figure", "plt.Figure", None]: """ PCA visualization (3D scatter + cumulative explained variance), side-by-side. - Fits PCA on the full dataset (up to max_components). - Plots first 3 PCs (subsampled to max_points). - Plots cumulative explained variance (CEV) vs k. - No printing. Parameters ---------- interactive: If True, uses Plotly and returns a Plotly Figure (single combined figure). If False, uses Matplotlib for BOTH panels and returns a Matplotlib Figure. plotly_scroll_zoom: If interactive, enables scroll zoom. plotly_double_click: Plotly double-click behavior. Use "reset" (default), "autosize", False, etc. Returns ------- If return_figs=True: - interactive=True -> plotly.graph_objects.Figure - interactive=False -> matplotlib.figure.Figure Else: None (but displays if show=True). """ # ------------------------------------------------------------ # Lazy imports # ------------------------------------------------------------ try: import matplotlib.pyplot as plt from matplotlib.colors import to_hex except ImportError as e: raise ImportError( "show_pca requires matplotlib. Install with `pip install matplotlib`." ) from e try: from sklearn.decomposition import PCA except ImportError as e: raise ImportError( "show_pca requires scikit-learn. Install with `pip install scikit-learn`." ) from e # Plotly only if interactive if interactive: try: import plotly.graph_objects as go from plotly.subplots import make_subplots except ImportError as e: raise ImportError( "interactive show_pca requires plotly. Install with `pip install plotly`." ) from e # ------------------------------------------------------------ # Helpers # ------------------------------------------------------------ def _set_axes_equal_3d(ax, xyz: np.ndarray) -> None: """ Make Matplotlib 3D axis scale equal in x/y/z (as close as possible), so point clouds don't look squashed compared to Plotly's aspectmode='data'. """ xyz = np.asarray(xyz, dtype=float) if xyz.ndim != 2 or xyz.shape[1] < 3: return x = xyz[:, 0] y = xyz[:, 1] z = xyz[:, 2] xmid = 0.5 * (x.max() + x.min()) ymid = 0.5 * (y.max() + y.min()) zmid = 0.5 * (z.max() + z.min()) max_range = max(x.max() - x.min(), y.max() - y.min(), z.max() - z.min()) if not np.isfinite(max_range) or max_range <= 0: max_range = 1.0 r = 0.5 * max_range ax.set_xlim(xmid - r, xmid + r) ax.set_ylim(ymid - r, ymid + r) ax.set_zlim(zmid - r, zmid + r) # Newer Matplotlib supports explicit 3D box aspect if hasattr(ax, "set_box_aspect"): ax.set_box_aspect((1, 1, 1)) def _make_static_3d_prettier(ax) -> None: """Make Matplotlib 3D look cleaner without deleting the axis back panes.""" # Plotly-ish camera vibe try: ax3d.view_init(elev=elev, azim=azim) except Exception: pass # Keep grid, but make it subtle try: ax.grid(True) except Exception: pass # Light, visible panes (the "backdrop") for axis in (getattr(ax, "xaxis", None), getattr(ax, "yaxis", None), getattr(ax, "zaxis", None)): if axis is None: continue try: axis.pane.set_alpha(0.12) # <--- backdrop visible again axis.pane.set_edgecolor((0, 0, 0, 0.15)) except Exception: pass # Subtle grid line styling (works on most mpl versions) try: ax.xaxis._axinfo["grid"]["linewidth"] = 0.6 ax.yaxis._axinfo["grid"]["linewidth"] = 0.6 ax.zaxis._axinfo["grid"]["linewidth"] = 0.6 ax.xaxis._axinfo["grid"]["linestyle"] = "-" ax.yaxis._axinfo["grid"]["linestyle"] = "-" ax.zaxis._axinfo["grid"]["linestyle"] = "-" ax.xaxis._axinfo["grid"]["color"] = (0, 0, 0, 0.12) ax.yaxis._axinfo["grid"]["color"] = (0, 0, 0, 0.12) ax.zaxis._axinfo["grid"]["color"] = (0, 0, 0, 0.12) except Exception: pass # ------------------------------------------------------------ # Input validation # ------------------------------------------------------------ X = np.asarray(data, dtype=float) if X.ndim != 2: raise ValueError("show_pca expects a 2D array of shape (n_samples, d).") n, d = X.shape if n <= 0: raise ValueError("show_pca received empty data (n_samples=0).") # Ensure at least 3 dims for plotting if d < 3: X = np.hstack([X, np.zeros((n, 3 - d), dtype=float)]) d = 3 n_components = int(min(max_components, n, d)) if n_components < 1: raise ValueError("Not enough samples/dimensions to compute PCA.") # ------------------------------------------------------------ # PCA fit # ------------------------------------------------------------ pca = PCA( n_components=n_components, svd_solver="randomized" if use_randomized else "auto", random_state=random_state if use_randomized else None, iterated_power=n_iter if use_randomized else "auto", ) pca.fit(X) ev = np.asarray(pca.explained_variance_ratio_, dtype=float) cev = np.cumsum(ev) # ------------------------------------------------------------ # Subsample for plotting # ------------------------------------------------------------ if max_points is not None and n > int(max_points): rng = np.random.default_rng(random_state) idx_plot = rng.choice(n, size=int(max_points), replace=False) idx_plot.sort() else: idx_plot = np.arange(n) X_plot = X[idx_plot] Z = pca.transform(X_plot) if Z.shape[1] < 3: Z = np.hstack([Z, np.zeros((Z.shape[0], 3 - Z.shape[1]), dtype=float)]) Z3 = Z[:, :3].copy() # Jitter flat z-axis if np.std(Z3[:, 2]) < float(tol_flat_z): rng = np.random.default_rng(random_state) Z3[:, 2] += float(jitter_eps) * rng.standard_normal(Z3.shape[0]) # ------------------------------------------------------------ # Handle U / colors # ------------------------------------------------------------ use_U = U is not None and np.size(U) > 0 colors_plot = None set_colors = None if use_U: U_bool = np.asarray(U, dtype=bool) if U_bool.ndim != 2 or U_bool.shape[1] != n: raise ValueError(f"U must have shape (n_sets, n_samples) = (?, {n}).") U_plot = U_bool[:, idx_plot] n_sets = U_plot.shape[0] if titles is None: titles = [f"U_{j}" for j in range(n_sets)] elif len(titles) != n_sets: raise ValueError("titles must have length equal to U.shape[0].") if colors is None: cmap = plt.get_cmap(set_cmap) a, b = map(float, set_cmap_range) ts = np.linspace(a, b, n_sets) if n_sets > 1 else [(a + b) / 2.0] set_colors = [to_hex(cmap(t)) for t in ts] else: colors_arr = np.asarray(colors, dtype=object) if colors_arr.shape[0] != n: raise ValueError("colors must have length n_samples.") colors_plot = colors_arr[idx_plot] else: if colors is not None: colors_arr = np.asarray(colors, dtype=object) if colors_arr.shape[0] != n: raise ValueError("colors must have length n_samples.") colors_plot = colors_arr[idx_plot] # ------------------------------------------------------------ # CEV helper arrays # ------------------------------------------------------------ ks = np.arange(1, n_components + 1) title_str = ( f"PCA Summary" ) # ============================================================ # INTERACTIVE (Plotly): single figure with 2 columns # ============================================================ if interactive: fig = make_subplots( rows=1, cols=2, column_widths=[0.62, 0.38], specs=[[{"type": "scene"}, {"type": "xy"}]], horizontal_spacing=0.06, subplot_titles=("3D PCA Projection", "Cumulative Explained Variance"), ) if use_U: for j in range(U_plot.shape[0]): idx = np.where(U_plot[j])[0] if idx.size == 0: continue c = set_colors[j] if colors is None else colors_plot[idx] fig.add_trace( go.Scatter3d( x=Z3[idx, 0], y=Z3[idx, 1], z=Z3[idx, 2], mode="markers", marker=dict(size=float(size), color=c, opacity=0.85), name=titles[j], showlegend=True, ), row=1, col=1, ) else: fig.add_trace( go.Scatter3d( x=Z3[:, 0], y=Z3[:, 1], z=Z3[:, 2], mode="markers", marker=dict(size=float(size), color=colors_plot, opacity=0.85), name="data", showlegend=True, ), row=1, col=1, ) # CEV line (2D) fig.add_trace( go.Scatter( x=ks, y=cev, mode="lines+markers", name="Explained Variance", showlegend=False, hovertemplate="k=%{x}<br>CEV=%{y:.4f}<extra></extra>", ), row=1, col=2, ) fig.update_xaxes(title_text="k", row=1, col=2) fig.update_yaxes(title_text="Explained Variance", range=[0.0, 1.01], row=1, col=2) # IMPORTANT: preserve UI state so clicking/zooming right panel doesn't # reset/squish the 3D camera (and vice versa). fig.update_layout( title=title_str, scene=dict(aspectmode="data", uirevision="pca_scene"), margin=dict(l=0, r=0, t=60, b=0), legend=dict(x=0.01, y=0.99), uirevision="pca", # global uirevision ) fig.update_xaxes(uirevision="pca_axes", row=1, col=2) fig.update_yaxes(uirevision="pca_axes", row=1, col=2) if show: fig.show( config={ "scrollZoom": bool(plotly_scroll_zoom), "doubleClick": plotly_double_click, } ) if return_figs: return fig return None # ============================================================ # STATIC (Matplotlib): one figure with 1 row, 2 cols # ============================================================ fig = plt.figure(figsize=(12, 5)) ax3d = fig.add_subplot(1, 2, 1, projection="3d") axev = fig.add_subplot(1, 2, 2) # 3D scatter if use_U: for j in range(U_plot.shape[0]): idx = np.where(U_plot[j])[0] if idx.size == 0: continue if colors is None: c = set_colors[j] ax3d.scatter( Z3[idx, 0], Z3[idx, 1], Z3[idx, 2], s=float(size) * 6.0, c=c, alpha=0.85, label=titles[j], depthshade=True, ) else: c = colors_plot[idx] ax3d.scatter( Z3[idx, 0], Z3[idx, 1], Z3[idx, 2], s=float(size) * 6.0, c=c, alpha=0.85, label=titles[j], depthshade=True, ) ax3d.legend(loc="upper left", bbox_to_anchor=(0.0, 1.0)) else: ax3d.scatter( Z3[:, 0], Z3[:, 1], Z3[:, 2], s=float(size) * 6.0, c=colors_plot, alpha=0.85, depthshade=True, ) ax3d.set_title("3D PCA Projection") # Make Matplotlib 3D look less squashed (closer to Plotly) _set_axes_equal_3d(ax3d, Z3) _make_static_3d_prettier(ax3d) # CEV plot axev.plot(ks, cev, marker="o") axev.set_title("Cumulative Explained Variance") axev.set_xlabel("Number of Components (k)") axev.set_ylabel("Explained Variance") axev.set_ylim(0.0, 1.01) axev.grid(True) fig.suptitle(title_str) fig.tight_layout() if show: plt.show() if return_figs: return fig return None