Skip to content

Commit cc9dc68

Browse files
committed
wip code review
1 parent c9674ca commit cc9dc68

File tree

3 files changed

+151
-75
lines changed

3 files changed

+151
-75
lines changed

src/spatialdata_io/_constants/_constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,5 +409,5 @@ class VisiumHDKeys(ModeEnum):
409409
FILE_FORMAT = "file_format"
410410

411411
# Cell Segmentation keys
412-
CELL_SEG_KEY_HD = 'cell_segmentations'
413-
NUCLEUS_SEG_KEY_HD = 'nucleus_segmentations'
412+
CELL_SEG_KEY_HD = "cell_segmentations"
413+
NUCLEUS_SEG_KEY_HD = "nucleus_segmentations"

src/spatialdata_io/readers/visium_hd.py

Lines changed: 96 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,20 @@
3535
if TYPE_CHECKING:
3636
from collections.abc import Mapping
3737

38-
import anndata
38+
from anndata import AnnData
3939
from multiscale_spatial_image import MultiscaleSpatialImage
4040
from spatial_image import SpatialImage
4141
from spatialdata._types import ArrayLike
4242

4343
RNG = default_rng(0)
4444

45+
4546
@inject_docs(vx=VisiumHDKeys)
4647
def visium_hd(
4748
path: str | Path,
4849
dataset_id: str | None = None,
4950
filtered_counts_file: bool = True,
50-
load_segmentations_only: bool = True,
51+
load_segmentations_only: bool = False,
5152
load_nucleus_segmentations: bool = False,
5253
bin_size: int | list[int] | None = None,
5354
bins_as_squares: bool = True,
@@ -74,6 +75,8 @@ def visium_hd(
7475
load_segmentations_only
7576
If `True`, only the segmented cell boundaries and their associated counts will be loaded. All binned data
7677
will be skipped.
78+
load_nucleus_segmentations
79+
...
7780
bin_size
7881
When specified, load the data of a specific bin size, or a list of bin sizes. By default, it loads all the
7982
available bin sizes.
@@ -115,15 +118,26 @@ def visium_hd(
115118
CELL_GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.CELL_SEGMENTATION_GEOJSON_PATH
116119
NUCLEUS_GEOJSON_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.NUCLEUS_SEGMENTATION_GEOJSON_PATH
117120
SCALE_FACTORS_PATH = SEGMENTED_OUTPUTS_PATH / VisiumHDKeys.SPATIAL / VisiumHDKeys.SCALEFACTORS_FILE
118-
BARCODE_MAPPINGS_PATH = next((file for file in path.rglob("*") if file.name.endswith(VisiumHDKeys.BARCODE_MAPPINGS_FILE)), None)
119-
FILTERED_MATRIX_2U_PATH = path / VisiumHDKeys.BINNED_OUTPUTS / f"{VisiumHDKeys.BIN_PREFIX}002um" / VisiumHDKeys.FILTERED_COUNTS_FILE
120-
cell_segmentation_files_exist = COUNT_MATRIX_PATH.exists() and CELL_GEOJSON_PATH.exists() and SCALE_FACTORS_PATH.exists()
121-
nucleus_segmentation_files_exist = NUCLEUS_GEOJSON_PATH.exists() and (BARCODE_MAPPINGS_PATH is not None and BARCODE_MAPPINGS_PATH.exists()) and FILTERED_MATRIX_2U_PATH.exists()
121+
BARCODE_MAPPINGS_PATH = next(
122+
(file for file in path.rglob("*") if file.name.endswith(VisiumHDKeys.BARCODE_MAPPINGS_FILE)), None
123+
)
124+
FILTERED_MATRIX_2U_PATH = (
125+
path / VisiumHDKeys.BINNED_OUTPUTS / f"{VisiumHDKeys.BIN_PREFIX}002um" / VisiumHDKeys.FILTERED_COUNTS_FILE
126+
)
127+
cell_segmentation_files_exist = (
128+
COUNT_MATRIX_PATH.exists() and CELL_GEOJSON_PATH.exists() and SCALE_FACTORS_PATH.exists()
129+
)
130+
nucleus_segmentation_files_exist = (
131+
NUCLEUS_GEOJSON_PATH.exists()
132+
and (BARCODE_MAPPINGS_PATH is not None and BARCODE_MAPPINGS_PATH.exists())
133+
and FILTERED_MATRIX_2U_PATH.exists()
134+
)
122135

123136
if dataset_id is None:
124137
dataset_id = _infer_dataset_id(path)
125138

126139
filename_prefix = _get_filename_prefix(path, dataset_id)
140+
127141
def load_image(path: Path, suffix: str, scale_factors: list[int] | None = None) -> None:
128142
_load_image(
129143
path=path,
@@ -147,6 +161,7 @@ def load_image(path: Path, suffix: str, scale_factors: list[int] | None = None)
147161

148162
# Load Binned Data (skipped if load_segmentations_only is True)
149163
if not load_segmentations_only:
164+
150165
def _get_bins(path_bins: Path) -> list[str]:
151166
return sorted(
152167
[
@@ -158,9 +173,7 @@ def _get_bins(path_bins: Path) -> list[str]:
158173

159174
all_path_bins = [path_bin for path_bin in all_files if VisiumHDKeys.BINNED_OUTPUTS in str(path_bin)]
160175
if len(all_path_bins) != 0:
161-
path_bins_parts = all_path_bins[
162-
-1
163-
].parts
176+
path_bins_parts = all_path_bins[-1].parts
164177
path_bins = Path(*path_bins_parts[: path_bins_parts.index(VisiumHDKeys.BINNED_OUTPUTS) + 1])
165178
else:
166179
path_bins = path
@@ -181,6 +194,7 @@ def _get_bins(path_bins: Path) -> list[str]:
181194
if bin_size is None or bin_sizes == []:
182195
bin_sizes = all_bin_sizes
183196

197+
# iterate over the given bins and load the data
184198
for bin_size_str in bin_sizes:
185199
path_bin = path_bins / bin_size_str
186200
counts_file = VisiumHDKeys.FILTERED_COUNTS_FILE if filtered_counts_file else VisiumHDKeys.RAW_COUNTS_FILE
@@ -195,6 +209,7 @@ def _get_bins(path_bins: Path) -> list[str]:
195209
with open(path_bin_spatial / VisiumHDKeys.SCALEFACTORS_FILE) as file:
196210
scalefactors = json.load(file)
197211

212+
# consistency check
198213
found_bin_size = re.search(r"\d{3}", bin_size_str)
199214
assert found_bin_size is not None
200215
assert float(found_bin_size.group()) == scalefactors[VisiumHDKeys.SCALEFACTORS_BIN_SIZE_UM]
@@ -206,6 +221,7 @@ def _get_bins(path_bins: Path) -> list[str]:
206221

207222
tissue_positions_file = path_bin_spatial / VisiumHDKeys.TISSUE_POSITIONS_FILE
208223

224+
# read coordinates and set up adata.obs and adata.obsm
209225
coords = pd.read_parquet(tissue_positions_file)
210226
assert all(
211227
coords.columns.values
@@ -221,7 +237,9 @@ def _get_bins(path_bins: Path) -> list[str]:
221237
coords.set_index(VisiumHDKeys.BARCODE, inplace=True, drop=True)
222238
coords_filtered = coords.loc[adata.obs.index]
223239
adata.obs = pd.merge(adata.obs, coords_filtered, how="left", left_index=True, right_index=True)
240+
# compatibility to legacy squidpy
224241
adata.obsm["spatial"] = adata.obs[[VisiumHDKeys.LOCATIONS_X, VisiumHDKeys.LOCATIONS_Y]].values
242+
# dropping the spatial coordinates (will be stored in shapes)
225243
adata.obs.drop(
226244
columns=[
227245
VisiumHDKeys.LOCATIONS_X,
@@ -231,6 +249,8 @@ def _get_bins(path_bins: Path) -> list[str]:
231249
)
232250
adata.obs[VisiumHDKeys.INSTANCE_KEY] = np.arange(len(adata))
233251

252+
# scaling
253+
transform_original = Identity()
234254
transform_lowres = Scale(
235255
np.array(
236256
[
@@ -249,11 +269,13 @@ def _get_bins(path_bins: Path) -> list[str]:
249269
),
250270
axes=("x", "y"),
251271
)
272+
# parse shapes
252273
shapes_name = dataset_id + "_" + bin_size_str
253274
radius = scalefactors[VisiumHDKeys.SCALEFACTORS_SPOT_DIAMETER_FULLRES] / 2.0
254275

255276
# Here we ensure that only the correct coordinate systems are created for the binned data
256277
transformations = {
278+
dataset_id: transform_original,
257279
f"{dataset_id}_downscaled_hires": transform_hires,
258280
f"{dataset_id}_downscaled_lowres": transform_lowres,
259281
}
@@ -272,6 +294,7 @@ def _get_bins(path_bins: Path) -> list[str]:
272294
GeoDataFrame(geometry=squares_series), transformations=transformations
273295
)
274296

297+
# parse table
275298
adata.obs[VisiumHDKeys.REGION_KEY] = shapes_name
276299
adata.obs[VisiumHDKeys.REGION_KEY] = adata.obs[VisiumHDKeys.REGION_KEY].astype("category")
277300

@@ -290,45 +313,52 @@ def _get_bins(path_bins: Path) -> list[str]:
290313
cell_adata_hd = sc.read_10x_h5(COUNT_MATRIX_PATH)
291314
cell_adata_hd.var_names_make_unique()
292315

293-
shapes_transformations_hd = _make_shapes_transformation(scale_factors_path=SCALE_FACTORS_PATH, dataset_id=dataset_id) # Used for both cell and nucleus segmentations
316+
shapes_transformations_hd = _make_shapes_transformation(
317+
scale_factors_path=SCALE_FACTORS_PATH, dataset_id=dataset_id
318+
) # Used for both cell and nucleus segmentations
294319
cell_geojson_features_map = _make_geojson_features_map(CELL_GEOJSON_PATH)
295-
cell_shapes_gdf = _extract_geometries_from_geojson(cell_adata_hd, geojson_features_map=cell_geojson_features_map)
320+
cell_shapes_gdf = _extract_geometries_from_geojson(
321+
cell_adata_hd, geojson_features_map=cell_geojson_features_map
322+
)
296323

297324
SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.CELL_SEG_KEY_HD}"
298-
cell_adata_hd.obs['cell_id'] = cell_adata_hd.obs.index
299-
cell_adata_hd.obs['region'] = SHAPES_KEY_HD
300-
cell_adata_hd.obs['region'] = cell_adata_hd.obs['region'].astype('category')
325+
cell_adata_hd.obs["cell_id"] = cell_adata_hd.obs.index
326+
cell_adata_hd.obs["region"] = SHAPES_KEY_HD
327+
cell_adata_hd.obs["region"] = cell_adata_hd.obs["region"].astype("category")
301328
cell_adata_hd = cell_adata_hd[cell_shapes_gdf.index].copy()
302329

303330
shapes[SHAPES_KEY_HD] = ShapesModel.parse(cell_shapes_gdf, transformations=shapes_transformations_hd)
304331
tables[VisiumHDKeys.CELL_SEG_KEY_HD] = TableModel.parse(
305-
cell_adata_hd,
306-
region=SHAPES_KEY_HD,
307-
region_key='region',
308-
instance_key='cell_id'
332+
cell_adata_hd, region=SHAPES_KEY_HD, region_key="region", instance_key="cell_id"
309333
)
310334

311335
# load nucleus segmentations if available
312336
if nucleus_segmentation_files_exist and load_nucleus_segmentations:
313337
print("Found nucleus segmentation data. Incorporating nucleus_segmentations.")
314338

315-
nucleus_adata_hd = _make_filtered_nucleus_adata(filtered_matrix_h5_path=FILTERED_MATRIX_2U_PATH,barcode_mappings_parquet_path=BARCODE_MAPPINGS_PATH)
316-
geojson_features_map = _make_geojson_features_map(NUCLEUS_GEOJSON_PATH)
317-
nucleus_shapes_gdf = _extract_geometries_from_geojson(adata=nucleus_adata_hd, geojson_features_map=geojson_features_map)
318-
319-
SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.NUCLEUS_SEG_KEY_HD}"
320-
nucleus_adata_hd.obs['cell_id'] = nucleus_adata_hd.obs.index
321-
nucleus_adata_hd.obs['region'] = SHAPES_KEY_HD
322-
nucleus_adata_hd.obs['region'] = nucleus_adata_hd.obs['region'].astype('category')
323-
nucleus_adata_hd = nucleus_adata_hd[nucleus_shapes_gdf.index].copy()
324-
325-
shapes[SHAPES_KEY_HD] = ShapesModel.parse(nucleus_shapes_gdf, transformations=shapes_transformations_hd)
326-
tables[VisiumHDKeys.NUCLEUS_SEG_KEY_HD] = TableModel.parse(
327-
nucleus_adata_hd,
328-
region=SHAPES_KEY_HD,
329-
region_key='region',
330-
instance_key='cell_id'
331-
)
339+
if BARCODE_MAPPINGS_PATH is None:
340+
warnings.warn(
341+
"Cannot find the barcode mappings file, skipping nucleus segmentations.", UserWarning, stacklevel=2
342+
)
343+
else:
344+
nucleus_adata_hd = _make_filtered_nucleus_adata(
345+
filtered_matrix_h5_path=FILTERED_MATRIX_2U_PATH, barcode_mappings_parquet_path=BARCODE_MAPPINGS_PATH
346+
)
347+
geojson_features_map = _make_geojson_features_map(NUCLEUS_GEOJSON_PATH)
348+
nucleus_shapes_gdf = _extract_geometries_from_geojson(
349+
adata=nucleus_adata_hd, geojson_features_map=geojson_features_map
350+
)
351+
352+
SHAPES_KEY_HD = f"{dataset_id}_{VisiumHDKeys.NUCLEUS_SEG_KEY_HD}"
353+
nucleus_adata_hd.obs["cell_id"] = nucleus_adata_hd.obs.index
354+
nucleus_adata_hd.obs["region"] = SHAPES_KEY_HD
355+
nucleus_adata_hd.obs["region"] = nucleus_adata_hd.obs["region"].astype("category")
356+
nucleus_adata_hd = nucleus_adata_hd[nucleus_shapes_gdf.index].copy()
357+
358+
shapes[SHAPES_KEY_HD] = ShapesModel.parse(nucleus_shapes_gdf, transformations=shapes_transformations_hd)
359+
tables[VisiumHDKeys.NUCLEUS_SEG_KEY_HD] = TableModel.parse(
360+
nucleus_adata_hd, region=SHAPES_KEY_HD, region_key="region", instance_key="cell_id"
361+
)
332362

333363
# Read all images and add transformations for both binning and segmentation
334364
fullres_image_file_paths = []
@@ -472,7 +502,9 @@ def _get_bins(path_bins: Path) -> list[str]:
472502
)
473503
warped = np.round(warped * 255).astype(np.uint8)
474504
if not load_segmentations_only:
475-
warped = Image2DModel.parse(warped, dims=("y", "x", "c"), transformations={dataset_id: affine}, rgb=True)
505+
warped = Image2DModel.parse(
506+
warped, dims=("y", "x", "c"), transformations={dataset_id: affine}, rgb=True
507+
)
476508
images[dataset_id + "_cytassist_image"] = warped
477509
elif load_all_images:
478510
warnings.warn(
@@ -645,12 +677,13 @@ def _get_transform_matrices(metadata: dict[str, Any], hd_layout: dict[str, Any])
645677

646678
return transform_matrices
647679

680+
648681
def _make_filtered_nucleus_adata(
649682
filtered_matrix_h5_path: Path,
650683
barcode_mappings_parquet_path: Path,
651-
bin_col_name: str = 'square_002um',
652-
aggregate_col_name: str = 'cell_id'
653-
) -> anndata.AnnData:
684+
bin_col_name: str = "square_002um",
685+
aggregate_col_name: str = "cell_id",
686+
) -> AnnData:
654687
"""Generate a filtered AnnData object by aggregating 2um binned data based on nucleus segmentation.
655688
656689
Uses a 2um filtered_feature_bc_matrix.h5 file and a barcode_mappings.parquet file containing
@@ -671,7 +704,7 @@ def _make_filtered_nucleus_adata(
671704
672705
Returns:
673706
--------
674-
anndata.AnnData
707+
AnnData
675708
An AnnData object where the observations correspond to filtered cell IDs
676709
and the variables correspond to the original features from the input data.
677710
"""
@@ -680,19 +713,19 @@ def _make_filtered_nucleus_adata(
680713
barcode_mappings = pq.read_table(barcode_mappings_parquet_path)
681714

682715
# Filter to only include valid cell IDs that are in both nucleus and cell
683-
barcode_mappings = barcode_mappings.filter((barcode_mappings['cell_id'].is_valid()) and barcode_mappings["in_nucleus"])
716+
barcode_mappings = barcode_mappings.filter(
717+
(barcode_mappings["cell_id"].is_valid()) and barcode_mappings["in_nucleus"]
718+
)
684719

685720
# Filter the 2um adata to only include squares present in the barcode mappings
686721
valid_squares = barcode_mappings[bin_col_name].unique()
687722
squares_to_keep = np.intersect1d(adata_2um.obs_names, valid_squares)
688723
adata_filtered = adata_2um[squares_to_keep, :].copy()
689724

690725
# Map each square to its corresponding cell ID
691-
square_to_cell_map = dict(zip(
692-
barcode_mappings[bin_col_name].to_pylist(),
693-
barcode_mappings[aggregate_col_name].to_pylist(), strict=False
694-
695-
))
726+
square_to_cell_map = dict(
727+
zip(barcode_mappings[bin_col_name].to_pylist(), barcode_mappings[aggregate_col_name].to_pylist(), strict=False)
728+
)
696729
ordered_cell_ids = [square_to_cell_map[square] for square in adata_filtered.obs_names]
697730
unique_cells = list(dict.fromkeys(ordered_cell_ids).keys())
698731
cell_to_idx = {cell: i for i, cell in enumerate(unique_cells)}
@@ -702,10 +735,7 @@ def _make_filtered_nucleus_adata(
702735
row_indices = np.arange(len(ordered_cell_ids))
703736
data = np.ones_like(row_indices)
704737

705-
aggregation_matrix = csc_matrix(
706-
(data, (row_indices, col_indices)),
707-
shape=(adata_filtered.n_obs, len(unique_cells))
708-
)
738+
aggregation_matrix = csc_matrix((data, (row_indices, col_indices)), shape=(adata_filtered.n_obs, len(unique_cells)))
709739

710740
# Make the final AnnData object where cell IDs are filtered
711741
# to the data under the segmented nuclei
@@ -716,12 +746,13 @@ def _make_filtered_nucleus_adata(
716746

717747
return adata_nucleus
718748

719-
def _extract_geometries_from_geojson(adata: anndata.AnnData, geojson_features_map: dict[str, Any]) -> GeoDataFrame:
749+
750+
def _extract_geometries_from_geojson(adata: AnnData, geojson_features_map: dict[str, Any]) -> GeoDataFrame:
720751
"""Extract geometries and create a GeoDataFrame from a GeoJSON features map.
721752
722753
Parameters
723754
----------
724-
cell_adata : anndata.AnnData
755+
cell_adata : AnnData
725756
AnnData object containing cell data.
726757
geojson_features_map : dict[str, Any]
727758
Dictionary mapping cell IDs to GeoJSON features.
@@ -737,7 +768,7 @@ def _extract_geometries_from_geojson(adata: anndata.AnnData, geojson_features_ma
737768
for obs_index_str in adata.obs.index:
738769
feature = geojson_features_map.get(obs_index_str)
739770
if feature:
740-
polygon_coords = np.array(feature['geometry']['coordinates'][0])
771+
polygon_coords = np.array(feature["geometry"]["coordinates"][0])
741772
geometries.append(Polygon(polygon_coords))
742773
cell_ids_ordered.append(obs_index_str)
743774
else:
@@ -748,10 +779,8 @@ def _extract_geometries_from_geojson(adata: anndata.AnnData, geojson_features_ma
748779
geometries = [geometries[i] for i in valid_indices]
749780
cell_ids_ordered = [cell_ids_ordered[i] for i in valid_indices]
750781

751-
return GeoDataFrame({
752-
'cell_id': cell_ids_ordered,
753-
'geometry': geometries
754-
}, index=cell_ids_ordered)
782+
return GeoDataFrame({"cell_id": cell_ids_ordered, "geometry": geometries}, index=cell_ids_ordered)
783+
755784

756785
def _make_shapes_transformation(scale_factors_path: Path, dataset_id: str) -> dict[str, Scale]:
757786
"""Load scale factors for lowres and hires images and create transformations.
@@ -770,18 +799,20 @@ def _make_shapes_transformation(scale_factors_path: Path, dataset_id: str) -> di
770799
"""
771800
with open(scale_factors_path) as f:
772801
scale_data_hd = json.load(f)
773-
lowres_scale_factor_hd = scale_data_hd['tissue_lowres_scalef']
774-
hires_scale_factor_hd = scale_data_hd['tissue_hires_scalef']
802+
lowres_scale_factor_hd = scale_data_hd["tissue_lowres_scalef"]
803+
hires_scale_factor_hd = scale_data_hd["tissue_hires_scalef"]
775804

776805
return {
777-
f"{dataset_id}_downscaled_lowres": Scale(np.array([lowres_scale_factor_hd, lowres_scale_factor_hd]), axes=("x", "y")),
778-
f"{dataset_id}_downscaled_hires": Scale(np.array([hires_scale_factor_hd, hires_scale_factor_hd]), axes=("x", "y"))
806+
f"{dataset_id}_downscaled_lowres": Scale(
807+
np.array([lowres_scale_factor_hd, lowres_scale_factor_hd]), axes=("x", "y")
808+
),
809+
f"{dataset_id}_downscaled_hires": Scale(
810+
np.array([hires_scale_factor_hd, hires_scale_factor_hd]), axes=("x", "y")
811+
),
779812
}
780813

814+
781815
def _make_geojson_features_map(geojson_path: Path) -> dict[str, Any]:
782816
with open(geojson_path) as f:
783817
geojson_data = json.load(f)
784-
return {
785-
f"cellid_{feature['properties']['cell_id']:09d}-1": feature
786-
for feature in geojson_data['features']
787-
}
818+
return {f"cellid_{feature['properties']['cell_id']:09d}-1": feature for feature in geojson_data["features"]}

0 commit comments

Comments
 (0)