diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 2dd9019f..dc13834f 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -453,7 +453,11 @@ def render_points( return sdata - @_deprecation_alias(elements="element", quantiles_for_norm="percentiles_for_norm", version="version 0.3.0") + @_deprecation_alias( + elements="element", + quantiles_for_norm="percentiles_for_norm", + version="version 0.3.0", + ) def render_images( self, element: str | None = None, @@ -464,6 +468,8 @@ def render_images( palette: list[str] | str | None = None, alpha: float | int = 1.0, scale: str | None = None, + multichannel_strategy: str = "stack", + bg_threshold: float = 1e-4, **kwargs: Any, ) -> sd.SpatialData: """ @@ -506,6 +512,12 @@ def render_images( 3) "full": Renders the full image without rasterization. In the case of multiscale images, the highest resolution scale is selected. Note that this may result in long computing times for large images. + multichannel_strategy : str, default "stack" + Method for rendering images with more than 3 channels. + "stack": Samples categorical colors and stacks the channels. + "pca": Uses PCA to reduce the number of channels to 3. + bg_threshold : float, default 1e-4 + Threshold below which values are considered background in the PCA dimred for images with 3+ channels. kwargs Additional arguments to be passed to cmap, norm, and other rendering functions. @@ -531,6 +543,8 @@ def render_images( cmap=cmap, norm=norm, scale=scale, + multichannel_strategy=multichannel_strategy, + bg_threshold=bg_threshold, ) sdata = self._copy() @@ -556,6 +570,7 @@ def render_images( na_color=param_values["na_color"], **kwargs, ) + sdata.plotting_tree[f"{n_steps + 1}_render_images"] = ImageRenderParams( element=element, channel=param_values["channel"], @@ -564,6 +579,8 @@ def render_images( alpha=param_values["alpha"], scale=param_values["scale"], zorder=n_steps, + bg_threshold=param_values["bg_threshold"], + multichannel_strategy=param_values["multichannel_strategy"], ) n_steps += 1 diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 8a1dc737..bad7bc22 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -338,7 +338,7 @@ def _render_shapes( cax = None if aggregate_with_reduction is not None: vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin - vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax + vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1) @@ -843,23 +843,24 @@ def _render_images( sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm) fig_params.fig.colorbar(sm, ax=ax) - # 2) Image has any number of channels but 1 else: layers = {} - for ch_index, c in enumerate(channels): - layers[c] = img.sel(c=c).copy(deep=True).squeeze() - - if not isinstance(render_params.cmap_params, list): - if render_params.cmap_params.norm is not None: - layers[c] = render_params.cmap_params.norm(layers[c]) + for ch_idx, ch in enumerate(channels): + layers[ch] = img.sel(c=ch).copy(deep=True).squeeze() + if isinstance(render_params.cmap_params, list): + ch_norm = render_params.cmap_params[ch_idx].norm + ch_cmap_is_default = render_params.cmap_params[ch_idx].cmap_is_default else: - if render_params.cmap_params[ch_index].norm is not None: - layers[c] = render_params.cmap_params[ch_index].norm(layers[c]) + ch_norm = render_params.cmap_params.norm + ch_cmap_is_default = render_params.cmap_params.cmap_is_default + + if not ch_cmap_is_default and ch_norm is not None: + layers[ch_idx] = ch_norm(layers[ch_idx]) # 2A) Image has 3 channels, no palette info, and no/only one cmap was given if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list): if render_params.cmap_params.cmap_is_default: # -> use RGB - stacked = np.stack([layers[c] for c in channels], axis=-1) + stacked = np.stack([layers[ch] for ch in layers], axis=-1) else: # -> use given cmap for each channel channel_cmaps = [render_params.cmap_params.cmap] * n_channels stacked = ( @@ -892,12 +893,105 @@ def _render_images( # overwrite if n_channels == 2 for intuitive result if n_channels == 2: seed_colors = ["#ff0000ff", "#00ff00ff"] - else: + channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] + colored = np.stack( + [channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)], + 0, + ).sum(0) + colored = colored[:, :, :3] + + elif n_channels == 3: seed_colors = _get_colors_for_categorical_obs(list(range(n_channels))) + channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] + colored = np.stack( + [channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], + 0, + ).sum(0) + colored = colored[:, :, :3] - channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] - colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) - colored = colored[:, :, :3] + else: + if render_params.multichannel_strategy == "stack": + if isinstance(render_params.cmap_params, list): + cmap_is_default = render_params.cmap_params[0].cmap_is_default + else: + cmap_is_default = render_params.cmap_params.cmap_is_default + + if cmap_is_default: + seed_colors = _get_colors_for_categorical_obs(list(range(n_channels))) + else: + # Sample n_channels colors evenly from the colormap + if isinstance(render_params.cmap_params, list): + seed_colors = [ + render_params.cmap_params[i].cmap(i / (n_channels - 1)) for i in range(n_channels) + ] + else: + seed_colors = [ + render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels) + ] + channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] + + # Stack (n_channels, height, width) → (height*width, n_channels) + H, W = next(iter(layers.values())).shape + comp_rgb = np.zeros((H, W, 3), dtype=float) + + # For each channel: map to RGBA, apply constant alpha, then add + for ch_idx, ch in enumerate(channels): + layer_arr = layers[ch] + rgba = channel_cmaps[ch_idx](layer_arr) + rgba[..., 3] = render_params.alpha + comp_rgb += rgba[..., :3] * rgba[..., 3][..., None] + + colored = np.clip(comp_rgb, 0, 1) + logger.info( + f"Your image has {n_channels} channels. Sampling categorical colors and using " + f"multichannel strategy '{render_params.multichannel_strategy}' to render." + ) + + elif render_params.multichannel_strategy == "pca": + from sklearn.decomposition import PCA + + # Stack (n_channels, height, width) → (height*width, n_channels) + H, W = next(iter(layers.values())).shape + pixel_matrix = np.stack( + [ + (layers[ch].data.ravel() if hasattr(layers[ch], "data") else layers[ch].ravel()) + for ch in channels + ], + axis=1, + ) + + # Calculate pixel sums and create mask for background + pixel_sums = np.sum(pixel_matrix, axis=1) + mask = pixel_sums > render_params.bg_threshold + + # Only use non-background pixels for PCA + pca_rgb = np.zeros((H * W, 3)) + if np.any(mask): + # Apply PCA only to non-background pixels + pca_result = PCA(n_components=3).fit_transform(pixel_matrix[mask]) + + # Take absolute values to ensure positive values + pca_result = np.abs(pca_result) + + # Normalize each channel independently to [0,1] + for i in range(3): + channel_min = pca_result[:, i].min() + channel_max = pca_result[:, i].max() + if channel_max > channel_min: + pca_result[:, i] = (pca_result[:, i] - channel_min) / (channel_max - channel_min) + + pca_rgb[mask] = pca_result + # Ensure background pixels stay at zero + pca_rgb[~mask] = 0 + else: + logger.warning("All pixels are below background threshold.") + + colored = pca_rgb.reshape(H, W, 3) + + logger.info( + f"Your image has {n_channels} channels. Using multichannel strategy " + f"'{render_params.multichannel_strategy}' to project to RGB." + ) _ax_show_and_transform( colored, @@ -908,13 +1002,22 @@ def _render_images( ) # 2C) Image has n channels and palette info + # immitating Napari's multi-channel additive blending (SRC_ALPHA, ONE): elif palette is not None and not got_multiple_cmaps: - if len(palette) != n_channels: - raise ValueError("If 'palette' is provided, its length must match the number of channels.") + channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette] - channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette if isinstance(c, str)] - colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0) - colored = colored[:, :, :3] + sample = next(iter(layers.values())) + H, W = sample.shape + comp_rgb = np.zeros((H, W, 3), dtype=float) + + # for each channel: map to RGBA, apply constant alpha, then add + for idx, ch in enumerate(channels): + layer_arr = layers[ch] + rgba = channel_cmaps[idx](layer_arr) + rgba[..., 3] = render_params.alpha + comp_rgb += rgba[..., :3] * rgba[..., 3][..., None] + + colored = np.clip(comp_rgb, 0, 1) _ax_show_and_transform( colored, @@ -924,6 +1027,7 @@ def _render_images( zorder=render_params.zorder, ) + # 2D) Image has n channels, no palette but cmap info elif palette is None and got_multiple_cmaps: channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr] colored = ( diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index b44175c3..24bc967a 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -125,6 +125,8 @@ class ImageRenderParams: percentiles_for_norm: tuple[float | None, float | None] = (None, None) scale: str | None = None zorder: int = 0 + multichannel_strategy: str = "stack" + bg_threshold: float = 1e-4 @dataclass diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index a2e8f767..9485445f 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -1777,6 +1777,16 @@ def _ensure_table_and_layer_exist_in_sdata( if method == "datashader" and ds_reduction is None: param_dict["ds_reduction"] = "sum" + if element_type == "images": + if param_dict.get("multichannel_strategy") not in ["pca", "stack"]: + raise ValueError("Parameter 'multichannel_strategy' must be one of the following: 'pca', 'stack'.") + + if param_dict.get("bg_threshold") is not None: + if not isinstance(param_dict["bg_threshold"], float | int): + raise TypeError("Parameter 'bg_threshold' must be a number.") + if param_dict["bg_threshold"] < 0: + raise ValueError("Parameter 'bg_threshold' must be a positive number.") + return param_dict @@ -2006,7 +2016,7 @@ def _validate_col_for_column_table( table_name = next(iter(tables)) if len(tables) > 1: warnings.warn( - f"Multiple tables contain color column, using {table_name}", + f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.", UserWarning, stacklevel=2, ) @@ -2023,6 +2033,8 @@ def _validate_image_render_params( cmap: list[Colormap | str] | Colormap | str | None, norm: Normalize | None, scale: str | None, + multichannel_strategy: str = "pca", + bg_threshold: float = 1e-4, ) -> dict[str, dict[str, Any]]: param_dict: dict[str, Any] = { "sdata": sdata, @@ -2034,34 +2046,63 @@ def _validate_image_render_params( "cmap": cmap, "norm": norm, "scale": scale, + "multichannel_strategy": multichannel_strategy, + "bg_threshold": bg_threshold, } param_dict = _type_check_params(param_dict, "images") - element_params: dict[str, dict[str, Any]] = {} for el in param_dict["element"]: element_params[el] = {} spatial_element = param_dict["sdata"][el] + # robustly get channel names from image or multiscale image spatial_element_ch = ( - spatial_element.c if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c + spatial_element.c.values if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c.values ) - if (channel := param_dict["channel"]) is not None and ( - (isinstance(channel[0], int) and max([abs(ch) for ch in channel]) <= len(spatial_element_ch)) - or all(ch in spatial_element_ch for ch in channel) - ): + + channel = param_dict["channel"] + if channel is not None: + # Normalize channel to always be a list of str or a list of int + if isinstance(channel, str): + channel = [channel] + + if isinstance(channel, int): + channel = [channel] + + # If channel is a list, ensure all elements are the same type + if not (isinstance(channel, list) and channel and all(isinstance(c, type(channel[0])) for c in channel)): + raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.") + + invalid = [c for c in channel if c not in spatial_element_ch] + if invalid: + raise ValueError( + f"Invalid channel(s): {', '.join(str(c) for c in invalid)}. Valid choices are: {spatial_element_ch}" + ) element_params[el]["channel"] = channel else: element_params[el]["channel"] = None element_params[el]["alpha"] = param_dict["alpha"] - if isinstance(palette := param_dict["palette"], list): + palette = param_dict["palette"] + assert isinstance(palette, list | type(None)) # if present, was converted to list, just to make sure + + if isinstance(palette, list): + # case A: single palette for all channels if len(palette) == 1: palette_length = len(channel) if channel is not None else len(spatial_element_ch) palette = palette * palette_length - if (channel is not None and len(palette) != len(channel)) and len(palette) != len(spatial_element_ch): - palette = None + + # case B: one palette per channel (either given or derived from channel length) + channels_to_use = spatial_element_ch if element_params[el]["channel"] is None else channel + if channels_to_use is not None and len(palette) != len(channels_to_use): + raise ValueError( + f"Palette length ({len(palette)}) does not match channel length " + f"({', '.join(str(c) for c in channels_to_use)})." + ) + element_params[el]["palette"] = palette + element_params[el]["na_color"] = param_dict["na_color"] if (cmap := param_dict["cmap"]) is not None: @@ -2080,6 +2121,9 @@ def _validate_image_render_params( else: element_params[el]["scale"] = scale + element_params[el]["multichannel_strategy"] = param_dict["multichannel_strategy"] + element_params[el]["bg_threshold"] = param_dict["bg_threshold"] + return element_params diff --git a/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png b/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png index c22b9f2b..16bedd33 100644 Binary files a/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png and b/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png differ diff --git a/tests/_images/Images_can_handle_actual_number_of_channels.png b/tests/_images/Images_can_handle_actual_number_of_channels.png new file mode 100644 index 00000000..72485169 Binary files /dev/null and b/tests/_images/Images_can_handle_actual_number_of_channels.png differ diff --git a/tests/_images/Images_can_handle_mixed_channel_order.png b/tests/_images/Images_can_handle_mixed_channel_order.png new file mode 100644 index 00000000..ed4dc136 Binary files /dev/null and b/tests/_images/Images_can_handle_mixed_channel_order.png differ diff --git a/tests/_images/Images_can_handle_multiple_channels_pca_strategy.png b/tests/_images/Images_can_handle_multiple_channels_pca_strategy.png new file mode 100644 index 00000000..12da8a07 Binary files /dev/null and b/tests/_images/Images_can_handle_multiple_channels_pca_strategy.png differ diff --git a/tests/_images/Images_can_handle_multiple_channels_stack_strategy.png b/tests/_images/Images_can_handle_multiple_channels_stack_strategy.png new file mode 100644 index 00000000..d13e3da9 Binary files /dev/null and b/tests/_images/Images_can_handle_multiple_channels_stack_strategy.png differ diff --git a/tests/_images/Images_can_handle_multiple_cmaps.png b/tests/_images/Images_can_handle_multiple_cmaps.png new file mode 100644 index 00000000..d3a41891 Binary files /dev/null and b/tests/_images/Images_can_handle_multiple_cmaps.png differ diff --git a/tests/_images/Images_can_handle_one_channel.png b/tests/_images/Images_can_handle_one_channel.png new file mode 100644 index 00000000..4c3d7699 Binary files /dev/null and b/tests/_images/Images_can_handle_one_channel.png differ diff --git a/tests/_images/Images_can_handle_one_palette_per_img_channel.png b/tests/_images/Images_can_handle_one_palette_per_img_channel.png new file mode 100644 index 00000000..ed4dc136 Binary files /dev/null and b/tests/_images/Images_can_handle_one_palette_per_img_channel.png differ diff --git a/tests/_images/Images_can_handle_one_palette_per_user_channel.png b/tests/_images/Images_can_handle_one_palette_per_user_channel.png new file mode 100644 index 00000000..ed4dc136 Binary files /dev/null and b/tests/_images/Images_can_handle_one_palette_per_user_channel.png differ diff --git a/tests/_images/Images_can_handle_scrambled_channels.png b/tests/_images/Images_can_handle_scrambled_channels.png new file mode 100644 index 00000000..85e3f242 Binary files /dev/null and b/tests/_images/Images_can_handle_scrambled_channels.png differ diff --git a/tests/_images/Images_can_handle_single_channel_default_color.png b/tests/_images/Images_can_handle_single_channel_default_color.png new file mode 100644 index 00000000..4c3d7699 Binary files /dev/null and b/tests/_images/Images_can_handle_single_channel_default_color.png differ diff --git a/tests/_images/Images_can_handle_single_channel_with_cmap.png b/tests/_images/Images_can_handle_single_channel_with_cmap.png new file mode 100644 index 00000000..0e1fecf8 Binary files /dev/null and b/tests/_images/Images_can_handle_single_channel_with_cmap.png differ diff --git a/tests/_images/Images_can_handle_subset_of_channels.png b/tests/_images/Images_can_handle_subset_of_channels.png new file mode 100644 index 00000000..9abf2f73 Binary files /dev/null and b/tests/_images/Images_can_handle_subset_of_channels.png differ diff --git a/tests/_images/Images_can_handle_three_channels_single_cmap.png b/tests/_images/Images_can_handle_three_channels_single_cmap.png new file mode 100644 index 00000000..655fc446 Binary files /dev/null and b/tests/_images/Images_can_handle_three_channels_single_cmap.png differ diff --git a/tests/_images/Images_can_render_a_single_channel_from_image.png b/tests/_images/Images_can_render_a_single_channel_from_image.png index 349f9218..4c3d7699 100644 Binary files a/tests/_images/Images_can_render_a_single_channel_from_image.png and b/tests/_images/Images_can_render_a_single_channel_from_image.png differ diff --git a/tests/_images/Images_can_render_a_single_channel_str_from_image.png b/tests/_images/Images_can_render_a_single_channel_str_from_image.png index 349f9218..4c3d7699 100644 Binary files a/tests/_images/Images_can_render_a_single_channel_str_from_image.png and b/tests/_images/Images_can_render_a_single_channel_str_from_image.png differ diff --git a/tests/_images/Images_can_render_a_single_channel_str_from_multiscale_image.png b/tests/_images/Images_can_render_a_single_channel_str_from_multiscale_image.png index 349f9218..4c3d7699 100644 Binary files a/tests/_images/Images_can_render_a_single_channel_str_from_multiscale_image.png and b/tests/_images/Images_can_render_a_single_channel_str_from_multiscale_image.png differ diff --git a/tests/conftest.py b/tests/conftest.py index 1c085297..09bce08e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -104,6 +104,31 @@ def test_sdata_multiple_images(): return sdata +def make_multichannel_blobs(n=512, nch=6, sigma=0.1, radius=0.6, random_state=0): + x = np.linspace(-1, 1, n) + X, Y = np.meshgrid(x, x) + + angles = np.linspace(0, 2 * np.pi, nch, endpoint=False) + centers = [(radius * np.cos(a), radius * np.sin(a)) for a in angles] + + chans = [] + for cx, cy in centers: + g = np.exp(-((X - cx) ** 2 + (Y - cy) ** 2) / (2 * sigma**2)) + chans.append(g) + return np.stack(chans, axis=-1).transpose(2, 1, 0) + + +@pytest.fixture +def sdata_multichannel() -> SpatialData: + """Creates a SpatialData object with 5 channels arranged in a circle. + + Each channel is a Gaussian blob positioned at evenly spaced angles around a circle. + The blobs have a radius of 0.4 and sigma of 0.2. + """ + data = make_multichannel_blobs(n=256, nch=5, sigma=0.2, radius=0.4) + return sd.SpatialData(images={"multichannel_image": Image2DModel.parse(data)}) + + @pytest.fixture def test_sdata_multiple_images_with_table(): """Creates an sdata object with multiple images.""" diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py index 5484ac4e..3e1c7d38 100644 --- a/tests/pl/test_render_images.py +++ b/tests/pl/test_render_images.py @@ -1,6 +1,7 @@ import dask.array as da import matplotlib import numpy as np +import pytest import scanpy as sc from matplotlib.colors import Normalize from spatial_image import to_spatial_image @@ -42,9 +43,6 @@ def test_plot_can_pass_cmap_list(self, sdata_blobs: SpatialData): cmap=[matplotlib.colormaps["seismic"], matplotlib.colormaps["Reds"], matplotlib.colormaps["Blues"]], ).pl.show() - def test_plot_can_render_a_single_channel_from_image(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_images(element="blobs_image", channel=0).pl.show() - def test_plot_can_render_a_single_channel_from_multiscale_image(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_images(element="blobs_multiscale_image", channel=0).pl.show() @@ -84,11 +82,6 @@ def test_plot_can_pass_color_to_single_channel(self, sdata_blobs: SpatialData): def test_plot_can_pass_cmap_to_single_channel(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_images(element="blobs_image", channel=1, cmap="Reds").pl.show() - def test_plot_can_pass_color_to_each_channel(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_images( - element="blobs_image", channel=[0, 1, 2], palette=["red", "green", "blue"] - ).pl.show() - def test_plot_can_pass_cmap_to_each_channel(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_images( element="blobs_image", channel=[0, 1, 2], cmap=["Reds", "Greens", "Blues"] @@ -132,3 +125,93 @@ def test_plot_can_stick_to_zorder(self, sdata_blobs: SpatialData): def test_plot_can_render_multiscale_image_with_custom_cmap(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_images("blobs_multiscale_image", channel=0, scale="scale2", cmap="Greys").pl.show() + + def test_plot_can_handle_one_palette_per_img_channel(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", palette=["red", "green", "blue"]).pl.show() + + def test_plot_can_handle_one_palette_per_user_channel(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images( + element="blobs_image", channel=[0, 1, 2], palette=["red", "green", "blue"] + ).pl.show() + + def test_plot_can_handle_mixed_channel_order(self, sdata_blobs: SpatialData): + """Test that channels can be specified in any order and are correctly matched with their palette colors""" + sdata_blobs.pl.render_images( + element="blobs_image", channel=[2, 0, 1], palette=["blue", "red", "green"] + ).pl.show() + + def test_plot_can_handle_single_channel_default_color(self, sdata_blobs: SpatialData): + """Test that a single channel without palette uses default color mapping""" + sdata_blobs.pl.render_images(element="blobs_image", channel=0).pl.show() + + def test_plot_can_handle_single_channel_with_cmap(self, sdata_blobs: SpatialData): + """Test that a single channel can use a cmap instead of a palette color""" + sdata_blobs.pl.render_images(element="blobs_image", channel=0, cmap="Reds").pl.show() + + def test_plot_can_handle_one_channel(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channel=[0]).pl.show() + + def test_plot_can_handle_subset_of_channels(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 2]).pl.show() + + def test_plot_can_handle_actual_number_of_channels(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1, 2]).pl.show() + + def test_plot_can_handle_scrambled_channels(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 2, 1]).pl.show() + + def test_plot_can_handle_three_channels_single_cmap(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1, 2], cmap="viridis").pl.show() + + def test_plot_can_handle_multiple_channels_stack_strategy(self, sdata_multichannel: SpatialData): + sdata_multichannel.pl.render_images(element="multichannel_image", multichannel_strategy="stack").pl.show() + + def test_plot_can_handle_multiple_channels_pca_strategy(self, sdata_multichannel: SpatialData): + sdata_multichannel.pl.render_images(element="multichannel_image", multichannel_strategy="pca").pl.show() + + def test_plot_can_handle_multiple_cmaps(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images( + element="blobs_image", channel=[0, 1, 2], cmap=["viridis", "Reds", "Blues"] + ).pl.show() + + +def test_fails_with_palette_and_multiple_cmaps(sdata_blobs: SpatialData): + with pytest.raises(ValueError, match="Both `palette` and `cmap` are specified. Please specify only one of them."): + sdata_blobs.pl.render_images( + element="blobs_image", + channel=[0, 1, 2], + palette=["red", "green", "blue"], + cmap=["viridis", "Reds", "Blues"], + ).pl.show() + + +def test_fail_when_len_palette_is_not_equal_to_len_img_channels(sdata_blobs: SpatialData): + with pytest.raises(ValueError, match="Palette length"): + sdata_blobs.pl.render_images(element="blobs_image", palette=["red", "green"]).pl.show() + + +def test_fail_when_len_palette_is_not_equal_to_len_user_channels(sdata_blobs: SpatialData): + with pytest.raises(ValueError, match="Palette length"): + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1, 2], palette=["red", "green"]).pl.show() + + +def test_fail_when_len_cmap_not_equal_len_img_channels(sdata_blobs): + with pytest.raises(ValueError, match="If 'cmap' is provided, its length must match the number of channels."): + sdata_blobs.pl.render_images(element="blobs_image", cmap=["Reds", "Blues"]).pl.show() + + +def test_fail_when_len_cmap_not_equal_len_user_channels(sdata_blobs): + with pytest.raises(ValueError, match="If 'cmap' is provided, its length must match the number of channels."): + sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1, 2], cmap=["viridis", "Reds"]).pl.show() + + +def test_fail_invalid_multichannel_strategy(sdata_multichannel): + with pytest.raises( + ValueError, match="Parameter 'multichannel_strategy' must be one of the following: 'pca', 'stack'." + ): + sdata_multichannel.pl.render_images(element="multichannel_image", multichannel_strategy="foo").pl.show() + + +def test_fail_channel_index_out_of_range(sdata_blobs): + with pytest.raises(ValueError, match="Invalid channel"): + sdata_blobs.pl.render_images(element="blobs_image", channel=10).pl.show()