Skip to content

Commit 6cce44e

Browse files
authored
Fix clims when plotting shapes element annotations with matplotlib rendering (#368)
1 parent 6ffe22b commit 6cce44e

File tree

6 files changed

+55
-28
lines changed

6 files changed

+55
-28
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,13 @@ def render_shapes(
241241
sd.SpatialData
242242
The modified SpatialData object with the rendered shapes.
243243
"""
244+
# TODO add Normalize object in tutorial notebook and point to that notebook here
245+
if "vmin" in kwargs or "vmax" in kwargs:
246+
warnings.warn(
247+
"`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.",
248+
DeprecationWarning,
249+
stacklevel=2,
250+
)
244251
params_dict = _validate_shape_render_params(
245252
self._sdata,
246253
element=element,
@@ -269,7 +276,6 @@ def render_shapes(
269276
cmap=cmap,
270277
norm=norm,
271278
na_color=params_dict[element]["na_color"], # type: ignore[arg-type]
272-
**kwargs,
273279
)
274280
sdata.plotting_tree[f"{n_steps+1}_render_shapes"] = ShapesRenderParams(
275281
element=element,
@@ -363,6 +369,13 @@ def render_points(
363369
sd.SpatialData
364370
The modified SpatialData object with the rendered shapes.
365371
"""
372+
# TODO add Normalize object in tutorial notebook and point to that notebook here
373+
if "vmin" in kwargs or "vmax" in kwargs:
374+
warnings.warn(
375+
"`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.",
376+
DeprecationWarning,
377+
stacklevel=2,
378+
)
366379
params_dict = _validate_points_render_params(
367380
self._sdata,
368381
element=element,
@@ -392,7 +405,6 @@ def render_points(
392405
cmap=cmap,
393406
norm=norm,
394407
na_color=param_values["na_color"], # type: ignore[arg-type]
395-
**kwargs,
396408
)
397409
sdata.plotting_tree[f"{n_steps+1}_render_points"] = PointsRenderParams(
398410
element=element,
@@ -473,6 +485,13 @@ def render_images(
473485
sd.SpatialData
474486
The SpatialData object with the rendered images.
475487
"""
488+
# TODO add Normalize object in tutorial notebook and point to that notebook here
489+
if "vmin" in kwargs or "vmax" in kwargs:
490+
warnings.warn(
491+
"`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.",
492+
DeprecationWarning,
493+
stacklevel=2,
494+
)
476495
params_dict = _validate_image_render_params(
477496
self._sdata,
478497
element=element,
@@ -498,7 +517,6 @@ def render_images(
498517
cmap=c,
499518
norm=norm,
500519
na_color=param_values["na_color"],
501-
**kwargs,
502520
)
503521
for c in cmap
504522
]
@@ -598,6 +616,13 @@ def render_labels(
598616
-------
599617
None
600618
"""
619+
# TODO add Normalize object in tutorial notebook and point to that notebook here
620+
if "vmin" in kwargs or "vmax" in kwargs:
621+
warnings.warn(
622+
"`vmin` and `vmax` are deprecated. Pass matplotlib `Normalize` object to norm instead.",
623+
DeprecationWarning,
624+
stacklevel=2,
625+
)
601626
params_dict = _validate_label_render_params(
602627
self._sdata,
603628
element=element,
@@ -623,7 +648,6 @@ def render_labels(
623648
cmap=cmap,
624649
norm=norm,
625650
na_color=param_values["na_color"], # type: ignore[arg-type]
626-
**kwargs,
627651
)
628652
sdata.plotting_tree[f"{n_steps+1}_render_labels"] = LabelsRenderParams(
629653
element=element,

src/spatialdata_plot/pl/utils.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
LinearSegmentedColormap,
3333
ListedColormap,
3434
Normalize,
35-
TwoSlopeNorm,
3635
to_rgba,
3736
)
3837
from matplotlib.figure import Figure
@@ -339,7 +338,7 @@ def _get_collection_shape(
339338
c = cmap(c)
340339
else:
341340
try:
342-
norm = colors.Normalize(vmin=min(c), vmax=max(c))
341+
norm = colors.Normalize(vmin=min(c), vmax=max(c)) if norm is None else norm
343342
except ValueError as e:
344343
raise ValueError(
345344
"Could not convert values in the `color` column to float, if `color` column represents"
@@ -353,7 +352,7 @@ def _get_collection_shape(
353352
c = cmap(c)
354353
else:
355354
try:
356-
norm = colors.Normalize(vmin=min(c), vmax=max(c))
355+
norm = colors.Normalize(vmin=min(c), vmax=max(c)) if norm is None else norm
357356
except ValueError as e:
358357
raise ValueError(
359358
"Could not convert values in the `color` column to float, if `color` column represents"
@@ -491,11 +490,8 @@ def _prepare_cmap_norm(
491490
cmap: Colormap | str | None = None,
492491
norm: Normalize | None = None,
493492
na_color: ColorLike | None = None,
494-
vmin: float | None = None,
495-
vmax: float | None = None,
496-
vcenter: float | None = None,
497-
**kwargs: Any,
498493
) -> CmapParams:
494+
# TODO: check refactoring norm out here as it gets overwritten later
499495
cmap_is_default = cmap is None
500496
if cmap is None:
501497
cmap = rcParams["image.cmap"]
@@ -505,13 +501,7 @@ def _prepare_cmap_norm(
505501
cmap = copy(cmap)
506502

507503
if norm is None:
508-
norm = Normalize(vmin=vmin, vmax=vmax, clip=True)
509-
elif isinstance(norm, Normalize) or not norm:
510-
pass # TODO
511-
elif vcenter is None:
512-
norm = Normalize(vmin=vmin, vmax=vmax, clip=True)
513-
else:
514-
norm = TwoSlopeNorm(vmin=vmin, vmax=vmax, vcenter=vcenter)
504+
norm = Normalize(vmin=None, vmax=None, clip=False)
515505

516506
na_color, na_color_modified_by_user = _sanitise_na_color(na_color)
517507
cmap.set_bad(na_color)
Loading
24.9 KB
Loading
-346 Bytes
Loading

tests/pl/test_render_shapes.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _make_multi():
9797

9898
def test_plot_can_color_from_geodataframe(self, sdata_blobs: SpatialData):
9999
blob = deepcopy(sdata_blobs)
100-
blob["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
100+
blob["table"].obs["region"] = "blobs_polygons"
101101
blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
102102
blob.shapes["blobs_polygons"]["value"] = [1, 10, 1, 20, 1]
103103
blob.pl.render_shapes(
@@ -111,7 +111,7 @@ def test_plot_can_scale_shapes(self, sdata_blobs: SpatialData):
111111
def test_plot_can_filter_with_groups(self, sdata_blobs: SpatialData):
112112
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
113113

114-
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
114+
sdata_blobs["table"].obs["region"] = "blobs_polygons"
115115
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
116116
sdata_blobs.shapes["blobs_polygons"]["cluster"] = "c1"
117117
sdata_blobs.shapes["blobs_polygons"].iloc[3:5, 1] = "c2"
@@ -125,7 +125,7 @@ def test_plot_can_filter_with_groups(self, sdata_blobs: SpatialData):
125125
)
126126

127127
def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData):
128-
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
128+
sdata_blobs["table"].obs["region"] = "blobs_polygons"
129129
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
130130
sdata_blobs.shapes["blobs_polygons"]["cluster"] = "c1"
131131
sdata_blobs.shapes["blobs_polygons"].iloc[3:5, 1] = "c2"
@@ -138,13 +138,13 @@ def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData):
138138
).pl.show()
139139

140140
def test_plot_colorbar_respects_input_limits(self, sdata_blobs: SpatialData):
141-
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
141+
sdata_blobs["table"].obs["region"] = "blobs_polygons"
142142
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
143143
sdata_blobs.shapes["blobs_polygons"]["cluster"] = [1, 2, 3, 5, 20]
144-
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups=["c1"]).pl.show()
144+
sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster").pl.show()
145145

146146
def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData):
147-
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
147+
sdata_blobs["table"].obs["region"] = "blobs_polygons"
148148
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
149149
sdata_blobs.shapes["blobs_polygons"]["cluster"] = [1, 2, 3, 5, 20]
150150
norm = Normalize(vmin=0, vmax=5, clip=True)
@@ -186,7 +186,7 @@ def test_plot_can_plot_with_annotation_despite_random_shuffling(self, sdata_blob
186186

187187
def test_plot_can_plot_queried_with_annotation_despite_random_shuffling(self, sdata_blobs: SpatialData):
188188
sdata_blobs["table"].obs["region"] = "blobs_circles"
189-
new_table = sdata_blobs["table"][:5]
189+
new_table = sdata_blobs["table"][:5].copy()
190190
new_table.uns["spatialdata_attrs"]["region"] = "blobs_circles"
191191
new_table.obs["instance_id"] = np.array(range(5))
192192

@@ -214,7 +214,7 @@ def test_plot_can_plot_queried_with_annotation_despite_random_shuffling(self, sd
214214

215215
def test_plot_can_color_two_shapes_elements_by_annotation(self, sdata_blobs: SpatialData):
216216
sdata_blobs["table"].obs["region"] = "blobs_circles"
217-
new_table = sdata_blobs["table"][:10]
217+
new_table = sdata_blobs["table"][:10].copy()
218218
new_table.uns["spatialdata_attrs"]["region"] = ["blobs_circles", "blobs_polygons"]
219219
new_table.obs["instance_id"] = np.concatenate((np.array(range(5)), np.array(range(5))))
220220

@@ -230,7 +230,7 @@ def test_plot_can_color_two_shapes_elements_by_annotation(self, sdata_blobs: Spa
230230

231231
def test_plot_can_color_two_queried_shapes_elements_by_annotation(self, sdata_blobs: SpatialData):
232232
sdata_blobs["table"].obs["region"] = "blobs_circles"
233-
new_table = sdata_blobs["table"][:10]
233+
new_table = sdata_blobs["table"][:10].copy()
234234
new_table.uns["spatialdata_attrs"]["region"] = ["blobs_circles", "blobs_polygons"]
235235
new_table.obs["instance_id"] = np.concatenate((np.array(range(5)), np.array(range(5))))
236236

@@ -312,7 +312,20 @@ def test_plot_datashader_can_color_by_category(self, sdata_blobs: SpatialData):
312312
sdata_blobs.pl.render_shapes(element="blobs_polygons", color="category", method="datashader").pl.show()
313313

314314
def test_plot_datashader_can_color_by_value(self, sdata_blobs: SpatialData):
315-
sdata_blobs["table"].obs["region"] = ["blobs_polygons"] * sdata_blobs["table"].n_obs
315+
sdata_blobs["table"].obs["region"] = "blobs_polygons"
316316
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
317317
sdata_blobs.shapes["blobs_polygons"]["value"] = [1, 10, 1, 20, 1]
318318
sdata_blobs.pl.render_shapes(element="blobs_polygons", color="value", method="datashader").pl.show()
319+
320+
def test_plot_can_set_clims_clip(self, sdata_blobs: SpatialData):
321+
table_shapes = sdata_blobs["table"][:5].copy()
322+
table_shapes.obs.instance_id = list(range(5))
323+
table_shapes.obs["region"] = "blobs_circles"
324+
table_shapes.obs["dummy_gene_expression"] = [i * 10 for i in range(5)]
325+
table_shapes.uns["spatialdata_attrs"]["region"] = "blobs_circles"
326+
sdata_blobs["new_table"] = table_shapes
327+
328+
norm = Normalize(vmin=20, vmax=40, clip=True)
329+
sdata_blobs.pl.render_shapes(
330+
"blobs_circles", color="dummy_gene_expression", norm=norm, table_name="new_table"
331+
).pl.show()

0 commit comments

Comments
 (0)