Skip to content

Fix multi-channel handling #451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 37 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c1ac439
mvp
timtreis Apr 22, 2025
6d0b305
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2025
8ee3565
Merge branch 'main' into bugfix/issue450-feature-pca-aggregation-for-…
timtreis Apr 22, 2025
64b360d
added bg-senstive PCA
timtreis Apr 22, 2025
07d07db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2025
0b7fec2
adjusted img render params
timtreis Apr 22, 2025
693de0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2025
11aa663
fixed bg being excluded from PCA
timtreis Apr 22, 2025
62c8c1b
merge
timtreis Apr 22, 2025
9b733bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2025
dc53d5e
covering more test cases2
timtreis May 11, 2025
c07baf5
fixed typecheck
timtreis May 11, 2025
a5fd498
updated tests
timtreis May 11, 2025
59f9d99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2025
9383c62
updated tests
timtreis May 11, 2025
d5e412f
merge
timtreis May 11, 2025
65e03d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2025
1d8b9e6
more test fixed
timtreis May 11, 2025
022640c
merge
timtreis May 11, 2025
8544579
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2025
2db28dc
more fixes
timtreis May 11, 2025
6775124
fixed bug
timtreis May 11, 2025
7806448
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2025
f73c8e2
bugfix
timtreis May 11, 2025
3e73648
merge
timtreis May 11, 2025
7712eb2
bugfix
timtreis May 11, 2025
4af92f5
changed default from PCA to stack
timtreis May 11, 2025
ccc4954
bugfix
timtreis May 11, 2025
5a8a711
fade into transparent now
timtreis May 11, 2025
de09ddc
test
timtreis May 11, 2025
0d14846
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2025
8de5bbe
revert
timtreis May 11, 2025
e81d776
merge
timtreis May 11, 2025
9a0e241
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2025
99bc6f5
revert
timtreis May 11, 2025
f67ecd4
merge
timtreis May 11, 2025
2e991c5
fixed img
timtreis May 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,11 @@ def render_points(

return sdata

@_deprecation_alias(elements="element", quantiles_for_norm="percentiles_for_norm", version="version 0.3.0")
@_deprecation_alias(
elements="element",
quantiles_for_norm="percentiles_for_norm",
version="version 0.3.0",
)
def render_images(
self,
element: str | None = None,
Expand All @@ -464,6 +468,8 @@ def render_images(
palette: list[str] | str | None = None,
alpha: float | int = 1.0,
scale: str | None = None,
multichannel_strategy: str = "stack",
bg_threshold: float = 1e-4,
**kwargs: Any,
) -> sd.SpatialData:
"""
Expand Down Expand Up @@ -506,6 +512,12 @@ def render_images(
3) "full": Renders the full image without rasterization. In the case of
multiscale images, the highest resolution scale is selected. Note that
this may result in long computing times for large images.
multichannel_strategy : str, default "stack"
Method for rendering images with more than 3 channels.
"stack": Samples categorical colors and stacks the channels.
"pca": Uses PCA to reduce the number of channels to 3.
bg_threshold : float, default 1e-4
Threshold below which values are considered background in the PCA dimred for images with 3+ channels.
kwargs
Additional arguments to be passed to cmap, norm, and other rendering functions.

Expand All @@ -531,6 +543,8 @@ def render_images(
cmap=cmap,
norm=norm,
scale=scale,
multichannel_strategy=multichannel_strategy,
bg_threshold=bg_threshold,
)

sdata = self._copy()
Expand All @@ -556,6 +570,7 @@ def render_images(
na_color=param_values["na_color"],
**kwargs,
)

sdata.plotting_tree[f"{n_steps + 1}_render_images"] = ImageRenderParams(
element=element,
channel=param_values["channel"],
Expand All @@ -564,6 +579,8 @@ def render_images(
alpha=param_values["alpha"],
scale=param_values["scale"],
zorder=n_steps,
bg_threshold=param_values["bg_threshold"],
multichannel_strategy=param_values["multichannel_strategy"],
)
n_steps += 1

Expand Down
144 changes: 124 additions & 20 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def _render_shapes(
cax = None
if aggregate_with_reduction is not None:
vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin
vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax
vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax
if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax:
# value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
# under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
Expand Down Expand Up @@ -843,23 +843,24 @@ def _render_images(
sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm)
fig_params.fig.colorbar(sm, ax=ax)

# 2) Image has any number of channels but 1
else:
layers = {}
for ch_index, c in enumerate(channels):
layers[c] = img.sel(c=c).copy(deep=True).squeeze()

if not isinstance(render_params.cmap_params, list):
if render_params.cmap_params.norm is not None:
layers[c] = render_params.cmap_params.norm(layers[c])
for ch_idx, ch in enumerate(channels):
layers[ch] = img.sel(c=ch).copy(deep=True).squeeze()
if isinstance(render_params.cmap_params, list):
ch_norm = render_params.cmap_params[ch_idx].norm
ch_cmap_is_default = render_params.cmap_params[ch_idx].cmap_is_default
else:
if render_params.cmap_params[ch_index].norm is not None:
layers[c] = render_params.cmap_params[ch_index].norm(layers[c])
ch_norm = render_params.cmap_params.norm
ch_cmap_is_default = render_params.cmap_params.cmap_is_default

if not ch_cmap_is_default and ch_norm is not None:
layers[ch_idx] = ch_norm(layers[ch_idx])

# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list):
if render_params.cmap_params.cmap_is_default: # -> use RGB
stacked = np.stack([layers[c] for c in channels], axis=-1)
stacked = np.stack([layers[ch] for ch in layers], axis=-1)
else: # -> use given cmap for each channel
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
stacked = (
Expand Down Expand Up @@ -892,12 +893,105 @@ def _render_images(
# overwrite if n_channels == 2 for intuitive result
if n_channels == 2:
seed_colors = ["#ff0000ff", "#00ff00ff"]
else:
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
colored = np.stack(
[channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)],
0,
).sum(0)
colored = colored[:, :, :3]

elif n_channels == 3:
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
colored = np.stack(
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
0,
).sum(0)
colored = colored[:, :, :3]

channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)
colored = colored[:, :, :3]
else:
if render_params.multichannel_strategy == "stack":
if isinstance(render_params.cmap_params, list):
cmap_is_default = render_params.cmap_params[0].cmap_is_default
else:
cmap_is_default = render_params.cmap_params.cmap_is_default

if cmap_is_default:
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
else:
# Sample n_channels colors evenly from the colormap
if isinstance(render_params.cmap_params, list):
seed_colors = [
render_params.cmap_params[i].cmap(i / (n_channels - 1)) for i in range(n_channels)
]
else:
seed_colors = [
render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels)
]
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]

# Stack (n_channels, height, width) → (height*width, n_channels)
H, W = next(iter(layers.values())).shape
comp_rgb = np.zeros((H, W, 3), dtype=float)

# For each channel: map to RGBA, apply constant alpha, then add
for ch_idx, ch in enumerate(channels):
layer_arr = layers[ch]
rgba = channel_cmaps[ch_idx](layer_arr)
rgba[..., 3] = render_params.alpha
comp_rgb += rgba[..., :3] * rgba[..., 3][..., None]

colored = np.clip(comp_rgb, 0, 1)
logger.info(
f"Your image has {n_channels} channels. Sampling categorical colors and using "
f"multichannel strategy '{render_params.multichannel_strategy}' to render."
)

elif render_params.multichannel_strategy == "pca":
from sklearn.decomposition import PCA

# Stack (n_channels, height, width) → (height*width, n_channels)
H, W = next(iter(layers.values())).shape
pixel_matrix = np.stack(
[
(layers[ch].data.ravel() if hasattr(layers[ch], "data") else layers[ch].ravel())
for ch in channels
],
axis=1,
)

# Calculate pixel sums and create mask for background
pixel_sums = np.sum(pixel_matrix, axis=1)
mask = pixel_sums > render_params.bg_threshold

# Only use non-background pixels for PCA
pca_rgb = np.zeros((H * W, 3))
if np.any(mask):
# Apply PCA only to non-background pixels
pca_result = PCA(n_components=3).fit_transform(pixel_matrix[mask])

# Take absolute values to ensure positive values
pca_result = np.abs(pca_result)

# Normalize each channel independently to [0,1]
for i in range(3):
channel_min = pca_result[:, i].min()
channel_max = pca_result[:, i].max()
if channel_max > channel_min:
pca_result[:, i] = (pca_result[:, i] - channel_min) / (channel_max - channel_min)

pca_rgb[mask] = pca_result
# Ensure background pixels stay at zero
pca_rgb[~mask] = 0
else:
logger.warning("All pixels are below background threshold.")

colored = pca_rgb.reshape(H, W, 3)

logger.info(
f"Your image has {n_channels} channels. Using multichannel strategy "
f"'{render_params.multichannel_strategy}' to project to RGB."
)

_ax_show_and_transform(
colored,
Expand All @@ -908,13 +1002,22 @@ def _render_images(
)

# 2C) Image has n channels and palette info
# immitating Napari's multi-channel additive blending (SRC_ALPHA, ONE):
elif palette is not None and not got_multiple_cmaps:
if len(palette) != n_channels:
raise ValueError("If 'palette' is provided, its length must match the number of channels.")
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette]

channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette if isinstance(c, str)]
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
colored = colored[:, :, :3]
sample = next(iter(layers.values()))
H, W = sample.shape
comp_rgb = np.zeros((H, W, 3), dtype=float)

# for each channel: map to RGBA, apply constant alpha, then add
for idx, ch in enumerate(channels):
layer_arr = layers[ch]
rgba = channel_cmaps[idx](layer_arr)
rgba[..., 3] = render_params.alpha
comp_rgb += rgba[..., :3] * rgba[..., 3][..., None]

colored = np.clip(comp_rgb, 0, 1)

_ax_show_and_transform(
colored,
Expand All @@ -924,6 +1027,7 @@ def _render_images(
zorder=render_params.zorder,
)

# 2D) Image has n channels, no palette but cmap info
elif palette is None and got_multiple_cmaps:
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]
colored = (
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class ImageRenderParams:
percentiles_for_norm: tuple[float | None, float | None] = (None, None)
scale: str | None = None
zorder: int = 0
multichannel_strategy: str = "stack"
bg_threshold: float = 1e-4


@dataclass
Expand Down
64 changes: 54 additions & 10 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,6 +1777,16 @@ def _ensure_table_and_layer_exist_in_sdata(
if method == "datashader" and ds_reduction is None:
param_dict["ds_reduction"] = "sum"

if element_type == "images":
if param_dict.get("multichannel_strategy") not in ["pca", "stack"]:
raise ValueError("Parameter 'multichannel_strategy' must be one of the following: 'pca', 'stack'.")

if param_dict.get("bg_threshold") is not None:
if not isinstance(param_dict["bg_threshold"], float | int):
raise TypeError("Parameter 'bg_threshold' must be a number.")
if param_dict["bg_threshold"] < 0:
raise ValueError("Parameter 'bg_threshold' must be a positive number.")

return param_dict


Expand Down Expand Up @@ -2006,7 +2016,7 @@ def _validate_col_for_column_table(
table_name = next(iter(tables))
if len(tables) > 1:
warnings.warn(
f"Multiple tables contain color column, using {table_name}",
f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.",
UserWarning,
stacklevel=2,
)
Expand All @@ -2023,6 +2033,8 @@ def _validate_image_render_params(
cmap: list[Colormap | str] | Colormap | str | None,
norm: Normalize | None,
scale: str | None,
multichannel_strategy: str = "pca",
bg_threshold: float = 1e-4,
) -> dict[str, dict[str, Any]]:
param_dict: dict[str, Any] = {
"sdata": sdata,
Expand All @@ -2034,34 +2046,63 @@ def _validate_image_render_params(
"cmap": cmap,
"norm": norm,
"scale": scale,
"multichannel_strategy": multichannel_strategy,
"bg_threshold": bg_threshold,
}
param_dict = _type_check_params(param_dict, "images")

element_params: dict[str, dict[str, Any]] = {}
for el in param_dict["element"]:
element_params[el] = {}
spatial_element = param_dict["sdata"][el]

# robustly get channel names from image or multiscale image
spatial_element_ch = (
spatial_element.c if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c
spatial_element.c.values if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c.values
)
if (channel := param_dict["channel"]) is not None and (
(isinstance(channel[0], int) and max([abs(ch) for ch in channel]) <= len(spatial_element_ch))
or all(ch in spatial_element_ch for ch in channel)
):

channel = param_dict["channel"]
if channel is not None:
# Normalize channel to always be a list of str or a list of int
if isinstance(channel, str):
channel = [channel]

if isinstance(channel, int):
channel = [channel]

# If channel is a list, ensure all elements are the same type
if not (isinstance(channel, list) and channel and all(isinstance(c, type(channel[0])) for c in channel)):
raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.")

invalid = [c for c in channel if c not in spatial_element_ch]
if invalid:
raise ValueError(
f"Invalid channel(s): {', '.join(str(c) for c in invalid)}. Valid choices are: {spatial_element_ch}"
)
element_params[el]["channel"] = channel
else:
element_params[el]["channel"] = None

element_params[el]["alpha"] = param_dict["alpha"]

if isinstance(palette := param_dict["palette"], list):
palette = param_dict["palette"]
assert isinstance(palette, list | type(None)) # if present, was converted to list, just to make sure

if isinstance(palette, list):
# case A: single palette for all channels
if len(palette) == 1:
palette_length = len(channel) if channel is not None else len(spatial_element_ch)
palette = palette * palette_length
if (channel is not None and len(palette) != len(channel)) and len(palette) != len(spatial_element_ch):
palette = None

# case B: one palette per channel (either given or derived from channel length)
channels_to_use = spatial_element_ch if element_params[el]["channel"] is None else channel
if channels_to_use is not None and len(palette) != len(channels_to_use):
raise ValueError(
f"Palette length ({len(palette)}) does not match channel length "
f"({', '.join(str(c) for c in channels_to_use)})."
)

element_params[el]["palette"] = palette

element_params[el]["na_color"] = param_dict["na_color"]

if (cmap := param_dict["cmap"]) is not None:
Expand All @@ -2080,6 +2121,9 @@ def _validate_image_render_params(
else:
element_params[el]["scale"] = scale

element_params[el]["multichannel_strategy"] = param_dict["multichannel_strategy"]
element_params[el]["bg_threshold"] = param_dict["bg_threshold"]

return element_params


Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Images_can_handle_one_channel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_render_a_single_channel_from_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading