Skip to content

Commit 3b4978d

Browse files
Backport PR #3529: Speed up wilcoxon rank-sum test with numba (#3539)
Co-authored-by: George Wu <[email protected]>
1 parent 0226a24 commit 3b4978d

File tree

4 files changed

+58
-26
lines changed

4 files changed

+58
-26
lines changed

benchmarks/benchmarks/tools.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,11 @@ def time_leiden():
4747

4848
def peakmem_leiden():
4949
sc.tl.leiden(adata, flavor="igraph")
50+
51+
52+
def time_rank_genes_groups() -> None:
53+
sc.tl.rank_genes_groups(adata, "bulk_labels", method="wilcoxon")
54+
55+
56+
def peakmem_rank_genes_groups() -> None:
57+
sc.tl.rank_genes_groups(adata, "bulk_labels", method="wilcoxon")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Speed up wilcoxon rank-sum test with numba {smaller}`G Wu`

src/scanpy/experimental/pp/_highly_variable_genes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from math import sqrt
66
from typing import TYPE_CHECKING
77

8-
import numba as nb
8+
import numba
99
import numpy as np
1010
import pandas as pd
1111
import scipy.sparse as sp_sparse
@@ -62,7 +62,7 @@ def clac_clipped_res_sparse(gene: int, cell: int, value: np.float64) -> np.float
6262
return res
6363

6464
residuals = np.zeros(n_genes, dtype=np.float64)
65-
for gene in nb.prange(n_genes):
65+
for gene in numba.prange(n_genes):
6666
start_idx = indptr[gene]
6767
stop_idx = indptr[gene + 1]
6868

@@ -113,7 +113,7 @@ def clac_clipped_res_dense(gene: int, cell: int) -> np.float64:
113113

114114
residuals = np.zeros(n_genes, dtype=np.float64)
115115

116-
for gene in nb.prange(n_genes):
116+
for gene in numba.prange(n_genes):
117117
sum_clipped_res = np.float64(0.0)
118118
for cell in range(n_cells):
119119
sum_clipped_res += clac_clipped_res_dense(gene, cell)

src/scanpy/tools/_rank_genes_groups.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44

55
from typing import TYPE_CHECKING, Literal
66

7+
import numba
78
import numpy as np
89
import pandas as pd
910
from scipy.sparse import issparse, vstack
1011

1112
from .. import _utils
1213
from .. import logging as logg
13-
from .._compat import old_positionals
14+
from .._compat import njit, old_positionals
1415
from .._utils import (
1516
check_nonnegative_integers,
1617
get_literal_vals,
@@ -46,11 +47,50 @@ def _select_top_n(scores: NDArray, n_top: int):
4647
return global_indices
4748

4849

50+
@njit
51+
def rankdata(data: NDArray[np.number]) -> NDArray[np.float64]:
52+
"""Parallelized version of scipy.stats.rankdata."""
53+
ranked = np.empty(data.shape, dtype=np.float64)
54+
for j in numba.prange(data.shape[1]):
55+
arr = np.ravel(data[:, j])
56+
sorter = np.argsort(arr)
57+
58+
arr = arr[sorter]
59+
obs = np.concatenate((np.array([True]), arr[1:] != arr[:-1]))
60+
61+
dense = np.empty(obs.size, dtype=np.int64)
62+
dense[sorter] = obs.cumsum()
63+
64+
# cumulative counts of each unique value
65+
count = np.concatenate((np.flatnonzero(obs), np.array([len(obs)])))
66+
ranked[:, j] = 0.5 * (count[dense] + count[dense - 1] + 1)
67+
68+
return ranked
69+
70+
71+
@njit
72+
def _tiecorrect(rankvals: NDArray[np.number]) -> NDArray[np.float64]:
73+
"""Parallelized version of scipy.stats.tiecorrect."""
74+
tc = np.ones(rankvals.shape[1], dtype=np.float64)
75+
for j in numba.prange(rankvals.shape[1]):
76+
arr = np.sort(np.ravel(rankvals[:, j]))
77+
idx = np.flatnonzero(
78+
np.concatenate((np.array([True]), arr[1:] != arr[:-1], np.array([True])))
79+
)
80+
cnt = np.diff(idx).astype(np.float64)
81+
82+
size = np.float64(arr.size)
83+
if size >= 2:
84+
tc[j] = 1.0 - (cnt**3 - cnt).sum() / (size**3 - size)
85+
86+
return tc
87+
88+
4989
def _ranks(
50-
X: np.ndarray | _CSMatrix,
90+
X: NDArray[np.number] | _CSMatrix,
5191
mask_obs: NDArray[np.bool_] | None = None,
5292
mask_obs_rest: NDArray[np.bool_] | None = None,
53-
) -> Generator[tuple[pd.DataFrame, int, int], None, None]:
93+
) -> Generator[tuple[NDArray[np.float64], int, int], None, None]:
5494
n_genes = X.shape[1]
5595

5696
if issparse(X):
@@ -77,25 +117,10 @@ def _ranks(
77117
for left in range(0, n_genes, max_chunk):
78118
right = min(left + max_chunk, n_genes)
79119

80-
df = pd.DataFrame(data=get_chunk(X, left, right))
81-
ranks = df.rank()
120+
ranks = rankdata(get_chunk(X, left, right))
82121
yield ranks, left, right
83122

84123

85-
def _tiecorrect(ranks: pd.DataFrame) -> np.float64:
86-
size = np.float64(ranks.shape[0])
87-
if size < 2:
88-
return np.repeat(ranks.shape[1], 1.0)
89-
90-
arr = np.sort(ranks, axis=0)
91-
tf = np.insert(arr[1:] != arr[:-1], (0, arr.shape[0] - 1), True, axis=0)
92-
idx = np.where(tf, np.arange(tf.shape[0])[:, None], 0)
93-
idx = np.sort(idx, axis=0)
94-
cnt = np.diff(idx, axis=0).astype(np.float64)
95-
96-
return 1.0 - (cnt**3 - cnt).sum(axis=0) / (size**3 - size)
97-
98-
99124
class _RankGenes:
100125
def __init__(
101126
self,
@@ -311,7 +336,7 @@ def wilcoxon(
311336

312337
# Calculate rank sums for each chunk for the current mask
313338
for ranks, left, right in _ranks(self.X, mask_obs, mask_obs_rest):
314-
scores[left:right] = ranks.iloc[0:n_active, :].sum(axis=0)
339+
scores[left:right] = ranks[0:n_active, :].sum(axis=0)
315340
if tie_correct:
316341
T[left:right] = _tiecorrect(ranks)
317342

@@ -339,9 +364,7 @@ def wilcoxon(
339364
for ranks, left, right in _ranks(self.X):
340365
# sum up adjusted_ranks to calculate W_m,n
341366
for group_index, mask_obs in enumerate(self.groups_masks_obs):
342-
scores[group_index, left:right] = ranks.iloc[mask_obs, :].sum(
343-
axis=0
344-
)
367+
scores[group_index, left:right] = ranks[mask_obs, :].sum(axis=0)
345368
if tie_correct:
346369
T[group_index, left:right] = _tiecorrect(ranks)
347370

0 commit comments

Comments
 (0)