diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 97c7898..696c2af 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: 20.8b1 # Replace by any tag/version: https://github.com/psf/black/tags + rev: 24.10.0 # Replace by any tag/version: https://github.com/psf/black/tags hooks: - id: black language_version: python3 diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index a2cdf87..c77e85a 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -315,8 +315,12 @@ def _init_as_view(self, mudata_ref: "MuData", index): for attr, idx in (("obs", obsidx), ("var", varidx)): posmap = {} + size = getattr(self, attr).shape[0] for mod, mapping in getattr(mudata_ref, attr + "map").items(): - posmap[mod] = mapping[idx].copy() + cposmap = np.zeros((size,), dtype=mapping.dtype) + cidx = mapping[idx] > 0 + cposmap[cidx > 0] = np.arange(cidx.sum()) + 1 + posmap[mod] = cposmap setattr(self, "_" + attr + "map", posmap) self.is_view = True diff --git a/tests/test_view_copy.py b/tests/test_view_copy.py index 41f3a31..819d60d 100644 --- a/tests/test_view_copy.py +++ b/tests/test_view_copy.py @@ -1,6 +1,7 @@ from pathlib import Path import numpy as np +import pandas as pd import pytest from anndata import AnnData @@ -10,8 +11,15 @@ @pytest.fixture() def mdata(): - mod1 = AnnData(np.arange(0, 100, 0.1).reshape(-1, 10)) - mod2 = AnnData(np.arange(101, 2101, 1).reshape(-1, 20)) + rng = np.random.default_rng(42) + mod1 = AnnData( + np.arange(0, 100, 0.1).reshape(-1, 10), + obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False)), + ) + mod2 = AnnData( + np.arange(101, 2101, 1).reshape(-1, 20), + obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False)), + ) mods = {"mod1": mod1, "mod2": mod2} # Make var_names different in different modalities for m in ["mod1", "mod2"]: @@ -70,6 +78,11 @@ def test_view_view(self, mdata): assert mdata_view.is_view assert mdata_view.n_obs == view_n_obs + for modname, mod in mdata_view.mod.items(): + assert mdata_view.obsmap[modname].max() == mod.n_obs + idx = mdata_view.obsmap[modname] + assert np.all(mdata_view.obs_names[idx > 0] == mod.obs_names[idx[idx > 0] - 1]) + view_view_n_obs = 2 mdata_view_view = mdata_view[list(range(view_view_n_obs)), :] assert mdata_view_view.is_view