Skip to content

Commit c6d6153

Browse files
timtreispre-commit-ci[bot]melonora
authored
Refactor of colorbar and norm logic (#346)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wouter-Michiel Vierdag <[email protected]>
1 parent 6cef5df commit c6d6153

20 files changed

+33
-76
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@ and this project adheres to [Semantic Versioning][].
1818

1919
- Lowered RMSE-threshold for plot-based tests from 45 to 15 (#344)
2020
- When subsetting to `groups`, `NA` isn't automatically added to legend (#344)
21+
- When rendering a single image channel, a colorbar is now shown (#346)
22+
- Removed `percentiles_for_norm` parameter (#346)
23+
- Changed `norm` to no longer accept bools, only `mpl.colors.Normalise` or `None` (#346)
2124

2225
### Fixed
2326

2427
- Filtering with `groups` now preserves original cmap (#344)
2528
- Non-selected `groups` are now not shown in `na_color` (#344)
29+
- Several issues associated with `norm` and `colorbar` (#346)
2630

2731
## [0.2.5] - 2024-08-23
2832

src/spatialdata_plot/pl/basic.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def render_shapes(
166166
outline_color: str | list[float] = "#000000ff",
167167
outline_alpha: float | int = 0.0,
168168
cmap: Colormap | str | None = None,
169-
norm: bool | Normalize = False,
169+
norm: Normalize | None = None,
170170
scale: float | int = 1.0,
171171
method: str | None = None,
172172
table_name: str | None = None,
@@ -301,7 +301,7 @@ def render_points(
301301
palette: list[str] | str | None = None,
302302
na_color: ColorLike | None = "default",
303303
cmap: Colormap | str | None = None,
304-
norm: None | Normalize = None,
304+
norm: Normalize | None = None,
305305
size: float | int = 1.0,
306306
method: str | None = None,
307307
table_name: str | None = None,
@@ -422,7 +422,6 @@ def render_images(
422422
na_color: ColorLike | None = "default",
423423
palette: list[str] | str | None = None,
424424
alpha: float | int = 1.0,
425-
percentiles_for_norm: tuple[float, float] | None = None,
426425
scale: str | None = None,
427426
**kwargs: Any,
428427
) -> sd.SpatialData:
@@ -457,8 +456,6 @@ def render_images(
457456
Palette to color images. The number of palettes should be equal to the number of channels.
458457
alpha : float | int, default 1.0
459458
Alpha value for the images. Must be a numeric between 0 and 1.
460-
percentiles_for_norm : tuple[float, float] | None
461-
Optional pair of floats (pmin < pmax, 0-100) which will be used for quantile normalization.
462459
scale : str | None
463460
Influences the resolution of the rendering. Possibilities include:
464461
1) `None` (default): The image is rasterized to fit the canvas size. For
@@ -486,20 +483,14 @@ def render_images(
486483
cmap=cmap,
487484
norm=norm,
488485
scale=scale,
489-
percentiles_for_norm=percentiles_for_norm,
490486
)
491487

492488
sdata = self._copy()
493489
sdata = _verify_plotting_tree(sdata)
494490
n_steps = len(sdata.plotting_tree.keys())
495491

496492
for element, param_values in params_dict.items():
497-
# cmap_params = _prepare_cmap_norm(
498-
# cmap=params_dict[element]["cmap"],
499-
# norm=norm,
500-
# na_color=params_dict[element]["na_color"], # type: ignore[arg-type]
501-
# **kwargs,
502-
# )
493+
503494
cmap_params: list[CmapParams] | CmapParams
504495
if isinstance(cmap, list):
505496
cmap_params = [
@@ -525,7 +516,6 @@ def render_images(
525516
cmap_params=cmap_params,
526517
palette=param_values["palette"],
527518
alpha=param_values["alpha"],
528-
percentiles_for_norm=param_values["percentiles_for_norm"],
529519
scale=param_values["scale"],
530520
zorder=n_steps,
531521
)

src/spatialdata_plot/pl/render.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import datashader as ds
1010
import geopandas as gpd
1111
import matplotlib
12+
import matplotlib.pyplot as plt
1213
import matplotlib.transforms as mtransforms
1314
import numpy as np
1415
import pandas as pd
@@ -47,7 +48,6 @@
4748
_maybe_set_colors,
4849
_mpl_ax_contains_elements,
4950
_multiscale_to_spatial_image,
50-
_normalize,
5151
_rasterize_if_necessary,
5252
_set_color_source_vec,
5353
to_hex,
@@ -128,6 +128,7 @@ def _render_shapes(
128128
shapes = shapes.reset_index()
129129
color_source_vector = color_source_vector[mask]
130130
color_vector = color_vector[mask]
131+
131132
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
132133

133134
# Using dict.fromkeys here since set returns in arbitrary order
@@ -255,9 +256,13 @@ def _render_shapes(
255256
for path in _cax.get_paths():
256257
path.vertices = trans.transform(path.vertices)
257258

258-
# Sets the limits of the colorbar to the values instead of [0, 1]
259-
if not norm and not values_are_categorical:
260-
_cax.set_clim(min(color_vector), max(color_vector))
259+
if not values_are_categorical:
260+
# If the user passed a Normalize object with vmin/vmax we'll use those,
261+
# # if not we'll use the min/max of the color_vector
262+
_cax.set_clim(
263+
vmin=render_params.cmap_params.norm.vmin or min(color_vector),
264+
vmax=render_params.cmap_params.norm.vmax or max(color_vector),
265+
)
261266

262267
if len(set(color_vector)) != 1 or list(set(color_vector))[0] != to_hex(render_params.cmap_params.na_color):
263268
# necessary in case different shapes elements are annotated with one table
@@ -603,11 +608,6 @@ def _render_images(
603608
if n_channels == 1 and not isinstance(render_params.cmap_params, list):
604609
layer = img.sel(c=channels[0]).squeeze() if isinstance(channels[0], str) else img.isel(c=channels[0]).squeeze()
605610

606-
if render_params.percentiles_for_norm != (None, None):
607-
layer = _normalize(
608-
layer, pmin=render_params.percentiles_for_norm[0], pmax=render_params.percentiles_for_norm[1], clip=True
609-
)
610-
611611
if render_params.cmap_params.norm: # type: ignore[attr-defined]
612612
layer = render_params.cmap_params.norm(layer) # type: ignore[attr-defined]
613613

@@ -623,20 +623,16 @@ def _render_images(
623623

624624
_ax_show_and_transform(layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder)
625625

626+
if legend_params.colorbar:
627+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm)
628+
fig_params.fig.colorbar(sm, ax=ax)
629+
626630
# 2) Image has any number of channels but 1
627631
else:
628632
layers = {}
629633
for ch_index, c in enumerate(channels):
630634
layers[c] = img.sel(c=c).copy(deep=True).squeeze()
631635

632-
if render_params.percentiles_for_norm != (None, None):
633-
layers[c] = _normalize(
634-
layers[c],
635-
pmin=render_params.percentiles_for_norm[0],
636-
pmax=render_params.percentiles_for_norm[1],
637-
clip=True,
638-
)
639-
640636
if not isinstance(render_params.cmap_params, list):
641637
if render_params.cmap_params.norm is not None:
642638
layers[c] = render_params.cmap_params.norm(layers[c])

src/spatialdata_plot/pl/utils.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def _get_scalebar(
489489

490490
def _prepare_cmap_norm(
491491
cmap: Colormap | str | None = None,
492-
norm: Normalize | bool = False,
492+
norm: Normalize | None = None,
493493
na_color: ColorLike | None = None,
494494
vmin: float | None = None,
495495
vmax: float | None = None,
@@ -1623,29 +1623,6 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
16231623
if scale < 0:
16241624
raise ValueError("Parameter 'scale' must be a positive number.")
16251625

1626-
if (percentiles_for_norm := param_dict.get("percentiles_for_norm")) is None:
1627-
percentiles_for_norm = (None, None)
1628-
elif not (isinstance(percentiles_for_norm, (list, tuple)) or len(percentiles_for_norm) != 2):
1629-
raise TypeError("Parameter 'percentiles_for_norm' must be a list or tuple of exactly two floats or None.")
1630-
elif not all(
1631-
isinstance(p, (float, int, type(None)))
1632-
and isinstance(p, type(percentiles_for_norm[0]))
1633-
and (p is None or 0 <= p <= 100)
1634-
for p in percentiles_for_norm
1635-
):
1636-
raise TypeError(
1637-
"Each item in 'percentiles_for_norm' must be of the same dtype and must be a float or int within [0, 100], "
1638-
"or None"
1639-
)
1640-
elif (
1641-
percentiles_for_norm[0] is not None
1642-
and percentiles_for_norm[1] is not None
1643-
and percentiles_for_norm[0] > percentiles_for_norm[1]
1644-
):
1645-
raise ValueError("The first number in 'percentiles_for_norm' must not be smaller than the second.")
1646-
if "percentiles_for_norm" in param_dict:
1647-
param_dict["percentiles_for_norm"] = percentiles_for_norm
1648-
16491626
if size := param_dict.get("size"):
16501627
if not isinstance(size, (float, int)):
16511628
raise TypeError("Parameter 'size' must be numeric.")
@@ -1886,7 +1863,6 @@ def _validate_image_render_params(
18861863
cmap: list[Colormap | str] | Colormap | str | None,
18871864
norm: Normalize | None,
18881865
scale: str | None,
1889-
percentiles_for_norm: tuple[float | None, float | None] | None,
18901866
) -> dict[str, dict[str, Any]]:
18911867
param_dict: dict[str, Any] = {
18921868
"sdata": sdata,
@@ -1898,7 +1874,6 @@ def _validate_image_render_params(
18981874
"cmap": cmap,
18991875
"norm": norm,
19001876
"scale": scale,
1901-
"percentiles_for_norm": percentiles_for_norm,
19021877
}
19031878
param_dict = _type_check_params(param_dict, "images")
19041879

@@ -1945,8 +1920,6 @@ def _validate_image_render_params(
19451920
else:
19461921
element_params[el]["scale"] = scale
19471922

1948-
element_params[el]["percentiles_for_norm"] = param_dict["percentiles_for_norm"]
1949-
19501923
return element_params
19511924

19521925

Loading
Loading
Loading
Loading
Loading
Loading
Binary file not shown.
Loading
Loading
Loading
1.27 KB
Loading
905 Bytes
Loading
Loading

tests/pl/test_render.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def test_render_images_can_plot_one_cyx_image(request):
2323
def test_render_images_can_plot_multiple_cyx_images(share_coordinate_system: str, request):
2424
fun = request.getfixturevalue("get_sdata_with_multiple_images")
2525
sdata = fun(share_coordinate_system)
26-
sdata.pl.render_images().pl.show()
26+
sdata.pl.render_images().pl.show(
27+
colorbar=False, # otherwise we'll get one cbar per image in the same cs
28+
)
2729
axs = plt.gcf().get_axes()
2830

2931
if share_coordinate_system == "all":

tests/pl/test_render_images.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import matplotlib
33
import numpy as np
44
import scanpy as sc
5-
from matplotlib import pyplot as plt
65
from matplotlib.colors import Normalize
76
from spatial_image import to_spatial_image
87
from spatialdata import SpatialData
@@ -49,9 +48,6 @@ def test_plot_can_render_a_single_channel_from_image(self, sdata_blobs: SpatialD
4948
def test_plot_can_render_a_single_channel_from_multiscale_image(self, sdata_blobs: SpatialData):
5049
sdata_blobs.pl.render_images(element="blobs_multiscale_image", channel=0).pl.show()
5150

52-
def test_plot_can_render_a_single_channel_from_image_no_el(self, sdata_blobs: SpatialData):
53-
sdata_blobs.pl.render_images(channel=0).pl.show()
54-
5551
def test_plot_can_render_a_single_channel_str_from_image(self, sdata_blobs_str: SpatialData):
5652
sdata_blobs_str.pl.render_images(element="blobs_image", channel="c1").pl.show()
5753

@@ -70,16 +66,13 @@ def test_plot_can_render_two_channels_str_from_image(self, sdata_blobs_str: Spat
7066
def test_plot_can_render_two_channels_str_from_multiscale_image(self, sdata_blobs_str: SpatialData):
7167
sdata_blobs_str.pl.render_images(element="blobs_multiscale_image", channel=["c1", "c2"]).pl.show()
7268

73-
def test_plot_can_pass_vmin_vmax(self, sdata_blobs: SpatialData):
74-
fig, axs = plt.subplots(ncols=2, figsize=(6, 3))
75-
sdata_blobs.pl.render_images(element="blobs_image", channel=1).pl.show(ax=axs[0])
76-
sdata_blobs.pl.render_images(element="blobs_image", channel=1, vmin=0, vmax=0.4).pl.show(ax=axs[1])
77-
78-
def test_plot_can_pass_normalize(self, sdata_blobs: SpatialData):
79-
fig, axs = plt.subplots(ncols=2, figsize=(6, 3))
69+
def test_plot_can_pass_normalize_clip_True(self, sdata_blobs: SpatialData):
8070
norm = Normalize(vmin=0, vmax=0.4, clip=True)
81-
sdata_blobs.pl.render_images(element="blobs_image", channel=1).pl.show(ax=axs[0])
82-
sdata_blobs.pl.render_images(element="blobs_image", channel=1, norm=norm).pl.show(ax=axs[1])
71+
sdata_blobs.pl.render_images(element="blobs_image", channel=0, norm=norm).pl.show()
72+
73+
def test_plot_can_pass_normalize_clip_False(self, sdata_blobs: SpatialData):
74+
norm = Normalize(vmin=0, vmax=0.4, clip=False)
75+
sdata_blobs.pl.render_images(element="blobs_image", channel=0, norm=norm).pl.show()
8376

8477
def test_plot_can_pass_color_to_single_channel(self, sdata_blobs: SpatialData):
8578
sdata_blobs.pl.render_images(element="blobs_image", channel=1, palette="red").pl.show()
@@ -97,9 +90,6 @@ def test_plot_can_pass_cmap_to_each_channel(self, sdata_blobs: SpatialData):
9790
element="blobs_image", channel=[0, 1, 2], cmap=["Reds", "Greens", "Blues"]
9891
).pl.show()
9992

100-
def test_plot_can_normalize_image(self, sdata_blobs: SpatialData):
101-
sdata_blobs.pl.render_images(element="blobs_image", percentiles_for_norm=(5, 90)).pl.show()
102-
10393
def test_plot_can_render_multiscale_image(self, sdata_blobs: SpatialData):
10494
sdata_blobs.pl.render_images("blobs_multiscale_image").pl.show()
10595

tests/pl/test_render_shapes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import scanpy as sc
88
from anndata import AnnData
9+
from matplotlib.colors import Normalize
910
from shapely.geometry import MultiPolygon, Point, Polygon
1011
from spatialdata import SpatialData, deepcopy
1112
from spatialdata.models import ShapesModel, TableModel
@@ -146,7 +147,8 @@ def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData):
146147
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
147148
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
148149
sdata_blobs.shapes["blobs_polygons"]["cluster"] = [1, 2, 3, 5, 20]
149-
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups=["c1"], norm=True).pl.show()
150+
norm = Normalize(vmin=0, vmax=5, clip=True)
151+
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups=["c1"], norm=norm).pl.show()
150152

151153
def test_plot_can_plot_shapes_after_spatial_query(self, sdata_blobs: SpatialData):
152154
# subset to only shapes, should be unnecessary after rasterizeation of multiscale images is included

0 commit comments

Comments
 (0)