diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 8546b3d7..62e51ab0 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -26,7 +26,7 @@ jobs: PYTHON: ${{ matrix.python }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python }} uses: actions/setup-python@v4 with: diff --git a/src/spatialdata_io/readers/xenium.py b/src/spatialdata_io/readers/xenium.py index ff036067..d9715b40 100644 --- a/src/spatialdata_io/readers/xenium.py +++ b/src/spatialdata_io/readers/xenium.py @@ -248,21 +248,33 @@ def xenium( table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] = "cell_labels" if nucleus_boundaries: + invalid_nuc_ids = _find_invalid_ids(path, XeniumKeys.NUCLEUS_BOUNDARIES_FILE) + if len(invalid_nuc_ids) > 0: + logging.warning( + f"Found {len(invalid_nuc_ids)} invalid polygons for nuclei, removing the masks corresponding to the IDs: {invalid_nuc_ids}" + ) polygons["nucleus_boundaries"] = _get_polygons( path, XeniumKeys.NUCLEUS_BOUNDARIES_FILE, specs, n_jobs, idx=table.obs[str(XeniumKeys.CELL_ID)].copy(), + invalid_ids=invalid_nuc_ids, ) if cells_boundaries: + invalid_cell_ids = _find_invalid_ids(path, XeniumKeys.CELL_BOUNDARIES_FILE) + if len(invalid_cell_ids) > 0: + logging.warning( + f"Found {len(invalid_cell_ids)} invalid polygons for cells, removing the masks corresponding to the IDs: {invalid_cell_ids}" + ) polygons["cell_boundaries"] = _get_polygons( path, XeniumKeys.CELL_BOUNDARIES_FILE, specs, n_jobs, idx=table.obs[str(XeniumKeys.CELL_ID)].copy(), + invalid_ids=invalid_cell_ids, ) if transcripts: @@ -338,7 +350,9 @@ def filter(self, record: logging.LogRecord) -> bool: logger.removeFilter(IgnoreSpecificMessage()) if table is not None: - tables["table"] = table + valid_nucleus_mask = ~table.obs[XeniumKeys.CELL_ID].isin(invalid_nuc_ids) + valid_cell_mask = ~table.obs[XeniumKeys.CELL_ID].isin(invalid_cell_ids) + tables["table"] = table[valid_nucleus_mask & valid_cell_mask].copy() elements_dict = {"images": images, "labels": labels, "points": points, "tables": tables, "shapes": polygons} if cells_as_circles: @@ -354,6 +368,16 @@ def filter(self, record: logging.LogRecord) -> bool: return sdata +def _find_invalid_ids( + path: Path, + file: str, +) -> ArrayLike: + """Filter out cell ids with too few vertices to form a valid polygon.""" + df = pq.read_table(path / file).to_pandas() + invalid_ids = df.groupby(XeniumKeys.CELL_ID).filter(lambda x: len(x) < 4)[XeniumKeys.CELL_ID].unique() + return [] if len(invalid_ids) == 0 else invalid_ids + + def _decode_cell_id_column(cell_id_column: pd.Series) -> pd.Series: if isinstance(cell_id_column.iloc[0], bytes): return cell_id_column.apply(lambda x: x.decode("utf-8")) @@ -361,17 +385,37 @@ def _decode_cell_id_column(cell_id_column: pd.Series) -> pd.Series: def _get_polygons( - path: Path, file: str, specs: dict[str, Any], n_jobs: int, idx: ArrayLike | None = None + path: Path, + file: str, + specs: dict[str, Any], + n_jobs: int, + idx: ArrayLike | None = None, + invalid_ids: ArrayLike | None = None, ) -> GeoDataFrame: def _poly(arr: ArrayLike) -> Polygon: return Polygon(arr[:-1]) + if invalid_ids is None: + invalid_ids = [] + # seems to be faster than pd.read_parquet df = pq.read_table(path / file).to_pandas() + df[XeniumKeys.CELL_ID] = _decode_cell_id_column(df[XeniumKeys.CELL_ID]) + + # Filter based on valid cell IDs if idx is provided + if idx is not None: + idx = idx[~idx.isin(invalid_ids)] + if len(invalid_ids) > 0: + idx = idx.reset_index(drop=True) + df = df[df[XeniumKeys.CELL_ID].isin(idx)] + else: + # If no idx provided, just (potentially) filter out invalid IDs + df = df[~df[XeniumKeys.CELL_ID].isin(invalid_ids)] group_by = df.groupby(XeniumKeys.CELL_ID) index = pd.Series(group_by.indices.keys()) index = _decode_cell_id_column(index) + out = Parallel(n_jobs=n_jobs)( delayed(_poly)(i.to_numpy()) for _, i in group_by[[XeniumKeys.BOUNDARIES_VERTEX_X, XeniumKeys.BOUNDARIES_VERTEX_Y]]