2626from spatialdata_plot ._logging import logger
2727from spatialdata_plot .pl .render_params import (
2828 Color ,
29+ ColorbarSpec ,
2930 FigParams ,
3031 ImageRenderParams ,
3132 LabelsRenderParams ,
6162_Normalize = Normalize | abc .Sequence [Normalize ]
6263
6364
65+ def _split_colorbar_params (params : dict [str , object ] | None ) -> tuple [dict [str , object ], dict [str , object ], str | None ]:
66+ """Split colorbar params into layout hints, Matplotlib kwargs, and label override."""
67+ layout : dict [str , object ] = {}
68+ cbar_kwargs : dict [str , object ] = {}
69+ label_override : str | None = None
70+ for key , value in (params or {}).items ():
71+ key_lower = key .lower ()
72+ if key_lower in {"loc" , "location" }:
73+ layout ["location" ] = value
74+ elif key_lower == "width" or key_lower == "fraction" :
75+ layout ["fraction" ] = value
76+ elif key_lower == "pad" :
77+ layout ["pad" ] = value
78+ elif key_lower == "label" :
79+ label_override = None if value is None else str (value )
80+ else :
81+ cbar_kwargs [key ] = value
82+ return layout , cbar_kwargs , label_override
83+
84+
85+ def _resolve_colorbar_label (
86+ colorbar_params : dict [str , object ] | None , fallback : str | None , * , is_default_channel_name : bool = False
87+ ) -> str | None :
88+ """Pick a colorbar label from params or fall back to provided value."""
89+ _ , _ , label = _split_colorbar_params (colorbar_params )
90+ if label is not None :
91+ return label
92+ if is_default_channel_name :
93+ return None
94+ return fallback
95+
96+
97+ def _should_request_colorbar (
98+ colorbar : bool | str | None ,
99+ * ,
100+ has_mappable : bool ,
101+ is_continuous : bool ,
102+ auto_condition : bool = True ,
103+ ) -> bool :
104+ """Resolve colorbar setting to a final boolean request."""
105+ if not has_mappable or not is_continuous :
106+ return False
107+ if colorbar is True :
108+ return True
109+ if colorbar in {False , None }:
110+ return False
111+ return bool (auto_condition )
112+
113+
64114def _render_shapes (
65115 sdata : sd .SpatialData ,
66116 render_params : ShapesRenderParams ,
@@ -69,6 +119,7 @@ def _render_shapes(
69119 fig_params : FigParams ,
70120 scalebar_params : ScalebarParams ,
71121 legend_params : LegendParams ,
122+ colorbar_requests : list [ColorbarSpec ] | None = None ,
72123) -> None :
73124 element = render_params .element
74125 col_for_color = render_params .col_for_color
@@ -80,7 +131,8 @@ def _render_shapes(
80131 filter_tables = bool (render_params .table_name ),
81132 )
82133
83- if (table_name := render_params .table_name ) is None :
134+ table_name = render_params .table_name
135+ if table_name is None :
84136 table = None
85137 shapes = sdata_filt [element ]
86138 else :
@@ -159,16 +211,13 @@ def _render_shapes(
159211 else :
160212 palette = ListedColormap (dict .fromkeys (color_vector [~ pd .Categorical (color_source_vector ).isnull ()]))
161213
162- if (
214+ has_valid_color = (
163215 len (set (color_vector )) != 1
164216 or list (set (color_vector ))[0 ] != render_params .cmap_params .na_color .get_hex_with_alpha ()
165- ):
217+ )
218+ if has_valid_color and color_source_vector is not None and col_for_color is not None :
166219 # necessary in case different shapes elements are annotated with one table
167- if color_source_vector is not None and col_for_color is not None :
168- color_source_vector = color_source_vector .remove_unused_categories ()
169-
170- # False if user specified color-like with 'color' parameter
171- colorbar = False if col_for_color is None else legend_params .colorbar
220+ color_source_vector = color_source_vector .remove_unused_categories ()
172221
173222 # Apply the transformation to the PatchCollection's paths
174223 trans , trans_data = _prepare_transformation (sdata_filt .shapes [element ], coordinate_system )
@@ -515,8 +564,11 @@ def _render_shapes(
515564 if color_source_vector is not None and render_params .col_for_color is not None :
516565 color_source_vector = color_source_vector .remove_unused_categories ()
517566
518- # False if user specified color-like with 'color' parameter
519- colorbar = False if render_params .col_for_color is None else legend_params .colorbar
567+ wants_colorbar = _should_request_colorbar (
568+ render_params .colorbar ,
569+ has_mappable = cax is not None ,
570+ is_continuous = render_params .col_for_color is not None and color_source_vector is None ,
571+ )
520572
521573 _ = _decorate_axs (
522574 ax = ax ,
@@ -534,7 +586,13 @@ def _render_shapes(
534586 legend_loc = legend_params .legend_loc ,
535587 legend_fontoutline = legend_params .legend_fontoutline ,
536588 na_in_legend = legend_params .na_in_legend ,
537- colorbar = colorbar ,
589+ colorbar = wants_colorbar and legend_params .colorbar ,
590+ colorbar_params = render_params .colorbar_params ,
591+ colorbar_requests = colorbar_requests ,
592+ colorbar_label = _resolve_colorbar_label (
593+ render_params .colorbar_params ,
594+ col_for_color if isinstance (col_for_color , str ) else None ,
595+ ),
538596 scalebar_dx = scalebar_params .scalebar_dx ,
539597 scalebar_units = scalebar_params .scalebar_units ,
540598 )
@@ -548,6 +606,7 @@ def _render_points(
548606 fig_params : FigParams ,
549607 scalebar_params : ScalebarParams ,
550608 legend_params : LegendParams ,
609+ colorbar_requests : list [ColorbarSpec ] | None = None ,
551610) -> None :
552611 element = render_params .element
553612 col_for_color = render_params .col_for_color
@@ -894,6 +953,12 @@ def _render_points(
894953 else :
895954 palette = ListedColormap (dict .fromkeys (color_vector [~ pd .Categorical (color_source_vector ).isnull ()]))
896955
956+ wants_colorbar = _should_request_colorbar (
957+ render_params .colorbar ,
958+ has_mappable = cax is not None ,
959+ is_continuous = col_for_color is not None and color_source_vector is None ,
960+ )
961+
897962 _ = _decorate_axs (
898963 ax = ax ,
899964 cax = cax ,
@@ -910,7 +975,13 @@ def _render_points(
910975 legend_loc = legend_params .legend_loc ,
911976 legend_fontoutline = legend_params .legend_fontoutline ,
912977 na_in_legend = legend_params .na_in_legend ,
913- colorbar = legend_params .colorbar ,
978+ colorbar = wants_colorbar and legend_params .colorbar ,
979+ colorbar_params = render_params .colorbar_params ,
980+ colorbar_requests = colorbar_requests ,
981+ colorbar_label = _resolve_colorbar_label (
982+ render_params .colorbar_params ,
983+ col_for_color if isinstance (col_for_color , str ) else None ,
984+ ),
914985 scalebar_dx = scalebar_params .scalebar_dx ,
915986 scalebar_units = scalebar_params .scalebar_units ,
916987 )
@@ -925,6 +996,7 @@ def _render_images(
925996 scalebar_params : ScalebarParams ,
926997 legend_params : LegendParams ,
927998 rasterize : bool ,
999+ colorbar_requests : list [ColorbarSpec ] | None = None ,
9281000) -> None :
9291001 sdata_filt = sdata .filter_by_coordinate_system (
9301002 coordinate_system = coordinate_system ,
@@ -1003,9 +1075,26 @@ def _render_images(
10031075 norm = render_params .cmap_params .norm ,
10041076 )
10051077
1006- if legend_params .colorbar :
1078+ wants_colorbar = _should_request_colorbar (
1079+ render_params .colorbar ,
1080+ has_mappable = n_channels == 1 ,
1081+ is_continuous = True ,
1082+ auto_condition = n_channels == 1 ,
1083+ )
1084+ if wants_colorbar and legend_params .colorbar and colorbar_requests is not None :
10071085 sm = plt .cm .ScalarMappable (cmap = cmap , norm = render_params .cmap_params .norm )
1008- fig_params .fig .colorbar (sm , ax = ax )
1086+ colorbar_requests .append (
1087+ ColorbarSpec (
1088+ ax = ax ,
1089+ mappable = sm ,
1090+ params = render_params .colorbar_params ,
1091+ label = _resolve_colorbar_label (
1092+ render_params .colorbar_params ,
1093+ str (channels [0 ]),
1094+ is_default_channel_name = isinstance (channels [0 ], (int , np .integer )),
1095+ ),
1096+ )
1097+ )
10091098
10101099 # 2) Image has any number of channels but 1
10111100 else :
@@ -1165,6 +1254,7 @@ def _render_labels(
11651254 scalebar_params : ScalebarParams ,
11661255 legend_params : LegendParams ,
11671256 rasterize : bool ,
1257+ colorbar_requests : list [ColorbarSpec ] | None = None ,
11681258) -> None :
11691259 element = render_params .element
11701260 table_name = render_params .table_name
@@ -1310,6 +1400,12 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
13101400 else :
13111401 raise ValueError ("Parameters 'fill_alpha' and 'outline_alpha' cannot both be 0." )
13121402
1403+ colorbar_requested = _should_request_colorbar (
1404+ render_params .colorbar ,
1405+ has_mappable = cax is not None ,
1406+ is_continuous = color is not None and color_source_vector is None and not categorical ,
1407+ )
1408+
13131409 _ = _decorate_axs (
13141410 ax = ax ,
13151411 cax = cax ,
@@ -1326,7 +1422,13 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
13261422 legend_loc = legend_params .legend_loc ,
13271423 legend_fontoutline = legend_params .legend_fontoutline ,
13281424 na_in_legend = (legend_params .na_in_legend if groups is None else len (groups ) == len (set (color_vector ))),
1329- colorbar = legend_params .colorbar ,
1425+ colorbar = colorbar_requested and legend_params .colorbar ,
1426+ colorbar_params = render_params .colorbar_params ,
1427+ colorbar_requests = colorbar_requests ,
1428+ colorbar_label = _resolve_colorbar_label (
1429+ render_params .colorbar_params ,
1430+ color if isinstance (color , str ) else None ,
1431+ ),
13301432 scalebar_dx = scalebar_params .scalebar_dx ,
13311433 scalebar_units = scalebar_params .scalebar_units ,
13321434 # scalebar_kwargs=scalebar_params.scalebar_kwargs,
0 commit comments