Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3939.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure `scanpy.pp.scrublet()` preserves (categorical) data types within `.obs`
9 changes: 3 additions & 6 deletions src/scanpy/preprocessing/_scrublet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):

# Run Scrublet independently on batches and return just the
# scrublet-relevant parts of the objects to add to the input object

batches = np.unique(adata.obs[batch_key])
scrubbed = [
_run_scrublet(
Expand All @@ -271,29 +270,27 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
)
for batch in batches
]
scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed])
scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed]).astype(
adata.obs.dtypes
)

# Now reset the obs to get the scrublet scores

adata.obs = scrubbed_obs.loc[adata.obs_names.values]

# Save the .uns from each batch separately

adata.uns["scrublet"] = {}
adata.uns["scrublet"]["batches"] = dict(
zip(batches, [scrub["uns"] for scrub in scrubbed], strict=True)
)

# Record that we've done batched analysis, so e.g. the plotting
# function knows what to do.

adata.uns["scrublet"]["batched_by"] = batch_key

else:
scrubbed = _run_scrublet(adata_obs, adata_sim)

# Copy outcomes to input object from our processed version

adata.obs["doublet_score"] = scrubbed["obs"]["doublet_score"]
adata.obs["predicted_doublet"] = scrubbed["obs"]["predicted_doublet"]
adata.uns["scrublet"] = scrubbed["uns"]
Expand Down
10 changes: 10 additions & 0 deletions tests/test_scrublet.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,13 @@ def test_scrublet_simulate_doublets():
adata_sim.obsm["doublet_parents"],
np.array([[13, 132], [106, 43], [152, 3], [160, 103]]),
)


def test_scrublet_dtypes() -> None:
"""Test that Scrublet does not change dtypes of existing data.obs cols."""
adata = pbmc200()
adata.obs["batch"] = pd.Categorical(100 * ["a"] + 100 * ["b"])

sc.pp.scrublet(adata, use_approx_neighbors=False, batch_key="batch")

assert adata.obs["batch"].dtype == "category"
Loading