# circle_bundles/fiberwise_clustering.py
from __future__ import annotations
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
__all__ = [
"fiberwise_clustering",
"plot_fiberwise_pca_grid",
"plot_fiberwise_summary_bars",
"get_cluster_persistence",
"get_filtered_cluster_graph",
"get_weights"
]
# ---------------------------------
# Internal: dependency guards
# ---------------------------------
def _require_networkx():
try:
import networkx as nx # type: ignore
except Exception as e: # pragma: no cover
raise ImportError(
"fiberwise_clustering requires networkx. Install with `pip install networkx`."
) from e
return nx
def _require_sklearn():
try:
from sklearn.cluster import DBSCAN # type: ignore
from sklearn.decomposition import PCA # type: ignore
except Exception as e: # pragma: no cover
raise ImportError(
"fiberwise_clustering requires scikit-learn. Install with `pip install scikit-learn`."
) from e
return DBSCAN, PCA
# ---------------------------------
# Graph construction helpers
# ---------------------------------
def safe_add_edges(G, U: np.ndarray, cl: np.ndarray) -> None:
"""
Add edges between cluster-nodes (j, c_j) and (k, c_k) when fibers j,k
share at least one sample that is non-noise in BOTH fibers.
Node convention:
node = (fiber_index, cluster_label)
Edge attributes:
- indices_shared : sorted list of sample indices supporting the edge
"""
U = np.asarray(U, dtype=bool)
cl = np.asarray(cl)
n_fibers, n_samples = U.shape
if cl.shape != (n_fibers, n_samples):
raise ValueError(f"cl must have shape {U.shape}, got {cl.shape}")
# Accumulate shared indices per (u,v) edge deterministically
shared: Dict[Tuple[Tuple[int, int], Tuple[int, int]], List[int]] = defaultdict(list)
for k in range(n_fibers):
Uk = U[k]
for j in range(k):
overlap = Uk & U[j]
if not np.any(overlap):
continue
inds = np.where(overlap)[0]
cj = cl[j, inds]
ck = cl[k, inds]
valid = (cj != -1) & (ck != -1)
if not np.any(valid):
continue
inds = inds[valid]
cj = cj[valid].astype(int)
ck = ck[valid].astype(int)
for idx, a, b in zip(inds, cj, ck):
u = (j, int(a))
v = (k, int(b))
# Only connect clusters that exist as nodes
if (u not in G) or (v not in G):
continue
key = (u, v) if u <= v else (v, u) # deterministic undirected key
shared[key].append(int(idx))
for (u, v), idxs in shared.items():
idxs_sorted = sorted(set(idxs))
if G.has_edge(u, v):
prev = G.edges[u, v].get("indices_shared", [])
merged = sorted(set(prev).union(idxs_sorted))
G.edges[u, v]["indices_shared"] = merged
else:
G.add_edge(u, v, indices_shared=idxs_sorted)
def get_weights(G, method: str = "cardinality"):
"""
Compute and attach edge weights for a fiberwise cluster graph.
This function assigns/overwrites the edge attribute ``"weight"`` in-place, based on
how much the endpoint clusters overlap in sample membership. It is typically used
after :func:`fiberwise_clustering` to enable thresholding and simple persistence-style
diagnostics via :func:`get_cluster_persistence` and :func:`get_filtered_cluster_graph`.
Parameters
----------
G:
A ``networkx.Graph`` whose nodes represent fiber-clusters, and whose nodes store
an ``"indices"`` attribute: the list of sample indices belonging to that cluster.
If edges contain ``"indices_shared"``, those will be used directly (faster and
consistent with the overlap construction in :func:`fiberwise_clustering`).
method:
Overlap-to-weight rule. One of:
- ``"cardinality"``: weight = ``|A ∩ B|``
- ``"rel_card"``: weight = ``|A ∩ B| / min(|A|, |B|)``
- ``"rel_card2"``: weight = ``|A ∩ B| / ((|A| + |B|)/2)``
Here ``A`` and ``B`` are the endpoint node membership sets.
Returns
-------
G:
The same graph instance, returned for convenience, with edge attribute ``"weight"``
set on every edge.
Notes
-----
- The graph is modified in-place.
- If an edge has ``indices_shared``, we use ``len(indices_shared)`` as ``|A ∩ B|``.
Otherwise we fall back to computing the intersection from node ``indices``.
"""
if method not in {"cardinality", "rel_card", "rel_card2"}:
raise ValueError("method must be one of: 'cardinality', 'rel_card', 'rel_card2'")
for u, v, data in G.edges(data=True):
# overlap size: prefer indices_shared (created by safe_add_edges)
shared = data.get("indices_shared", None)
if shared is not None:
inter_size = float(len(shared))
else:
# fallback: compute from node indices
Au = set(G.nodes[u].get("indices", []))
Av = set(G.nodes[v].get("indices", []))
inter_size = float(len(Au & Av))
if method == "cardinality":
data["weight"] = inter_size
continue
# endpoint sizes
size_u = float(len(G.nodes[u].get("indices", [])))
size_v = float(len(G.nodes[v].get("indices", [])))
if method == "rel_card":
denom = float(min(size_u, size_v))
data["weight"] = inter_size / denom if denom > 0 else 0.0
else: # rel_card2
denom = float((size_u + size_v) / 2.0)
data["weight"] = inter_size / denom if denom > 0 else 0.0
return G
# ---------------------------------
# Main compute routine
# ---------------------------------
[docs]
def fiberwise_clustering(
data: np.ndarray,
U: np.ndarray,
eps_values: np.ndarray,
min_sample_values: np.ndarray,
*,
build_pca_embeddings: bool = True,
pca_dim: int = 2,
verbose: bool = True,
):
"""
Cluster each fiber with DBSCAN and then merge clusters globally via an overlap graph.
This is a useful diagnostic tool when you have a cover (fibers) over the dataset and
want to see whether local cluster structure aligns across overlaps.
Workflow
--------
1) For each fiber r (row of ``U``), run DBSCAN on the points ``data[U[r]]`` using
``eps_values[r]`` and ``min_sample_values[r]``.
2) Create a graph whose nodes are *fiber-clusters* ``(r, label)`` (excluding label ``-1``).
Each node stores the list of sample indices belonging to that DBSCAN cluster.
3) Add an edge between two nodes if the corresponding fiber-clusters share at least one
sample index that is non-noise in both fibers. Edges store ``indices_shared``.
4) Define a global component label per sample by taking connected components of the graph
and assigning each sample that appears in any node of that connected component.
Parameters
----------
data:
Array of shape ``(n_samples, d)`` containing the ambient data vectors.
U:
Boolean membership matrix of shape ``(n_fibers, n_samples)``. Row ``r`` indicates
which samples lie in fiber/cover set ``U_r``.
eps_values:
Array of shape ``(n_fibers,)`` with the DBSCAN ``eps`` parameter used for each fiber.
min_sample_values:
Array of shape ``(n_fibers,)`` with the DBSCAN ``min_samples`` parameter used for each fiber.
build_pca_embeddings:
If True, compute PCA embeddings (within each fiber) for quick plotting with
:func:`plot_fiberwise_pca_grid`.
pca_dim:
Number of PCA components to compute per fiber (typically 2).
Returns
-------
components:
Integer array of shape ``(n_samples,)`` giving a global component label per sample.
Samples not belonging to any non-noise fiber-cluster remain ``-1``.
G:
A ``networkx.Graph``. Nodes are tuples ``(fiber_idx, cluster_label)`` with node attribute
``indices`` (list of sample indices). Edges indicate overlap, and store
``indices_shared`` (the supporting sample indices).
graph_dict:
A simple serialization-friendly representation of the graph (nodes/links).
cl:
Integer array of shape ``(n_fibers, n_samples)`` giving the DBSCAN label of each sample
within each fiber. Label ``-1`` denotes noise (or not present in that fiber).
summary:
Dict of helpful arrays for downstream plotting, including:
- ``fiber_component_counts``: number of DBSCAN clusters per fiber (excluding noise)
- ``global_component_counts``: number of graph nodes per global component
- ``point_counts``: number of samples assigned to each global component
- ``pca_store``: per-fiber PCA embeddings if requested
Notes
-----
- This routine requires optional dependencies: ``networkx`` and ``scikit-learn``.
- Edge weights are not computed here. If you want a weighted filtration on the cluster graph,
call :func:`get_weights` (internal helper) or attach your own ``edge["weight"]`` values.
Examples
--------
>>> components, G, graph_dict, cl, summary = fiberwise_clustering(
... data, U,
... eps_values=np.full(U.shape[0], 0.25),
... min_sample_values=np.full(U.shape[0], 10),
... )
>>> fig, _ = plot_fiberwise_summary_bars(summary)
"""
from ..utils.status_utils import _status, _status_clear
def _v(msg: str):
if verbose:
_status(msg)
nx = _require_networkx()
DBSCAN, PCA = _require_sklearn()
U = np.asarray(U, dtype=bool)
data = np.asarray(data)
eps_values = np.asarray(eps_values, dtype=float)
min_sample_values = np.asarray(min_sample_values, dtype=int)
n_fibers, n_samples = U.shape
if data.ndim != 2:
raise ValueError(f"data must be 2D (n_samples, d). Got shape {data.shape}.")
if data.shape[0] != n_samples:
raise ValueError(f"data must have n_samples={n_samples} rows, got {data.shape[0]}")
if eps_values.shape != (n_fibers,):
raise ValueError(f"eps_values must have shape ({n_fibers},), got {eps_values.shape}")
if min_sample_values.shape != (n_fibers,):
raise ValueError(f"min_sample_values must have shape ({n_fibers},), got {min_sample_values.shape}")
cl = -1 * np.ones((n_fibers, n_samples), dtype=int)
G = nx.Graph()
fiber_component_counts = np.zeros(n_fibers, dtype=int)
pca_store: Dict[int, Dict[str, Any]] = {}
# --- Fiberwise DBSCAN + graph nodes ---
for r in range(n_fibers):
_v(f"Clustering set {r+1}/{n_fibers}...")
row_inds = np.where(U[r])[0]
if row_inds.size == 0:
continue
fiber_pts = data[row_inds]
db = DBSCAN(eps=float(eps_values[r]), min_samples=int(min_sample_values[r]))
labels = db.fit_predict(fiber_pts)
cl[r, row_inds] = labels
unique_clusters = set(labels.tolist())
unique_clusters.discard(-1)
fiber_component_counts[r] = len(unique_clusters)
for lab in unique_clusters:
mask = labels == lab
cluster_indices = row_inds[mask]
G.add_node((r, int(lab)), indices=cluster_indices.tolist())
if build_pca_embeddings and fiber_pts.shape[0] > 1 and int(pca_dim) > 0:
pca = PCA(n_components=int(pca_dim))
pca_data = pca.fit_transform(fiber_pts)
pca_store[r] = {"pca": pca_data, "clusters": labels, "row_inds": row_inds}
_v(f"Computing global clusters...")
# --- Overlap edges ---
safe_add_edges(G, U, cl)
# --- Global components: map nodes -> union of their sample indices ---
components = -1 * np.ones(n_samples, dtype=int)
global_component_counts: List[int] = []
point_counts: List[int] = []
connected_components = list(nx.connected_components(G))
for comp_id, comp_nodes in enumerate(connected_components):
global_component_counts.append(len(comp_nodes))
all_inds: List[int] = []
for node in comp_nodes:
all_inds.extend(G.nodes[node].get("indices", []))
all_inds = np.unique(all_inds).astype(int)
components[all_inds] = comp_id
point_counts.append(int(len(all_inds)))
# --- graph_dict (simple serialization) ---
nodes_dict = {
f"fiber{node[0]}_cluster{node[1]}": G.nodes[node].get("indices", [])
for node in G.nodes
}
links_dict: Dict[str, List[str]] = defaultdict(list)
for u, v in G.edges:
ku = f"fiber{u[0]}_cluster{u[1]}"
kv = f"fiber{v[0]}_cluster{v[1]}"
links_dict[ku].append(kv)
links_dict[kv].append(ku)
graph_dict = {
"nodes": nodes_dict,
"links": dict(links_dict),
"simplices": [[f"fiber{node[0]}_cluster{node[1]}"] for node in G.nodes],
"meta_data": {
"projection": "custom",
"n_cubes": int(n_fibers),
"perc_overlap": 0.5,
"clusterer": "DBSCAN()",
"scaler": "None",
"nerve_min_intersection": 1,
},
"meta_nodes": {},
}
summary = {
"fiber_component_counts": fiber_component_counts,
"global_component_counts": np.array(global_component_counts, dtype=int),
"point_counts": np.array(point_counts, dtype=int),
"pca_store": pca_store,
"n_fibers": int(n_fibers),
"n_samples": int(n_samples),
}
if verbose:
_status_clear()
return components, G, graph_dict, cl, summary
# ---------------------------------
# Plotting helpers
# ---------------------------------
def plot_fiberwise_pca_grid(
summary: dict,
*,
to_view: Optional[Sequence[int]] = None,
cmap: str = "viridis",
point_size: float = 5.0,
n_cols: int = 4,
save_path: Optional[str] = None,
):
"""
Plot per-fiber PCA scatter plots colored by DBSCAN cluster labels.
This function consumes the ``summary`` returned by :func:`fiberwise_clustering`
when called with ``build_pca_embeddings=True``.
Parameters
----------
summary:
The ``summary`` dict returned by :func:`fiberwise_clustering`.
to_view:
Optional list of fiber indices to plot. If omitted, plots all fibers that
have PCA data available.
cmap:
Matplotlib colormap name used to color points by cluster label.
point_size:
Scatter marker size passed to Matplotlib.
n_cols:
Number of columns in the subplot grid.
save_path:
If provided, save the figure to this path (PDF recommended).
Returns
-------
fig, axes:
The Matplotlib figure and a flattened array of axes. Returns ``(None, None)``
if there is nothing to plot.
Raises
------
ValueError
If PCA data is missing (i.e. ``summary["pca_store"]`` is empty).
"""
import matplotlib.pyplot as plt
pca_store = summary.get("pca_store", {})
if not pca_store:
raise ValueError("No PCA data found. Run fiberwise_clustering with build_pca_embeddings=True.")
fibers_available = sorted(pca_store.keys())
fibers = fibers_available if not to_view else [int(r) for r in to_view if int(r) in pca_store]
if len(fibers) == 0:
return None, None
n_rows = int(np.ceil(len(fibers) / int(n_cols)))
fig, axes = plt.subplots(n_rows, int(n_cols), figsize=(4 * int(n_cols), 4 * n_rows))
axes = np.atleast_1d(axes).ravel()
for ax in axes[len(fibers):]:
ax.axis("off")
for i, r in enumerate(fibers):
ax = axes[i]
pca_data = pca_store[r]["pca"]
clusters = pca_store[r]["clusters"]
ax.scatter(pca_data[:, 0], pca_data[:, 1], s=point_size, c=clusters, cmap=cmap)
ax.set_title(f"Fiber {r}", fontsize=12)
ax.set_aspect("equal", "box")
ax.grid(True, alpha=0.3)
fig.tight_layout()
if save_path is not None:
fig.savefig(save_path, format="pdf", bbox_inches="tight")
plt.show()
return fig, axes
def plot_fiberwise_summary_bars(
summary: dict,
*,
hide_biggest: bool = False,
save_path: Optional[str] = None,
dpi: int = 200,
):
"""
Plot quick bar-chart summaries of fiberwise clustering and global components.
Produces three panels:
1) number of DBSCAN clusters per fiber (excluding noise)
2) number of cluster-nodes per global connected component (sorted)
3) number of samples per global connected component (sorted)
Parameters
----------
summary:
The ``summary`` dict returned by :func:`fiberwise_clustering`.
hide_biggest:
If True, drop the single largest global component in the "points per component"
panel (useful when one component dominates the scale).
save_path:
If provided, saves a PDF with ``"_summary"`` appended to the filename.
dpi:
Figure DPI.
Returns
-------
fig, axs:
Matplotlib figure and axes array.
"""
import matplotlib.pyplot as plt
fiber_component_counts = np.asarray(summary["fiber_component_counts"])
global_component_counts = np.asarray(summary["global_component_counts"])
point_counts = np.asarray(summary["point_counts"])
sorted_component_counts = sorted(global_component_counts.tolist())
sorted_point_counts = sorted(point_counts.tolist())
if hide_biggest and len(sorted_point_counts) > 0:
sorted_point_counts = sorted_point_counts[:-1]
fig, axs = plt.subplots(1, 3, figsize=(18, 6), dpi=int(dpi))
axs[0].bar(np.arange(1, len(fiber_component_counts) + 1), fiber_component_counts, edgecolor="black", linewidth=1.0)
axs[0].set_title("Number of Clusters per Fiber")
axs[0].set_xlabel("Fiber Number")
axs[0].set_ylabel("Number of Clusters")
axs[0].grid(True, alpha=0.25)
axs[1].bar(np.arange(1, len(sorted_component_counts) + 1), sorted_component_counts, edgecolor="black", linewidth=1.0)
axs[1].set_title("Clusters per Global Component (Sorted)")
axs[1].set_xlabel("Global Component Index")
axs[1].set_ylabel("Cluster Count")
axs[1].grid(True, alpha=0.25)
axs[2].bar(np.arange(1, len(sorted_point_counts) + 1), sorted_point_counts, edgecolor="black", linewidth=1.0)
axs[2].set_title("Points per Global Component (Sorted)")
axs[2].set_xlabel("Global Component Index")
axs[2].set_ylabel("Point Count")
axs[2].grid(True, alpha=0.25)
fig.tight_layout()
if save_path is not None:
out = save_path
if out.lower().endswith(".pdf"):
out = out[:-4] + "_summary.pdf"
else:
out = out + "_summary.pdf"
fig.savefig(out, format="pdf", bbox_inches="tight")
plt.show()
return fig, axs
# ---------------------------------
# Edge-threshold "persistence" + filtering
# ---------------------------------
def _ensure_has_weights(G) -> None:
# If there are edges but none have a 'weight', warn loudly (or raise).
if G.number_of_edges() == 0:
return
has_any = any(("weight" in data) for _, _, data in G.edges(data=True))
if not has_any:
raise ValueError(
"Graph has no edge weights. Run `get_weights(G, method=...)` first "
"or add a 'weight' attribute to each edge."
)
def get_cluster_persistence(
G,
*,
show_results: bool = True,
save_path: Optional[str] = None,
):
"""
Track connected-component count as low-weight edges are removed.
Interprets the cluster graph ``G`` as a weighted graph (edge attribute ``"weight"``),
and forms a simple edge-threshold filtration:
- Start with the full graph.
- For each distinct weight value ``w`` in increasing order, remove all edges with weight
exactly ``w`` (equivalently: keep only edges with weight > w).
- Record the number of connected components after each removal step.
Parameters
----------
G:
A ``networkx.Graph`` whose edges have a numeric ``"weight"`` attribute.
(Use the internal helper ``get_weights(G, method=...)`` or attach weights yourself.)
show_results:
If True, plot the curve ``(# components) vs (threshold weight)``.
save_path:
If provided and ``show_results=True``, save the plot to this path (PDF recommended).
Returns
-------
history:
List of dicts with keys:
- ``"weight"``: the threshold value (first entry uses ``-np.inf`` for the full graph)
- ``"n_components"``: number of connected components at that stage
- ``"components"``: the list of node-sets returned by ``nx.connected_components``
Raises
------
ValueError
If the graph has edges but none have a ``"weight"`` attribute.
"""
nx = _require_networkx()
_ensure_has_weights(G)
import matplotlib.pyplot as plt
Gc = G.copy()
history: List[Dict[str, Any]] = []
comps0 = list(nx.connected_components(Gc))
history.append({"weight": -np.inf, "n_components": len(comps0), "components": comps0})
edges_by_w: Dict[float, List[Tuple[Any, Any]]] = defaultdict(list)
for u, v, data in Gc.edges(data=True):
w = float(data.get("weight", 0.0))
edges_by_w[w].append((u, v))
for w in sorted(edges_by_w.keys()):
Gc.remove_edges_from(edges_by_w[w])
comps = list(nx.connected_components(Gc))
history.append({"weight": w, "n_components": len(comps), "components": comps})
if show_results:
ws = [h["weight"] for h in history[1:]] # skip -inf
ns = [h["n_components"] for h in history[1:]]
plt.figure(figsize=(10, 6))
plt.plot(ws, ns, marker="o", linewidth=2)
plt.xlabel("Weight Threshold")
plt.ylabel("Number of Connected Components")
plt.title("Cluster Persistence: Weight vs. Number of Connected Components")
plt.grid(True, alpha=0.3)
plt.tight_layout()
if save_path is not None:
plt.savefig(save_path, format="pdf", bbox_inches="tight")
plt.show()
return history
def _separate_intersections_on_removed_edges(
G,
cl: np.ndarray,
*,
thresh: float,
rule: str = "to_smaller_cluster",
):
"""
For each edge with weight <= thresh, split intersection samples so that those samples
belong to ONLY ONE endpoint cluster (chosen by rule).
rule:
- 'to_smaller_cluster': keep shared samples in smaller cluster, remove from larger
- 'to_larger_cluster' : keep shared samples in larger cluster, remove from smaller
"""
nx = _require_networkx()
if rule not in ("to_smaller_cluster", "to_larger_cluster"):
raise ValueError("rule must be 'to_smaller_cluster' or 'to_larger_cluster'")
_ensure_has_weights(G)
G_clean = G.copy()
cl_clean = np.array(cl, copy=True)
edges_to_remove: List[Tuple[Any, Any]] = []
for u, v, data in G_clean.edges(data=True):
w = float(data.get("weight", 0.0))
if w <= float(thresh):
edges_to_remove.append((u, v))
removed_from_node: Dict[Tuple[int, int], set[int]] = defaultdict(set)
for u, v in edges_to_remove:
inds_u = set(G_clean.nodes[u].get("indices", []))
inds_v = set(G_clean.nodes[v].get("indices", []))
shared_edge = G_clean.edges[u, v].get("indices_shared", None) if G_clean.has_edge(u, v) else None
inter = set(shared_edge) if shared_edge is not None else (inds_u & inds_v)
if not inter:
continue
size_u = len(inds_u)
size_v = len(inds_v)
if rule == "to_smaller_cluster":
remove_from = v if size_u <= size_v else u
else:
remove_from = v if size_u >= size_v else u
r, label = remove_from
to_remove = inter - removed_from_node[remove_from]
if not to_remove:
continue
new_idx = [idx for idx in G_clean.nodes[remove_from].get("indices", []) if idx not in to_remove]
G_clean.nodes[remove_from]["indices"] = new_idx
removed_from_node[remove_from].update(to_remove)
for idx in to_remove:
if cl_clean[r, idx] == label:
cl_clean[r, idx] = -1
filtered_G = nx.Graph()
filtered_G.add_nodes_from(G_clean.nodes(data=True))
for u, v, data in G_clean.edges(data=True):
if float(data.get("weight", 0.0)) > float(thresh):
filtered_G.add_edge(u, v, **data)
return G_clean, cl_clean, filtered_G
def get_filtered_cluster_graph(
data: np.ndarray,
G,
cl: np.ndarray,
*,
thresh: float,
rule: str = "to_smaller_cluster",
show_results: bool = True,
hide_biggest: bool = False,
save_path: Optional[str] = None,
):
"""
Threshold the weighted cluster graph and recompute global components.
This is a “post-processing” step for :func:`fiberwise_clustering` once you have
a weighted cluster graph ``G``:
1) Remove edges with ``weight <= thresh`` (keep only edges with ``weight > thresh``).
2) For each removed edge, resolve shared sample indices so that intersection samples are
assigned to only one endpoint cluster (controlled by ``rule``). This updates the node
membership lists and also produces a cleaned label matrix ``cl_clean``.
3) Recompute connected components of the filtered graph and assign samples accordingly.
Parameters
----------
data:
Original data array of shape ``(n_samples, d)`` (used here only for shape validation).
G:
A ``networkx.Graph`` whose edges have a numeric ``"weight"`` attribute and whose nodes
store an ``"indices"`` list of member sample indices.
cl:
The per-fiber DBSCAN label matrix returned by :func:`fiberwise_clustering`
(shape ``(n_fibers, n_samples)``).
thresh:
Edge threshold. Edges with ``weight <= thresh`` are removed.
rule:
How to resolve intersection samples on removed edges:
- ``"to_smaller_cluster"``: keep shared samples in the smaller endpoint cluster
- ``"to_larger_cluster"``: keep shared samples in the larger endpoint cluster
show_results:
If True, plot bar charts for clusters-per-component and points-per-component.
hide_biggest:
If True, drop the largest component from the points-per-component panel.
save_path:
If provided and ``show_results=True``, save a PDF with ``"_summary"`` appended.
Returns
-------
components_filtered:
Integer array of shape ``(n_samples,)`` giving the new global component label per sample.
Unassigned samples remain ``-1``.
filtered_G:
The thresholded graph containing only edges with ``weight > thresh``.
graph_dict_filtered:
Serialization-friendly representation of the filtered graph.
cl_clean:
Cleaned copy of ``cl`` where samples removed from a cluster due to intersection resolution
are set to ``-1`` for that fiber.
comp_inds:
Boolean array of shape ``(n_components, n_samples)`` indicating membership of each
global component.
Raises
------
ValueError
If the graph has edges but none have a ``"weight"`` attribute.
"""
nx = _require_networkx()
import matplotlib.pyplot as plt
data = np.asarray(data)
if data.ndim != 2:
raise ValueError(f"data must be 2D (n_samples, d). Got {data.shape}.")
n_samples = data.shape[0]
G_clean, cl_clean, filtered_G = _separate_intersections_on_removed_edges(
G, cl, thresh=float(thresh), rule=rule
)
comps = list(nx.connected_components(filtered_G))
components_filtered = -1 * np.ones(n_samples, dtype=int)
global_component_counts: List[int] = []
point_counts: List[int] = []
for comp_id, component in enumerate(comps):
global_component_counts.append(len(component))
all_inds: List[int] = []
for node in component:
all_inds.extend(filtered_G.nodes[node].get("indices", []))
all_inds = np.unique(all_inds).astype(int)
components_filtered[all_inds] = comp_id
point_counts.append(int(len(all_inds)))
n_components = len(comps)
comp_inds = np.zeros((n_components, n_samples), dtype=bool)
for j in range(n_components):
comp_inds[j, components_filtered == j] = True
if show_results:
sorted_component_counts = sorted(global_component_counts)
sorted_point_counts = sorted(point_counts)
if hide_biggest and len(sorted_point_counts) > 0:
sorted_point_counts = sorted_point_counts[:-1]
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
axs[0].bar(range(1, len(sorted_component_counts) + 1), sorted_component_counts, edgecolor="black")
axs[0].set_title("Clusters per Global Component")
axs[0].set_xlabel("Global Component Index")
axs[0].set_ylabel("Cluster Count")
axs[0].grid(True, alpha=0.25)
axs[1].bar(range(1, len(sorted_point_counts) + 1), sorted_point_counts, edgecolor="black")
axs[1].set_title("Points per Global Component")
axs[1].set_xlabel("Global Component Index")
axs[1].set_ylabel("Point Count")
axs[1].grid(True, alpha=0.25)
fig.tight_layout()
if save_path is not None:
out = save_path[:-4] + "_summary.pdf" if save_path.lower().endswith(".pdf") else save_path + "_summary.pdf"
fig.savefig(out, format="pdf", bbox_inches="tight")
print(f"✅ Saved filtered summary figure to {out}")
plt.show()
links_dict: Dict[str, List[str]] = defaultdict(list)
for u, v in filtered_G.edges:
ku = f"fiber{u[0]}_cluster{u[1]}"
kv = f"fiber{v[0]}_cluster{v[1]}"
links_dict[ku].append(kv)
links_dict[kv].append(ku)
graph_dict_filtered = {
"nodes": {
f"fiber{node[0]}_cluster{node[1]}": filtered_G.nodes[node].get("indices", [])
for node in filtered_G.nodes
},
"links": dict(links_dict),
"simplices": [[f"fiber{node[0]}_cluster{node[1]}"] for node in filtered_G.nodes],
"meta_data": {
"projection": "custom",
"perc_overlap": 0.5,
"clusterer": "DBSCAN()",
"scaler": "None",
"nerve_min_intersection": 1,
"edge_threshold": float(thresh),
"intersection_rule": rule,
},
"meta_nodes": {},
}
return components_filtered, filtered_G, graph_dict_filtered, cl_clean, comp_inds