Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
PYTHON: ${{ matrix.python }}

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v4
with:
Expand Down
48 changes: 46 additions & 2 deletions src/spatialdata_io/readers/xenium.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,21 +248,33 @@ def xenium(
table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] = "cell_labels"

if nucleus_boundaries:
invalid_nuc_ids = _find_invalid_ids(path, XeniumKeys.NUCLEUS_BOUNDARIES_FILE)
if len(invalid_nuc_ids) > 0:
logging.warning(
f"Found {len(invalid_nuc_ids)} invalid polygons for nuclei, removing the masks corresponding to the IDs: {invalid_nuc_ids}"
)
polygons["nucleus_boundaries"] = _get_polygons(
path,
XeniumKeys.NUCLEUS_BOUNDARIES_FILE,
specs,
n_jobs,
idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
invalid_ids=invalid_nuc_ids,
)

if cells_boundaries:
invalid_cell_ids = _find_invalid_ids(path, XeniumKeys.CELL_BOUNDARIES_FILE)
if len(invalid_cell_ids) > 0:
logging.warning(
f"Found {len(invalid_cell_ids)} invalid polygons for cells, removing the masks corresponding to the IDs: {invalid_cell_ids}"
)
polygons["cell_boundaries"] = _get_polygons(
path,
XeniumKeys.CELL_BOUNDARIES_FILE,
specs,
n_jobs,
idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
invalid_ids=invalid_cell_ids,
)

if transcripts:
Expand Down Expand Up @@ -338,7 +350,9 @@ def filter(self, record: logging.LogRecord) -> bool:
logger.removeFilter(IgnoreSpecificMessage())

if table is not None:
tables["table"] = table
valid_nucleus_mask = ~table.obs[XeniumKeys.CELL_ID].isin(invalid_nuc_ids)
valid_cell_mask = ~table.obs[XeniumKeys.CELL_ID].isin(invalid_cell_ids)
tables["table"] = table[valid_nucleus_mask & valid_cell_mask].copy()

elements_dict = {"images": images, "labels": labels, "points": points, "tables": tables, "shapes": polygons}
if cells_as_circles:
Expand All @@ -354,24 +368,54 @@ def filter(self, record: logging.LogRecord) -> bool:
return sdata


def _find_invalid_ids(
path: Path,
file: str,
) -> ArrayLike:
"""Filter out cell ids with too few vertices to form a valid polygon."""
df = pq.read_table(path / file).to_pandas()
invalid_ids = df.groupby(XeniumKeys.CELL_ID).filter(lambda x: len(x) < 4)[XeniumKeys.CELL_ID].unique()
return [] if len(invalid_ids) == 0 else invalid_ids


def _decode_cell_id_column(cell_id_column: pd.Series) -> pd.Series:
if isinstance(cell_id_column.iloc[0], bytes):
return cell_id_column.apply(lambda x: x.decode("utf-8"))
return cell_id_column


def _get_polygons(
path: Path, file: str, specs: dict[str, Any], n_jobs: int, idx: ArrayLike | None = None
path: Path,
file: str,
specs: dict[str, Any],
n_jobs: int,
idx: ArrayLike | None = None,
invalid_ids: ArrayLike | None = None,
) -> GeoDataFrame:
def _poly(arr: ArrayLike) -> Polygon:
return Polygon(arr[:-1])

if invalid_ids is None:
invalid_ids = []

# seems to be faster than pd.read_parquet
df = pq.read_table(path / file).to_pandas()
df[XeniumKeys.CELL_ID] = _decode_cell_id_column(df[XeniumKeys.CELL_ID])

# Filter based on valid cell IDs if idx is provided
if idx is not None:
idx = idx[~idx.isin(invalid_ids)]
if len(invalid_ids) > 0:
idx = idx.reset_index(drop=True)
df = df[df[XeniumKeys.CELL_ID].isin(idx)]
else:
# If no idx provided, just (potentially) filter out invalid IDs
df = df[~df[XeniumKeys.CELL_ID].isin(invalid_ids)]

group_by = df.groupby(XeniumKeys.CELL_ID)
index = pd.Series(group_by.indices.keys())
index = _decode_cell_id_column(index)

out = Parallel(n_jobs=n_jobs)(
delayed(_poly)(i.to_numpy())
for _, i in group_by[[XeniumKeys.BOUNDARIES_VERTEX_X, XeniumKeys.BOUNDARIES_VERTEX_Y]]
Expand Down
Loading