1919from matplotlib .backend_bases import RendererBase
2020from matplotlib .colors import Colormap , Normalize
2121from matplotlib .figure import Figure
22- from matplotlib . transforms import Bbox
22+ from mpl_toolkits . axes_grid1 . inset_locator import inset_axes
2323from spatialdata import get_extent
2424from spatialdata ._utils import _deprecation_alias
2525from xarray import DataArray , DataTree
@@ -960,9 +960,8 @@ def _draw_colorbar(
960960 spec : ColorbarSpec ,
961961 fig : Figure ,
962962 renderer : RendererBase ,
963- axis_bbox : Bbox ,
964- base_offsets : dict [str , float ],
965- trackers : dict [str , float ],
963+ base_offsets_axes : dict [str , float ],
964+ trackers_axes : dict [str , float ],
966965 ) -> None :
967966 base_layout = {
968967 "location" : CBAR_DEFAULT_LOCATION ,
@@ -983,35 +982,23 @@ def _draw_colorbar(
983982 fraction = float (cast (float | int , layout .get ("fraction" , base_layout ["fraction" ])))
984983 pad = float (cast (float | int , layout .get ("pad" , base_layout ["pad" ])))
985984
986- span_width = axis_bbox .width + base_offsets ["left" ] + base_offsets ["right" ]
987- span_height = axis_bbox .height + base_offsets ["top" ] + base_offsets ["bottom" ]
988-
989985 if location in {"left" , "right" }:
990- pad_fig = pad * axis_bbox .width
991- width_fig = fraction * axis_bbox .width
992- height_fig = span_height
993- if location == "left" :
994- x0 = trackers ["left" ] - pad_fig - width_fig
995- y0 = axis_bbox .y0 - base_offsets ["bottom" ]
996- trackers ["left" ] = x0
997- else :
998- x0 = trackers ["right" ] + pad_fig
999- y0 = axis_bbox .y0 - base_offsets ["bottom" ]
1000- trackers ["right" ] = x0 + width_fig
1001- cax = fig .add_axes ([x0 , y0 , width_fig , height_fig ])
986+ pad_axes = pad + trackers_axes [location ]
987+ x0 = - pad_axes - fraction if location == "left" else 1 + pad_axes
988+ bbox = (float (x0 ), 0.0 , float (fraction ), 1.0 )
1002989 else :
1003- pad_fig = pad * axis_bbox . height
1004- height_fig = fraction * axis_bbox . height
1005- width_fig = span_width
1006- if location == "bottom" :
1007- x0 = axis_bbox . x0 - base_offsets [ "left" ]
1008- y0 = trackers [ "bottom" ] - pad_fig - height_fig
1009- trackers [ "bottom" ] = y0
1010- else :
1011- x0 = axis_bbox . x0 - base_offsets [ "left" ]
1012- y0 = trackers [ "top" ] + pad_fig
1013- trackers [ "top" ] = y0 + height_fig
1014- cax = fig . add_axes ([ x0 , y0 , width_fig , height_fig ] )
990+ pad_axes = pad + trackers_axes [ location ]
991+ y0 = - pad_axes - fraction if location == "bottom" else 1 + pad_axes
992+ bbox = ( 0.0 , float ( y0 ), 1.0 , float ( fraction ))
993+ cax = inset_axes (
994+ spec . ax ,
995+ width = "100%" ,
996+ height = "100%" ,
997+ loc = "center" ,
998+ bbox_to_anchor = bbox ,
999+ bbox_transform = spec . ax . transAxes ,
1000+ borderpad = 0.0 ,
1001+ )
10151002
10161003 cb = fig .colorbar (spec .mappable , cax = cax , ** cbar_kwargs )
10171004 if location == "left" :
@@ -1037,15 +1024,15 @@ def _draw_colorbar(
10371024 if spec .alpha is not None :
10381025 with contextlib .suppress (Exception ):
10391026 cb .solids .set_alpha (spec .alpha )
1040- bbox_axes = cb .ax .get_tightbbox (renderer ).transformed (fig . transFigure .inverted ())
1027+ bbox_axes = cb .ax .get_tightbbox (renderer ).transformed (spec . ax . transAxes .inverted ())
10411028 if location == "left" :
1042- trackers ["left" ] = bbox_axes .x0
1029+ trackers_axes ["left" ] = pad_axes + bbox_axes .width
10431030 elif location == "right" :
1044- trackers ["right" ] = bbox_axes .x1
1031+ trackers_axes ["right" ] = pad_axes + bbox_axes .width
10451032 elif location == "bottom" :
1046- trackers ["bottom" ] = bbox_axes .y0
1033+ trackers_axes ["bottom" ] = pad_axes + bbox_axes .height
10471034 elif location == "top" :
1048- trackers ["top" ] = bbox_axes .y1
1035+ trackers_axes ["top" ] = pad_axes + bbox_axes .height
10491036
10501037 cs_contents = _get_cs_contents (sdata )
10511038
@@ -1210,22 +1197,16 @@ def _draw_colorbar(
12101197 continue
12111198 seen_mappables .add (mappable_id )
12121199 unique_specs .append (spec )
1213- axis_bbox = axis .get_position ()
1214- tight_bbox = axis .get_tightbbox (renderer ).transformed (fig .transFigure .inverted ())
1215- base_offsets = {
1216- "left" : axis_bbox .x0 - tight_bbox .x0 ,
1217- "right" : tight_bbox .x1 - axis_bbox .x1 ,
1218- "bottom" : axis_bbox .y0 - tight_bbox .y0 ,
1219- "top" : tight_bbox .y1 - axis_bbox .y1 ,
1220- }
1221- trackers = {
1222- "left" : axis_bbox .x0 - base_offsets ["left" ],
1223- "right" : axis_bbox .x1 + base_offsets ["right" ],
1224- "bottom" : axis_bbox .y0 - base_offsets ["bottom" ],
1225- "top" : axis_bbox .y1 + base_offsets ["top" ],
1200+ tight_bbox = axis .get_tightbbox (renderer ).transformed (axis .transAxes .inverted ())
1201+ base_offsets_axes = {
1202+ "left" : max (0.0 , - tight_bbox .x0 ),
1203+ "right" : max (0.0 , tight_bbox .x1 - 1 ),
1204+ "bottom" : max (0.0 , - tight_bbox .y0 ),
1205+ "top" : max (0.0 , tight_bbox .y1 - 1 ),
12261206 }
1207+ trackers_axes = {k : base_offsets_axes [k ] for k in base_offsets_axes }
12271208 for spec in unique_specs :
1228- _draw_colorbar (spec , fig , renderer , axis_bbox , base_offsets , trackers )
1209+ _draw_colorbar (spec , fig , renderer , base_offsets_axes , trackers_axes )
12291210
12301211 if fig_params .fig is not None and save is not None :
12311212 save_fig (fig_params .fig , path = save )
0 commit comments