Skip to content

Commit

Permalink
make obsmap and viewmap correct for views
Browse files Browse the repository at this point in the history
  • Loading branch information
ilia-kats committed Jan 24, 2025
1 parent 37eb8c1 commit f545472
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 20.8b1 # Replace by any tag/version: https://github.com/psf/black/tags
rev: 24.10.0 # Replace by any tag/version: https://github.com/psf/black/tags
hooks:
- id: black
language_version: python3
6 changes: 5 additions & 1 deletion src/mudata/_core/mudata.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,12 @@ def _init_as_view(self, mudata_ref: "MuData", index):

for attr, idx in (("obs", obsidx), ("var", varidx)):
posmap = {}
size = getattr(self, attr).shape[0]
for mod, mapping in getattr(mudata_ref, attr + "map").items():
posmap[mod] = mapping[idx].copy()
cposmap = np.zeros((size,), dtype=mapping.dtype)
cidx = mapping[idx] > 0
cposmap[cidx > 0] = np.arange(cidx.sum()) + 1
posmap[mod] = cposmap
setattr(self, "_" + attr + "map", posmap)

self.is_view = True
Expand Down
17 changes: 15 additions & 2 deletions tests/test_view_copy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import numpy as np
import pandas as pd
import pytest
from anndata import AnnData

Expand All @@ -10,8 +11,15 @@

@pytest.fixture()
def mdata():
mod1 = AnnData(np.arange(0, 100, 0.1).reshape(-1, 10))
mod2 = AnnData(np.arange(101, 2101, 1).reshape(-1, 20))
rng = np.random.default_rng(42)
mod1 = AnnData(
np.arange(0, 100, 0.1).reshape(-1, 10),
obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False)),
)
mod2 = AnnData(
np.arange(101, 2101, 1).reshape(-1, 20),
obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False)),
)
mods = {"mod1": mod1, "mod2": mod2}
# Make var_names different in different modalities
for m in ["mod1", "mod2"]:
Expand Down Expand Up @@ -70,6 +78,11 @@ def test_view_view(self, mdata):
assert mdata_view.is_view
assert mdata_view.n_obs == view_n_obs

for modname, mod in mdata_view.mod.items():
assert mdata_view.obsmap[modname].max() == mod.n_obs
idx = mdata_view.obsmap[modname]
assert np.all(mdata_view.obs_names[idx > 0] == mod.obs_names[idx[idx > 0] - 1])

view_view_n_obs = 2
mdata_view_view = mdata_view[list(range(view_view_n_obs)), :]
assert mdata_view_view.is_view
Expand Down

0 comments on commit f545472

Please sign in to comment.