44
55from typing import TYPE_CHECKING , Literal
66
7+ import numba
78import numpy as np
89import pandas as pd
910from scipy .sparse import issparse , vstack
1011
1112from .. import _utils
1213from .. import logging as logg
13- from .._compat import old_positionals
14+ from .._compat import njit , old_positionals
1415from .._utils import (
1516 check_nonnegative_integers ,
1617 get_literal_vals ,
@@ -46,11 +47,50 @@ def _select_top_n(scores: NDArray, n_top: int):
4647 return global_indices
4748
4849
50+ @njit
51+ def rankdata (data : NDArray [np .number ]) -> NDArray [np .float64 ]:
52+ """Parallelized version of scipy.stats.rankdata."""
53+ ranked = np .empty (data .shape , dtype = np .float64 )
54+ for j in numba .prange (data .shape [1 ]):
55+ arr = np .ravel (data [:, j ])
56+ sorter = np .argsort (arr )
57+
58+ arr = arr [sorter ]
59+ obs = np .concatenate ((np .array ([True ]), arr [1 :] != arr [:- 1 ]))
60+
61+ dense = np .empty (obs .size , dtype = np .int64 )
62+ dense [sorter ] = obs .cumsum ()
63+
64+ # cumulative counts of each unique value
65+ count = np .concatenate ((np .flatnonzero (obs ), np .array ([len (obs )])))
66+ ranked [:, j ] = 0.5 * (count [dense ] + count [dense - 1 ] + 1 )
67+
68+ return ranked
69+
70+
71+ @njit
72+ def _tiecorrect (rankvals : NDArray [np .number ]) -> NDArray [np .float64 ]:
73+ """Parallelized version of scipy.stats.tiecorrect."""
74+ tc = np .ones (rankvals .shape [1 ], dtype = np .float64 )
75+ for j in numba .prange (rankvals .shape [1 ]):
76+ arr = np .sort (np .ravel (rankvals [:, j ]))
77+ idx = np .flatnonzero (
78+ np .concatenate ((np .array ([True ]), arr [1 :] != arr [:- 1 ], np .array ([True ])))
79+ )
80+ cnt = np .diff (idx ).astype (np .float64 )
81+
82+ size = np .float64 (arr .size )
83+ if size >= 2 :
84+ tc [j ] = 1.0 - (cnt ** 3 - cnt ).sum () / (size ** 3 - size )
85+
86+ return tc
87+
88+
4989def _ranks (
50- X : np .ndarray | _CSMatrix ,
90+ X : NDArray [ np .number ] | _CSMatrix ,
5191 mask_obs : NDArray [np .bool_ ] | None = None ,
5292 mask_obs_rest : NDArray [np .bool_ ] | None = None ,
53- ) -> Generator [tuple [pd . DataFrame , int , int ], None , None ]:
93+ ) -> Generator [tuple [NDArray [ np . float64 ] , int , int ], None , None ]:
5494 n_genes = X .shape [1 ]
5595
5696 if issparse (X ):
@@ -77,25 +117,10 @@ def _ranks(
77117 for left in range (0 , n_genes , max_chunk ):
78118 right = min (left + max_chunk , n_genes )
79119
80- df = pd .DataFrame (data = get_chunk (X , left , right ))
81- ranks = df .rank ()
120+ ranks = rankdata (get_chunk (X , left , right ))
82121 yield ranks , left , right
83122
84123
85- def _tiecorrect (ranks : pd .DataFrame ) -> np .float64 :
86- size = np .float64 (ranks .shape [0 ])
87- if size < 2 :
88- return np .repeat (ranks .shape [1 ], 1.0 )
89-
90- arr = np .sort (ranks , axis = 0 )
91- tf = np .insert (arr [1 :] != arr [:- 1 ], (0 , arr .shape [0 ] - 1 ), True , axis = 0 )
92- idx = np .where (tf , np .arange (tf .shape [0 ])[:, None ], 0 )
93- idx = np .sort (idx , axis = 0 )
94- cnt = np .diff (idx , axis = 0 ).astype (np .float64 )
95-
96- return 1.0 - (cnt ** 3 - cnt ).sum (axis = 0 ) / (size ** 3 - size )
97-
98-
99124class _RankGenes :
100125 def __init__ (
101126 self ,
@@ -311,7 +336,7 @@ def wilcoxon(
311336
312337 # Calculate rank sums for each chunk for the current mask
313338 for ranks , left , right in _ranks (self .X , mask_obs , mask_obs_rest ):
314- scores [left :right ] = ranks . iloc [0 :n_active , :].sum (axis = 0 )
339+ scores [left :right ] = ranks [0 :n_active , :].sum (axis = 0 )
315340 if tie_correct :
316341 T [left :right ] = _tiecorrect (ranks )
317342
@@ -339,9 +364,7 @@ def wilcoxon(
339364 for ranks , left , right in _ranks (self .X ):
340365 # sum up adjusted_ranks to calculate W_m,n
341366 for group_index , mask_obs in enumerate (self .groups_masks_obs ):
342- scores [group_index , left :right ] = ranks .iloc [mask_obs , :].sum (
343- axis = 0
344- )
367+ scores [group_index , left :right ] = ranks [mask_obs , :].sum (axis = 0 )
345368 if tie_correct :
346369 T [group_index , left :right ] = _tiecorrect (ranks )
347370
0 commit comments