Skip to content

Commit

Permalink
Improve pull/push interfaces
Browse files Browse the repository at this point in the history
- Add pull=True argument to the update functions
- Handle categorical data with different categories in _pull_attr
- Improve error messages for pull/push operations
  • Loading branch information
gtca committed Jul 3, 2024
1 parent 37ec45e commit e28b60d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 12 deletions.
49 changes: 39 additions & 10 deletions src/mudata/_core/mudata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,12 +1053,21 @@ def obs_vector(self, key: str, layer: str | None = None) -> np.ndarray:
raise KeyError(f"There is no key {key} in MuData .obs or in .obs of any modalities.")
return self.obs[key].values

def update_obs(self):
def update_obs(self, pull: bool = True):
"""
Update .obs slot of MuData with the newest .obs data from all the modalities
Update global .obs_names according to the .obs_names of all the modalities.
If pull=True, update .obs slot of MuData with the newest .obs data from all the modalities.
NOTE: from v0.4, this method will use pull=False by default.
Params
------
pull
If True, pull .obs columns from all modalities.
True by default (will change to False by default in the next versions).
"""
join_common = self.axis == 1
self._update_attr("obs", axis=1, join_common=join_common)
self._update_attr("obs", axis=1, join_common=join_common, pull=pull)

def obs_names_make_unique(self):
"""
Expand Down Expand Up @@ -1155,12 +1164,21 @@ def var_vector(self, key: str, layer: str | None = None) -> np.ndarray:
raise KeyError(f"There is no key {key} in MuData .var or in .var of any modalities.")
return self.var[key].values

def update_var(self):
def update_var(self, pull: bool = True):
"""
Update .var slot of MuData with the newest .var data from all the modalities
Update global .var_names according to the .var_names of all the modalities.
If pull=True, update .var slot of MuData with the newest .var data from all the modalities.
NOTE: from v0.4, this method will use pull=False by default.
Params
------
pull
If True, pull .var columns from all modalities.
True by default (will change to False by default in the next versions).
"""
join_common = self.axis == 0
self._update_attr("var", axis=0, join_common=join_common)
self._update_attr("var", axis=0, join_common=join_common, pull=pull)

def var_names_make_unique(self):
"""
Expand Down Expand Up @@ -1348,12 +1366,20 @@ def uns_keys(self) -> list[str]:
"""List keys of unstructured annotation."""
return list(self._uns.keys())

def update(self):
def update(self, pull: bool = True):
"""
Update both .obs and .var of MuData with the data from all the modalities
NOTE: from v0.4, this method will use pull=False by default.
Params
------
pull
If True, pull columns from all modalities.
True by default (will change to False by default in the next versions).
"""
self.update_var()
self.update_obs()
self.update_var(pull=pull)
self.update_obs(pull=pull)

@property
def axis(self) -> int:
Expand Down Expand Up @@ -1808,7 +1834,9 @@ def _push_attr(
if count > 1 and c in getattr(self, attr).columns:
raise ValueError(
f"Cannot push multiple columns with the same name {c} with and without modality prefix. "
"You might have to explicitely specify columns to push."
"You might have to explicitely specify columns to push.\n"
"In case there are columns with the same name with and without modality prefix, "
"this has to be resolved first."
)

attrmap = getattr(self, f"{attr}map")
Expand All @@ -1835,6 +1863,7 @@ def _push_attr(
if not only_drop:
# TODO: _maybe_coerce_to_bool
# TODO: _maybe_coerce_to_int
# TODO: _prune_unused_categories
mod_df = getattr(mod, attr).set_index(np.arange(mod_n_attr))
mod_df = _update_and_concat(mod_df, df)
mod_df = mod_df.set_index(getattr(mod, f"{attr}_names"))
Expand Down
12 changes: 10 additions & 2 deletions src/mudata/_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,19 @@ def _classify_prefixed_columns(

def _update_and_concat(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame:
df = df1.copy()
# This converts boolean to object dtype, unfrtunately
# This converts boolean to object dtype, unfortunately
# df.update(df2)
common_cols = df1.columns.intersection(df2.columns)
for col in common_cols:
df[col].update(df2[col])
if isinstance(df[col].values, pd.Categorical) and isinstance(
df2[col].values, pd.Categorical
):
common_cats = pd.api.types.union_categoricals([df[col], df2[col]]).categories
df[col] = df[col].cat.set_categories(common_cats)
df2[col] = df2[col].cat.set_categories(common_cats)
df[col].update(df2[col])
else:
df[col].update(df2[col])
new_cols = df2.columns.difference(df1.columns)
res = pd.concat([df, df2[new_cols]], axis=1, sort=False, verify_integrity=True)
return res
Expand Down

0 comments on commit e28b60d

Please sign in to comment.