Skip to content

Commit 9d06854

Browse files
committed
rewrite logics without prange (prange tries to rewrite one element in the same time
1 parent 968e093 commit 9d06854

File tree

1 file changed

+14
-30
lines changed

1 file changed

+14
-30
lines changed

src/scanpy/tools/_score_genes.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
from typing import TYPE_CHECKING
66

7-
import numba
7+
import math
88
import numpy as np
99
import pandas as pd
1010

1111
from .. import logging as logg
12-
from .._compat import CSBase, njit, old_positionals
12+
from .._compat import CSBase, old_positionals
1313
from .._utils import _check_use_raw, is_backed_type
1414
from ..get import _get_obs_rep
1515

@@ -29,39 +29,23 @@
2929
_GetSubset = Callable[[_StrIdx], np.ndarray | CSBase]
3030

3131

32-
@njit
3332
def _get_sparce_nanmean_columns(
34-
data: NDArray[Any], indicies: NDArray[np.int32], shape: tuple
33+
data: NDArray[Any], indices: NDArray[np.int32], shape: tuple
3534
) -> NDArray[np.float64]:
36-
sums = np.zeros(shape[1], dtype=np.float64)
37-
counts = np.repeat(float(shape[0]), shape[1])
38-
for data_index in numba.prange(len(data)):
39-
if np.isnan(data[data_index]):
40-
counts[indicies[data_index]] -= 1.0
41-
continue
42-
sums[indicies[data_index]] += data[data_index]
43-
# if we have row column nans return nan (not inf)
44-
counts[counts == 0.0] = np.nan
45-
return sums / counts
46-
47-
48-
@njit
35+
sum_arr = np.zeros(shape[1], dtype = np.float64)
36+
nans_arr = np.zeros(shape[1], dtype = np.float64)
37+
np.add.at(sum_arr, indices, np.nan_to_num(data, nan=0.0))
38+
np.add.at(nans_arr, indices, np.isnan(data))
39+
nans_arr[nans_arr==shape[0]] = np.nan
40+
return sum_arr/(shape[0] - nans_arr)
41+
42+
4943
def _get_sparce_nanmean_rows(
5044
data: NDArray[Any], indptr: NDArray[np.int32], shape: tuple
5145
) -> NDArray[np.float64]:
52-
sums = np.zeros(shape[0], dtype=np.float64)
53-
counts = np.repeat(float(shape[1]), shape[0])
54-
for cur_row_index in numba.prange(shape[0]):
55-
for data_index in numba.prange(
56-
indptr[cur_row_index], indptr[cur_row_index + 1]
57-
):
58-
if np.isnan(data[data_index]):
59-
counts[cur_row_index] -= 1.0
60-
continue
61-
sums[cur_row_index] += data[data_index]
62-
# if we have row from nans return nan (not inf)
63-
counts[counts == 0.0] = np.nan
64-
return sums / counts
46+
sum_arr = np.add.reduceat(np.nan_to_num(data, nan=0.0), indptr[:-1], dtype=np.float64)
47+
nans_arr = np.add.reduceat(np.isnan(data), indptr[:-1], dtype=np.float64)
48+
return sum_arr/(shape[1] - nans_arr)
6549

6650

6751
def _sparse_nanmean(X: CSBase, axis: Literal[0, 1]) -> NDArray[np.float64]:

0 commit comments

Comments
 (0)