3535if TYPE_CHECKING :
3636 from collections .abc import Mapping
3737
38- import anndata
38+ from anndata import AnnData
3939 from multiscale_spatial_image import MultiscaleSpatialImage
4040 from spatial_image import SpatialImage
4141 from spatialdata ._types import ArrayLike
4242
4343RNG = default_rng (0 )
4444
45+
4546@inject_docs (vx = VisiumHDKeys )
4647def visium_hd (
4748 path : str | Path ,
4849 dataset_id : str | None = None ,
4950 filtered_counts_file : bool = True ,
50- load_segmentations_only : bool = True ,
51+ load_segmentations_only : bool = False ,
5152 load_nucleus_segmentations : bool = False ,
5253 bin_size : int | list [int ] | None = None ,
5354 bins_as_squares : bool = True ,
@@ -74,6 +75,8 @@ def visium_hd(
7475 load_segmentations_only
7576 If `True`, only the segmented cell boundaries and their associated counts will be loaded. All binned data
7677 will be skipped.
78+ load_nucleus_segmentations
79+ ...
7780 bin_size
7881 When specified, load the data of a specific bin size, or a list of bin sizes. By default, it loads all the
7982 available bin sizes.
@@ -115,15 +118,26 @@ def visium_hd(
115118 CELL_GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys .CELL_SEGMENTATION_GEOJSON_PATH
116119 NUCLEUS_GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys .NUCLEUS_SEGMENTATION_GEOJSON_PATH
117120 SCALE_FACTORS_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys .SPATIAL / VisiumHDKeys .SCALEFACTORS_FILE
118- BARCODE_MAPPINGS_PATH = next ((file for file in path .rglob ("*" ) if file .name .endswith (VisiumHDKeys .BARCODE_MAPPINGS_FILE )), None )
119- FILTERED_MATRIX_2U_PATH = path / VisiumHDKeys .BINNED_OUTPUTS / f"{ VisiumHDKeys .BIN_PREFIX } 002um" / VisiumHDKeys .FILTERED_COUNTS_FILE
120- cell_segmentation_files_exist = COUNT_MATRIX_PATH .exists () and CELL_GEOJSON_PATH .exists () and SCALE_FACTORS_PATH .exists ()
121- nucleus_segmentation_files_exist = NUCLEUS_GEOJSON_PATH .exists () and (BARCODE_MAPPINGS_PATH is not None and BARCODE_MAPPINGS_PATH .exists ()) and FILTERED_MATRIX_2U_PATH .exists ()
121+ BARCODE_MAPPINGS_PATH = next (
122+ (file for file in path .rglob ("*" ) if file .name .endswith (VisiumHDKeys .BARCODE_MAPPINGS_FILE )), None
123+ )
124+ FILTERED_MATRIX_2U_PATH = (
125+ path / VisiumHDKeys .BINNED_OUTPUTS / f"{ VisiumHDKeys .BIN_PREFIX } 002um" / VisiumHDKeys .FILTERED_COUNTS_FILE
126+ )
127+ cell_segmentation_files_exist = (
128+ COUNT_MATRIX_PATH .exists () and CELL_GEOJSON_PATH .exists () and SCALE_FACTORS_PATH .exists ()
129+ )
130+ nucleus_segmentation_files_exist = (
131+ NUCLEUS_GEOJSON_PATH .exists ()
132+ and (BARCODE_MAPPINGS_PATH is not None and BARCODE_MAPPINGS_PATH .exists ())
133+ and FILTERED_MATRIX_2U_PATH .exists ()
134+ )
122135
123136 if dataset_id is None :
124137 dataset_id = _infer_dataset_id (path )
125138
126139 filename_prefix = _get_filename_prefix (path , dataset_id )
140+
127141 def load_image (path : Path , suffix : str , scale_factors : list [int ] | None = None ) -> None :
128142 _load_image (
129143 path = path ,
@@ -147,6 +161,7 @@ def load_image(path: Path, suffix: str, scale_factors: list[int] | None = None)
147161
148162 # Load Binned Data (skipped if load_segmentations_only is True)
149163 if not load_segmentations_only :
164+
150165 def _get_bins (path_bins : Path ) -> list [str ]:
151166 return sorted (
152167 [
@@ -158,9 +173,7 @@ def _get_bins(path_bins: Path) -> list[str]:
158173
159174 all_path_bins = [path_bin for path_bin in all_files if VisiumHDKeys .BINNED_OUTPUTS in str (path_bin )]
160175 if len (all_path_bins ) != 0 :
161- path_bins_parts = all_path_bins [
162- - 1
163- ].parts
176+ path_bins_parts = all_path_bins [- 1 ].parts
164177 path_bins = Path (* path_bins_parts [: path_bins_parts .index (VisiumHDKeys .BINNED_OUTPUTS ) + 1 ])
165178 else :
166179 path_bins = path
@@ -181,6 +194,7 @@ def _get_bins(path_bins: Path) -> list[str]:
181194 if bin_size is None or bin_sizes == []:
182195 bin_sizes = all_bin_sizes
183196
197+ # iterate over the given bins and load the data
184198 for bin_size_str in bin_sizes :
185199 path_bin = path_bins / bin_size_str
186200 counts_file = VisiumHDKeys .FILTERED_COUNTS_FILE if filtered_counts_file else VisiumHDKeys .RAW_COUNTS_FILE
@@ -195,6 +209,7 @@ def _get_bins(path_bins: Path) -> list[str]:
195209 with open (path_bin_spatial / VisiumHDKeys .SCALEFACTORS_FILE ) as file :
196210 scalefactors = json .load (file )
197211
212+ # consistency check
198213 found_bin_size = re .search (r"\d{3}" , bin_size_str )
199214 assert found_bin_size is not None
200215 assert float (found_bin_size .group ()) == scalefactors [VisiumHDKeys .SCALEFACTORS_BIN_SIZE_UM ]
@@ -206,6 +221,7 @@ def _get_bins(path_bins: Path) -> list[str]:
206221
207222 tissue_positions_file = path_bin_spatial / VisiumHDKeys .TISSUE_POSITIONS_FILE
208223
224+ # read coordinates and set up adata.obs and adata.obsm
209225 coords = pd .read_parquet (tissue_positions_file )
210226 assert all (
211227 coords .columns .values
@@ -221,7 +237,9 @@ def _get_bins(path_bins: Path) -> list[str]:
221237 coords .set_index (VisiumHDKeys .BARCODE , inplace = True , drop = True )
222238 coords_filtered = coords .loc [adata .obs .index ]
223239 adata .obs = pd .merge (adata .obs , coords_filtered , how = "left" , left_index = True , right_index = True )
240+ # compatibility to legacy squidpy
224241 adata .obsm ["spatial" ] = adata .obs [[VisiumHDKeys .LOCATIONS_X , VisiumHDKeys .LOCATIONS_Y ]].values
242+ # dropping the spatial coordinates (will be stored in shapes)
225243 adata .obs .drop (
226244 columns = [
227245 VisiumHDKeys .LOCATIONS_X ,
@@ -231,6 +249,8 @@ def _get_bins(path_bins: Path) -> list[str]:
231249 )
232250 adata .obs [VisiumHDKeys .INSTANCE_KEY ] = np .arange (len (adata ))
233251
252+ # scaling
253+ transform_original = Identity ()
234254 transform_lowres = Scale (
235255 np .array (
236256 [
@@ -249,11 +269,13 @@ def _get_bins(path_bins: Path) -> list[str]:
249269 ),
250270 axes = ("x" , "y" ),
251271 )
272+ # parse shapes
252273 shapes_name = dataset_id + "_" + bin_size_str
253274 radius = scalefactors [VisiumHDKeys .SCALEFACTORS_SPOT_DIAMETER_FULLRES ] / 2.0
254275
255276 # Here we ensure that only the correct coordinate systems are created for the binned data
256277 transformations = {
278+ dataset_id : transform_original ,
257279 f"{ dataset_id } _downscaled_hires" : transform_hires ,
258280 f"{ dataset_id } _downscaled_lowres" : transform_lowres ,
259281 }
@@ -272,6 +294,7 @@ def _get_bins(path_bins: Path) -> list[str]:
272294 GeoDataFrame (geometry = squares_series ), transformations = transformations
273295 )
274296
297+ # parse table
275298 adata .obs [VisiumHDKeys .REGION_KEY ] = shapes_name
276299 adata .obs [VisiumHDKeys .REGION_KEY ] = adata .obs [VisiumHDKeys .REGION_KEY ].astype ("category" )
277300
@@ -290,45 +313,52 @@ def _get_bins(path_bins: Path) -> list[str]:
290313 cell_adata_hd = sc .read_10x_h5 (COUNT_MATRIX_PATH )
291314 cell_adata_hd .var_names_make_unique ()
292315
293- shapes_transformations_hd = _make_shapes_transformation (scale_factors_path = SCALE_FACTORS_PATH , dataset_id = dataset_id ) # Used for both cell and nucleus segmentations
316+ shapes_transformations_hd = _make_shapes_transformation (
317+ scale_factors_path = SCALE_FACTORS_PATH , dataset_id = dataset_id
318+ ) # Used for both cell and nucleus segmentations
294319 cell_geojson_features_map = _make_geojson_features_map (CELL_GEOJSON_PATH )
295- cell_shapes_gdf = _extract_geometries_from_geojson (cell_adata_hd , geojson_features_map = cell_geojson_features_map )
320+ cell_shapes_gdf = _extract_geometries_from_geojson (
321+ cell_adata_hd , geojson_features_map = cell_geojson_features_map
322+ )
296323
297324 SHAPES_KEY_HD = f"{ dataset_id } _{ VisiumHDKeys .CELL_SEG_KEY_HD } "
298- cell_adata_hd .obs [' cell_id' ] = cell_adata_hd .obs .index
299- cell_adata_hd .obs [' region' ] = SHAPES_KEY_HD
300- cell_adata_hd .obs [' region' ] = cell_adata_hd .obs [' region' ].astype (' category' )
325+ cell_adata_hd .obs [" cell_id" ] = cell_adata_hd .obs .index
326+ cell_adata_hd .obs [" region" ] = SHAPES_KEY_HD
327+ cell_adata_hd .obs [" region" ] = cell_adata_hd .obs [" region" ].astype (" category" )
301328 cell_adata_hd = cell_adata_hd [cell_shapes_gdf .index ].copy ()
302329
303330 shapes [SHAPES_KEY_HD ] = ShapesModel .parse (cell_shapes_gdf , transformations = shapes_transformations_hd )
304331 tables [VisiumHDKeys .CELL_SEG_KEY_HD ] = TableModel .parse (
305- cell_adata_hd ,
306- region = SHAPES_KEY_HD ,
307- region_key = 'region' ,
308- instance_key = 'cell_id'
332+ cell_adata_hd , region = SHAPES_KEY_HD , region_key = "region" , instance_key = "cell_id"
309333 )
310334
311335 # load nucleus segmentations if available
312336 if nucleus_segmentation_files_exist and load_nucleus_segmentations :
313337 print ("Found nucleus segmentation data. Incorporating nucleus_segmentations." )
314338
315- nucleus_adata_hd = _make_filtered_nucleus_adata (filtered_matrix_h5_path = FILTERED_MATRIX_2U_PATH ,barcode_mappings_parquet_path = BARCODE_MAPPINGS_PATH )
316- geojson_features_map = _make_geojson_features_map (NUCLEUS_GEOJSON_PATH )
317- nucleus_shapes_gdf = _extract_geometries_from_geojson (adata = nucleus_adata_hd , geojson_features_map = geojson_features_map )
318-
319- SHAPES_KEY_HD = f"{ dataset_id } _{ VisiumHDKeys .NUCLEUS_SEG_KEY_HD } "
320- nucleus_adata_hd .obs ['cell_id' ] = nucleus_adata_hd .obs .index
321- nucleus_adata_hd .obs ['region' ] = SHAPES_KEY_HD
322- nucleus_adata_hd .obs ['region' ] = nucleus_adata_hd .obs ['region' ].astype ('category' )
323- nucleus_adata_hd = nucleus_adata_hd [nucleus_shapes_gdf .index ].copy ()
324-
325- shapes [SHAPES_KEY_HD ] = ShapesModel .parse (nucleus_shapes_gdf , transformations = shapes_transformations_hd )
326- tables [VisiumHDKeys .NUCLEUS_SEG_KEY_HD ] = TableModel .parse (
327- nucleus_adata_hd ,
328- region = SHAPES_KEY_HD ,
329- region_key = 'region' ,
330- instance_key = 'cell_id'
331- )
339+ if BARCODE_MAPPINGS_PATH is None :
340+ warnings .warn (
341+ "Cannot find the barcode mappings file, skipping nucleus segmentations." , UserWarning , stacklevel = 2
342+ )
343+ else :
344+ nucleus_adata_hd = _make_filtered_nucleus_adata (
345+ filtered_matrix_h5_path = FILTERED_MATRIX_2U_PATH , barcode_mappings_parquet_path = BARCODE_MAPPINGS_PATH
346+ )
347+ geojson_features_map = _make_geojson_features_map (NUCLEUS_GEOJSON_PATH )
348+ nucleus_shapes_gdf = _extract_geometries_from_geojson (
349+ adata = nucleus_adata_hd , geojson_features_map = geojson_features_map
350+ )
351+
352+ SHAPES_KEY_HD = f"{ dataset_id } _{ VisiumHDKeys .NUCLEUS_SEG_KEY_HD } "
353+ nucleus_adata_hd .obs ["cell_id" ] = nucleus_adata_hd .obs .index
354+ nucleus_adata_hd .obs ["region" ] = SHAPES_KEY_HD
355+ nucleus_adata_hd .obs ["region" ] = nucleus_adata_hd .obs ["region" ].astype ("category" )
356+ nucleus_adata_hd = nucleus_adata_hd [nucleus_shapes_gdf .index ].copy ()
357+
358+ shapes [SHAPES_KEY_HD ] = ShapesModel .parse (nucleus_shapes_gdf , transformations = shapes_transformations_hd )
359+ tables [VisiumHDKeys .NUCLEUS_SEG_KEY_HD ] = TableModel .parse (
360+ nucleus_adata_hd , region = SHAPES_KEY_HD , region_key = "region" , instance_key = "cell_id"
361+ )
332362
333363 # Read all images and add transformations for both binning and segmentation
334364 fullres_image_file_paths = []
@@ -472,7 +502,9 @@ def _get_bins(path_bins: Path) -> list[str]:
472502 )
473503 warped = np .round (warped * 255 ).astype (np .uint8 )
474504 if not load_segmentations_only :
475- warped = Image2DModel .parse (warped , dims = ("y" , "x" , "c" ), transformations = {dataset_id : affine }, rgb = True )
505+ warped = Image2DModel .parse (
506+ warped , dims = ("y" , "x" , "c" ), transformations = {dataset_id : affine }, rgb = True
507+ )
476508 images [dataset_id + "_cytassist_image" ] = warped
477509 elif load_all_images :
478510 warnings .warn (
@@ -645,12 +677,13 @@ def _get_transform_matrices(metadata: dict[str, Any], hd_layout: dict[str, Any])
645677
646678 return transform_matrices
647679
680+
648681def _make_filtered_nucleus_adata (
649682 filtered_matrix_h5_path : Path ,
650683 barcode_mappings_parquet_path : Path ,
651- bin_col_name : str = ' square_002um' ,
652- aggregate_col_name : str = ' cell_id'
653- ) -> anndata . AnnData :
684+ bin_col_name : str = " square_002um" ,
685+ aggregate_col_name : str = " cell_id" ,
686+ ) -> AnnData :
654687 """Generate a filtered AnnData object by aggregating 2um binned data based on nucleus segmentation.
655688
656689 Uses a 2um filtered_feature_bc_matrix.h5 file and a barcode_mappings.parquet file containing
@@ -671,7 +704,7 @@ def _make_filtered_nucleus_adata(
671704
672705 Returns:
673706 --------
674- anndata. AnnData
707+ AnnData
675708 An AnnData object where the observations correspond to filtered cell IDs
676709 and the variables correspond to the original features from the input data.
677710 """
@@ -680,19 +713,19 @@ def _make_filtered_nucleus_adata(
680713 barcode_mappings = pq .read_table (barcode_mappings_parquet_path )
681714
682715 # Filter to only include valid cell IDs that are in both nucleus and cell
683- barcode_mappings = barcode_mappings .filter ((barcode_mappings ['cell_id' ].is_valid ()) and barcode_mappings ["in_nucleus" ])
716+ barcode_mappings = barcode_mappings .filter (
717+ (barcode_mappings ["cell_id" ].is_valid ()) and barcode_mappings ["in_nucleus" ]
718+ )
684719
685720 # Filter the 2um adata to only include squares present in the barcode mappings
686721 valid_squares = barcode_mappings [bin_col_name ].unique ()
687722 squares_to_keep = np .intersect1d (adata_2um .obs_names , valid_squares )
688723 adata_filtered = adata_2um [squares_to_keep , :].copy ()
689724
690725 # Map each square to its corresponding cell ID
691- square_to_cell_map = dict (zip (
692- barcode_mappings [bin_col_name ].to_pylist (),
693- barcode_mappings [aggregate_col_name ].to_pylist (), strict = False
694-
695- ))
726+ square_to_cell_map = dict (
727+ zip (barcode_mappings [bin_col_name ].to_pylist (), barcode_mappings [aggregate_col_name ].to_pylist (), strict = False )
728+ )
696729 ordered_cell_ids = [square_to_cell_map [square ] for square in adata_filtered .obs_names ]
697730 unique_cells = list (dict .fromkeys (ordered_cell_ids ).keys ())
698731 cell_to_idx = {cell : i for i , cell in enumerate (unique_cells )}
@@ -702,10 +735,7 @@ def _make_filtered_nucleus_adata(
702735 row_indices = np .arange (len (ordered_cell_ids ))
703736 data = np .ones_like (row_indices )
704737
705- aggregation_matrix = csc_matrix (
706- (data , (row_indices , col_indices )),
707- shape = (adata_filtered .n_obs , len (unique_cells ))
708- )
738+ aggregation_matrix = csc_matrix ((data , (row_indices , col_indices )), shape = (adata_filtered .n_obs , len (unique_cells )))
709739
710740 # Make the final AnnData object where cell IDs are filtered
711741 # to the data under the segmented nuclei
@@ -716,12 +746,13 @@ def _make_filtered_nucleus_adata(
716746
717747 return adata_nucleus
718748
719- def _extract_geometries_from_geojson (adata : anndata .AnnData , geojson_features_map : dict [str , Any ]) -> GeoDataFrame :
749+
750+ def _extract_geometries_from_geojson (adata : AnnData , geojson_features_map : dict [str , Any ]) -> GeoDataFrame :
720751 """Extract geometries and create a GeoDataFrame from a GeoJSON features map.
721752
722753 Parameters
723754 ----------
724- cell_adata : anndata. AnnData
755+ cell_adata : AnnData
725756 AnnData object containing cell data.
726757 geojson_features_map : dict[str, Any]
727758 Dictionary mapping cell IDs to GeoJSON features.
@@ -737,7 +768,7 @@ def _extract_geometries_from_geojson(adata: anndata.AnnData, geojson_features_ma
737768 for obs_index_str in adata .obs .index :
738769 feature = geojson_features_map .get (obs_index_str )
739770 if feature :
740- polygon_coords = np .array (feature [' geometry' ][ ' coordinates' ][0 ])
771+ polygon_coords = np .array (feature [" geometry" ][ " coordinates" ][0 ])
741772 geometries .append (Polygon (polygon_coords ))
742773 cell_ids_ordered .append (obs_index_str )
743774 else :
@@ -748,10 +779,8 @@ def _extract_geometries_from_geojson(adata: anndata.AnnData, geojson_features_ma
748779 geometries = [geometries [i ] for i in valid_indices ]
749780 cell_ids_ordered = [cell_ids_ordered [i ] for i in valid_indices ]
750781
751- return GeoDataFrame ({
752- 'cell_id' : cell_ids_ordered ,
753- 'geometry' : geometries
754- }, index = cell_ids_ordered )
782+ return GeoDataFrame ({"cell_id" : cell_ids_ordered , "geometry" : geometries }, index = cell_ids_ordered )
783+
755784
756785def _make_shapes_transformation (scale_factors_path : Path , dataset_id : str ) -> dict [str , Scale ]:
757786 """Load scale factors for lowres and hires images and create transformations.
@@ -770,18 +799,20 @@ def _make_shapes_transformation(scale_factors_path: Path, dataset_id: str) -> di
770799 """
771800 with open (scale_factors_path ) as f :
772801 scale_data_hd = json .load (f )
773- lowres_scale_factor_hd = scale_data_hd [' tissue_lowres_scalef' ]
774- hires_scale_factor_hd = scale_data_hd [' tissue_hires_scalef' ]
802+ lowres_scale_factor_hd = scale_data_hd [" tissue_lowres_scalef" ]
803+ hires_scale_factor_hd = scale_data_hd [" tissue_hires_scalef" ]
775804
776805 return {
777- f"{ dataset_id } _downscaled_lowres" : Scale (np .array ([lowres_scale_factor_hd , lowres_scale_factor_hd ]), axes = ("x" , "y" )),
778- f"{ dataset_id } _downscaled_hires" : Scale (np .array ([hires_scale_factor_hd , hires_scale_factor_hd ]), axes = ("x" , "y" ))
806+ f"{ dataset_id } _downscaled_lowres" : Scale (
807+ np .array ([lowres_scale_factor_hd , lowres_scale_factor_hd ]), axes = ("x" , "y" )
808+ ),
809+ f"{ dataset_id } _downscaled_hires" : Scale (
810+ np .array ([hires_scale_factor_hd , hires_scale_factor_hd ]), axes = ("x" , "y" )
811+ ),
779812 }
780813
814+
781815def _make_geojson_features_map (geojson_path : Path ) -> dict [str , Any ]:
782816 with open (geojson_path ) as f :
783817 geojson_data = json .load (f )
784- return {
785- f"cellid_{ feature ['properties' ]['cell_id' ]:09d} -1" : feature
786- for feature in geojson_data ['features' ]
787- }
818+ return {f"cellid_{ feature ['properties' ]['cell_id' ]:09d} -1" : feature for feature in geojson_data ["features" ]}
0 commit comments