Skip to content

Commit 2513645

Browse files
authored
Implement colorbar logic (#518)
1 parent 8d7aa1c commit 2513645

File tree

223 files changed

+625
-92
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

223 files changed

+625
-92
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 174 additions & 4 deletions
Large diffs are not rendered by default.

src/spatialdata_plot/pl/render.py

Lines changed: 117 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from spatialdata_plot._logging import logger
2727
from spatialdata_plot.pl.render_params import (
2828
Color,
29+
ColorbarSpec,
2930
FigParams,
3031
ImageRenderParams,
3132
LabelsRenderParams,
@@ -61,6 +62,55 @@
6162
_Normalize = Normalize | abc.Sequence[Normalize]
6263

6364

65+
def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str, object], dict[str, object], str | None]:
66+
"""Split colorbar params into layout hints, Matplotlib kwargs, and label override."""
67+
layout: dict[str, object] = {}
68+
cbar_kwargs: dict[str, object] = {}
69+
label_override: str | None = None
70+
for key, value in (params or {}).items():
71+
key_lower = key.lower()
72+
if key_lower in {"loc", "location"}:
73+
layout["location"] = value
74+
elif key_lower == "width" or key_lower == "fraction":
75+
layout["fraction"] = value
76+
elif key_lower == "pad":
77+
layout["pad"] = value
78+
elif key_lower == "label":
79+
label_override = None if value is None else str(value)
80+
else:
81+
cbar_kwargs[key] = value
82+
return layout, cbar_kwargs, label_override
83+
84+
85+
def _resolve_colorbar_label(
86+
colorbar_params: dict[str, object] | None, fallback: str | None, *, is_default_channel_name: bool = False
87+
) -> str | None:
88+
"""Pick a colorbar label from params or fall back to provided value."""
89+
_, _, label = _split_colorbar_params(colorbar_params)
90+
if label is not None:
91+
return label
92+
if is_default_channel_name:
93+
return None
94+
return fallback
95+
96+
97+
def _should_request_colorbar(
98+
colorbar: bool | str | None,
99+
*,
100+
has_mappable: bool,
101+
is_continuous: bool,
102+
auto_condition: bool = True,
103+
) -> bool:
104+
"""Resolve colorbar setting to a final boolean request."""
105+
if not has_mappable or not is_continuous:
106+
return False
107+
if colorbar is True:
108+
return True
109+
if colorbar in {False, None}:
110+
return False
111+
return bool(auto_condition)
112+
113+
64114
def _render_shapes(
65115
sdata: sd.SpatialData,
66116
render_params: ShapesRenderParams,
@@ -69,6 +119,7 @@ def _render_shapes(
69119
fig_params: FigParams,
70120
scalebar_params: ScalebarParams,
71121
legend_params: LegendParams,
122+
colorbar_requests: list[ColorbarSpec] | None = None,
72123
) -> None:
73124
element = render_params.element
74125
col_for_color = render_params.col_for_color
@@ -80,7 +131,8 @@ def _render_shapes(
80131
filter_tables=bool(render_params.table_name),
81132
)
82133

83-
if (table_name := render_params.table_name) is None:
134+
table_name = render_params.table_name
135+
if table_name is None:
84136
table = None
85137
shapes = sdata_filt[element]
86138
else:
@@ -159,16 +211,13 @@ def _render_shapes(
159211
else:
160212
palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()]))
161213

162-
if (
214+
has_valid_color = (
163215
len(set(color_vector)) != 1
164216
or list(set(color_vector))[0] != render_params.cmap_params.na_color.get_hex_with_alpha()
165-
):
217+
)
218+
if has_valid_color and color_source_vector is not None and col_for_color is not None:
166219
# necessary in case different shapes elements are annotated with one table
167-
if color_source_vector is not None and col_for_color is not None:
168-
color_source_vector = color_source_vector.remove_unused_categories()
169-
170-
# False if user specified color-like with 'color' parameter
171-
colorbar = False if col_for_color is None else legend_params.colorbar
220+
color_source_vector = color_source_vector.remove_unused_categories()
172221

173222
# Apply the transformation to the PatchCollection's paths
174223
trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)
@@ -515,8 +564,11 @@ def _render_shapes(
515564
if color_source_vector is not None and render_params.col_for_color is not None:
516565
color_source_vector = color_source_vector.remove_unused_categories()
517566

518-
# False if user specified color-like with 'color' parameter
519-
colorbar = False if render_params.col_for_color is None else legend_params.colorbar
567+
wants_colorbar = _should_request_colorbar(
568+
render_params.colorbar,
569+
has_mappable=cax is not None,
570+
is_continuous=render_params.col_for_color is not None and color_source_vector is None,
571+
)
520572

521573
_ = _decorate_axs(
522574
ax=ax,
@@ -534,7 +586,13 @@ def _render_shapes(
534586
legend_loc=legend_params.legend_loc,
535587
legend_fontoutline=legend_params.legend_fontoutline,
536588
na_in_legend=legend_params.na_in_legend,
537-
colorbar=colorbar,
589+
colorbar=wants_colorbar and legend_params.colorbar,
590+
colorbar_params=render_params.colorbar_params,
591+
colorbar_requests=colorbar_requests,
592+
colorbar_label=_resolve_colorbar_label(
593+
render_params.colorbar_params,
594+
col_for_color if isinstance(col_for_color, str) else None,
595+
),
538596
scalebar_dx=scalebar_params.scalebar_dx,
539597
scalebar_units=scalebar_params.scalebar_units,
540598
)
@@ -548,6 +606,7 @@ def _render_points(
548606
fig_params: FigParams,
549607
scalebar_params: ScalebarParams,
550608
legend_params: LegendParams,
609+
colorbar_requests: list[ColorbarSpec] | None = None,
551610
) -> None:
552611
element = render_params.element
553612
col_for_color = render_params.col_for_color
@@ -894,6 +953,12 @@ def _render_points(
894953
else:
895954
palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()]))
896955

956+
wants_colorbar = _should_request_colorbar(
957+
render_params.colorbar,
958+
has_mappable=cax is not None,
959+
is_continuous=col_for_color is not None and color_source_vector is None,
960+
)
961+
897962
_ = _decorate_axs(
898963
ax=ax,
899964
cax=cax,
@@ -910,7 +975,13 @@ def _render_points(
910975
legend_loc=legend_params.legend_loc,
911976
legend_fontoutline=legend_params.legend_fontoutline,
912977
na_in_legend=legend_params.na_in_legend,
913-
colorbar=legend_params.colorbar,
978+
colorbar=wants_colorbar and legend_params.colorbar,
979+
colorbar_params=render_params.colorbar_params,
980+
colorbar_requests=colorbar_requests,
981+
colorbar_label=_resolve_colorbar_label(
982+
render_params.colorbar_params,
983+
col_for_color if isinstance(col_for_color, str) else None,
984+
),
914985
scalebar_dx=scalebar_params.scalebar_dx,
915986
scalebar_units=scalebar_params.scalebar_units,
916987
)
@@ -925,6 +996,7 @@ def _render_images(
925996
scalebar_params: ScalebarParams,
926997
legend_params: LegendParams,
927998
rasterize: bool,
999+
colorbar_requests: list[ColorbarSpec] | None = None,
9281000
) -> None:
9291001
sdata_filt = sdata.filter_by_coordinate_system(
9301002
coordinate_system=coordinate_system,
@@ -1003,9 +1075,26 @@ def _render_images(
10031075
norm=render_params.cmap_params.norm,
10041076
)
10051077

1006-
if legend_params.colorbar:
1078+
wants_colorbar = _should_request_colorbar(
1079+
render_params.colorbar,
1080+
has_mappable=n_channels == 1,
1081+
is_continuous=True,
1082+
auto_condition=n_channels == 1,
1083+
)
1084+
if wants_colorbar and legend_params.colorbar and colorbar_requests is not None:
10071085
sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm)
1008-
fig_params.fig.colorbar(sm, ax=ax)
1086+
colorbar_requests.append(
1087+
ColorbarSpec(
1088+
ax=ax,
1089+
mappable=sm,
1090+
params=render_params.colorbar_params,
1091+
label=_resolve_colorbar_label(
1092+
render_params.colorbar_params,
1093+
str(channels[0]),
1094+
is_default_channel_name=isinstance(channels[0], (int, np.integer)),
1095+
),
1096+
)
1097+
)
10091098

10101099
# 2) Image has any number of channels but 1
10111100
else:
@@ -1165,6 +1254,7 @@ def _render_labels(
11651254
scalebar_params: ScalebarParams,
11661255
legend_params: LegendParams,
11671256
rasterize: bool,
1257+
colorbar_requests: list[ColorbarSpec] | None = None,
11681258
) -> None:
11691259
element = render_params.element
11701260
table_name = render_params.table_name
@@ -1310,6 +1400,12 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
13101400
else:
13111401
raise ValueError("Parameters 'fill_alpha' and 'outline_alpha' cannot both be 0.")
13121402

1403+
colorbar_requested = _should_request_colorbar(
1404+
render_params.colorbar,
1405+
has_mappable=cax is not None,
1406+
is_continuous=color is not None and color_source_vector is None and not categorical,
1407+
)
1408+
13131409
_ = _decorate_axs(
13141410
ax=ax,
13151411
cax=cax,
@@ -1326,7 +1422,13 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
13261422
legend_loc=legend_params.legend_loc,
13271423
legend_fontoutline=legend_params.legend_fontoutline,
13281424
na_in_legend=(legend_params.na_in_legend if groups is None else len(groups) == len(set(color_vector))),
1329-
colorbar=legend_params.colorbar,
1425+
colorbar=colorbar_requested and legend_params.colorbar,
1426+
colorbar_params=render_params.colorbar_params,
1427+
colorbar_requests=colorbar_requests,
1428+
colorbar_label=_resolve_colorbar_label(
1429+
render_params.colorbar_params,
1430+
color if isinstance(color, str) else None,
1431+
),
13301432
scalebar_dx=scalebar_params.scalebar_dx,
13311433
scalebar_units=scalebar_params.scalebar_units,
13321434
# scalebar_kwargs=scalebar_params.scalebar_kwargs,

src/spatialdata_plot/pl/render_params.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
from matplotlib.axes import Axes
9+
from matplotlib.cm import ScalarMappable
910
from matplotlib.colors import Colormap, ListedColormap, Normalize, rgb2hex, to_hex
1011
from matplotlib.figure import Figure
1112

@@ -183,6 +184,22 @@ class LegendParams:
183184
colorbar: bool = True
184185

185186

187+
@dataclass
188+
class ColorbarSpec:
189+
"""Data required to create a colorbar."""
190+
191+
ax: Axes
192+
mappable: ScalarMappable
193+
params: dict[str, object] | None = None
194+
label: str | None = None
195+
alpha: float | None = None
196+
197+
198+
CBAR_DEFAULT_LOCATION = "right"
199+
CBAR_DEFAULT_FRACTION = 0.075
200+
CBAR_DEFAULT_PAD = 0.015
201+
202+
186203
@dataclass
187204
class ScalebarParams:
188205
"""Scalebar params."""
@@ -213,6 +230,8 @@ class ShapesRenderParams:
213230
table_layer: str | None = None
214231
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None
215232
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
233+
colorbar: bool | str | None = "auto"
234+
colorbar_params: dict[str, object] | None = None
216235

217236

218237
@dataclass
@@ -233,6 +252,8 @@ class PointsRenderParams:
233252
table_name: str | None = None
234253
table_layer: str | None = None
235254
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
255+
colorbar: bool | str | None = "auto"
256+
colorbar_params: dict[str, object] | None = None
236257

237258

238259
@dataclass
@@ -247,6 +268,8 @@ class ImageRenderParams:
247268
percentiles_for_norm: tuple[float | None, float | None] = (None, None)
248269
scale: str | None = None
249270
zorder: int = 0
271+
colorbar: bool | str | None = "auto"
272+
colorbar_params: dict[str, object] | None = None
250273

251274

252275
@dataclass
@@ -267,3 +290,5 @@ class LabelsRenderParams:
267290
table_name: str | None = None
268291
table_layer: str | None = None
269292
zorder: int = 0
293+
colorbar: bool | str | None = "auto"
294+
colorbar_params: dict[str, object] | None = None

0 commit comments

Comments
 (0)