Skip to content

Commit e5c2d16

Browse files
committed
initial review pass
2 parents 30e4076 + cc9dc68 commit e5c2d16

File tree

1 file changed

+60
-52
lines changed

1 file changed

+60
-52
lines changed

src/spatialdata_io/readers/visium_hd.py

Lines changed: 60 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
if TYPE_CHECKING:
3636
from collections.abc import Mapping
3737

38-
import anndata
38+
from anndata import AnnData
3939
from multiscale_spatial_image import MultiscaleSpatialImage
4040
from spatial_image import SpatialImage
4141
from spatialdata._types import ArrayLike
@@ -48,7 +48,7 @@ def visium_hd(
4848
path: str | Path,
4949
dataset_id: str | None = None,
5050
filtered_counts_file: bool = True,
51-
load_segmentations_only: bool = True,
51+
load_segmentations_only: bool = False,
5252
load_nucleus_segmentations: bool = False,
5353
bin_size: int | list[int] | None = None,
5454
bins_as_squares: bool = True,
@@ -75,6 +75,8 @@ def visium_hd(
7575
load_segmentations_only
7676
If `True`, only the segmented cell boundaries and their associated counts will be loaded. All binned data
7777
will be skipped.
78+
load_nucleus_segmentations
79+
...
7880
bin_size
7981
When specified, load the data of a specific bin size, or a list of bin sizes. By default, it loads all the
8082
available bin sizes.
@@ -158,6 +160,9 @@ def load_image(path: Path, suffix: str, scale_factors: list[int] | None = None)
158160
)
159161

160162
# TODO: load scalefactor independenly of the parameter load_segmentations_only
163+
with open(SCALE_FACTORS_PATH) as file:
164+
scalefactors = json.load(file)
165+
161166
transform_lowres = Scale(
162167
np.array(
163168
[
@@ -226,17 +231,19 @@ def _get_bins(path_bins: Path) -> list[str]:
226231

227232
path_bin_spatial = path_bin / VisiumHDKeys.SPATIAL
228233

234+
# the scale factors of binned data are consistent to the global ones
235+
# (already loaded in "scalefactors", but contain extra keys)
229236
with open(path_bin_spatial / VisiumHDKeys.SCALEFACTORS_FILE) as file:
230-
scalefactors = json.load(file)
237+
scalefactors_bins = json.load(file)
231238

232239
# consistency check
233240
found_bin_size = re.search(r"\d{3}", bin_size_str)
234241
assert found_bin_size is not None
235-
assert float(found_bin_size.group()) == scalefactors[VisiumHDKeys.SCALEFACTORS_BIN_SIZE_UM]
242+
assert float(found_bin_size.group()) == scalefactors_bins[VisiumHDKeys.SCALEFACTORS_BIN_SIZE_UM]
236243
assert np.isclose(
237-
scalefactors[VisiumHDKeys.SCALEFACTORS_BIN_SIZE_UM]
238-
/ scalefactors[VisiumHDKeys.SCALEFACTORS_SPOT_DIAMETER_FULLRES],
239-
scalefactors[VisiumHDKeys.SCALEFACTORS_MICRONS_PER_PIXEL],
244+
scalefactors_bins[VisiumHDKeys.SCALEFACTORS_BIN_SIZE_UM]
245+
/ scalefactors_bins[VisiumHDKeys.SCALEFACTORS_SPOT_DIAMETER_FULLRES],
246+
scalefactors_bins[VisiumHDKeys.SCALEFACTORS_MICRONS_PER_PIXEL],
240247
)
241248

242249
tissue_positions_file = path_bin_spatial / VisiumHDKeys.TISSUE_POSITIONS_FILE
@@ -273,7 +280,7 @@ def _get_bins(path_bins: Path) -> list[str]:
273280
transform_original = Identity()
274281
# parse shapes
275282
shapes_name = dataset_id + "_" + bin_size_str
276-
radius = scalefactors[VisiumHDKeys.SCALEFACTORS_SPOT_DIAMETER_FULLRES] / 2.0
283+
radius = scalefactors_bins[VisiumHDKeys.SCALEFACTORS_SPOT_DIAMETER_FULLRES] / 2.0
277284
transformations = {
278285
dataset_id: transform_original,
279286
f"{dataset_id}_downscaled_hires": transform_hires,
@@ -307,25 +314,6 @@ def _get_bins(path_bins: Path) -> list[str]:
307314
if var_names_make_unique:
308315
tables[bin_size_str].var_names_make_unique()
309316

310-
if annotate_table_by_labels:
311-
for bin_size_str in bin_sizes:
312-
shapes_name = dataset_id + "_" + bin_size_str
313-
# add labels layer (rasterized bins).
314-
labels_name = f"{dataset_id}_{bin_size_str}_labels"
315-
labels_element = rasterize_bins(
316-
sdata,
317-
bins=shapes_name,
318-
table_name=bin_size_str,
319-
row_key=VisiumHDKeys.ARRAY_ROW,
320-
col_key=VisiumHDKeys.ARRAY_COL,
321-
value_key=None,
322-
return_region_as_labels=True,
323-
)
324-
sdata[labels_name] = labels_element
325-
rasterize_bins_link_table_to_labels(
326-
sdata=sdata, table_name=bin_size_str, rasterized_labels_name=labels_name
327-
)
328-
329317
# Integrate the segmentation data (skipped if segmentation files are not found)
330318
if cell_segmentation_files_exist:
331319
print("Found segmentation data. Incorporating cell_segmentations.")
@@ -355,24 +343,29 @@ def _get_bins(path_bins: Path) -> list[str]:
355343
if nucleus_segmentation_files_exist and load_nucleus_segmentations:
356344
print("Found nucleus segmentation data. Incorporating nucleus_segmentations.")
357345

358-
nucleus_adata_hd = _make_filtered_nucleus_adata(
359-
filtered_matrix_h5_path=FILTERED_MATRIX_2U_PATH, barcode_mappings_parquet_path=BARCODE_MAPPINGS_PATH
360-
)
361-
geojson_features_map = _make_geojson_features_map(NUCLEUS_GEOJSON_PATH)
362-
nucleus_shapes_gdf = _extract_geometries_from_geojson(
363-
adata=nucleus_adata_hd, geojson_features_map=geojson_features_map
364-
)
346+
if BARCODE_MAPPINGS_PATH is None:
347+
warnings.warn(
348+
"Cannot find the barcode mappings file, skipping nucleus segmentations.", UserWarning, stacklevel=2
349+
)
350+
else:
351+
nucleus_adata_hd = _make_filtered_nucleus_adata(
352+
filtered_matrix_h5_path=FILTERED_MATRIX_2U_PATH, barcode_mappings_parquet_path=BARCODE_MAPPINGS_PATH
353+
)
354+
geojson_features_map = _make_geojson_features_map(NUCLEUS_GEOJSON_PATH)
355+
nucleus_shapes_gdf = _extract_geometries_from_geojson(
356+
adata=nucleus_adata_hd, geojson_features_map=geojson_features_map
357+
)
365358

366-
SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.NUCLEUS_SEG_KEY_HD}"
367-
nucleus_adata_hd.obs["cell_id"] = nucleus_adata_hd.obs.index
368-
nucleus_adata_hd.obs["region"] = SHAPES_KEY_HD
369-
nucleus_adata_hd.obs["region"] = nucleus_adata_hd.obs["region"].astype("category")
370-
nucleus_adata_hd = nucleus_adata_hd[nucleus_shapes_gdf.index].copy()
359+
SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.NUCLEUS_SEG_KEY_HD}"
360+
nucleus_adata_hd.obs["cell_id"] = nucleus_adata_hd.obs.index
361+
nucleus_adata_hd.obs["region"] = SHAPES_KEY_HD
362+
nucleus_adata_hd.obs["region"] = nucleus_adata_hd.obs["region"].astype("category")
363+
nucleus_adata_hd = nucleus_adata_hd[nucleus_shapes_gdf.index].copy()
371364

372-
shapes[SHAPES_KEY_HD] = ShapesModel.parse(nucleus_shapes_gdf, transformations=shapes_transformations_hd)
373-
tables[VisiumHDKeys.NUCLEUS_SEG_KEY_HD] = TableModel.parse(
374-
nucleus_adata_hd, region=SHAPES_KEY_HD, region_key="region", instance_key="cell_id"
375-
)
365+
shapes[SHAPES_KEY_HD] = ShapesModel.parse(nucleus_shapes_gdf, transformations=shapes_transformations_hd)
366+
tables[VisiumHDKeys.NUCLEUS_SEG_KEY_HD] = TableModel.parse(
367+
nucleus_adata_hd, region=SHAPES_KEY_HD, region_key="region", instance_key="cell_id"
368+
)
376369

377370
# Read all images and add transformations for both binning and segmentation
378371
fullres_image_file_paths = []
@@ -514,9 +507,7 @@ def _get_bins(path_bins: Path) -> list[str]:
514507
numpy_data, ProjectiveTransform(projective_shift).inverse, output_shape=transformed_shape, order=1
515508
)
516509
warped = np.round(warped * 255).astype(np.uint8)
517-
warped = Image2DModel.parse(
518-
warped, dims=("y", "x", "c"), transformations={dataset_id: affine}, rgb=True
519-
)
510+
warped = Image2DModel.parse(warped, dims=("y", "x", "c"), transformations={dataset_id: affine}, rgb=True)
520511
# we replace the cytassist image with the warped image
521512
images[dataset_id + "_cytassist_image"] = warped
522513
elif load_all_images:
@@ -528,7 +519,24 @@ def _get_bins(path_bins: Path) -> list[str]:
528519

529520
sdata = SpatialData(tables=tables, images=images, shapes=shapes, labels=labels)
530521

531-
522+
if annotate_table_by_labels:
523+
for bin_size_str in bin_sizes:
524+
shapes_name = dataset_id + "_" + bin_size_str
525+
# add labels layer (rasterized bins).
526+
labels_name = f"{dataset_id}_{bin_size_str}_labels"
527+
labels_element = rasterize_bins(
528+
sdata,
529+
bins=shapes_name,
530+
table_name=bin_size_str,
531+
row_key=VisiumHDKeys.ARRAY_ROW,
532+
col_key=VisiumHDKeys.ARRAY_COL,
533+
value_key=None,
534+
return_region_as_labels=True,
535+
)
536+
sdata[labels_name] = labels_element
537+
rasterize_bins_link_table_to_labels(
538+
sdata=sdata, table_name=bin_size_str, rasterized_labels_name=labels_name
539+
)
532540

533541
return sdata
534542

@@ -680,8 +688,8 @@ def _make_filtered_nucleus_adata(
680688
barcode_mappings_parquet_path: Path,
681689
bin_col_name: str = "square_002um",
682690
aggregate_col_name: str = "cell_id",
683-
) -> anndata.AnnData:
684-
"""Generate a filtered AnnData object by aggregating binned data (default 2um) based on nucleus segmentation.
691+
) -> AnnData:
692+
"""Generate a filtered AnnData object by aggregating 2um binned data based on nucleus segmentation.
685693
686694
Uses filtered_feature_bc_matrix.h5 file and a barcode_mappings.parquet file containing
687695
barcode mappings, filters the data to include only valid nucleus mappings,
@@ -701,7 +709,7 @@ def _make_filtered_nucleus_adata(
701709
702710
Returns:
703711
--------
704-
anndata.AnnData
712+
AnnData
705713
An AnnData object where the observations correspond to filtered cell IDs
706714
and the variables correspond to the original features from the input data.
707715
"""
@@ -744,12 +752,12 @@ def _make_filtered_nucleus_adata(
744752
return adata_nucleus
745753

746754

747-
def _extract_geometries_from_geojson(adata: anndata.AnnData, geojson_features_map: dict[str, Any]) -> GeoDataFrame:
755+
def _extract_geometries_from_geojson(adata: AnnData, geojson_features_map: dict[str, Any]) -> GeoDataFrame:
748756
"""Extract geometries and create a GeoDataFrame from a GeoJSON features map.
749757
750758
Parameters
751759
----------
752-
cell_adata : anndata.AnnData
760+
cell_adata : AnnData
753761
AnnData object containing cell data.
754762
geojson_features_map : dict[str, Any]
755763
Dictionary mapping cell IDs to GeoJSON features.

0 commit comments

Comments
 (0)