12
12
13
13
import matplotlib
14
14
import matplotlib .patches as mpatches
15
- import matplotlib .patches as mplp
16
15
import matplotlib .path as mpath
17
16
import matplotlib .pyplot as plt
18
- import multiscale_spatial_image as msi
19
17
import numpy as np
20
18
import pandas as pd
21
19
import shapely
49
47
from scanpy .plotting ._tools .scatterplots import _add_categorical_legend
50
48
from scanpy .plotting ._utils import add_colors_for_categorical_sample_annotation
51
49
from scanpy .plotting .palettes import default_20 , default_28 , default_102
52
- from shapely .geometry import LineString , Polygon
53
50
from skimage .color import label2rgb
54
51
from skimage .morphology import erosion , square
55
52
from skimage .segmentation import find_boundaries
@@ -283,6 +280,30 @@ def _sanitise_na_color(na_color: ColorLike | None) -> tuple[str, bool]:
283
280
raise ValueError (f"Invalid na_color value: { na_color } " )
284
281
285
282
283
+ def _get_centroid_of_pathpatch (pathpatch : mpatches .PathPatch ) -> tuple [float , float ]:
284
+ # Extract the vertices from the PathPatch
285
+ path = pathpatch .get_path ()
286
+ vertices = path .vertices
287
+ x = vertices [:, 0 ]
288
+ y = vertices [:, 1 ]
289
+
290
+ area = 0.5 * np .sum (x [:- 1 ] * y [1 :] - x [1 :] * y [:- 1 ])
291
+
292
+ # Calculate the centroid coordinates
293
+ centroid_x = np .sum ((x [:- 1 ] + x [1 :]) * (x [:- 1 ] * y [1 :] - x [1 :] * y [:- 1 ])) / (6 * area )
294
+ centroid_y = np .sum ((y [:- 1 ] + y [1 :]) * (x [:- 1 ] * y [1 :] - x [1 :] * y [:- 1 ])) / (6 * area )
295
+
296
+ return centroid_x , centroid_y
297
+
298
+
299
+ def _scale_pathpatch_around_centroid (pathpatch : mpatches .PathPatch , scale_factor : float ) -> None :
300
+
301
+ centroid = _get_centroid_of_pathpatch (pathpatch )
302
+ vertices = pathpatch .get_path ().vertices
303
+ scaled_vertices = np .array ([centroid + (vertex - centroid ) * scale_factor for vertex in vertices ])
304
+ pathpatch .get_path ().vertices = scaled_vertices
305
+
306
+
286
307
def _get_collection_shape (
287
308
shapes : list [GeoDataFrame ],
288
309
c : Any ,
@@ -352,63 +373,64 @@ def _get_collection_shape(
352
373
outline_c = outline_c * fill_c .shape [0 ]
353
374
354
375
shapes_df = pd .DataFrame (shapes , copy = True )
355
-
356
- # remove empty points/polygons
357
376
shapes_df = shapes_df [shapes_df ["geometry" ].apply (lambda geom : not geom .is_empty )]
358
-
359
- # reset index of shapes_df for case of spatial query
360
377
shapes_df = shapes_df .reset_index (drop = True )
361
378
362
- rows = []
363
-
364
- def assign_fill_and_outline_to_row (
365
- shapes : list [GeoDataFrame ], fill_c : list [Any ], outline_c : list [Any ], row : pd .Series , idx : int
379
+ def _assign_fill_and_outline_to_row (
380
+ fill_c : list [Any ], outline_c : list [Any ], row : dict [str , Any ], idx : int , is_multiple_shapes : bool
366
381
) -> None :
367
382
try :
368
- if len ( shapes ) > 1 and len (fill_c ) == 1 :
369
- row ["fill_c" ] = fill_c
370
- row ["outline_c" ] = outline_c
383
+ if is_multiple_shapes and len (fill_c ) == 1 :
384
+ row ["fill_c" ] = fill_c [ 0 ]
385
+ row ["outline_c" ] = outline_c [ 0 ]
371
386
else :
372
387
row ["fill_c" ] = fill_c [idx ]
373
388
row ["outline_c" ] = outline_c [idx ]
374
389
except IndexError as e :
375
- raise IndexError ("Could not assign fill and outline colors due to a mismatch in row-numbers." ) from e
376
-
377
- # Match colors to the geometry, potentially expanding the row in case of
378
- # multipolygons
379
- for idx , row in shapes_df .iterrows ():
380
- geom = row ["geometry" ]
381
- if geom .geom_type == "Polygon" :
382
- row = row .to_dict ()
383
- coords = np .array (geom .exterior .coords )
384
- centroid = np .mean (coords , axis = 0 )
385
- scaled_coords = [(centroid + (np .array (coord ) - centroid ) * s ).tolist () for coord in geom .exterior .coords ]
386
- row ["geometry" ] = mplp .Polygon (scaled_coords , closed = True )
387
- assign_fill_and_outline_to_row (shapes , fill_c , outline_c , row , idx )
388
- rows .append (row )
389
-
390
- elif geom .geom_type == "MultiPolygon" :
391
- # mp = _make_patch_from_multipolygon(geom)
392
- for polygon in geom .geoms :
393
- mp_copy = row .to_dict ()
394
- coords = np .array (polygon .exterior .coords )
395
- centroid = np .mean (coords , axis = 0 )
396
- scaled_coords = [(centroid + (coord - centroid ) * s ).tolist () for coord in coords ]
397
- mp_copy ["geometry" ] = mplp .Polygon (scaled_coords , closed = True )
398
- assign_fill_and_outline_to_row (shapes , fill_c , outline_c , mp_copy , idx )
399
- rows .append (mp_copy )
400
-
401
- elif geom .geom_type == "Point" :
402
- row = row .to_dict ()
403
- scaled_radius = row ["radius" ] * s
404
- row ["geometry" ] = mplp .Circle (
405
- (geom .x , geom .y ), radius = scaled_radius
406
- ) # Circle is always scaled from its center
407
- assign_fill_and_outline_to_row (shapes , fill_c , outline_c , row , idx )
408
- rows .append (row )
409
-
410
- patches = pd .DataFrame (rows )
411
-
390
+ raise IndexError ("Could not assign fill and outline colors due to a mismatch in row numbers." ) from e
391
+
392
+ def _process_polygon (row : pd .Series , s : float ) -> dict [str , Any ]:
393
+ coords = np .array (row ["geometry" ].exterior .coords )
394
+ centroid = np .mean (coords , axis = 0 )
395
+ scaled_coords = (centroid + (coords - centroid ) * s ).tolist ()
396
+ return {** row .to_dict (), "geometry" : mpatches .Polygon (scaled_coords , closed = True )}
397
+
398
+ def _process_multipolygon (row : pd .Series , s : float ) -> list [dict [str , Any ]]:
399
+ mp = _make_patch_from_multipolygon (row ["geometry" ])
400
+ row_dict = row .to_dict ()
401
+ for m in mp :
402
+ _scale_pathpatch_around_centroid (m , s )
403
+
404
+ return [{** row_dict , "geometry" : m } for m in mp ]
405
+
406
+ def _process_point (row : pd .Series , s : float ) -> dict [str , Any ]:
407
+ return {
408
+ ** row .to_dict (),
409
+ "geometry" : mpatches .Circle ((row ["geometry" ].x , row ["geometry" ].y ), radius = row ["radius" ] * s ),
410
+ }
411
+
412
+ def _create_patches (shapes_df : GeoDataFrame , fill_c : list [Any ], outline_c : list [Any ], s : float ) -> pd .DataFrame :
413
+ rows = []
414
+ is_multiple_shapes = len (shapes_df ) > 1
415
+
416
+ for idx , row in shapes_df .iterrows ():
417
+ geom_type = row ["geometry" ].geom_type
418
+ processed_rows = []
419
+
420
+ if geom_type == "Polygon" :
421
+ processed_rows .append (_process_polygon (row , s ))
422
+ elif geom_type == "MultiPolygon" :
423
+ processed_rows .extend (_process_multipolygon (row , s ))
424
+ elif geom_type == "Point" :
425
+ processed_rows .append (_process_point (row , s ))
426
+
427
+ for processed_row in processed_rows :
428
+ _assign_fill_and_outline_to_row (fill_c , outline_c , processed_row , idx , is_multiple_shapes )
429
+ rows .append (processed_row )
430
+
431
+ return pd .DataFrame (rows )
432
+
433
+ patches = _create_patches (shapes_df , fill_c , outline_c , s )
412
434
return PatchCollection (
413
435
patches ["geometry" ].values .tolist (),
414
436
snap = False ,
@@ -788,7 +810,7 @@ def _map_color_seg(
788
810
cell_id = np .array (cell_id )
789
811
if color_vector is not None and isinstance (color_vector .dtype , pd .CategoricalDtype ):
790
812
# users wants to plot a categorical column
791
- if isinstance ( na_color , tuple ) and len ( na_color ) == 4 and np .any (color_source_vector .isna ()):
813
+ if np .any (color_source_vector .isna ()):
792
814
cell_id [color_source_vector .isna ()] = 0
793
815
val_im : ArrayLike = map_array (seg , cell_id , color_vector .codes + 1 )
794
816
cols = colors .to_rgba_array (color_vector .categories )
@@ -873,9 +895,9 @@ def _modify_categorical_color_mapping(
873
895
modified_mapping = {key : mapping [key ] for key in mapping if key in groups or key == "NaN" }
874
896
elif len (palette ) == len (groups ) and isinstance (groups , list ) and isinstance (palette , list ):
875
897
modified_mapping = dict (zip (groups , palette ))
876
-
877
898
else :
878
899
raise ValueError (f"Expected palette to be of length `{ len (groups )} `, found `{ len (palette )} `." )
900
+
879
901
return modified_mapping
880
902
881
903
@@ -891,7 +913,7 @@ def _get_default_categorial_color_mapping(
891
913
palette = default_102
892
914
else :
893
915
palette = ["grey" for _ in range (len_cat )]
894
- logger .info ("input has more than 103 categories. Uniform " " 'grey' color will be used for all categories." )
916
+ logger .info ("input has more than 103 categories. Uniform 'grey' color will be used for all categories." )
895
917
896
918
return {cat : to_hex (to_rgba (col )[:3 ]) for cat , col in zip (color_source_vector .categories , palette [:len_cat ])}
897
919
@@ -922,54 +944,6 @@ def _get_categorical_color_mapping(
922
944
return _modify_categorical_color_mapping (base_mapping , groups , palette )
923
945
924
946
925
- def _get_palette (
926
- categories : Sequence [Any ],
927
- adata : AnnData | None = None ,
928
- cluster_key : None | str = None ,
929
- palette : ListedColormap | str | list [str ] | None = None ,
930
- alpha : float = 1.0 ,
931
- ) -> Mapping [str , str ] | None :
932
- palette = None if isinstance (palette , list ) and palette [0 ] is None else palette
933
- if adata is not None and palette is None :
934
- try :
935
- palette = adata .uns [f"{ cluster_key } _colors" ] # type: ignore[arg-type]
936
- if len (palette ) != len (categories ):
937
- raise ValueError (
938
- f"Expected palette to be of length `{ len (categories )} `, found `{ len (palette )} `. "
939
- + f"Removing the colors in `adata.uns` with `adata.uns.pop('{ cluster_key } _colors')` may help."
940
- )
941
- return {cat : to_hex (to_rgba (col )[:3 ]) for cat , col in zip (categories , palette )}
942
- except KeyError as e :
943
- logger .warning (e )
944
- return None
945
-
946
- len_cat = len (categories )
947
-
948
- if palette is None :
949
- if len_cat <= 20 :
950
- palette = default_20
951
- elif len_cat <= 28 :
952
- palette = default_28
953
- elif len_cat <= len (default_102 ): # 103 colors
954
- palette = default_102
955
- else :
956
- palette = ["grey" for _ in range (len_cat )]
957
- logger .info ("input has more than 103 categories. Uniform " "'grey' color will be used for all categories." )
958
- return {cat : to_hex (to_rgba (col )[:3 ]) for cat , col in zip (categories , palette [:len_cat ])}
959
-
960
- if isinstance (palette , str ):
961
- cmap = ListedColormap ([palette ])
962
- elif isinstance (palette , list ):
963
- cmap = ListedColormap (palette )
964
- elif isinstance (palette , ListedColormap ):
965
- cmap = palette
966
- else :
967
- raise TypeError (f"Palette is { type (palette )} but should be string or list." )
968
- palette = [to_hex (np .round (x , 5 )) for x in cmap (np .linspace (0 , 1 , len_cat ), alpha = alpha )]
969
-
970
- return dict (zip (categories , palette ))
971
-
972
-
973
947
def _maybe_set_colors (
974
948
source : AnnData , target : AnnData , key : str , palette : str | ListedColormap | Cycler | Sequence [Any ] | None = None
975
949
) -> None :
@@ -1137,34 +1111,6 @@ def save_fig(fig: Figure, path: str | Path, make_dir: bool = True, ext: str = "p
1137
1111
fig .savefig (path , ** kwargs )
1138
1112
1139
1113
1140
- def _get_cs_element_map (
1141
- element : str | Sequence [str ] | None ,
1142
- element_map : Mapping [str , Any ],
1143
- ) -> Mapping [str , str ]:
1144
- """Get the mapping between the coordinate system and the class."""
1145
- # from spatialdata.models import Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, ShapesModel
1146
- element = list (element_map .keys ())[0 ] if element is None else element
1147
- element = [element ] if isinstance (element , str ) else element
1148
- d = {}
1149
- for e in element :
1150
- cs = list (element_map [e ].attrs ["transform" ].keys ())[0 ]
1151
- d [cs ] = e
1152
- # model = get_model(element_map["blobs_labels"])
1153
- # if model in [Image2DModel, Image3DModel, Labels2DModel, Labels3DModel]
1154
- return d
1155
-
1156
-
1157
- def _multiscale_to_image (sdata : sd .SpatialData ) -> sd .SpatialData :
1158
- if sdata .images is None :
1159
- raise ValueError ("No images found in the SpatialData object." )
1160
-
1161
- for k , v in sdata .images .items ():
1162
- if isinstance (v , msi .multiscale_spatial_image .DataTree ):
1163
- sdata .images [k ] = Image2DModel .parse (v ["scale0" ].ds .to_array ().squeeze (axis = 0 ))
1164
-
1165
- return sdata
1166
-
1167
-
1168
1114
def _get_linear_colormap (colors : list [str ], background : str ) -> list [LinearSegmentedColormap ]:
1169
1115
return [LinearSegmentedColormap .from_list (c , [background , c ], N = 256 ) for c in colors ]
1170
1116
@@ -1176,62 +1122,6 @@ def _get_listed_colormap(color_dict: dict[str, str]) -> ListedColormap:
1176
1122
return ListedColormap (["black" ] + colors , N = len (colors ) + 1 )
1177
1123
1178
1124
1179
- def _translate_image (
1180
- image : DataArray ,
1181
- translation : sd .transformations .transformations .Translation ,
1182
- ) -> DataArray :
1183
- shifts : dict [str , int ] = {axis : int (translation .translation [idx ]) for idx , axis in enumerate (translation .axes )}
1184
- img = image .values .copy ()
1185
- # for yx images (important for rasterized MultiscaleImages as labels)
1186
- expanded_dims = False
1187
- if len (img .shape ) == 2 :
1188
- img = np .expand_dims (img , axis = 0 )
1189
- expanded_dims = True
1190
-
1191
- shifted_channels = []
1192
-
1193
- # split channels, shift axes individually, them recombine
1194
- if len (img .shape ) == 3 :
1195
- for c in range (img .shape [0 ]):
1196
- channel = img [c , :, :]
1197
-
1198
- # iterates over [x, y]
1199
- for axis , shift in shifts .items ():
1200
- pad_x , pad_y = (0 , 0 ), (0 , 0 )
1201
- if axis == "x" and shift > 0 :
1202
- pad_x = (abs (shift ), 0 )
1203
- elif axis == "x" and shift < 0 :
1204
- pad_x = (0 , abs (shift ))
1205
-
1206
- if axis == "y" and shift > 0 :
1207
- pad_y = (abs (shift ), 0 )
1208
- elif axis == "y" and shift < 0 :
1209
- pad_y = (0 , abs (shift ))
1210
-
1211
- channel = np .pad (channel , (pad_y , pad_x ), mode = "constant" )
1212
-
1213
- shifted_channels .append (channel )
1214
-
1215
- if expanded_dims :
1216
- return Labels2DModel .parse (
1217
- np .array (shifted_channels [0 ]),
1218
- dims = ["y" , "x" ],
1219
- transformations = image .attrs ["transform" ],
1220
- )
1221
- return Image2DModel .parse (
1222
- np .array (shifted_channels ),
1223
- dims = ["c" , "y" , "x" ],
1224
- transformations = image .attrs ["transform" ],
1225
- )
1226
-
1227
-
1228
- def _convert_polygon_to_linestrings (polygon : Polygon ) -> list [LineString ]:
1229
- b = polygon .boundary .coords
1230
- linestrings = [LineString (b [k : k + 2 ]) for k in range (len (b ) - 1 )]
1231
-
1232
- return [list (ls .coords ) for ls in linestrings ]
1233
-
1234
-
1235
1125
def _split_multipolygon_into_outer_and_inner (mp : shapely .MultiPolygon ): # type: ignore
1236
1126
# https://stackoverflow.com/a/21922058
1237
1127
0 commit comments