From e28b60d6cc3461a57dd0a89124e984043482712e Mon Sep 17 00:00:00 2001 From: Danila Date: Tue, 2 Jul 2024 18:05:53 -0700 Subject: [PATCH] Improve pull/push interfaces - Add pull=True argument to the update functions - Handle categorical data with different categories in _pull_attr - Improve error messages for pull/push operations --- src/mudata/_core/mudata.py | 49 ++++++++++++++++++++++++++++++-------- src/mudata/_core/utils.py | 12 ++++++++-- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 406cf73..9715f1f 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -1053,12 +1053,21 @@ def obs_vector(self, key: str, layer: str | None = None) -> np.ndarray: raise KeyError(f"There is no key {key} in MuData .obs or in .obs of any modalities.") return self.obs[key].values - def update_obs(self): + def update_obs(self, pull: bool = True): """ - Update .obs slot of MuData with the newest .obs data from all the modalities + Update global .obs_names according to the .obs_names of all the modalities. + If pull=True, update .obs slot of MuData with the newest .obs data from all the modalities. + + NOTE: from v0.4, this method will use pull=False by default. + + Params + ------ + pull + If True, pull .obs columns from all modalities. + True by default (will change to False by default in the next versions). """ join_common = self.axis == 1 - self._update_attr("obs", axis=1, join_common=join_common) + self._update_attr("obs", axis=1, join_common=join_common, pull=pull) def obs_names_make_unique(self): """ @@ -1155,12 +1164,21 @@ def var_vector(self, key: str, layer: str | None = None) -> np.ndarray: raise KeyError(f"There is no key {key} in MuData .var or in .var of any modalities.") return self.var[key].values - def update_var(self): + def update_var(self, pull: bool = True): """ - Update .var slot of MuData with the newest .var data from all the modalities + Update global .var_names according to the .var_names of all the modalities. + If pull=True, update .var slot of MuData with the newest .var data from all the modalities. + + NOTE: from v0.4, this method will use pull=False by default. + + Params + ------ + pull + If True, pull .var columns from all modalities. + True by default (will change to False by default in the next versions). """ join_common = self.axis == 0 - self._update_attr("var", axis=0, join_common=join_common) + self._update_attr("var", axis=0, join_common=join_common, pull=pull) def var_names_make_unique(self): """ @@ -1348,12 +1366,20 @@ def uns_keys(self) -> list[str]: """List keys of unstructured annotation.""" return list(self._uns.keys()) - def update(self): + def update(self, pull: bool = True): """ Update both .obs and .var of MuData with the data from all the modalities + + NOTE: from v0.4, this method will use pull=False by default. + + Params + ------ + pull + If True, pull columns from all modalities. + True by default (will change to False by default in the next versions). """ - self.update_var() - self.update_obs() + self.update_var(pull=pull) + self.update_obs(pull=pull) @property def axis(self) -> int: @@ -1808,7 +1834,9 @@ def _push_attr( if count > 1 and c in getattr(self, attr).columns: raise ValueError( f"Cannot push multiple columns with the same name {c} with and without modality prefix. " - "You might have to explicitely specify columns to push." + "You might have to explicitely specify columns to push.\n" + "In case there are columns with the same name with and without modality prefix, " + "this has to be resolved first." ) attrmap = getattr(self, f"{attr}map") @@ -1835,6 +1863,7 @@ def _push_attr( if not only_drop: # TODO: _maybe_coerce_to_bool # TODO: _maybe_coerce_to_int + # TODO: _prune_unused_categories mod_df = getattr(mod, attr).set_index(np.arange(mod_n_attr)) mod_df = _update_and_concat(mod_df, df) mod_df = mod_df.set_index(getattr(mod, f"{attr}_names")) diff --git a/src/mudata/_core/utils.py b/src/mudata/_core/utils.py index 7203d56..3b97c81 100644 --- a/src/mudata/_core/utils.py +++ b/src/mudata/_core/utils.py @@ -152,11 +152,19 @@ def _classify_prefixed_columns( def _update_and_concat(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame: df = df1.copy() - # This converts boolean to object dtype, unfrtunately + # This converts boolean to object dtype, unfortunately # df.update(df2) common_cols = df1.columns.intersection(df2.columns) for col in common_cols: - df[col].update(df2[col]) + if isinstance(df[col].values, pd.Categorical) and isinstance( + df2[col].values, pd.Categorical + ): + common_cats = pd.api.types.union_categoricals([df[col], df2[col]]).categories + df[col] = df[col].cat.set_categories(common_cats) + df2[col] = df2[col].cat.set_categories(common_cats) + df[col].update(df2[col]) + else: + df[col].update(df2[col]) new_cols = df2.columns.difference(df1.columns) res = pd.concat([df, df2[new_cols]], axis=1, sort=False, verify_integrity=True) return res