Skip to content

Speed up categorical regressor with numba #3353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
086f70d
add function and test
Intron7 Nov 11, 2024
37244a9
add test
Intron7 Nov 11, 2024
b4ecb0a
add test for regressor
Intron7 Nov 11, 2024
36858d9
add release note
Intron7 Nov 11, 2024
be1bccc
Merge branch 'main' into create_cat_regressor
Intron7 Nov 11, 2024
a1a59ae
update typing
Intron7 Nov 11, 2024
7b41bc8
update test
Intron7 Nov 11, 2024
119a142
update test
Intron7 Nov 12, 2024
d77fa9c
update dtype
Intron7 Nov 12, 2024
236e356
rename cats
Intron7 Nov 12, 2024
bb9cde4
Update tests/test_preprocessing.py
Intron7 Nov 12, 2024
bbb5035
Update tests/test_preprocessing.py
Intron7 Nov 12, 2024
2a92193
Update tests/test_preprocessing.py
Intron7 Nov 12, 2024
c7b78c0
Update tests/test_preprocessing.py
ilan-gold Nov 12, 2024
b001c0e
remove test
Intron7 Nov 13, 2024
c3ce03e
update kernel
Intron7 Nov 13, 2024
c50226a
remove test
Intron7 Nov 13, 2024
c6665f4
make test together
Intron7 Nov 21, 2024
858e247
cleanup
Intron7 Nov 21, 2024
2421bd5
add comment
Intron7 Nov 21, 2024
2e16c45
Merge branch 'main' into create_cat_regressor
Intron7 Dec 16, 2024
1b7d7e1
Merge branch 'main' into create_cat_regressor
Intron7 Jan 23, 2025
3b7fe6e
Merge branch 'main' into create_cat_regressor
Intron7 Feb 10, 2025
726a625
update doc strings and clean up names
Intron7 Feb 10, 2025
104a0f3
Update src/scanpy/preprocessing/_simple.py
Intron7 Feb 11, 2025
f9b13be
update dtypes
Intron7 Feb 11, 2025
6eafd04
update atol for test
Intron7 Feb 11, 2025
1dae8f4
remove int fix
Intron7 Feb 11, 2025
39ad1c0
Update docs/release-notes/3353.performance.md
Intron7 Apr 14, 2025
4f3db86
Merge branch 'main' into create_cat_regressor
Intron7 Apr 14, 2025
2d578c8
Update src/scanpy/preprocessing/_simple.py
Intron7 Apr 14, 2025
eedb314
Fix sparse check
flying-sheep Apr 14, 2025
6a19bc6
Merge branch 'main' into create_cat_regressor
ilan-gold May 21, 2025
0053bce
(fix): correct dtype check
ilan-gold May 21, 2025
75be7e2
(fix): regress_out with int tested
ilan-gold May 22, 2025
2fdd0bf
Merge branch 'main' into create_cat_regressor
ilan-gold May 22, 2025
2b3f1f1
(fix): float32 regress data type
ilan-gold May 22, 2025
0ab9fee
Merge branch 'create_cat_regressor' of github.com:scverse/scanpy into…
ilan-gold May 22, 2025
832981b
(fix): atol/rtol
ilan-gold May 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3353.performance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Speed up for a categorical regressor in {func}`~scanpy.pp.regress_out` {smaller}`S Dicks`
31 changes: 25 additions & 6 deletions src/scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,23 @@
DT = TypeVar("DT")


@njit
def _create_regressor_categorical(
X: np.ndarray, number_categories: int, cat_array: np.ndarray
) -> np.ndarray:
# create regressor matrix for categorical variables
# would be best to use X dtype but this matches old behavior
regressors = np.zeros(X.shape, dtype=np.float32)
# iterate over categories
for category in range(number_categories):
# iterate over genes and calculate mean expression
# for each gene per category
mask = category == cat_array
for ix in numba.prange(X.T.shape[0]):
regressors[mask, ix] = X.T[ix, mask].mean()
return regressors


@njit
def get_resid(
data: np.ndarray,
Expand Down Expand Up @@ -722,13 +739,15 @@
)
raise ValueError(msg)
logg.debug("... regressing on per-gene means within categories")
regressors = np.zeros(X.shape, dtype="float32")
# set number of categories to the same dtype as the categories
cat_array = adata.obs[keys[0]].cat.codes.to_numpy()
number_categories = cat_array.dtype.type(len(adata.obs[keys[0]].cat.categories))

X = to_dense(X, order="F") if isinstance(X, CSBase) else X
# TODO figure out if we should use a numba kernel for this
for category in adata.obs[keys[0]].cat.categories:
mask = (category == adata.obs[keys[0]]).values
for ix, x in enumerate(X.T):
regressors[mask, ix] = x[mask].mean()
if np.issubdtype(X.dtype, np.integer):
target_dtype = np.float32 if X.dtype.itemsize <= 4 else np.float64
X = X.astype(target_dtype)

Check warning on line 749 in src/scanpy/preprocessing/_simple.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/preprocessing/_simple.py#L748-L749

Added lines #L748 - L749 were not covered by tests
regressors = _create_regressor_categorical(X, number_categories, cat_array)
variable_is_categorical = True
# regress on one or several ordinal variables
else:
Expand Down
Binary file added tests/_data/cat_regressor_for_int_input.npy
Binary file not shown.
Binary file added tests/_data/regress_test_small_cat.npy
Binary file not shown.
37 changes: 32 additions & 5 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import scanpy as sc
from scanpy._compat import CSBase
from scanpy.datasets._datasets import pbmc3k
from testing.scanpy._helpers import (
anndata_v0_8_constructor_compat,
check_rep_mutation,
Expand Down Expand Up @@ -351,6 +352,25 @@ def test_regress_out_ordinal():
np.testing.assert_array_equal(single.X, multi.X)


@pytest.mark.parametrize("dtype", [np.uint32, np.float64, np.uint64])
def test_regress_out_int(dtype):
adata = pbmc3k()[:200, :200].copy()
adata.X = adata.X.astype(np.float64 if dtype != np.uint32 else np.float32)
dtype = adata.X.dtype
adata.obs["labels"] = pd.Categorical(
(["A"] * (adata.X.shape[0] - 100)) + (["B"] * 100)
)
adata_other = adata.copy()
adata_other.X = adata_other.X.astype(dtype)
# results using only one processor
sc.pp.regress_out(adata, keys=["labels"])
sc.pp.regress_out(adata_other, keys=["labels"])
assert_equal(adata_other, adata)
# This file was generated under scanpy 1.10.3
ground_truth = np.load(DATA_PATH / "cat_regressor_for_int_input.npy")
np.testing.assert_allclose(ground_truth, adata_other.X, atol=1e-5, rtol=1e-5)


@pytest.mark.parametrize("dtype", [np.int64, np.float64, np.int32])
def test_regress_out_layer(dtype):
from scipy.sparse import random
Expand Down Expand Up @@ -417,14 +437,21 @@ def test_regress_out_constants():
assert_equal(adata, adata_copy)


def test_regress_out_reproducible():
adata = pbmc68k_reduced()
@pytest.mark.parametrize(
("keys", "test_file", "atol"),
[
(["n_counts", "percent_mito"], "regress_test_small.npy", 0.0),
(["bulk_labels"], "regress_test_small_cat.npy", 1e-6),
],
)
def test_regress_out_reproducible(keys, test_file, atol):
adata = sc.datasets.pbmc68k_reduced()
adata = adata.raw.to_adata()[:200, :200].copy()
sc.pp.regress_out(adata, keys=["n_counts", "percent_mito"])
sc.pp.regress_out(adata, keys=keys)
# This file was generated from the original implementation in version 1.10.3
# Now we compare new implementation with the old one
tester = np.load(DATA_PATH / "regress_test_small.npy")
np.testing.assert_allclose(adata.X, tester)
tester = np.load(DATA_PATH / test_file)
np.testing.assert_allclose(adata.X, tester, atol=atol)


def test_regress_out_constants_equivalent():
Expand Down
Loading