diff --git a/docs/release-notes/3764.feature.md b/docs/release-notes/3764.feature.md new file mode 100644 index 0000000000..a9946fe192 --- /dev/null +++ b/docs/release-notes/3764.feature.md @@ -0,0 +1 @@ +{func}`scanpy.pl.dotplot` now supports a `group_colors` parameter for custom per-group coloring with perceptually uniform color gradients via OKLab interpolation. {smaller}`R Baber` diff --git a/pyproject.toml b/pyproject.toml index 5565870587..2a81086bd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,11 +103,12 @@ test = [ # optional storage and processing modes "scanpy[dask]", "zarr>=2.18.7", - # additional tested algorithms + # additional tested features "scanpy[scrublet]", "scanpy[leiden]", "scanpy[skmisc]", "scanpy[dask-ml]", + "scanpy[plotting]", ] doc = [ "sphinx>=8.2.3", @@ -126,7 +127,7 @@ doc = [ "sphinxcontrib-bibtex", "sphinxcontrib-katex", # TODO: remove necessity for being able to import doc-linked classes - "scanpy[paga,dask-ml,leiden]", + "scanpy[paga,dask-ml,leiden,plotting]", "sam-algorithm", ] dev = [ @@ -145,6 +146,8 @@ skmisc = [ "scikit-misc>=0.5.1" ] # highly_variable_genes m harmony = [ "harmonypy" ] # Harmony dataset integration scanorama = [ "scanorama" ] # Scanorama dataset integration scrublet = [ "scikit-image>=0.23.1" ] # Doublet detection with automatic thresholds +# Plotting +plotting = [ "colour-science" ] # Acceleration rapids = [ "cudf>=0.9", "cuml>=0.9", "cugraph>=0.9" ] # GPU accelerated calculation of neighbors dask = [ "dask[array]>=2024.5.1", "anndata[dask]" ] # Use the Dask parallelization engine diff --git a/src/scanpy/plotting/_dotplot.py b/src/scanpy/plotting/_dotplot.py index c8bb7cb492..b903e781f3 100644 --- a/src/scanpy/plotting/_dotplot.py +++ b/src/scanpy/plotting/_dotplot.py @@ -6,12 +6,19 @@ from matplotlib import colormaps from .. import logging as logg -from .._compat import old_positionals +from .._compat import old_positionals, warn from .._settings import settings from .._utils import _doc_params, _empty from ._baseplot_class import BasePlot, doc_common_groupby_plot_args from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm -from ._utils import _dk, check_colornorm, fix_kwds, make_grid_spec, savefig_or_show +from ._utils import ( + _create_white_to_color_gradient, + _dk, + check_colornorm, + fix_kwds, + make_grid_spec, + savefig_or_show, +) if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -156,6 +163,7 @@ def __init__( # noqa: PLR0913 vmax: float | None = None, vcenter: float | None = None, norm: Normalize | None = None, + group_colors: Mapping[str, ColorLike] | None = None, **kwds, ) -> None: BasePlot.__init__( @@ -182,13 +190,74 @@ def __init__( # noqa: PLR0913 **kwds, ) + # Set default style parameters + self.cmap = self.DEFAULT_COLORMAP + self.dot_max = self.DEFAULT_DOT_MAX + self.dot_min = self.DEFAULT_DOT_MIN + self.smallest_dot = self.DEFAULT_SMALLEST_DOT + self.largest_dot = self.DEFAULT_LARGEST_DOT + self.color_on = self.DEFAULT_COLOR_ON + self.size_exponent = self.DEFAULT_SIZE_EXPONENT + self.grid = False + self.plot_x_padding = self.DEFAULT_PLOT_X_PADDING + self.plot_y_padding = self.DEFAULT_PLOT_Y_PADDING + + self.dot_edge_color = self.DEFAULT_DOT_EDGECOLOR + self.dot_edge_lw = self.DEFAULT_DOT_EDGELW + + # set legend defaults + self.color_legend_title = self.DEFAULT_COLOR_LEGEND_TITLE + self.size_title = self.DEFAULT_SIZE_LEGEND_TITLE + self.legends_width = self.DEFAULT_LEGENDS_WIDTH + self.show_size_legend = True + self.show_colorbar = True + + # Store parameters needed by helper methods and prepare the dot data. + self.standard_scale = standard_scale + self.expression_cutoff = expression_cutoff + self.mean_only_expressed = mean_only_expressed + self.dot_color_df, self.dot_size_df = self._prepare_dot_data( + dot_color_df, dot_size_df + ) + self.group_cmaps = self._prepare_group_cmaps(group_colors) + + def _prepare_group_cmaps( + self, group_colors: Mapping[str, ColorLike] | None + ) -> dict[str, Colormap] | None: + if group_colors is None: + return None + group_cmaps = {} + missing_groups = [] + for group in self.dot_color_df.index: + if group in group_colors: + group_cmaps[group] = _create_white_to_color_gradient( + group_colors[group] + ) + else: + group_cmaps[group] = self.cmap + missing_groups.append(group) + if missing_groups: + warn( + f"The following groups will use the default colormap as no " + f"specific colors were assigned: {missing_groups}", + UserWarning, + ) + return group_cmaps + + def _prepare_dot_data( + self, dot_color_df: pd.DataFrame | None, dot_size_df: pd.DataFrame | None + ) -> tuple[pd.DataFrame, pd.DataFrame]: + """Calculate the dataframes for dot size and color. + + Refactored to helper to satisfy complexity checks. + """ # for if category defined by groupby (if any) compute for each var_name # 1. the fraction of cells in the category having a value >expression_cutoff # 2. the mean value over the category # 1. compute fraction of cells having value > expression_cutoff # transform obs_tidy into boolean matrix using the expression_cutoff - obs_bool = self.obs_tidy > expression_cutoff + obs_bool = self.obs_tidy > self.expression_cutoff # compute the sum per group which in the boolean matrix this is the number # of values >expression_cutoff, and divide the result by the total number of @@ -201,7 +270,7 @@ def __init__( # noqa: PLR0913 if dot_color_df is None: # 2. compute mean expression value value - if mean_only_expressed: + if self.mean_only_expressed: dot_color_df = ( self.obs_tidy .mask(~obs_bool) @@ -212,15 +281,15 @@ def __init__( # noqa: PLR0913 else: dot_color_df = self.obs_tidy.groupby(level=0, observed=True).mean() - if standard_scale == "group": + if self.standard_scale == "group": dot_color_df = dot_color_df.sub(dot_color_df.min(axis=1), axis=0) dot_color_df = dot_color_df.div( dot_color_df.max(axis=1), axis=0 ).fillna(0) - elif standard_scale == "var": + elif self.standard_scale == "var": dot_color_df -= dot_color_df.min(axis=0) dot_color_df = (dot_color_df / dot_color_df.max(axis=0)).fillna(0) - elif standard_scale is None: + elif self.standard_scale is None: pass else: logg.warning("Unknown type for standard_scale, ignored") @@ -252,35 +321,17 @@ def __init__( # noqa: PLR0913 # using the order from the doc_size_df dot_color_df = dot_color_df.loc[dot_size_df.index][dot_size_df.columns] - self.dot_color_df, self.dot_size_df = ( - df.loc[ - categories_order if categories_order is not None else self.categories - ] - for df in (dot_color_df, dot_size_df) + # Use self.categories_order if set, else self.categories + order = ( + self.categories_order + if self.categories_order is not None + else self.categories + ) + dot_color_df, dot_size_df = ( + df.loc[order] for df in (dot_color_df, dot_size_df) ) - self.standard_scale = standard_scale - - # Set default style parameters - self.cmap = self.DEFAULT_COLORMAP - self.dot_max = self.DEFAULT_DOT_MAX - self.dot_min = self.DEFAULT_DOT_MIN - self.smallest_dot = self.DEFAULT_SMALLEST_DOT - self.largest_dot = self.DEFAULT_LARGEST_DOT - self.color_on = self.DEFAULT_COLOR_ON - self.size_exponent = self.DEFAULT_SIZE_EXPONENT - self.grid = False - self.plot_x_padding = self.DEFAULT_PLOT_X_PADDING - self.plot_y_padding = self.DEFAULT_PLOT_Y_PADDING - - self.dot_edge_color = self.DEFAULT_DOT_EDGECOLOR - self.dot_edge_lw = self.DEFAULT_DOT_EDGELW - # set legend defaults - self.color_legend_title = self.DEFAULT_COLOR_LEGEND_TITLE - self.size_title = self.DEFAULT_SIZE_LEGEND_TITLE - self.legends_width = self.DEFAULT_LEGENDS_WIDTH - self.show_size_legend = True - self.show_colorbar = True + return dot_color_df, dot_size_df @old_positionals( "cmap", @@ -539,12 +590,27 @@ def _plot_legend(self, legend_ax, return_ax_dict, normalize): # third row: spacer to avoid color and size legend titles to overlap # fourth row: colorbar + # Define base heights for legend components as a fraction of figure height cbar_legend_height = self.min_figure_height * 0.08 size_legend_height = self.min_figure_height * 0.27 spacer_height = self.min_figure_height * 0.3 + # If group_colors is used, dynamically calculate the total height needed for all colorbars + if self.group_cmaps is not None: + # Use a slightly larger height for better spacing + per_cbar_height = self.min_figure_height * 0.12 + n_cbars = len(self.dot_color_df.index) + cbar_legend_height = per_cbar_height * n_cbars + + # Calculate the height of the top spacer to push content down + top_spacer_height = ( + self.height - size_legend_height - cbar_legend_height - spacer_height + ) + top_spacer_height = max(top_spacer_height, 0) # prevent negative height + + # Create the 4-row GridSpec for the legend area height_ratios = [ - self.height - size_legend_height - cbar_legend_height - spacer_height, + top_spacer_height, size_legend_height, spacer_height, cbar_legend_height, @@ -552,17 +618,79 @@ def _plot_legend(self, legend_ax, return_ax_dict, normalize): fig, legend_gs = make_grid_spec( legend_ax, nrows=4, ncols=1, height_ratios=height_ratios ) + # Hide the frame of the main legend container axis for a cleaner look + legend_ax.set_axis_off() + # Plot size legend into the second row of the grid if self.show_size_legend: size_legend_ax = fig.add_subplot(legend_gs[1]) self._plot_size_legend(size_legend_ax) return_ax_dict["size_legend_ax"] = size_legend_ax + # Plot colorbar(s) into the fourth row of the grid if self.show_colorbar: - color_legend_ax = fig.add_subplot(legend_gs[3]) + if self.group_cmaps is None: + color_legend_ax = fig.add_subplot(legend_gs[3]) + self._plot_colorbar(color_legend_ax, normalize) + return_ax_dict["color_legend_ax"] = color_legend_ax + else: + self._plot_stacked_colorbars(fig, legend_gs[3], normalize) + return_ax_dict["color_legend_ax"] = legend_ax + + def _plot_stacked_colorbars(self, fig, colorbar_area_spec, normalize): + """Plot the stacked colorbars legend when using group_colors.""" + import matplotlib as mpl + import matplotlib.colorbar + from matplotlib.cm import ScalarMappable + + plotted_groups = self.dot_color_df.index + groups_to_plot = list(plotted_groups) + n_cbars = len(groups_to_plot) + + # Create a sub-grid just for the colorbars + # Create an empty column to keep colorbars at 3/4 of legend width (1.5 like default with dp.legend_width = 2.0) + colorbar_gs = colorbar_area_spec.subgridspec( + n_cbars, 2, hspace=0.6, width_ratios=[3, 1] + ) + + # Create a dedicated normalizer for the legend + vmin = self.dot_color_df.values.min() + vmax = self.dot_color_df.values.max() + legend_norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) + + for i, group_name in enumerate(groups_to_plot): + ax = fig.add_subplot( + colorbar_gs[i, 0] + ) # Place the colorbar Axes in the first, wider column + + # self.group_cmaps[group_name] is already a Colormap object (or string from fallback) + cmap = self.group_cmaps[group_name] + if isinstance(cmap, str): + cmap = colormaps.get_cmap(cmap) + + mappable = ScalarMappable(norm=legend_norm, cmap=cmap) + + cb = matplotlib.colorbar.Colorbar( + ax, mappable=mappable, orientation="horizontal" + ) + cb.ax.xaxis.set_tick_params(labelsize="small") + + ax.text( + 1.1, + 0.5, + group_name, + ha="left", + va="center", + transform=ax.transAxes, + fontsize="small", + ) + + if i == 0: + cb.ax.set_title(self.color_legend_title, fontsize="small") - self._plot_colorbar(color_legend_ax, normalize) - return_ax_dict["color_legend_ax"] = color_legend_ax + if i < n_cbars - 1: + cb.ax.xaxis.set_ticklabels([]) + cb.ax.xaxis.set_ticks([]) def _mainplot(self, ax: Axes): # work on a copy of the dataframes. This is to avoid changes @@ -589,6 +717,7 @@ def _mainplot(self, ax: Axes): _color_df, ax, cmap=self.cmap, + group_cmaps=self.group_cmaps, color_on=self.color_on, dot_max=self.dot_max, dot_min=self.dot_min, @@ -619,6 +748,7 @@ def _dotplot( # noqa: PLR0912, PLR0913, PLR0915 dot_ax: Axes, *, cmap: Colormap | str | None, + group_cmaps: Mapping[str, Colormap] | None, color_on: Literal["dot", "square"], dot_max: float | None, dot_min: float | None, @@ -741,7 +871,42 @@ def _dotplot( # noqa: PLR0912, PLR0913, PLR0915 size = size * (largest_dot - smallest_dot) + smallest_dot normalize = check_colornorm(vmin, vmax, vcenter, norm) - if color_on == "square": + if group_cmaps is not None: + # Plotting logic for group-specific colormaps + groups_iter = dot_color.columns if are_axes_swapped else dot_color.index + n_vars = dot_color.shape[0] if are_axes_swapped else dot_color.shape[1] + n_groups = len(groups_iter) + + # Here we loop through each group and plot it with its own cmap + for group_idx, group_name in enumerate(groups_iter): + group_cmap = group_cmaps[group_name] + # Handle fallback case where group_cmap might be a string + if isinstance(group_cmap, str): + group_cmap = colormaps.get_cmap(group_cmap) + + # Slice the flattened data arrays correctly depending on orientation + if not are_axes_swapped: + # Slicing data for a whole row + indices = slice(group_idx * n_vars, (group_idx + 1) * n_vars) + else: + # Slicing data for a whole column + indices = slice(group_idx, None, n_groups) + + x_group = x[indices] + y_group = y[indices] + size_group = size[indices] + mean_group = mean_flat[indices] + + color = group_cmap(normalize(mean_group)) + kwds_scatter = fix_kwds( + kwds, + s=size_group, + color=color, + linewidth=edge_lw, + edgecolor=edge_color, + ) + dot_ax.scatter(x_group, y_group, **kwds_scatter) + elif color_on == "square": if edge_color is None: from seaborn.utils import relative_luminance @@ -761,27 +926,28 @@ def _dotplot( # noqa: PLR0912, PLR0913, PLR0915 dot_ax.pcolor(dot_color.values, cmap=cmap, norm=normalize) for axis in ["top", "bottom", "left", "right"]: dot_ax.spines[axis].set_linewidth(1.5) - kwds = fix_kwds( + # Create a temporary kwargs dict for this group's scatter call + # to avoid modifying the original kwds dictionary within the loop. + kwds_scatter = fix_kwds( kwds, s=size, linewidth=edge_lw, facecolor="none", edgecolor=edge_color, ) - dot_ax.scatter(x, y, **kwds) + dot_ax.scatter(x, y, **kwds_scatter) else: edge_color = "none" if edge_color is None else edge_color edge_lw = 0.0 if edge_lw is None else edge_lw - color = cmap(normalize(mean_flat)) - kwds = fix_kwds( + kwds_scatter = fix_kwds( kwds, s=size, color=color, linewidth=edge_lw, edgecolor=edge_color, ) - dot_ax.scatter(x, y, **kwds) + dot_ax.scatter(x, y, **kwds_scatter) y_ticks = np.arange(dot_color.shape[0]) + 0.5 dot_ax.set_yticks(y_ticks) @@ -880,6 +1046,7 @@ def dotplot( # noqa: PLR0913 norm: Normalize | None = None, # Style parameters cmap: Colormap | str | None = DotPlot.DEFAULT_COLORMAP, + group_colors: Mapping[str, ColorLike] | None = None, dot_max: float | None = DotPlot.DEFAULT_DOT_MAX, dot_min: float | None = DotPlot.DEFAULT_DOT_MIN, smallest_dot: float = DotPlot.DEFAULT_SMALLEST_DOT, @@ -918,6 +1085,14 @@ def dotplot( # noqa: PLR0913 mean_only_expressed If True, gene expression is averaged only over the cells expressing the given genes. + group_colors + A mapping of group names to colors. + e.g. `{{'T-cell': 'blue', 'B-cell': '#aa40fc'}}`. + Colors can be specified as any valid matplotlib color. + If `group_colors` is used, a colormap is generated from white + to the given color for each group. + If a group is not present in the dictionary, the value of `cmap` + is used. dot_max If ``None``, the maximum dot size is set to the maximum fraction value found (e.g. 0.6). If given, the value should be a number between 0 and 1. @@ -948,7 +1123,7 @@ def dotplot( # noqa: PLR0913 Examples -------- Create a dot plot using the given markers and the PBMC example dataset grouped by - the category 'bulk_labels'. + the category `'bulk_labels'`. .. plot:: :context: close-figs @@ -958,15 +1133,17 @@ def dotplot( # noqa: PLR0913 markers = ['C1QA', 'PSAP', 'CD79A', 'CD79B', 'CST3', 'LYZ'] sc.pl.dotplot(adata, markers, groupby='bulk_labels', dendrogram=True) - Using var_names as dict: + Grouping `var_names` as well and specifying group colors for `groupby`: .. plot:: :context: close-figs + from matplotlib import cm markers = {{'T-cell': 'CD3D', 'B-cell': 'CD79A', 'myeloid': 'CST3'}} - sc.pl.dotplot(adata, markers, groupby='bulk_labels', dendrogram=True) + group_colors = dict(zip(adata.obs["bulk_labels"].cat.categories, cm.tab10.colors)) + sc.pl.dotplot(adata, markers, groupby='bulk_labels', group_colors=group_colors, dendrogram=True) - Get DotPlot object for fine tuning + Get `DotPlot` object for fine tuning .. plot:: :context: close-figs @@ -974,7 +1151,7 @@ def dotplot( # noqa: PLR0913 dp = sc.pl.dotplot(adata, markers, 'bulk_labels', return_fig=True) dp.add_totals().style(dot_edge_color='black', dot_edge_lw=0.5).show() - The axes used can be obtained using the get_axes() method + The axes used can be obtained using the `get_axes()` method .. code-block:: python @@ -986,6 +1163,14 @@ def dotplot( # noqa: PLR0913 # instead of `cmap` cmap = kwds.pop("color_map", cmap) + # Warn if both cmap and group_colors are specified + if group_colors is not None and cmap != DotPlot.DEFAULT_COLORMAP: + warn( + "Both `cmap` and `group_colors` are specified. " + "`group_colors` takes precedence for the specified groups.", + UserWarning, + ) + dp = DotPlot( adata, var_names, @@ -1005,6 +1190,7 @@ def dotplot( # noqa: PLR0913 var_group_rotation=var_group_rotation, layer=layer, dot_color_df=dot_color_df, + group_colors=group_colors, ax=ax, vmin=vmin, vmax=vmax, @@ -1024,7 +1210,9 @@ def dotplot( # noqa: PLR0913 dot_min=dot_min, smallest_dot=smallest_dot, dot_edge_lw=kwds.pop("linewidth", _empty), - ).legend(colorbar_title=colorbar_title, size_title=size_title) + ).legend( + colorbar_title=colorbar_title, size_title=size_title, width=2.0 + ) # Width 2.0 to avoid size legend circles to overlap if return_fig: return dp diff --git a/src/scanpy/plotting/_utils.py b/src/scanpy/plotting/_utils.py index f0e22d1ede..d648e3569b 100644 --- a/src/scanpy/plotting/_utils.py +++ b/src/scanpy/plotting/_utils.py @@ -25,7 +25,7 @@ from collections.abc import Collection from anndata import AnnData - from matplotlib.colors import Colormap + from matplotlib.colors import Colormap, ListedColormap from matplotlib.figure import Figure from matplotlib.typing import MarkerType from numpy.typing import ArrayLike @@ -41,6 +41,7 @@ "_FontSize", "_FontWeight", "_LegendLoc", + "_create_white_to_color_gradient", "_deprecated_scale", "_dk", "add_colors_for_categorical_sample_annotation", @@ -93,7 +94,7 @@ "upper center", "center", ] -type ColorLike = str | tuple[float, ...] +type ColorLike = str | tuple[float, float, float] | tuple[float, float, float, float] class _AxesSubplot(Axes, axes.SubplotBase): @@ -1105,3 +1106,64 @@ def _deprecated_scale( def _dk(dendrogram: bool | str | None) -> str | None: # noqa: FBT001 """Convert the `dendrogram` parameter to a `dendrogram_key` parameter.""" return None if isinstance(dendrogram, bool) else dendrogram + + +def _create_white_to_color_gradient( + color: ColorLike, n_steps: int = 256 +) -> ListedColormap: + """Generate a perceptually uniform colormap from white to a target color. + + This function uses the OKLab color space for interpolation to ensure that + the brightness of the generated colormap changes uniformly. + + Parameters + ---------- + color + The target color for the gradient. Can be any valid matplotlib color. + n_steps + The number of steps in the colormap. + + Returns + ------- + A `matplotlib.colors.ListedColormap` object. + """ + popt = np.get_printoptions() + try: + import colour + except ImportError: + msg = ( + "Please install the `colour-science` package to use `group_colors`: " + "`pip install colour-science` or `pip install scanpy[plotting]`" + ) + raise ImportError(msg) from None + finally: # https://github.com/colour-science/colour/issues/1388 + np.set_printoptions(legacy=popt["legacy"]) + + from matplotlib.colors import ListedColormap, to_hex + + # Convert the input color to a hex string + hex_color = to_hex(color, keep_alpha=False) + + # Define the color space for interpolation + space = "OKLab" + + # Convert start (white) and end (target color) to the OKLab color space + target_oklab = colour.convert(hex_color, "Hexadecimal", space) + white_oklab = colour.convert("#ffffff", "Hexadecimal", space) + + # Create the gradient through linear interpolation in OKLab + gradient = colour.algebra.lerp( + np.linspace(0, 1, n_steps)[..., np.newaxis], + white_oklab, + target_oklab, + ) + + # Convert the gradient back to sRGB for display + rgb_gradient = colour.convert(gradient, space, "sRGB") + + # Clip values to be within the valid [0, 1] range for RGB + clipped_rgb = np.clip(rgb_gradient, 0, 1) + + return ListedColormap( + clipped_rgb, name=color if isinstance(color, str) else hex_color + ) diff --git a/src/testing/scanpy/_pytest/marks.py b/src/testing/scanpy/_pytest/marks.py index 3060b6272a..8b25f0457d 100644 --- a/src/testing/scanpy/_pytest/marks.py +++ b/src/testing/scanpy/_pytest/marks.py @@ -30,6 +30,7 @@ def _generate_next_value_( mod: str + colour = "colour-science" dask = auto() dask_ml = auto() fa2 = auto() diff --git a/tests/_images/clustermap/expected.png b/tests/_images/clustermap/expected.png index 2ae53a6f8d..24c66502bd 100644 Binary files a/tests/_images/clustermap/expected.png and b/tests/_images/clustermap/expected.png differ diff --git a/tests/_images/clustermap_withcolor/expected.png b/tests/_images/clustermap_withcolor/expected.png index 246b4497d4..007a15624f 100644 Binary files a/tests/_images/clustermap_withcolor/expected.png and b/tests/_images/clustermap_withcolor/expected.png differ diff --git a/tests/_images/dotplot/expected.png b/tests/_images/dotplot/expected.png index 9c4b822369..45f2f5d867 100644 Binary files a/tests/_images/dotplot/expected.png and b/tests/_images/dotplot/expected.png differ diff --git a/tests/_images/dotplot2/expected.png b/tests/_images/dotplot2/expected.png index ea85317b98..c0202c0014 100644 Binary files a/tests/_images/dotplot2/expected.png and b/tests/_images/dotplot2/expected.png differ diff --git a/tests/_images/dotplot3/expected.png b/tests/_images/dotplot3/expected.png index 93b42e24ce..712ae1f1fd 100644 Binary files a/tests/_images/dotplot3/expected.png and b/tests/_images/dotplot3/expected.png differ diff --git a/tests/_images/dotplot_dict/expected.png b/tests/_images/dotplot_dict/expected.png index d805ea94db..cf35b845da 100644 Binary files a/tests/_images/dotplot_dict/expected.png and b/tests/_images/dotplot_dict/expected.png differ diff --git a/tests/_images/dotplot_gene_symbols/expected.png b/tests/_images/dotplot_gene_symbols/expected.png index 1f5c4e0c2f..7cc2f9f78b 100644 Binary files a/tests/_images/dotplot_gene_symbols/expected.png and b/tests/_images/dotplot_gene_symbols/expected.png differ diff --git a/tests/_images/dotplot_group_colors/expected.png b/tests/_images/dotplot_group_colors/expected.png new file mode 100644 index 0000000000..f04aa122cf Binary files /dev/null and b/tests/_images/dotplot_group_colors/expected.png differ diff --git a/tests/_images/dotplot_group_colors_fallback/expected.png b/tests/_images/dotplot_group_colors_fallback/expected.png new file mode 100644 index 0000000000..25003fc60b Binary files /dev/null and b/tests/_images/dotplot_group_colors_fallback/expected.png differ diff --git a/tests/_images/dotplot_group_colors_swap_axes/expected.png b/tests/_images/dotplot_group_colors_swap_axes/expected.png new file mode 100644 index 0000000000..6d22d650e1 Binary files /dev/null and b/tests/_images/dotplot_group_colors_swap_axes/expected.png differ diff --git a/tests/_images/dotplot_groupby_index/expected.png b/tests/_images/dotplot_groupby_index/expected.png index 57f5962ee2..85150912f1 100644 Binary files a/tests/_images/dotplot_groupby_index/expected.png and b/tests/_images/dotplot_groupby_index/expected.png differ diff --git a/tests/_images/dotplot_groupby_list_catorder/expected.png b/tests/_images/dotplot_groupby_list_catorder/expected.png index fd6c3453a0..4548e0c55a 100644 Binary files a/tests/_images/dotplot_groupby_list_catorder/expected.png and b/tests/_images/dotplot_groupby_list_catorder/expected.png differ diff --git a/tests/_images/dotplot_obj/expected.png b/tests/_images/dotplot_obj/expected.png index 746c0bfee6..7cee3fada0 100644 Binary files a/tests/_images/dotplot_obj/expected.png and b/tests/_images/dotplot_obj/expected.png differ diff --git a/tests/_images/dotplot_obj_std_scale_group/expected.png b/tests/_images/dotplot_obj_std_scale_group/expected.png index c62cebf090..bd6ab394b7 100644 Binary files a/tests/_images/dotplot_obj_std_scale_group/expected.png and b/tests/_images/dotplot_obj_std_scale_group/expected.png differ diff --git a/tests/_images/dotplot_obj_std_scale_group_swap_axes/expected.png b/tests/_images/dotplot_obj_std_scale_group_swap_axes/expected.png index 66c51656a6..3e6cc18b01 100644 Binary files a/tests/_images/dotplot_obj_std_scale_group_swap_axes/expected.png and b/tests/_images/dotplot_obj_std_scale_group_swap_axes/expected.png differ diff --git a/tests/_images/dotplot_obj_std_scale_var/expected.png b/tests/_images/dotplot_obj_std_scale_var/expected.png index ad2a9130af..82b4e87c77 100644 Binary files a/tests/_images/dotplot_obj_std_scale_var/expected.png and b/tests/_images/dotplot_obj_std_scale_var/expected.png differ diff --git a/tests/_images/dotplot_obj_std_scale_var_swap_axes/expected.png b/tests/_images/dotplot_obj_std_scale_var_swap_axes/expected.png index 6e0b1a7d8a..d67890976e 100644 Binary files a/tests/_images/dotplot_obj_std_scale_var_swap_axes/expected.png and b/tests/_images/dotplot_obj_std_scale_var_swap_axes/expected.png differ diff --git a/tests/_images/dotplot_obj_swap_axes/expected.png b/tests/_images/dotplot_obj_swap_axes/expected.png index ab366c1a25..70f42e450e 100644 Binary files a/tests/_images/dotplot_obj_swap_axes/expected.png and b/tests/_images/dotplot_obj_swap_axes/expected.png differ diff --git a/tests/_images/dotplot_std_scale_group/expected.png b/tests/_images/dotplot_std_scale_group/expected.png index 72572a40db..e8bafbe644 100644 Binary files a/tests/_images/dotplot_std_scale_group/expected.png and b/tests/_images/dotplot_std_scale_group/expected.png differ diff --git a/tests/_images/dotplot_std_scale_var/expected.png b/tests/_images/dotplot_std_scale_var/expected.png new file mode 100644 index 0000000000..029327f6bd Binary files /dev/null and b/tests/_images/dotplot_std_scale_var/expected.png differ diff --git a/tests/_images/dotplot_totals/expected.png b/tests/_images/dotplot_totals/expected.png index a7d1269c94..384124146d 100644 Binary files a/tests/_images/dotplot_totals/expected.png and b/tests/_images/dotplot_totals/expected.png differ diff --git a/tests/_images/multiple_plots/expected.png b/tests/_images/multiple_plots/expected.png index f0857c4721..2b514ccc30 100644 Binary files a/tests/_images/multiple_plots/expected.png and b/tests/_images/multiple_plots/expected.png differ diff --git a/tests/_images/ranked_genes_dotplot/expected.png b/tests/_images/ranked_genes_dotplot/expected.png index d6c97610f7..bad71a0858 100644 Binary files a/tests/_images/ranked_genes_dotplot/expected.png and b/tests/_images/ranked_genes_dotplot/expected.png differ diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 1a935a560d..73f558766b 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -148,7 +148,7 @@ def test_heatmap_var_as_dict(image_comparer) -> None: @needs.leidenalg -@pytest.mark.parametrize("swap_axes", [True, False]) +@pytest.mark.parametrize("swap_axes", [True, False], ids=["swap_axes", "default"]) def test_heatmap_alignment(*, image_comparer, swap_axes: bool) -> None: """Test that plot elements are well aligned.""" save_and_compare_images = partial(image_comparer, ROOT, tol=15) @@ -369,7 +369,7 @@ def test_dotplot_matrixplot_stacked_violin( save_and_compare_images(id) -@pytest.mark.parametrize("swap_axes", [True, False]) +@pytest.mark.parametrize("swap_axes", [True, False], ids=["swap_axes", "default"]) @pytest.mark.parametrize("standard_scale", ["var", "group", None]) def test_dotplot_obj( image_comparer, standard_scale: Literal["var", "group"] | None, *, swap_axes: bool @@ -1498,7 +1498,7 @@ def pbmc_filtered() -> Callable[[], AnnData]: return pbmc.copy -@pytest.mark.parametrize("use_raw", [True, None]) +@pytest.mark.parametrize("use_raw", [True, None], ids=["use_raw", "default"]) def test_scatter_no_basis_raw(check_same_image, pbmc_filtered, tmp_path, use_raw): """Test scatterplots of raw layer with no basis.""" adata = pbmc_filtered() @@ -1901,3 +1901,135 @@ def test_violin_scale_warning(monkeypatch): def test_dogplot() -> None: """Test that the dogplot function runs without errors.""" sc.pl.dogplot() + + +def test_dotplot_group_colors_raises_error_on_missing_dep( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Check that an informative ImportError is raised when colour-science is missing.""" + import sys + + # Remove colour from sys.modules if present and block reimport + monkeypatch.setitem(sys.modules, "colour", None) + + adata = pbmc68k_reduced() + markers = ["CD79A"] + group_colors = {"CD19+ B": "blue"} + + with pytest.raises(ImportError, match="pip install colour-science"): + sc.pl.dotplot( + adata, + markers, + groupby="bulk_labels", + group_colors=group_colors, + show=False, + ) + + +@needs.colour +@pytest.mark.parametrize("swap_axes", [True, False], ids=["swap_axes", "default"]) +def test_dotplot_group_colors(*, image_comparer, swap_axes: bool) -> None: + """Check group_colors parameter with custom colors per group.""" + save_and_compare_images = partial(image_comparer, ROOT, tol=15) + + adata = pbmc68k_reduced() + + markers = ["SERPINB1", "IGFBP7", "GNLY", "IFITM1", "IMP3", "UBALD2", "LTB", "CLPP"] + + group_colors = { + "CD14+ Monocyte": "gray", + "Dendritic": "#a65628", # brown + "CD8+ Cytotoxic T": "red", + "CD8+/CD45RA+ Naive Cytotoxic": "green", + "CD4+/CD45RA+/CD25- Naive T": "orange", + "CD4+/CD25 T Reg": "blue", + "CD4+/CD45RO+ Memory": "#ff7f00", # orange + "CD19+ B": "#984ea3", # purple + "CD56+ NK": "pink", + "CD34+": "cyan", + } + + sc.pl.dotplot( + adata, + markers, + groupby="bulk_labels", + group_colors=group_colors, + dendrogram=True, + swap_axes=swap_axes, + show=False, + ) + save_and_compare_images(f"dotplot_group_colors{'_swap_axes' if swap_axes else ''}") + + +@needs.colour +def test_dotplot_group_colors_fallback(image_comparer) -> None: + """Check that fallback to default cmap works for groups not in group_colors.""" + save_and_compare_images = partial(image_comparer, ROOT, tol=15) + + adata = pbmc68k_reduced() + + markers = ["SERPINB1", "IGFBP7", "GNLY", "IFITM1"] + + # Intentionally incomplete dict to test fallback + group_colors = { + "CD14+ Monocyte": "gray", + "Dendritic": "purple", + } + + # Expect warning about missing groups since we only specify 2 of 10 groups + with pytest.warns( + UserWarning, match="will use the default colormap as no specific colors" + ): + sc.pl.dotplot( + adata, + markers, + groupby="bulk_labels", + group_colors=group_colors, + cmap="Reds", # Fallback cmap + dendrogram=True, + show=False, + ) + save_and_compare_images("dotplot_group_colors_fallback") + + +@needs.colour +def test_dotplot_group_colors_warns_on_cmap() -> None: + """Check that a warning is raised when both cmap and group_colors are passed.""" + adata = pbmc68k_reduced() + markers = ["CD79A"] + group_colors = {"CD19+ B": "blue"} + + # Expect both warnings: one for cmap+group_colors, one for missing groups + with pytest.warns(UserWarning, match="cmap|colormap") as record: + sc.pl.dotplot( + adata, + markers, + groupby="bulk_labels", + group_colors=group_colors, + cmap="viridis", + show=False, + ) + # Check that we got both expected warnings + warning_messages = [str(w.message) for w in record] + assert any("Both `cmap` and `group_colors`" in msg for msg in warning_messages) + assert any("no specific colors were assigned" in msg for msg in warning_messages) + + +@needs.colour +def test_dotplot_group_colors_warns_on_missing_groups() -> None: + """Check that a warning is raised when not all groups have colors assigned.""" + adata = pbmc68k_reduced() + markers = ["CD79A"] + # Only assign color to one group - others should trigger warning + group_colors = {"CD19+ B": "blue"} + + with pytest.warns( + UserWarning, match="will use the default colormap as no specific colors" + ): + sc.pl.dotplot( + adata, + markers, + groupby="bulk_labels", + group_colors=group_colors, + show=False, + )