diff --git a/docs/release-notes/3939.fix.md b/docs/release-notes/3939.fix.md new file mode 100644 index 0000000000..2f489979d8 --- /dev/null +++ b/docs/release-notes/3939.fix.md @@ -0,0 +1 @@ +Ensure `scanpy.pp.scrublet()` preserves (categorical) data types within `.obs` diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index e6394df458..b17821c723 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -262,7 +262,6 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): # Run Scrublet independently on batches and return just the # scrublet-relevant parts of the objects to add to the input object - batches = np.unique(adata.obs[batch_key]) scrubbed = [ _run_scrublet( @@ -271,14 +270,14 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): ) for batch in batches ] - scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed]) + scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed]).astype( + adata.obs.dtypes + ) # Now reset the obs to get the scrublet scores - adata.obs = scrubbed_obs.loc[adata.obs_names.values] # Save the .uns from each batch separately - adata.uns["scrublet"] = {} adata.uns["scrublet"]["batches"] = dict( zip(batches, [scrub["uns"] for scrub in scrubbed], strict=True) @@ -286,14 +285,12 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): # Record that we've done batched analysis, so e.g. the plotting # function knows what to do. - adata.uns["scrublet"]["batched_by"] = batch_key else: scrubbed = _run_scrublet(adata_obs, adata_sim) # Copy outcomes to input object from our processed version - adata.obs["doublet_score"] = scrubbed["obs"]["doublet_score"] adata.obs["predicted_doublet"] = scrubbed["obs"]["predicted_doublet"] adata.uns["scrublet"] = scrubbed["uns"] diff --git a/tests/test_scrublet.py b/tests/test_scrublet.py index b8d7560b4e..9844d908c3 100644 --- a/tests/test_scrublet.py +++ b/tests/test_scrublet.py @@ -228,3 +228,13 @@ def test_scrublet_simulate_doublets(): adata_sim.obsm["doublet_parents"], np.array([[13, 132], [106, 43], [152, 3], [160, 103]]), ) + + +def test_scrublet_dtypes() -> None: + """Test that Scrublet does not change dtypes of existing data.obs cols.""" + adata = pbmc200() + adata.obs["batch"] = pd.Categorical(100 * ["a"] + 100 * ["b"]) + + sc.pp.scrublet(adata, use_approx_neighbors=False, batch_key="batch") + + assert adata.obs["batch"].dtype == "category"