|
34 | 34 | from packaging.version import Version |
35 | 35 |
|
36 | 36 | from .. import logging as logg |
37 | | -from .._compat import CSBase, DaskArray, _CSArray, _CSMatrix, pkg_version |
| 37 | +from .._compat import CSBase, DaskArray, _CSArray, pkg_version |
38 | 38 | from .._settings import settings |
39 | | -from .compute.is_constant import is_constant # noqa: F401 |
40 | 39 |
|
41 | 40 | if Version(anndata_version) >= Version("0.10.0"): |
42 | 41 | from anndata._core.sparse_dataset import ( |
|
53 | 52 |
|
54 | 53 | from anndata import AnnData |
55 | 54 | from igraph import Graph |
56 | | - from numpy.typing import ArrayLike, DTypeLike, NDArray |
| 55 | + from numpy.typing import ArrayLike, NDArray |
57 | 56 |
|
58 | 57 | from .._compat import CSRBase |
59 | 58 | from ..neighbors import NeighborsParams, RPForestDict |
@@ -546,27 +545,6 @@ def get_literal_vals(typ: UnionType | Any) -> KeysView[Any]: |
546 | 545 | # -------------------------------------------------------------------------------- |
547 | 546 |
|
548 | 547 |
|
549 | | -@singledispatch |
550 | | -def elem_mul(x: _SupportedArray, y: _SupportedArray) -> _SupportedArray: |
551 | | - raise NotImplementedError |
552 | | - |
553 | | - |
554 | | -@elem_mul.register(np.ndarray) |
555 | | -@elem_mul.register(CSBase) |
556 | | -def _elem_mul_in_mem(x: _MemoryArray, y: _MemoryArray) -> _MemoryArray: |
557 | | - if isinstance(x, CSBase): |
558 | | - # returns coo_matrix, so cast back to input type |
559 | | - return type(x)(x.multiply(y)) |
560 | | - return x * y |
561 | | - |
562 | | - |
563 | | -@elem_mul.register(DaskArray) |
564 | | -def _elem_mul_dask(x: DaskArray, y: DaskArray) -> DaskArray: |
565 | | - import dask.array as da |
566 | | - |
567 | | - return da.map_blocks(elem_mul, x, y) |
568 | | - |
569 | | - |
570 | 548 | if TYPE_CHECKING: |
571 | 549 | Scaling_T = TypeVar("Scaling_T", DaskArray, np.ndarray) |
572 | 550 |
|
@@ -606,7 +584,7 @@ def axis_mul_or_truediv( |
606 | 584 | @axis_mul_or_truediv.register(CSBase) |
607 | 585 | def _( |
608 | 586 | X: CSBase, |
609 | | - scaling_array, |
| 587 | + scaling_array: np.ndarray, |
610 | 588 | axis: Literal[0, 1], |
611 | 589 | op: Callable[[Any, Any], Any], |
612 | 590 | *, |
@@ -746,78 +724,6 @@ def _(X: DaskArray, axis: Literal[0, 1]) -> DaskArray: |
746 | 724 | ) |
747 | 725 |
|
748 | 726 |
|
749 | | -@overload |
750 | | -def axis_sum( |
751 | | - X: _CSMatrix, |
752 | | - *, |
753 | | - axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None, |
754 | | - dtype: DTypeLike | None = None, |
755 | | -) -> np.matrix: ... |
756 | | - |
757 | | - |
758 | | -@overload |
759 | | -def axis_sum( |
760 | | - X: np.ndarray, # TODO: or sparray |
761 | | - *, |
762 | | - axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None, |
763 | | - dtype: DTypeLike | None = None, |
764 | | -) -> np.ndarray: ... |
765 | | - |
766 | | - |
767 | | -@singledispatch |
768 | | -def axis_sum( |
769 | | - X: np.ndarray | CSBase, |
770 | | - *, |
771 | | - axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None, |
772 | | - dtype: DTypeLike | None = None, |
773 | | -) -> np.ndarray | np.matrix: |
774 | | - return np.sum(X, axis=axis, dtype=dtype) |
775 | | - |
776 | | - |
777 | | -@axis_sum.register(DaskArray) |
778 | | -def _( |
779 | | - X: DaskArray, |
780 | | - *, |
781 | | - axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None, |
782 | | - dtype: DTypeLike | None = None, |
783 | | -) -> DaskArray: |
784 | | - import dask.array as da |
785 | | - |
786 | | - if dtype is None: |
787 | | - dtype = getattr(np.zeros(1, dtype=X.dtype).sum(), "dtype", object) |
788 | | - |
789 | | - if isinstance(X._meta, np.ndarray) and not isinstance(X._meta, np.matrix): |
790 | | - return X.sum(axis=axis, dtype=dtype) |
791 | | - |
792 | | - def sum_drop_keepdims(*args, **kwargs): |
793 | | - kwargs.pop("computing_meta", None) |
794 | | - # masked operations on sparse produce which numpy matrices gives the same API issues handled here |
795 | | - if isinstance(X._meta, _CSMatrix | np.matrix) or isinstance( |
796 | | - args[0], _CSMatrix | np.matrix |
797 | | - ): |
798 | | - kwargs.pop("keepdims", None) |
799 | | - axis = kwargs["axis"] |
800 | | - if isinstance(axis, tuple): |
801 | | - if len(axis) != 1: |
802 | | - msg = f"`axis_sum` can only sum over one axis when `axis` arg is provided but got {axis} instead" |
803 | | - raise ValueError(msg) |
804 | | - kwargs["axis"] = axis[0] |
805 | | - # returns a np.matrix normally, which is undesireable |
806 | | - return np.array(np.sum(*args, dtype=dtype, **kwargs)) |
807 | | - |
808 | | - def aggregate_sum(*args, **kwargs): |
809 | | - return np.sum(args[0], dtype=dtype, **kwargs) |
810 | | - |
811 | | - return da.reduction( |
812 | | - X, |
813 | | - sum_drop_keepdims, |
814 | | - aggregate_sum, |
815 | | - axis=axis, |
816 | | - dtype=dtype, |
817 | | - meta=np.array([], dtype=dtype), |
818 | | - ) |
819 | | - |
820 | | - |
821 | 727 | @singledispatch |
822 | 728 | def check_nonnegative_integers(X: _SupportedArray) -> bool | DaskArray: |
823 | 729 | """Check values of X to ensure it is count data.""" |
|
0 commit comments