diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2a9cc7f821..e206ef46ff 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,19 +64,19 @@ jobs: - name: Install dependencies if: matrix.dependencies-version == null - run: uv pip install --system --compile "scanpy[dev,test-full] @ ." + run: uv pip install --system --compile "scanpy[dev,test-full,leiden] @ ." - name: Install dependencies (no optional features) if: matrix.dependencies-version == 'min-optional' - run: uv pip install --system --compile "scanpy[dev,test-min] @ ." + run: uv pip install --system --compile "scanpy[dev,test-min,leiden] @ ." - name: Install dependencies (minimum versions) if: matrix.dependencies-version == 'minimum' run: | uv pip install --system --compile tomli packaging - deps=$(python3 ci/scripts/min-deps.py pyproject.toml --extra dev test) + deps=$(python3 ci/scripts/min-deps.py pyproject.toml --extra dev test leiden) uv pip install --system --compile $deps "scanpy @ ." - name: Install dependencies (pre-release versions) if: matrix.dependencies-version == 'pre-release' - run: uv pip install -v --system --compile --pre "scanpy[dev,test-full] @ ." "anndata[dev,test] @ git+https://github.com/scverse/anndata.git" + run: uv pip install -v --system --compile --pre "scanpy[dev,test-full,leiden] @ ." "anndata[dev,test] @ git+https://github.com/scverse/anndata.git" - name: Run pytest if: matrix.test-type == null diff --git a/.gitignore b/.gitignore index de85b8a6b7..dcdae7c37d 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,10 @@ /docs/api/generated /docs/external/generated /docs/jupyter_execute +cluster tree demo figure.pptx +mytest.py +_cluster_tree_standelone.py +expected.png # tests /*cache/ @@ -45,3 +49,5 @@ Thumbs.db # asv benchmark files /benchmarks/.asv /benchmarks/data/ +myenv/ +test.py diff --git a/.readthedocs.yml b/.readthedocs.yml index 0ede485a47..ac72c6888c 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -23,3 +23,4 @@ python: - doc - dev # for towncrier - leiden + - pytest-mpl # image comparison diff --git a/src/scanpy/plotting/__init__.py b/src/scanpy/plotting/__init__.py index b7deeb2b84..13e759c0ef 100644 --- a/src/scanpy/plotting/__init__.py +++ b/src/scanpy/plotting/__init__.py @@ -13,6 +13,7 @@ tracksplot, violin, ) +from ._cluster_tree import cluster_decision_tree from ._dotplot import DotPlot, dotplot from ._matrixplot import MatrixPlot, matrixplot from ._preprocessing import filter_genes_dispersion, highly_variable_genes @@ -58,6 +59,7 @@ "DotPlot", "MatrixPlot", "StackedViolin", + "cluster_decision_tree", "clustermap", "correlation_matrix", "dendrogram", diff --git a/src/scanpy/plotting/_cluster_tree.py b/src/scanpy/plotting/_cluster_tree.py new file mode 100644 index 0000000000..98ae30b9d4 --- /dev/null +++ b/src/scanpy/plotting/_cluster_tree.py @@ -0,0 +1,1334 @@ +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, TypedDict, cast + +import igraph as ig +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.patches import FancyArrowPatch, PathPatch +from matplotlib.path import Path + +if TYPE_CHECKING: + from typing import NotRequired + + import networkx as nx + import pandas as pd + from anndata import AnnData + + +class OutputSettings(TypedDict): + output_path: NotRequired[str | None] + draw: NotRequired[bool] + figsize: NotRequired[tuple[float, float] | None] + dpi: NotRequired[int | None] + + +class NodeStyle(TypedDict): + node_size: NotRequired[float] + node_color: NotRequired[str] + node_colormap: NotRequired[list[str] | None] + node_label_fontsize: NotRequired[float] + + +class EdgeStyle(TypedDict): + edge_color: NotRequired[str] + edge_curvature: NotRequired[float] + edge_threshold: NotRequired[float] + show_weight: NotRequired[bool] + edge_label_threshold: NotRequired[float] + edge_label_position: NotRequired[float] + edge_label_fontsize: NotRequired[float] + + +class GeneLabelSettings(TypedDict): + show_gene_labels: NotRequired[bool] + n_top_genes: NotRequired[int] + gene_label_threshold: NotRequired[float] + gene_label_style: NotRequired[dict[str, float]] + top_genes_dict: NotRequired[dict[tuple[str, str], list[str]] | None] + + +class LevelLabelStyle(TypedDict): + level_label_offset: NotRequired[float] + level_label_fontsize: NotRequired[float] + + +class TitleStyle(TypedDict): + title: NotRequired[str] + title_fontsize: NotRequired[float] + + +class LayoutSettings(TypedDict): + node_spacing: NotRequired[float] + level_spacing: NotRequired[float] + orientation: NotRequired[str] + barycenter_sweeps: NotRequired[int] + use_reingold_tilford: NotRequired[bool] + + +class ClusteringSettings(TypedDict): + prefix: NotRequired[str] + edge_threshold: NotRequired[float] + + +class ClusterTreePlotter: + def __init__( + self, + adata: AnnData, + resolutions: list[float], + *, + output_settings: OutputSettings | None = None, + node_style: NodeStyle | None = None, + edge_style: EdgeStyle | None = None, + gene_label_settings: GeneLabelSettings | None = None, + level_label_style: LevelLabelStyle | None = None, + title_style: TitleStyle | None = None, + layout_settings: LayoutSettings | None = None, + clustering_settings: ClusteringSettings | None = None, + ): + """ + Initialize the cluster tree plotter. + + Parameters + ---------- + adata + AnnData object with clustering results. + resolutions + List of resolution values. + output_settings + Output settings (output_path, draw, figsize, dpi). + node_style + Node styling (node_size, node_color, node_colormap, node_label_fontsize). + edge_style + Edge styling (edge_color, edge_curvature, edge_threshold, ...). + gene_label_settings + Gene label settings (show_gene_labels, n_top_genes, ...). + level_label_style + Level label settings (level_label_offset, level_label_fontsize). + title_style + Title settings (title, title_fontsize). + layout_settings + Layout settings (node_spacing, level_spacing). + clustering_settings + Clustering settings (prefix). + """ + self.adata = adata + self.resolutions = resolutions + self.output_settings = self._merge_with_default( + output_settings, self.default_output_settings() + ) + self.node_style = self._merge_with_default( + node_style, self.default_node_style() + ) + self.edge_style = self._merge_with_default( + edge_style, self.default_edge_style() + ) + self.gene_label_settings = self._merge_with_default( + gene_label_settings, self.default_gene_label_settings() + ) + self.level_label_style = self._merge_with_default( + level_label_style, self.default_level_label_style() + ) + self.title_style = self._merge_with_default( + title_style, self.default_title_style() + ) + self.layout_settings = self._merge_with_default( + layout_settings, self.default_layout_settings() + ) + self.clustering_settings = self._merge_with_default( + clustering_settings, self.default_clustering_settings() + ) + + self.settings = {} + self.settings["output"] = self.output_settings + self.settings["node"] = self.node_style + self.settings["edge"] = self.edge_style + self.settings["gene_label"] = self.gene_label_settings + self.settings["level_label"] = self.level_label_style + self.settings["title"] = self.title_style + self.settings["layout"] = self.layout_settings + self.settings["clustering"] = self.clustering_settings + + # Initialize attributes + self.G = None + self.pos = None + self.ax = plt.gca() # Initialize self.ax with the current axis + self.fig = None + + def _merge_with_default(self, user_dict, default_dict): + return {**default_dict, **(user_dict or {})} + + @staticmethod + def default_output_settings() -> OutputSettings: + return {"output_path": None, "draw": False, "figsize": (12, 6), "dpi": 300} + + @staticmethod + def default_node_style() -> NodeStyle: + return { + "node_size": 500, + "node_color": "prefix", + "node_colormap": None, + "node_label_fontsize": 12, + } + + @staticmethod + def default_edge_style() -> EdgeStyle: + return { + "edge_color": "parent", + "edge_curvature": 0.01, + "edge_threshold": 0.01, + "show_weight": True, + "edge_label_threshold": 0.05, + "edge_label_position": 0.8, + "edge_label_fontsize": 8, + } + + @staticmethod + def default_gene_label_settings() -> GeneLabelSettings: + return { + "show_gene_labels": False, + "n_top_genes": 2, + "gene_label_threshold": 0.001, + "gene_label_style": {"offset": 0.5, "fontsize": 8}, + "top_genes_dict": None, + } + + @staticmethod + def default_level_label_style() -> LevelLabelStyle: + return {"level_label_offset": 15, "level_label_fontsize": 12} + + @staticmethod + def default_title_style() -> TitleStyle: + return {"title": "Hierarchical Leiden Clustering", "title_fontsize": 20} + + @staticmethod + def default_layout_settings() -> LayoutSettings: + return { + "node_spacing": 5.0, + "level_spacing": 1.5, + "orientation": "vertical", + "barycenter_sweeps": 2, + "use_reingold_tilford": False, + } + + @staticmethod + def default_clustering_settings() -> ClusteringSettings: + return {"prefix": "leiden_res_", "edge_threshold": 0.05} + + def build_cluster_graph(self) -> None: + """ + Build a directed graph representing hierarchical clustering. + + Uses self.adata.obs, self.settings["clustering"]["prefix"], and self.settings["clustering"]["edge_threshold"]. + Stores the graph in self.G and updates top_genes_dict. + """ + import networkx as nx + + prefix = self.settings["clustering"]["prefix"] + edge_threshold = self.settings["clustering"]["edge_threshold"] + data = self.adata.obs + + # Validate input data + matching_columns = [col for col in data.columns if col.startswith(prefix)] + if not matching_columns: + msg = f"No columns found with prefix '{prefix}' in the DataFrame." + raise ValueError(msg) + + self.G = nx.DiGraph() + + # Extract resolutions from column names + resolutions_col = [col[len(prefix) :] for col in matching_columns] + resolutions_col = sorted( + [float(r) for r in resolutions_col if r.replace(".", "", 1).isdigit()] + ) + + # Add nodes with resolution attribute for layout + for i, res in enumerate(resolutions_col): + clusters = data[f"{prefix}{res}"].unique() + for cluster in sorted(clusters): + node = f"{res}_C{cluster}" + self.G.add_node(node, resolution=i, cluster=cluster) + + # Build edges between consecutive resolutions + for i in range(len(resolutions_col) - 1): + res1 = f"{prefix}{resolutions_col[i]}" + res2 = f"{prefix}{resolutions_col[i + 1]}" + + grouped = ( + data.loc[:, [res1, res2]] + .astype(str) + .groupby(res1, observed=False)[res2] + .value_counts(normalize=True) + ) + + for key, frac in grouped.items(): + parent, child = key if isinstance(key, tuple) else (key, None) + parent = str(parent) if parent is not None else "" + child = str(child) + parent_node = f"{resolutions_col[i]}_C{parent}" + child_node = f"{resolutions_col[i + 1]}_C{child}" + if frac >= edge_threshold: + self.G.add_edge(parent_node, child_node, weight=frac) + + self.settings["gene_label"]["top_genes_dict"] = self.adata.uns.get( + "top_genes_dict", {} + ) + + def compute_cluster_layout(self) -> dict[str, tuple[float, float]]: + """Compute node positions for the cluster decision tree with crossing minimization.""" + import networkx as nx + + if self.G is None: + msg = "Graph is not initialized. Call build_graph() first." + raise ValueError(msg) + + use_reingold_tilford = self.settings["layout"]["use_reingold_tilford"] + node_spacing = self.settings["layout"]["node_spacing"] + level_spacing = self.settings["layout"]["level_spacing"] + orientation = self.settings["layout"]["orientation"] + barycenter_sweeps = self.settings["layout"]["barycenter_sweeps"] + # Step 1: Apply Reingold-Tilford layout or fallback to multipartite layout + if use_reingold_tilford: + pos = self._apply_reingold_tilford_layout(self.G, node_spacing) + else: + pos = nx.multipartite_layout( + self.G, subset_key="resolution", scale=int(node_spacing) + ) + + # Step 2: Adjust orientation + pos = self._adjust_orientation( + pos=cast("dict[str, tuple[float, float]]", pos), orientation=orientation + ) + + # Step 3: Increase vertical spacing + pos = self._adjust_vertical_spacing(pos, level_spacing) + + # Step 4: Barycenter-based reordering to minimize edge crossings + pos = self._barycenter_sweep( + self.G, pos, self.resolutions, node_spacing, barycenter_sweeps + ) + + # Step 5: Optimize node ordering + filtered_edges = [ + (u, v, d["weight"]) + for u, v, d in self.G.edges(data=True) + if d["weight"] >= 0.02 + ] + edges = [(u, v) for u, v, w in filtered_edges] + edges_set = set(edges) + if len(edges_set) < len(edges): + print( + f"Warning: Found {len(edges) - len(edges_set)} duplicate edges in the visualization." + ) + edges = list(edges_set) + self._optimize_node_ordering(self.G, pos, edges, self.resolutions) + self.pos = pos + return self.pos + + def _apply_reingold_tilford_layout( + self, G: nx.DiGraph, node_spacing: float + ) -> dict[str, tuple[float, float]]: + """Apply Reingold-Tilford layout to the graph.""" + import networkx as nx + + try: + nodes = list(G.nodes) + edges = [(u, v) for u, v in G.edges()] + g = ig.Graph() + g.add_vertices(nodes) + g.add_edges([(nodes.index(u), nodes.index(v)) for u, v in edges]) + layout = g.layout_reingold_tilford(root=[0]) + return dict(zip(nodes, layout.coords, strict=False)) + except ImportError as e: + print( + f"igraph not installed or failed: {e}. Falling back to multipartite_layout." + ) + return dict( + nx.multipartite_layout( + G, subset_key="resolution", scale=int(node_spacing) + ) + ) + + def _adjust_orientation( + self, pos: dict[str, tuple[float, float]], orientation: str + ) -> dict[str, tuple[float, float]]: + """Adjust the node positions for the specified orientation.""" + if orientation == "vertical": + return {node: (y, -x) for node, (x, y) in pos.items()} + return pos + + def _adjust_vertical_spacing( + self, pos: dict[str, tuple[float, float]], level_spacing: float + ) -> dict[str, tuple[float, float]]: + """Increase vertical spacing between nodes at different levels.""" + new_pos = {} + for node, (x, y) in pos.items(): + new_y = y * level_spacing + new_pos[node] = (x, new_y) + return new_pos + + def _barycenter_sweep( + self, + G: nx.DiGraph, + pos: dict[str, tuple[float, float]], + resolutions: list, + node_spacing: float, + barycenter_sweeps: int, + ) -> dict[str, tuple[float, float]]: + """Perform barycenter-based reordering to minimize edge crossings.""" + for _sweep in range(barycenter_sweeps): + # Downward sweep: Adjust nodes based on parent positions + pos = self._downward_sweep(G, pos, resolutions, node_spacing) + # Upward sweep: Adjust nodes based on child positions + pos = self._upward_sweep(G, pos, resolutions, node_spacing) + self.pos = pos + return pos + + def _downward_sweep( + self, G: nx.DiGraph, pos: dict, resolutions: list, node_spacing: float + ) -> dict[str, tuple[float, float]]: + """Perform downward sweep in barycenter reordering.""" + for res in resolutions[1:]: + nodes_at_level = [node for node in G.nodes if node.startswith(f"{res}_C")] + node_to_barycenter = {} + for node in nodes_at_level: + parents = list(G.predecessors(node)) + barycenter = ( + np.mean([pos[parent][0] for parent in parents]) if parents else 0 + ) + node_to_barycenter[node] = barycenter + sorted_nodes = sorted( + node_to_barycenter.keys(), key=lambda x: node_to_barycenter[x] + ) + y_level = pos[sorted_nodes[0]][1] + n_nodes = len(sorted_nodes) + x_positions = ( + np.linspace( + -node_spacing * (n_nodes - 1) / 2, + node_spacing * (n_nodes - 1) / 2, + n_nodes, + ) + if n_nodes > 1 + else [0] + ) + for node, x in zip(sorted_nodes, x_positions, strict=True): + pos[node] = (x, y_level) + return pos + + def _upward_sweep( + self, + G: nx.DiGraph, + pos: dict[str, tuple[float, float]], + resolutions: list, + node_spacing: float, + ) -> dict[str, tuple[float, float]]: + """Perform upward sweep in barycenter reordering.""" + for res in reversed(resolutions[:-1]): + nodes_at_level = [node for node in G.nodes if node.startswith(f"{res}_C")] + node_to_barycenter = {} + for node in nodes_at_level: + children = list(G.successors(node)) + barycenter = ( + np.mean([pos[child][0] for child in children]) if children else 0 + ) + node_to_barycenter[node] = barycenter + sorted_nodes = sorted( + node_to_barycenter.keys(), key=lambda x: node_to_barycenter[x] + ) + y_level = pos[sorted_nodes[0]][1] + n_nodes = len(sorted_nodes) + x_positions = ( + np.linspace( + -node_spacing * (n_nodes - 1) / 2, + node_spacing * (n_nodes - 1) / 2, + n_nodes, + ) + if n_nodes > 1 + else [0] + ) + for node, x in zip(sorted_nodes, x_positions, strict=True): + pos[node] = (x, y_level) + return pos + + def _optimize_node_ordering( + self, + G: nx.DiGraph, + pos: dict[str, tuple[float, float]], + edges: list[tuple[str, str]], + resolutions: list, + max_iterations=10, + ) -> None: + """Optimize node ordering at each level to minimize edge crossings by swapping adjacent nodes.""" + # Group nodes by resolution level + level_nodes = { + res_idx: [ + node for node in G.nodes if G.nodes[node]["resolution"] == res_idx + ] + for res_idx in range(len(resolutions)) + } + + for res_idx in range(len(resolutions)): + nodes = level_nodes[res_idx] + if len(nodes) < 2: + continue + + # Sort nodes by their x-coordinate to establish an initial order + nodes.sort(key=lambda node: pos[node][0]) + + iteration = 0 + improved = True + while improved and iteration < max_iterations: + improved = False + for i in range(len(nodes) - 1): + node1, node2 = nodes[i], nodes[i + 1] + x1, y1 = pos[node1] + x2, y2 = pos[node2] + + # Compute current number of crossings + current_crossings = self._count_crossings(G, pos, edges) + + # Swap positions and compute new crossings + pos[node1] = (x2, y1) + pos[node2] = (x1, y2) + new_crossings = self._count_crossings(G, pos, edges) + + # If swapping reduces crossings, keep the swap + if new_crossings < current_crossings: + nodes[i], nodes[i + 1] = nodes[i + 1], nodes[i] + improved = True + else: + # Revert the swap if it doesn't improve crossings + pos[node1] = (x1, y1) + pos[node2] = (x2, y2) + + iteration += 1 + + def _count_crossings( + self, + G: nx.DiGraph, + pos: dict[str, tuple[float, float]], + edges: list[tuple[str, str]], + ) -> int: + """Count the number of edge crossings in the graph based on node positions.""" + crossings = 0 + for i, (u1, v1) in enumerate(edges): + for _j, (u2, v2) in enumerate(edges[i + 1 :], start=i + 1): + # Skip edges at the same level to avoid counting self-crossings + level_u1 = G.nodes[u1]["resolution"] + level_v1 = G.nodes[v1]["resolution"] + level_u2 = G.nodes[u2]["resolution"] + level_v2 = G.nodes[v2]["resolution"] + if level_u1 == level_u2 and level_v1 == level_v2: + continue + + # Get coordinates of the edge endpoints + x1_start, y1_start = pos[u1] + x1_end, y1_end = pos[v1] + x2_start, y2_start = pos[u2] + x2_end, y2_end = pos[v2] + + # Compute the direction vectors of the edges + dx1 = x1_end - x1_start + dy1 = y1_end - y1_start + dx2 = x2_end - x2_start + dy2 = y2_end - y2_start + + # Compute the denominator for the line intersection formula + denom = dx1 * dy2 - dy1 * dx2 + if abs(denom) < 1e-8: # Adjusted threshold for numerical stability + continue + + # Compute intersection parameters s and t + s = ((x2_start - x1_start) * dy2 - (y2_start - y1_start) * dx2) / denom + t = ((x2_start - x1_start) * dy1 - (y2_start - y1_start) * dx1) / denom + + # Check if the intersection occurs within both edge segments + if 0 < s < 1 and 0 < t < 1: + crossings += 1 + + return crossings + + def draw_cluster_tree(self) -> None: + """Draw a hierarchical cluster tree with nodes, edges, and labels.""" + if self.G is None or self.pos is None: + msg = "Graph or positions not initialized. Call build_graph() and compute_cluster_layout() first." + raise ValueError(msg) + if "cluster_resolution_cluster_data" not in self.adata.uns: + msg = "adata.uns['cluster_resolution_cluster_data'] not found." + raise ValueError(msg) + + import networkx as nx + + # Retrieve settings + settings = self._get_draw_settings() + data = settings["data"] + prefix = settings["prefix"] + + # Step 1: Compute Cluster Sizes, Node Sizes, and Node Colors + cluster_sizes = self._compute_cluster_sizes(data, prefix, self.resolutions) + node_sizes = self._scale_node_sizes( + data, prefix, self.resolutions, cluster_sizes, settings["node_size"] + ) + color_schemes = self._generate_node_color_schemes( + data, + prefix, + self.resolutions, + settings["node_color"], + settings["node_colormap"], + ) + node_colors = self._assign_node_colors( + data, prefix, self.resolutions, settings["node_color"], color_schemes + ) + # Step 2: Set up the plot figure and axis + self.fig = plt.figure(figsize=settings["figsize"], dpi=settings["dpi"]) + self.ax = self.fig.add_subplot(111) + # Step 3: Compute Edge Weights, Edge Colors + edges, weights, edge_colors = self._compute_edge_weights_colors( + self.G, settings["edge_threshold"], settings["edge_color"], node_colors + ) + # Step 4: Draw Nodes and Node Labels + node_styles = {"colors": node_colors, "sizes": node_sizes} + node_labels, gene_labels = self._draw_nodes_and_labels( + self.G, + self.pos, + self.resolutions, + node_styles=node_styles, + data=data, + prefix=prefix, + top_genes_dict=self.adata.uns.get("cluster_resolution_top_genes", {}), + show_gene_labels=settings["show_gene_labels"], + n_top_genes=settings["n_top_genes"], + gene_label_threshold=settings["gene_label_threshold"], + ) + nx.draw_networkx_labels( + self.G, + self.pos, + labels=node_labels, + font_size=int(settings["node_label_fontsize"]), + font_color="black", + ) + # Step 5: Draw Gene Labels + gene_label_bottoms = {} + if settings["show_gene_labels"] and gene_labels: + gene_label_bottoms = self._draw_gene_labels( + self.ax, + self.pos, + gene_labels, + node_sizes=node_sizes, + node_colors=node_colors, + offset=settings["gene_label_offset"], + fontsize=settings["gene_label_fontsize"], + ) + # Step 6: Build and Draw Edge Labels + edge_labels = self._build_edge_labels( + self.G, settings["edge_threshold"], settings["edge_label_threshold"] + ) + edge_label_style = { + "position": settings["edge_label_position"], + "fontsize": settings["edge_label_fontsize"], + } + self._draw_edges_with_labels( + self.ax, + self.pos, + edges, + weights, + edge_colors=edge_colors, + node_sizes=node_sizes, + gene_label_bottoms=gene_label_bottoms, + show_gene_labels=settings["show_gene_labels"], + edge_labels=edge_labels, + edge_label_style=edge_label_style, + ) + # Step 7: Draw Level Labels + self._draw_level_labels( + resolutions=self.resolutions, + pos=self.pos, + data=self.adata.uns["cluster_resolution_cluster_data"], + prefix=prefix, + level_label_offset=settings["level_label_offset"], + level_label_fontsize=settings["level_label_fontsize"], + ) + # Step 8: Final Plot Settings + self.ax.set_title(settings["title"], fontsize=settings["title_fontsize"]) + self.ax.axis("off") + # Save or show the plot + if settings["output_path"]: + plt.savefig(settings["output_path"], bbox_inches="tight") + if settings["draw"]: + plt.show() + + def _get_draw_settings(self) -> dict: + """Retrieve settings for drawing the cluster tree.""" + data = self.adata.uns["cluster_resolution_cluster_data"] + return { + "data": data, + "prefix": self.settings["clustering"]["prefix"], + "node_size": self.settings["node"]["node_size"], + "node_color": self.settings["node"]["node_color"], + "node_colormap": self.settings["node"]["node_colormap"], + "figsize": self.settings["output"]["figsize"], + "dpi": self.settings["output"]["dpi"], + "edge_threshold": self.settings["edge"]["edge_threshold"], + "edge_color": self.settings["edge"]["edge_color"], + "show_gene_labels": self.settings["gene_label"]["show_gene_labels"], + "n_top_genes": self.settings["gene_label"]["n_top_genes"], + "gene_label_threshold": self.settings["gene_label"]["gene_label_threshold"], + "node_label_fontsize": self.settings["node"]["node_label_fontsize"], + "gene_label_offset": self.settings["gene_label"]["gene_label_style"][ + "offset" + ], + "gene_label_fontsize": self.settings["gene_label"]["gene_label_style"][ + "fontsize" + ], + "edge_label_threshold": self.settings["edge"]["edge_label_threshold"], + "edge_label_position": self.settings["edge"]["edge_label_position"], + "edge_label_fontsize": self.settings["edge"]["edge_label_fontsize"], + "level_label_offset": self.settings["level_label"]["level_label_offset"], + "level_label_fontsize": self.settings["level_label"][ + "level_label_fontsize" + ], + "title": self.settings["title"]["title"], + "title_fontsize": self.settings["title"]["title_fontsize"], + "output_path": self.settings["output"]["output_path"], + "draw": self.settings["output"]["draw"], + } + + def _compute_cluster_sizes( + self, data: pd.DataFrame, prefix: str, resolutions: list + ) -> dict[str, int]: + """Compute cluster sizes for each node.""" + cluster_sizes = {} + for res in resolutions: + res_key = f"{prefix}{res}" + counts = data[res_key].value_counts() + for cluster, count in counts.items(): + node = f"{res}_C{cluster}" + cluster_sizes[node] = count + return cluster_sizes + + def _scale_node_sizes( + self, + data: pd.DataFrame, + prefix: str, + resolutions: list, + cluster_sizes: dict[str, int], + node_size: float, + ) -> dict[str, float]: + """Scale node sizes based on cluster sizes and node_size setting.""" + node_sizes = {} + for res in resolutions: + nodes_at_level = [ + f"{res}_C{cluster}" for cluster in data[f"{prefix}{res}"].unique() + ] + sizes = np.array([cluster_sizes[node] for node in nodes_at_level]) + if len(sizes) > 1: + min_size, max_size = sizes.min(), sizes.max() + if min_size != max_size: + normalized_sizes = 0.5 + (sizes - min_size) / (max_size - min_size) + else: + normalized_sizes = np.ones_like(sizes) * 0.5 + scaled_sizes = normalized_sizes * node_size + else: + scaled_sizes = np.array([node_size]) + if len(nodes_at_level) != len(scaled_sizes): + msg = ( + f"Length mismatch at resolution {res}: " + f"{len(nodes_at_level)} nodes vs {len(scaled_sizes)} sizes" + ) + raise ValueError(msg) + node_sizes.update(dict(zip(nodes_at_level, scaled_sizes, strict=False))) + return node_sizes + + def _generate_node_color_schemes( + self, + data: pd.DataFrame, + prefix: str, + resolutions: list, + node_color: str | None, + node_colormap: list[str] | None, + ) -> list[str] | dict[str, list] | None: + """Generate color schemes for nodes.""" + import seaborn as sns + + if node_color != "prefix": + return None + + if node_colormap is None: + return { + r: sns.color_palette("Set3", n_colors=data[f"{prefix}{r}"].nunique()) + for r in resolutions + } + + if len(node_colormap) < len(resolutions): + node_colormap = list(node_colormap) + [ + node_colormap[i % len(node_colormap)] + for i in range(len(resolutions) - len(node_colormap)) + ] + + color_schemes = {} + for i, r in enumerate(resolutions): + color_spec = node_colormap[i] + if (isinstance(color_spec, str) and mcolors.is_color_like(color_spec)) or ( + isinstance(color_spec, tuple) + and len(color_spec) in (3, 4) + and all(isinstance(x, int | float) for x in color_spec) + ): + color_schemes[r] = [color_spec] + else: + try: + color_schemes[r] = sns.color_palette( + color_spec, n_colors=data[f"{prefix}{r}"].nunique() + ) + except ValueError: + print( + f"Warning: '{color_spec}' is not valid for {r}. Using 'Set3'." + ) + color_schemes[r] = sns.color_palette( + "Set3", n_colors=data[f"{prefix}{r}"].nunique() + ) + return color_schemes + + def _assign_node_colors( + self, + data: pd.DataFrame, + prefix: str, + resolutions: list, + node_color: str, + color_schemes: list[str] | dict[str, list] | None, + ) -> dict[str, str]: + node_colors = {} + for res in resolutions: + clusters = data[f"{prefix}{res}"].unique() + for cluster in clusters: + node = f"{res}_C{cluster}" + if node_color == "prefix": + if color_schemes is None: + msg = "color_schemes is None but node_color='prefix'" + raise RuntimeError(msg) + colors = color_schemes[res] + node_colors[node] = ( + colors[0] + if len(colors) == 1 + else colors[int(cluster) % len(colors)] + ) + else: + node_colors[node] = node_color + return node_colors + + def _compute_edge_weights_colors( + self, + G: nx.DiGraph, + edge_threshold: float, + edge_color: str, + node_colors: dict, + ) -> tuple[list, list, list]: + """Compute edge weights and colors based on the graph and edge_threshold.""" + edges = [ + (u, v) for u, v, d in G.edges(data=True) if d["weight"] >= edge_threshold + ] + weights = [ + max(d["weight"] * 5, 1.0) + for u, v, d in G.edges(data=True) + if d["weight"] >= edge_threshold + ] + edge_colors = [] + for u, v in edges: + d = G[u][v] + if edge_color == "parent": + edge_colors.append(node_colors[u]) + elif edge_color == "samples": + edge_colors.append(plt.cm.get_cmap("viridis")(d["weight"] / 5)) + else: + edge_colors.append(edge_color) + return edges, weights, edge_colors + + def _draw_nodes_and_labels( + self, + G: nx.DiGraph, + pos: dict[str, tuple[float, float]], + resolutions: list, + *, + node_styles: dict, + data: pd.DataFrame, + prefix: str, + top_genes_dict: dict[tuple[str, str], list[str]], + show_gene_labels: bool, + n_top_genes: int, + gene_label_threshold: float, + ) -> tuple[dict, dict]: + """Draw the nodes and their labels.""" + import networkx as nx + + node_colors = node_styles["colors"] + node_sizes = node_styles["sizes"] + node_labels = {} + gene_labels = {} + for res in resolutions: + clusters = data[f"{prefix}{res}"].unique() + for cluster in clusters: + node = f"{res}_C{cluster}" + color = node_colors[node] + size = node_sizes[node] + nx.draw_networkx_nodes( + G, + pos, + nodelist=[node], + node_size=size, + node_color=color, + edgecolors="none", + ) + node_labels[node] = str(cluster) + if show_gene_labels and top_genes_dict: + res_idx = resolutions.index(float(res)) + if res_idx == 0: + continue # No parent level for the top resolution + parent_res = resolutions[res_idx - 1] + parent_clusters = data[f"{prefix}{parent_res}"].unique() + for parent_cluster in parent_clusters: + parent_node = f"{parent_res}_C{parent_cluster}" + try: + edge_weight = G[parent_node][node]["weight"] + except KeyError: + continue + if edge_weight >= gene_label_threshold: + key = (f"res_{parent_node}", f"res_{node}") + if key in top_genes_dict: + genes = top_genes_dict[key][:n_top_genes] + gene_labels[node] = "\n".join(genes) if genes else "" + return node_labels, gene_labels + + def _draw_gene_labels( + self, + ax, + pos: dict[str, tuple[float, float]], + gene_labels: dict[str, str], + *, + node_sizes: dict[str, float], + node_colors: dict[str, str], + offset: float = 0.2, + fontsize: float = 8, + ) -> dict[str, float]: + """Draw gene labels in boxes below nodes with matching boundary colors.""" + gene_label_bottoms = {} + for node, label in gene_labels.items(): + if label: + x, y = pos[node] + # Compute the node radius in data coordinates + radius = math.sqrt(node_sizes[node] / math.pi) + _fig_width, fig_height = ax.figure.get_size_inches() + radius_fig = radius / (72 * fig_height) + # xlim = ax.get_xlim() + ylim = ax.get_ylim() + data_height = ylim[0] - ylim[1] + radius_data = radius_fig * data_height + + # Position the top of the label box just below the node + box_top_y = y - radius_data - offset + + # Compute the height of the label box based on the number of lines + num_lines = label.count("\n") + 1 + line_height = 0.03 # Reduced line height for better scaling + label_height = num_lines * line_height + 0.04 # Reduced padding + box_center_y = box_top_y - label_height / 2 + + # Draw the label + ax.text( + x, + box_center_y, + label, + fontsize=fontsize, + ha="center", + va="center", + color="black", + bbox=dict( + facecolor="white", + edgecolor=node_colors[node], + boxstyle="round,pad=0.2", # Reduced padding for the box + ), + ) + gene_label_bottoms[node] = box_top_y - label_height + return gene_label_bottoms + + def _build_edge_labels( + self, G: nx.DiGraph, edge_threshold: float, edge_label_threshold: float + ) -> dict: + """Build the edge labels to display on the plot.""" + edge_labels = { + (u, v): f"{w:.2f}" + for u, v, w in [ + (u, v, d["weight"]) + for u, v, d in G.edges(data=True) + if d["weight"] >= edge_threshold + ] + if w >= edge_label_threshold + } + return edge_labels + + def _draw_edges_with_labels( + self, + ax, + pos: dict[str, tuple[float, float]], + edges: list, + weights: list, + *, + edge_colors: list, + node_sizes: dict, + gene_label_bottoms: dict, + show_gene_labels: bool, + edge_labels: dict, + edge_label_style: dict, + ) -> None: + """Draw edges with labels using Bezier curves.""" + edge_label_position = edge_label_style["position"] + edge_label_fontsize = edge_label_style["fontsize"] + for (u, v), w, e_color in zip(edges, weights, edge_colors, strict=False): + x1, y1 = pos[u] + x2, y2 = pos[v] + radius_parent = math.sqrt(node_sizes[u] / math.pi) + radius_child = math.sqrt(node_sizes[v] / math.pi) + _fig_width, fig_height = ax.figure.get_size_inches() + radius_parent_fig = radius_parent / (72 * fig_height) + radius_child_fig = radius_child / (72 * fig_height) + ylim = ax.get_ylim() + data_height = ylim[0] - ylim[1] + radius_parent_data = radius_parent_fig * data_height + radius_child_data = radius_child_fig * data_height + start_y = ( + gene_label_bottoms[u] + if (show_gene_labels and u in gene_label_bottoms and edge_labels.get(u)) + else y1 - radius_parent_data + ) + start_x = x1 + end_x, end_y = x2, y2 - radius_child_data + + p0, p1, p2, p3 = self._draw_curved_edge( + ax, + start_x, + start_y, + end_x, + end_y, + linewidth=w, + color=e_color, + edge_curvature=0.01, + ) + + if (u, v) in edge_labels and p0 is not None: + t = edge_label_position + point = self._evaluate_bezier(t, p0, p1, p2, p3) + label_x, label_y = point[0], point[1] + tangent = self._evaluate_bezier_tangent(t, p0, p1, p2, p3) + tangent_angle = np.arctan2(tangent[1], tangent[0]) + rotation = np.degrees(tangent_angle) + if rotation > 90: + rotation -= 180 + elif rotation < -90: + rotation += 180 + ax.text( + label_x, + label_y, + edge_labels[(u, v)], + fontsize=edge_label_fontsize, + rotation=rotation, + ha="center", + va="center", + bbox=None, + ) + + def _draw_curved_edge( + self, + ax, + start_x: float, + start_y: float, + end_x: float, + end_y: float, + *, + linewidth: float, + color: str, + edge_curvature: float = 0.1, + arrow_size: float = 12, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Draw a gentle S-shaped curved edge between two points with an arrowhead. Retun a tuple of Bézier control points (p0, p1, p2, p3) for label positioning.""" + # Define the start and end points + p0 = np.array([start_x, start_y]) + p3 = np.array([end_x, end_y]) + + # Calculate the vector from start to end + vec = p3 - p0 + length = np.sqrt(vec[0] ** 2 + vec[1] ** 2) + + if length == 0: + empty_array = np.array([[], []]) + return empty_array, empty_array, empty_array, empty_array + + # Unit vector along the edge + unit_vec = vec / length + + # Perpendicular vector for creating the S-shape + perp_vec = np.array([-unit_vec[1], unit_vec[0]]) + + # Define control points for a single cubic Bézier curve with an S-shape, Place control points at 1/3 and 2/3 along the edge, with small perpendicular offsets + offset = length * edge_curvature + p1 = ( + p0 + (p3 - p0) / 3 + perp_vec * offset + ) # First control point (bend outward) + p2 = ( + p0 + 2 * (p3 - p0) / 3 - perp_vec * offset + ) # Second control point (bend inward) + + # Define the path vertices and codes for a single cubic Bézier curve + vertices = [ + (start_x, start_y), # Start point + (p1[0], p1[1]), # First control point + (p2[0], p2[1]), # Second control point + (end_x, end_y), # End point + ] + codes = [ + Path.MOVETO, # Move to start + Path.CURVE4, # Cubic Bézier curve (needs 3 points: p0, p1, p2) + Path.CURVE4, # Continuation of the Bézier curve + Path.CURVE4, # End of the Bézier curve + ] + + # Create the path + path = Path(vertices, codes) + + # Draw the curve + patch = PathPatch( + path, facecolor="none", edgecolor=color, linewidth=linewidth, alpha=0.8 + ) + ax.add_patch(patch) + + # Add an arrowhead at the end + arrow = FancyArrowPatch( + (end_x, end_y), + (end_x, end_y), + arrowstyle="->", + mutation_scale=arrow_size, + color=color, + linewidth=linewidth, + alpha=0.8, + ) + ax.add_patch(arrow) + + return p0, p1, p2, p3 + + def _evaluate_bezier( + self, t: float, p0: np.ndarray, p1: np.ndarray, p2: np.ndarray, p3: np.ndarray + ) -> np.ndarray: + """Evaluate a cubic Bezier curve at parameter t.""" + if not 0 <= t <= 1: + msg = "Parameter t must be in the range [0, 1]" + raise ValueError(msg) + + t2 = t * t + t3 = t2 * t + mt = 1 - t + mt2 = mt * mt + mt3 = mt2 * mt + return mt3 * p0 + 3 * mt2 * t * p1 + 3 * mt * t2 * p2 + t3 * p3 + + def _evaluate_bezier_tangent( + self, t: float, p0: np.ndarray, p1: np.ndarray, p2: np.ndarray, p3: np.ndarray + ) -> np.ndarray: + """Compute the tangent vector of a cubic Bezier curve at parameter t.""" + if not 0 <= t <= 1: + msg = "Parameter t must be in the range [0, 1]" + raise ValueError(msg) + + t2 = t * t + mt = 1 - t + mt2 = mt * mt + return 3 * mt2 * (p1 - p0) + 6 * mt * t * (p2 - p1) + 3 * t2 * (p3 - p2) + + def _draw_level_labels( + self, + resolutions: list, + pos: dict[str, tuple[float, float]], + data: pd.DataFrame, + *, + prefix: str, + level_label_offset: float, + level_label_fontsize: float, + ) -> None: + """Draw level labels for each resolution in the plot.""" + level_positions = {} + for node, (_x, y) in pos.items(): + res = node.split("_")[0] + level_positions[res] = y + + cluster_counts = {} + for res in resolutions: + res_str = f"{res:.1f}" + col_name = f"{prefix}{res_str}" + if col_name not in data.columns: + msg = f"Column {col_name} not found in data. Ensure clustering results are present." + raise ValueError(msg) + num_clusters = len(data[col_name].dropna().unique()) + cluster_counts[res_str] = num_clusters + + min_x = min(p[0] for p in pos.values()) + label_offset = min_x - level_label_offset + for res in resolutions: + res_str = f"{res:.1f}" + label_pos = level_positions[res_str] + num_clusters = cluster_counts[res_str] + label_text = f"Resolution {res_str}:\n {num_clusters} clusters" + plt.text( + label_offset, + label_pos, + label_text, + fontsize=level_label_fontsize, + verticalalignment="center", + bbox=dict(facecolor="white", edgecolor="black", alpha=0.7), + ) + + @staticmethod + def cluster_decision_tree( + adata: AnnData, + resolutions: list[float], + *, + output_settings: dict | OutputSettings | None = None, + node_style: dict | NodeStyle | None = None, + edge_style: dict | EdgeStyle | None = None, + gene_label_settings: dict | GeneLabelSettings | None = None, + level_label_style: dict | LevelLabelStyle | None = None, + title_style: dict | TitleStyle | None = None, + layout_settings: dict | LayoutSettings | None = None, + clustering_settings: dict | ClusteringSettings | None = None, + ) -> nx.DiGraph: + """Plot a hierarchical clustering decision tree based on multiple resolutions. + + This static method performs Leiden clustering at different resolutions (if not already computed), + constructs a decision tree representing hierarchical relationships between clusters, + and visualizes it as a directed graph. Nodes represent clusters at different resolutions, + edges represent transitions between clusters, and edge weights indicate the proportion of + cells transitioning from a parent to a child cluster. + + Parameters + ---------- + adata + Annotated data matrix with clustering results in adata.uns["cluster_resolution_cluster_data"]. + resolutions + List of resolution values for Leiden clustering. + output_settings + Dictionary with output options (output_path, draw, figsize, dpi). + node_style + Dictionary with node appearance (node_size, node_color, node_colormap, node_label_fontsize). + edge_style + Dictionary with edge appearance (edge_color, edge_curvature, edge_threshold, etc.). + gene_label_settings + Dictionary with gene label options (show_gene_labels, n_top_genes, etc.). + level_label_style + Dictionary with level label options (level_label_offset, level_label_fontsize). + title_style + Dictionary with title options (title, title_fontsize). + layout_settings + Dictionary with layout options (orientation, node_spacing, level_spacing, etc.). + clustering_settings + Dictionary with clustering options (prefix, edge_threshold). + + Returns + ------- + Directed graph representing the hierarchical clustering. + + """ + # Run all validations + ClusterTreePlotter._validate_parameters(output_settings, node_style, edge_style) + ClusterTreePlotter._validate_clustering_data( + adata, resolutions, clustering_settings + ) + ClusterTreePlotter._validate_gene_labels(adata, gene_label_settings) + + # Initialize ClusterTreePlotter + plotter = ClusterTreePlotter( + adata, + resolutions, + output_settings=cast("OutputSettings", output_settings), + node_style=cast("NodeStyle", node_style), + edge_style=cast("EdgeStyle", edge_style), + gene_label_settings=cast("GeneLabelSettings", gene_label_settings), + level_label_style=cast("LevelLabelStyle", level_label_style), + title_style=cast("TitleStyle", title_style), + layout_settings=cast("LayoutSettings", layout_settings), + clustering_settings=cast("ClusteringSettings", clustering_settings), + ) + # Build graph and compute layout + plotter.build_cluster_graph() + plotter.compute_cluster_layout() + + # Draw if requested + if (output_settings or {}).get("draw", True) or (output_settings or {}).get( + "output_path" + ): + plotter.draw_cluster_tree() + + if plotter.G is None: + msg = "Graph is not initialized. Ensure build_cluster_graph() has been called." + raise ValueError(msg) + return plotter.G + + @staticmethod + def _validate_parameters(output_settings, node_style, edge_style): + if output_settings: + figsize = output_settings.get("figsize") + if ( + not isinstance(figsize, tuple | list) + or len(figsize) != 2 + or any(dim <= 0 for dim in figsize) + ): + msg = "figsize must be a tuple of two positive numbers (width, height)." + raise ValueError(msg) + + dpi = output_settings.get("dpi", 0) + if not isinstance(dpi, int | float) or dpi <= 0: + msg = "dpi must be a positive number." + raise ValueError(msg) + + if output_settings.get("draw") not in [True, False, None]: + msg = "draw must be True, False, or None." + raise ValueError(msg) + + if node_style: + node_size_val = node_style.get("node_size") + if node_size_val is not None and node_size_val <= 0: + msg = "node_size must be a positive number." + raise ValueError(msg) + + if edge_style and ( + (edge_style.get("edge_threshold", 0)) < 0 + or edge_style.get("edge_label_threshold", 0) < 0 + ): + msg = "edge_threshold and edge_label_threshold must be non-negative." + raise ValueError(msg) + + @staticmethod + def _validate_clustering_data(adata, resolutions, clustering_settings): + if "cluster_resolution_cluster_data" not in adata.uns: + msg = "adata.uns['cluster_resolution_cluster_data'] not found. Run `sc.tl.cluster_resolution_finder` first." + raise ValueError(msg) + if not resolutions: + msg = "You must provide a list of resolutions." + raise ValueError(msg) + + prefix = (clustering_settings or {}).get("prefix", "leiden_res_") + cluster_columns = [f"{prefix}{res}" for res in resolutions] + data = adata.uns["cluster_resolution_cluster_data"] + missing = [col for col in cluster_columns if col not in data.columns] + if missing: + msg = f"Missing clustering columns: {missing}" + raise ValueError(msg) + + @staticmethod + def _validate_gene_labels(adata, gene_label_settings): + if ( + gene_label_settings + and gene_label_settings.get("show_gene_labels", False) + and "cluster_resolution_top_genes" not in adata.uns + ): + msg = "Gene labels requested but `adata.uns['cluster_resolution_top_genes']` not found. Run `sc.tl.cluster_resolution_finder` first." + raise ValueError(msg) + + +cluster_decision_tree = ClusterTreePlotter.cluster_decision_tree diff --git a/src/scanpy/tools/__init__.py b/src/scanpy/tools/__init__.py index e426470b84..5118fb5f9f 100644 --- a/src/scanpy/tools/__init__.py +++ b/src/scanpy/tools/__init__.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING +from ._cluster_resolution import find_cluster_resolution from ._dendrogram import dendrogram from ._diffmap import diffmap from ._dpt import dpt @@ -47,6 +48,7 @@ def __getattr__(name: str) -> Any: "draw_graph", "embedding_density", "filter_rank_genes_groups", + "find_cluster_resolution", "ingest", "leiden", "louvain", diff --git a/src/scanpy/tools/_cluster_resolution.py b/src/scanpy/tools/_cluster_resolution.py new file mode 100644 index 0000000000..85cb26464e --- /dev/null +++ b/src/scanpy/tools/_cluster_resolution.py @@ -0,0 +1,347 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING + +import pandas as pd + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Literal + + from anndata import AnnData + + +def find_cluster_specific_genes( + adata: AnnData, + resolutions: Sequence[float], + *, + prefix: str = "leiden_res_", + method: Literal["wilcoxon"] = "wilcoxon", + n_top_genes: int = 3, + min_cells: int = 2, + deg_mode: Literal["within_parent", "per_resolution"] = "within_parent", + verbose: bool = False, +) -> dict[tuple[str, str], list[str]]: + """Find differentially expressed genes for clusters in two modes.""" + from . import rank_genes_groups + + if deg_mode not in ["within_parent", "per_resolution"]: + msg = "deg_mode must be 'within_parent' or 'per_resolution'" + raise ValueError(msg) + + # Validate resolutions and clustering columns + for res in resolutions: + col = f"{prefix}{res}" + if col not in adata.obs: + msg = f"Column {col} not found in adata.obs" + raise ValueError(msg) + + top_genes_dict: dict[tuple[str, str], list[str]] = {} + + if deg_mode == "within_parent": + top_genes_dict.update( + find_within_parent_degs( + adata, + resolutions, + prefix=prefix, + n_top_genes=n_top_genes, + min_cells=min_cells, + rank_genes_groups=rank_genes_groups, + verbose=verbose, + ) + ) + elif deg_mode == "per_resolution": + top_genes_dict.update( + find_per_resolution_degs( + adata, + resolutions, + prefix=prefix, + n_top_genes=n_top_genes, + min_cells=min_cells, + rank_genes_groups=rank_genes_groups, + verbose=verbose, + ) + ) + + return top_genes_dict + + +def find_within_parent_degs( + adata: AnnData, + resolutions: Sequence[float], + *, + prefix: str, + n_top_genes: int, + min_cells: int, + rank_genes_groups, + verbose: bool = False, +) -> dict[tuple[str, str], list[str]]: + top_genes_dict = {} + + for i, res in enumerate(resolutions[:-1]): + res_key = f"{prefix}{res}" + next_res_key = f"{prefix}{resolutions[i + 1]}" + clusters = adata.obs[res_key].cat.categories + + for cluster in clusters: + cluster_mask = adata.obs[res_key] == cluster + cluster_adata = adata[cluster_mask, :] + + subclusters = cluster_adata.obs[next_res_key].value_counts() + valid_subclusters = subclusters[subclusters >= min_cells].index + + if len(valid_subclusters) < 2: + if verbose: + print( + f"Skipping res_{res}_C{cluster}: < 2 subclusters with >= {min_cells} cells." + ) + continue + + subcluster_mask = cluster_adata.obs[next_res_key].isin(valid_subclusters) + deg_adata = cluster_adata[subcluster_mask, :] + + try: + rank_genes_groups(deg_adata, groupby=next_res_key, method="wilcoxon") + for subcluster in valid_subclusters: + names = deg_adata.uns["rank_genes_groups"]["names"][subcluster] + scores = deg_adata.uns["rank_genes_groups"]["scores"][subcluster] + top_genes = [ + name + for name, score in zip(names, scores, strict=False) + if score > 0 + ][:n_top_genes] + parent_node = f"res_{res}_C{cluster}" + child_node = f"res_{resolutions[i + 1]}_C{subcluster}" + top_genes_dict[(parent_node, child_node)] = top_genes + if verbose: + print(f"{parent_node} -> {child_node}: {top_genes}") + except KeyError as e: + print(f"Key error when processing {parent_node} -> {child_node}: {e}") + continue + except TypeError as e: + print( + f"Type error with the data when processing {parent_node} -> {child_node}: {e}" + ) + continue + + return top_genes_dict + + +def find_per_resolution_degs( + adata: AnnData, + resolutions: Sequence[float], + *, + prefix: str, + n_top_genes: int, + min_cells: int, + rank_genes_groups, + verbose: bool = False, +) -> dict[tuple[str, str], list[str]]: + top_genes_dict = {} + + for i, res in enumerate(resolutions[1:], 1): + res_key = f"{prefix}{res}" + prev_res_key = f"{prefix}{resolutions[i - 1]}" + clusters = adata.obs[res_key].cat.categories + valid_clusters = [ + c for c in clusters if (adata.obs[res_key] == c).sum() >= min_cells + ] + + if not valid_clusters: + if verbose: + print( + f"Skipping resolution {res}: no clusters with >= {min_cells} cells." + ) + continue + + deg_adata = adata[adata.obs[res_key].isin(valid_clusters), :] + try: + rank_genes_groups( + deg_adata, groupby=res_key, method="wilcoxon", reference="rest" + ) + for cluster in valid_clusters: + names = deg_adata.uns["rank_genes_groups"]["names"][cluster] + scores = deg_adata.uns["rank_genes_groups"]["scores"][cluster] + top_genes = [ + name + for name, score in zip(names, scores, strict=False) + if score > 0 + ][:n_top_genes] + parent_cluster = adata.obs[deg_adata.obs[res_key] == cluster][ + prev_res_key + ].mode()[0] + parent_node = f"res_{resolutions[i - 1]}_C{parent_cluster}" + child_node = f"res_{res}_C{cluster}" + top_genes_dict[(parent_node, child_node)] = top_genes + if verbose: + print(f"{parent_node} -> {child_node}: {top_genes}") + except KeyError as e: + print(f"Key error when processing {parent_node} -> {child_node}: {e}") + continue + except TypeError as e: + print( + f"Type error with the data when processing {parent_node} -> {child_node}: {e}" + ) + continue + + return top_genes_dict + + +def find_cluster_resolution( + adata: AnnData, + resolutions: list[float], + *, + prefix: str = "leiden_res_", + method: Literal["wilcoxon"] = "wilcoxon", + n_top_genes: int = 3, + min_cells: int = 2, + deg_mode: Literal["within_parent", "per_resolution"] = "within_parent", + flavor: Literal["igraph"] = "igraph", + n_iterations: int = 2, + verbose: bool = False, +) -> None: + """ + Find clusters across multiple resolutions and identify cluster-specific genes. + + This function performs Leiden clustering at specified resolutions, identifies + differentially expressed genes (DEGs) for clusters, and stores the results in `adata`. + + Params + ------ + adata + The annotated data matrix. + resolutions + List of resolution values for Leiden clustering (e.g., [0.0, 0.2, 0.5]). + prefix + Prefix for clustering keys in `adata.obs` (e.g., "leiden_res_"). + method + Method for differential expression analysis: only "wilcoxon" is supported. + n_top_genes + Number of top genes to identify per child cluster. + min_cells + Minimum number of cells required in a subcluster to include it. + deg_mode + Mode for DEG analysis: "within_parent" (compare child to parent cluster) or + "per_resolution" (compare within each resolution). + flavor + Flavor of Leiden clustering: only "igraph" is supported. + n_iterations + Number of iterations for Leiden clustering. + + Returns + ------- + None + + The following annotations are added to `adata`: + + leiden_res_{resolution} + Cluster assignments for each resolution in `adata.obs`. + cluster_resolution_top_genes + Dictionary mapping (parent_node, child_node) pairs to lists of top marker genes, + stored in `adata.uns`. + + Notes + ----- + This function requires the `igraph` library for Leiden clustering, which is included in the + `leiden` extra. Install it with: ``pip install scanpy[leiden]``. + + Requires `sc.pp.neighbors` to be run on `adata` beforehand. + + Examples + -------- + >>> import scanpy as sc + >>> adata = sc.datasets.pbmc68k_reduced() + >>> sc.pp.neighbors(adata) + >>> sc.tl.find_cluster_resolution(adata, resolutions=[0.0, 0.5]) + >>> sc.pl.cluster_decision_tree(adata, resolutions=[0.0, 0.5]) + """ + import io + + from . import leiden + + # Suppress prints if pytest is running + if "pytest" in sys.modules: + sys.stdout = io.StringIO() + + _validate_cluster_resolution_inputs(adata, resolutions, method, flavor) + + # Run Leiden clustering + for resolution in resolutions: + res_key = f"{prefix}{resolution}" + try: + leiden( + adata, + resolution=resolution, + flavor="igraph", + n_iterations=n_iterations, + key_added=res_key, + ) + if "pytest" not in sys.modules and not hasattr( + sys, "_called_from_test" + ): # Suppress print in tests + print(f"Completed Leiden clustering for resolution {resolution}") + except ValueError as e: + msg = f"Leiden clustering failed at resolution {resolution} due to invalid value: {e}" + raise RuntimeError(msg) from None + except TypeError as e: + msg = f"Leiden clustering failed at resolution {resolution} due to incorrect type: {e}" + raise RuntimeError(msg) from None + except RuntimeError as e: + msg = f"Leiden clustering failed at resolution {resolution}: {e}" + raise RuntimeError(msg) from None + + # Find cluster-specific genes + top_genes_dict = find_cluster_specific_genes( + adata=adata, + resolutions=resolutions, + prefix=prefix, + method=method, + n_top_genes=n_top_genes, + min_cells=min_cells, + deg_mode=deg_mode, + verbose=verbose, + ) + + # Create DataFrame for clusterDecisionTree + try: + cluster_data = pd.DataFrame( + {f"{prefix}{r}": adata.obs[f"{prefix}{r}"] for r in resolutions} + ) + except KeyError as e: + msg = f"Failed to create cluster_data DataFrame: missing column {e}" + raise RuntimeError(msg) from None + except ValueError as e: + msg = f"Failed to create cluster_data DataFrame due to invalid value: {e}" + raise RuntimeError(msg) from None + except TypeError as e: + msg = f"Failed to create cluster_data DataFrame due to incorrect type: {e}" + raise RuntimeError(msg) from None + + # Store the results in adata.uns + adata.uns["cluster_resolution_top_genes"] = top_genes_dict + adata.uns["cluster_resolution_cluster_data"] = cluster_data + + +def _validate_cluster_resolution_inputs( + adata: AnnData, + resolutions: Sequence[float], + method: str, + flavor: str, +) -> None: + """Validate inputs for the find_cluster_resolution function.""" + if not resolutions: + msg = "resolutions list cannot be empty" + raise ValueError(msg) + if not all(isinstance(r, int | float) and r >= 0 for r in resolutions): + msg = "All resolutions must be non-negative numbers" + raise ValueError(msg) + if method != "wilcoxon": + msg = "Only method='wilcoxon' is supported" + raise ValueError(msg) + if flavor != "igraph": + msg = "Only flavor='igraph' is supported" + raise ValueError(msg) + if "neighbors" not in adata.uns: + msg = "adata must have precomputed neighbors (run sc.pp.neighbors first)." + raise ValueError(msg) diff --git a/tests/_images/cluster_decision_tree_plot/expected.png b/tests/_images/cluster_decision_tree_plot/expected.png new file mode 100644 index 0000000000..04a51fce28 Binary files /dev/null and b/tests/_images/cluster_decision_tree_plot/expected.png differ diff --git a/tests/test_cluster_resolution.py b/tests/test_cluster_resolution.py new file mode 100644 index 0000000000..2dbea14d24 --- /dev/null +++ b/tests/test_cluster_resolution.py @@ -0,0 +1,159 @@ +# tests/test_cluster_resolution.py +from __future__ import annotations + +import re + +import pandas as pd +import pytest + +import scanpy as sc +from scanpy.tools._cluster_resolution import find_cluster_resolution +from testing.scanpy._helpers.data import pbmc68k_reduced + + +@pytest.fixture +def adata_for_test(): + """Fixture to provide a preprocessed AnnData object for testing.""" + import scanpy as sc + + adata = pbmc68k_reduced() + sc.pp.neighbors(adata) + return adata + + +# Test 1: Basic functionality +def test_cluster_resolution_finder_basic(adata_for_test): + """Test that cluster_resolution_finder runs without errors and modifies adata.""" + adata = adata_for_test.copy() # Create a copy to avoid modifying the fixture + resolutions = [0.1, 0.5] + result = find_cluster_resolution( + adata, + resolutions, + prefix="leiden_res_", + method="wilcoxon", + n_top_genes=2, + min_cells=2, + deg_mode="within_parent", + flavor="igraph", + n_iterations=2, + ) + + # Check that the function returns None + assert result is None + + # Check that clustering columns were added to adata.obs + for res in resolutions: + assert f"leiden_res_{res}" in adata.obs + + # Check that top_genes_dict was added to adata.uns + assert "cluster_resolution_top_genes" in adata.uns + top_genes_dict = adata.uns["cluster_resolution_top_genes"] + assert isinstance(top_genes_dict, dict) + assert len(top_genes_dict) > 0 + for (parent, child), genes in top_genes_dict.items(): + assert isinstance(parent, str) + assert isinstance(child, str) + assert isinstance(genes, list) + assert len(genes) <= 2 # n_top_genes=2 + + # Check that cluster_data was added to adata.uns + assert "cluster_resolution_cluster_data" in adata.uns + cluster_data = adata.uns["cluster_resolution_cluster_data"] + assert isinstance(cluster_data, pd.DataFrame) + for res in resolutions: + assert f"leiden_res_{res}" in cluster_data.columns + + +# Test 2: Conflicting arguments (invalid deg_mode) +def test_cluster_resolution_finder_invalid_deg_mode(adata_for_test): + """Test that an invalid deg_mode raises a ValueError.""" + adata = adata_for_test.copy() + with pytest.raises( + ValueError, match=r"deg_mode must be 'within_parent' or 'per_resolution'" + ): + find_cluster_resolution( + adata, + resolutions=[0.1], + deg_mode="invalid_mode", # type: ignore[arg-type] + ) + + +# Test 3: Input values that should cause an error (empty resolutions) +def test_cluster_resolution_finder_empty_resolutions(adata_for_test): + """Test that an empty resolutions list raises a ValueError.""" + adata = adata_for_test.copy() + with pytest.raises(ValueError, match=r"resolutions list cannot be empty"): + find_cluster_resolution( + adata, + resolutions=[], + ) + + +# Test 4: Input values that should cause an error (negative resolutions) +def test_cluster_resolution_finder_negative_resolutions(adata_for_test): + """Test that negative resolutions raise a ValueError.""" + adata = adata_for_test.copy() + with pytest.raises( + ValueError, match="All resolutions must be non-negative numbers" + ): + sc.tl.find_cluster_resolution( + adata, + resolutions=[0.1, -0.5], + ) + + +# Test 5: Input values that should cause an error (missing neighbors) +def test_cluster_resolution_finder_missing_neighbors(): + """Test that an adata object without neighbors raises a ValueError.""" + adata = sc.datasets.pbmc68k_reduced() # Create a fresh adata + # Remove neighbors if they exist + if "neighbors" in adata.uns: + del adata.uns["neighbors"] + # Also remove connectivities and distances to ensure leiden doesn't recompute + if "connectivities" in adata.obsp: + del adata.obsp["connectivities"] + if "distances" in adata.obsp: + del adata.obsp["distances"] + with pytest.raises( + ValueError, + match=re.escape( + "adata must have precomputed neighbors (run sc.pp.neighbors first)." + ), + ): + sc.tl.find_cluster_resolution( + adata, + resolutions=[0.1], + ) + + +# Test 6: Helpful error message (unsupported method) +def test_cluster_resolution_finder_unsupported_method(adata_for_test): + """Test that an unsupported method raises a ValueError with a helpful message.""" + adata = adata_for_test.copy() + with pytest.raises(ValueError, match="Only method='wilcoxon' is supported"): + find_cluster_resolution( + adata, + resolutions=[0.1], + method="t-test", # type: ignore[arg-type] + ) + + +# Test 7: Bounds on returned values (n_top_genes) +@pytest.mark.parametrize("n_top_genes", [1, 3]) +def test_cluster_resolution_finder_n_top_genes(adata_for_test, n_top_genes): + """Test that n_top_genes bounds the number of genes stored in adata.uns.""" + adata = adata_for_test.copy() + resolutions = [0.1, 0.5] + result = sc.tl.find_cluster_resolution( + adata, + resolutions, + n_top_genes=n_top_genes, + ) + + # Check that the function returns None + assert result is None + + # Check the number of genes in adata.uns["cluster_resolution_top_genes"] + top_genes_dict = adata.uns["cluster_resolution_top_genes"] + for genes in top_genes_dict.values(): + assert len(genes) <= n_top_genes diff --git a/tests/test_cluster_tree.py b/tests/test_cluster_tree.py new file mode 100644 index 0000000000..36e433fea8 --- /dev/null +++ b/tests/test_cluster_tree.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import networkx as nx +import pytest + +from scanpy.plotting._cluster_tree import cluster_decision_tree +from scanpy.tools._cluster_resolution import find_cluster_resolution +from testing.scanpy._helpers.data import pbmc68k_reduced +from testing.scanpy._pytest.marks import needs + +pytestmark = [needs.leidenalg] + + +@pytest.fixture +def adata_for_test(): + """Fixture to provide a preprocessed AnnData object for testing.""" + import scanpy as sc + + adata = pbmc68k_reduced() + sc.pp.neighbors(adata) + return adata + + +@pytest.fixture +def adata_with_clusters(adata_for_test): + """Fixture providing clustering data and top_genes_dict for cluster_decision_tree.""" + adata = adata_for_test.copy() + resolutions = [0.0, 0.2, 0.5, 1.0, 1.5, 2.0] + find_cluster_resolution( + adata, + resolutions, + prefix="leiden_res_", + n_top_genes=2, + min_cells=2, + deg_mode="within_parent", + flavor="igraph", + n_iterations=2, + ) + return adata, resolutions + + +# Test 1: Basic functionality without gene labels +def test_cluster_decision_tree_basic(adata_with_clusters): + """Test that cluster_decision_tree runs without errors and returns a graph.""" + adata, resolutions = adata_with_clusters + + G = cluster_decision_tree( + adata=adata, + resolutions=resolutions, + ) + + assert isinstance(G, nx.DiGraph) + assert len(G.nodes) > 0 + assert len(G.edges) > 0 + + for node in G.nodes: + assert "resolution" in G.nodes[node] + assert "cluster" in G.nodes[node] + + +# Test 2: Basic functionality with gene labels +def test_cluster_decision_tree_with_gene_labels(adata_with_clusters): + """Test that cluster_decision_tree handles gene labels when show_gene_labels is True.""" + adata, resolutions = adata_with_clusters + + G = cluster_decision_tree( + adata=adata, + resolutions=resolutions, + gene_label_settings={ + "show_gene_labels": True, + "n_top_genes": 2, + }, + ) + + assert isinstance(G, nx.DiGraph) + assert len(G.nodes) > 0 + assert len(G.edges) > 0 + + +# Test 3: Error condition (show_gene_labels=True but top_genes_dict missing in adata.uns) +def test_cluster_decision_tree_missing_top_genes_dict(adata_with_clusters): + """Test that show_gene_labels=True raises an error if top_genes_dict is missing in adata.uns.""" + adata, resolutions = adata_with_clusters + + del adata.uns["cluster_resolution_top_genes"] + + with pytest.raises( + ValueError, + match=r"Gene labels requested but `adata\.uns\['cluster_resolution_top_genes'\]` not found\. Run `sc\.tl\.cluster_resolution_finder` first\.", + ): + cluster_decision_tree( + adata=adata, + resolutions=resolutions, + gene_label_settings={"show_gene_labels": True}, + ) + + +# Test 4: Conflicting arguments (negative node_size) +def test_cluster_decision_tree_negative_node_size(adata_with_clusters): + """Test that a negative node_size raises a ValueError.""" + adata, resolutions = adata_with_clusters + + with pytest.raises(ValueError, match=r"node_size must be a positive number."): + cluster_decision_tree( + adata=adata, resolutions=resolutions, node_style={"node_size": -100} + ) + + +# Test 5: Error conditions (invalid figsize) +def test_cluster_decision_tree_invalid_figsize(adata_with_clusters): + """Test that an invalid figsize raises a ValueError.""" + adata, resolutions = adata_with_clusters + + with pytest.raises( + ValueError, + match=r"figsize must be a tuple of two positive numbers \(width, height\)\.", + ): + cluster_decision_tree( + adata=adata, + resolutions=resolutions, + output_settings={"figsize": (0, 5)}, # Invalid width + ) + + +# Test 6: Helpful error message (missing cluster_data in adata.uns) +def test_cluster_decision_tree_missing_cluster_data(adata_with_clusters): + """Test that a missing cluster_data in adata.uns raises a ValueError.""" + adata, resolutions = adata_with_clusters + + del adata.uns["cluster_resolution_cluster_data"] + + with pytest.raises( + ValueError, + match=r"adata\.uns\['cluster_resolution_cluster_data'\] not found\. Run `sc\.tl\.cluster_resolution_finder` first\.", + ): + cluster_decision_tree( + adata=adata, + resolutions=resolutions, + ) + + +# Test 7: Orthogonal effects (draw argument) +def test_cluster_decision_tree_draw_argument(adata_with_clusters): + """Test that the draw argument doesn't affect the graph output.""" + adata, resolutions = adata_with_clusters + + G_no_draw = cluster_decision_tree( + adata=adata, + resolutions=resolutions, + ) + + from unittest import mock + + with mock.patch("matplotlib.pyplot.show"): + G_draw = cluster_decision_tree(adata=adata, resolutions=resolutions) + + assert nx.is_isomorphic(G_no_draw, G_draw) + assert G_no_draw.nodes(data=True) == G_draw.nodes(data=True) + + def make_edge_hashable(edges): + return { + ( + u, + v, + tuple( + (k, tuple(v) if isinstance(v, list) else v) + for k, v in sorted(d.items()) + ), + ) + for u, v, d in edges + } + + assert make_edge_hashable(G_no_draw.edges(data=True)) == make_edge_hashable( + G_draw.edges(data=True) + ) + + +# Test 8: Equivalent inputs (node_colormap) +@pytest.mark.parametrize( + "node_colormap", + [ + None, + ["Set3", "Set3"], + ], +) +def test_cluster_decision_tree_node_colormap(adata_with_clusters, node_colormap): + """Test that node_colormap=None and a uniform colormap produce similar results.""" + adata, resolutions = adata_with_clusters + + G = cluster_decision_tree( + adata=adata, + resolutions=resolutions, + node_style={"node_colormap": node_colormap}, + ) + assert isinstance(G, nx.DiGraph) + assert len(G.nodes) > 0 + + +# Test 9: Bounds on gene labels (n_top_genes) +@pytest.mark.parametrize("n_top_genes", [1, 3]) +def test_cluster_decision_tree_n_top_genes(adata_with_clusters, n_top_genes): + """Test that n_top_genes parameter works correctly.""" + adata, resolutions = adata_with_clusters + + G = cluster_decision_tree( + adata=adata, + resolutions=resolutions, + gene_label_settings={"show_gene_labels": True, "n_top_genes": n_top_genes}, + ) + + assert isinstance(G, nx.DiGraph) + assert len(G.nodes) > 0 + assert len(G.edges) > 0 + + for node in G.nodes: + if "top_genes" in G.nodes[node]: + assert len(G.nodes[node]["top_genes"]) == n_top_genes