Skip to content

Commit

Permalink
Merge pull request cupy#8346 from ev-br/np2.0_imports_2
Browse files Browse the repository at this point in the history
make CuPy import under NumPy 2.0
  • Loading branch information
takagi authored and chainer-ci committed Jun 7, 2024
1 parent 73e6c30 commit 187f965
Show file tree
Hide file tree
Showing 27 changed files with 272 additions and 71 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pretest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ jobs:
# TODO: Test against various NumPy/SciPy versions.
pip install 'numpy==1.25.*' 'scipy==1.10.*'
pip install 'mypy==1.5.*' 'types-setuptools==57.4.14' 'pytest>=7.2'
sed -i s/error::cupy.exceptions./error::numpy./ setup.cfg # Because no cupy here. See cupy#8346
pytest -v tests/typing_tests/test_typing.py
build-cuda:
Expand Down
207 changes: 184 additions & 23 deletions cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def is_available():
from numpy import half # NOQA
from numpy import single # NOQA
from numpy import double # NOQA
from numpy import float_ # NOQA
from numpy import longfloat # NOQA
from numpy import float64 as float_ # NOQA
# from numpy import longfloat # NOQA # XXX
from numpy import float16 # NOQA
from numpy import float32 # NOQA
from numpy import float64 # NOQA
Expand All @@ -147,10 +147,10 @@ def is_available():
# Complex floating-point numbers
# -----------------------------------------------------------------------------
from numpy import csingle # NOQA
from numpy import singlecomplex # NOQA
from numpy import complex64 as singlecomplex # NOQA
from numpy import cdouble # NOQA
from numpy import cfloat # NOQA
from numpy import complex_ # NOQA
from numpy import complex128 as cfloat # NOQA
from numpy import complex128 as complex_ # NOQA
from numpy import complex64 # NOQA
from numpy import complex128 # NOQA

Expand Down Expand Up @@ -360,23 +360,16 @@ def result_type(*arrays_and_dtypes):

from cupy._core.core import min_scalar_type # NOQA

from numpy import obj2sctype # NOQA
from numpy import promote_types # NOQA

from numpy import dtype # NOQA
from numpy import format_parser # NOQA

from numpy import finfo # NOQA
from numpy import iinfo # NOQA

from numpy import find_common_type # NOQA
from numpy import issctype # NOQA
from numpy import issubclass_ # NOQA
from numpy import issubdtype # NOQA
from numpy import issubsctype # NOQA

from numpy import mintypecode # NOQA
from numpy import sctype2char # NOQA
from numpy import typename # NOQA

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -422,7 +415,6 @@ def result_type(*arrays_and_dtypes):
from cupy._indexing.iterate import flatiter # NOQA

# Borrowed from NumPy
from numpy import get_array_wrap # NOQA
from numpy import index_exp # NOQA
from numpy import ndindex # NOQA
from numpy import s_ # NOQA
Expand Down Expand Up @@ -453,11 +445,9 @@ def base_repr(number, base=2, padding=0): # NOQA (needed to avoid redefinition


# Borrowed from NumPy
from numpy import DataSource # NOQA
from numpy import get_printoptions # NOQA
from numpy import set_printoptions # NOQA
from numpy import printoptions # NOQA
from numpy import set_string_function # NOQA


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -545,7 +535,7 @@ def isscalar(element):
from cupy.lib._routines_poly import roots # NOQA

# Borrowed from NumPy
from numpy import RankWarning # NOQA
from cupy.exceptions import RankWarning # NOQA

# -----------------------------------------------------------------------------
# Mathematical functions
Expand Down Expand Up @@ -672,10 +662,8 @@ def isscalar(element):
from cupy._misc.who import who # NOQA

# Borrowed from NumPy
from numpy import disp # NOQA
from numpy import iterable # NOQA
from numpy import safe_eval # NOQA
from numpy import AxisError # NOQA
from cupy.exceptions import AxisError # NOQA


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -743,10 +731,10 @@ def isscalar(element):
# -----------------------------------------------------------------------------
# Classes without their own docs
# -----------------------------------------------------------------------------
from numpy import ComplexWarning # NOQA
from numpy import ModuleDeprecationWarning # NOQA
from numpy import TooHardError # NOQA
from numpy import VisibleDeprecationWarning # NOQA
from cupy.exceptions import ComplexWarning # NOQA
from cupy.exceptions import ModuleDeprecationWarning # NOQA
from cupy.exceptions import TooHardError # NOQA
from cupy.exceptions import VisibleDeprecationWarning # NOQA


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -920,6 +908,179 @@ def show_config(*, _full=False):
]


# np 2.0: XXX shims for things removed in np 2.0

# https://github.com/numpy/numpy/blob/v1.26.4/numpy/core/numerictypes.py#L283-L322 # NOQA
def issubclass_(arg1, arg2):
try:
return issubclass(arg1, arg2)
except TypeError:
return False

# https://github.com/numpy/numpy/blob/v1.26.0/numpy/core/numerictypes.py#L229-L280 # NOQA


def obj2sctype(rep, default=None):
"""
Return the scalar dtype or NumPy equivalent of Python type of an object.
Parameters
----------
rep : any
The object of which the type is returned.
default : any, optional
If given, this is returned for objects whose types can not be
determined. If not given, None is returned for those objects.
Returns
-------
dtype : dtype or Python type
The data type of `rep`.
"""
# prevent abstract classes being upcast
if isinstance(rep, type) and issubclass(rep, _numpy.generic):
return rep
# extract dtype from arrays
if isinstance(rep, _numpy.ndarray):
return rep.dtype.type
# fall back on dtype to convert
try:
res = _numpy.dtype(rep)
except Exception:
return default
else:
return res.type


# https://github.com/numpy/numpy/blob/v1.26.0/numpy/core/numerictypes.py#L326C1-L355C1 # NOQA
def issubsctype(arg1, arg2):
"""
Determine if the first argument is a subclass of the second argument.
Parameters
----------
arg1, arg2 : dtype or dtype specifier
Data-types.
Returns
-------
out : bool
The result.
"""
return issubclass(obj2sctype(arg1), obj2sctype(arg2))


# https://github.com/numpy/numpy/blob/v1.26.0/numpy/core/numerictypes.py#L457 # NOQA
def sctype2char(sctype):
"""
Return the string representation of a scalar dtype.
Parameters
----------
sctype : scalar dtype or object
If a scalar dtype, the corresponding string character is
returned. If an object, `sctype2char` tries to infer its scalar type
and then return the corresponding string character.
Returns
-------
typechar : str
The string character corresponding to the scalar type.
Raises
------
ValueError
If `sctype` is an object for which the type can not be inferred.
"""
sctype = obj2sctype(sctype)
if sctype is None:
raise ValueError("unrecognized type")
return _numpy.dtype(sctype).char


# https://github.com/numpy/numpy/blob/v1.26.0/numpy/core/numerictypes.py#L184 # NOQA
def issctype(rep):
"""
Determines whether the given object represents a scalar data-type.
Parameters
----------
rep : any
If `rep` is an instance of a scalar dtype, True is returned. If not,
False is returned.
Returns
-------
out : bool
Boolean result of check whether `rep` is a scalar dtype.
"""
if not isinstance(rep, (type, _numpy.dtype)):
return False
try:
res = obj2sctype(rep)
if res and res != _numpy.object_:
return True
return False
except Exception:
return False


# np 2.0: XXX shims for things moved in np 2.0
if _numpy.__version__ < "2":
from numpy import format_parser # NOQA
from numpy import DataSource # NOQA
else:
from numpy.rec import format_parser # type: ignore [no-redef] # NOQA
from numpy.lib.npyio import DataSource # NOQA


# np 2.0: XXX shims for things removed without replacement
if _numpy.__version__ < "2":
from numpy import find_common_type # NOQA
from numpy import set_string_function # NOQA
from numpy import get_array_wrap # NOQA
from numpy import disp # NOQA
from numpy import safe_eval # NOQA
else:

_template = '''\
''This function has been removed in NumPy v2.
Use {recommendation} instead.
CuPy has been providing this function as an alias to the NumPy
implementation, so it cannot be used in environments with NumPy
v2 installed. If you rely on this function and you cannot modify
the code to use {recommendation}, please downgrade NumPy to v1.26
or earlier.
'''

def find_common_type(*args, **kwds):
mesg = _template.format(
recommendation='`promote_types` or `result_type`'
)
raise RuntimeError(mesg)

def set_string_function(*args, **kwds): # type: ignore [misc]
mesg = _template.format(recommendation='`np.set_printoptions`')
raise RuntimeError(mesg)

def get_array_wrap(*args, **kwds): # type: ignore [no-redef]
mesg = _template.format(recommendation="<no replacement>")
raise RuntimeError(mesg)

def disp(*args, **kwds): # type: ignore [misc]
mesg = _template.format(recommendation="your own print function")
raise RuntimeError(mesg)

def safe_eval(*args, **kwds): # type: ignore [misc]
mesg = _template.format(recommendation="`ast.literal_eval`")
raise RuntimeError(mesg)


def __getattr__(name):
if name in _deprecated_apis:
return getattr(_numpy, name)
Expand Down
3 changes: 2 additions & 1 deletion cupy/_core/_routines_indexing.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import numpy

import cupy
import cupy._core.core as core
from cupy.exceptions import AxisError
from cupy._core._kernel import ElementwiseKernel, _get_warpsize
from cupy._core._ufuncs import elementwise_copy

Expand Down Expand Up @@ -1054,7 +1055,7 @@ cdef _ndarray_base _diagonal(
Py_ssize_t axis2=1):
cdef Py_ssize_t ndim = a.ndim
if not (-ndim <= axis1 < ndim and -ndim <= axis2 < ndim):
raise numpy.AxisError(
raise AxisError(
'axis1(={0}) and axis2(={1}) must be within range '
'(ndim={2})'.format(axis1, axis2, ndim))

Expand Down
3 changes: 2 additions & 1 deletion cupy/_core/_routines_manipulation.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import numpy
from cupy._core._kernel import ElementwiseKernel
from cupy._core._ufuncs import elementwise_copy
import cupy._core.core as core
from cupy.exceptions import AxisError

cimport cpython # NOQA
cimport cython # NOQA
Expand Down Expand Up @@ -390,7 +391,7 @@ cpdef _ndarray_base _transpose(
for i in range(axes_size):
axis = axes[i]
if axis < -ndim or axis >= ndim:
raise numpy.AxisError(axis, ndim)
raise AxisError(axis, ndim)
axis %= ndim
a_axes.push_back(axis)
if axis_flags[axis]:
Expand Down
9 changes: 5 additions & 4 deletions cupy/_core/_routines_sorting.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import string
import numpy

import cupy
from cupy.exceptions import AxisError
from cupy._core._scalar import get_typename as _get_typename
from cupy._core._ufuncs import elementwise_copy
import cupy._core.core as core
Expand All @@ -25,8 +26,8 @@ cdef _ndarray_sort(_ndarray_base self, int axis):
'reinstall CuPy after uninstalling it.')

if ndim == 0:
raise numpy.AxisError('Sorting arrays with the rank of zero is not '
'supported') # as numpy.sort() raises
raise AxisError('Sorting arrays with the rank of zero is not '
'supported') # as numpy.sort() raises

# TODO(takagi): Support sorting views
if not self._c_contiguous:
Expand Down Expand Up @@ -128,8 +129,8 @@ cdef _ndarray_partition(_ndarray_base self, kth, int axis):
cdef _ndarray_base data

if ndim == 0:
raise numpy.AxisError('Sorting arrays with the rank of zero is not '
'supported')
raise AxisError('Sorting arrays with the rank of zero is not '
'supported')

if not self._c_contiguous:
raise NotImplementedError('Sorting non-contiguous array is not '
Expand Down
3 changes: 2 additions & 1 deletion cupy/_core/_routines_statistics.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import numpy
from numpy import nan

import cupy
from cupy.exceptions import AxisError
from cupy._core import _reduction
from cupy._core._reduction import create_reduction_func
from cupy._core._reduction import ReductionKernel
Expand Down Expand Up @@ -412,7 +413,7 @@ cpdef _ndarray_base _median(
sz = a.size
else:
if axis < -keep_ndim or axis >= keep_ndim:
raise numpy.AxisError('Axis overrun')
raise AxisError('Axis overrun')
sz = a.shape[axis]
if sz % 2 == 0:
szh = sz // 2
Expand Down
6 changes: 5 additions & 1 deletion cupy/_core/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ from cupy_backends.cuda cimport stream as _stream_module
from cupy_backends.cuda.api cimport runtime


NUMPY_1x = numpy.__version__ < '2'


# If rop of cupy.ndarray is called, cupy's op is the last chance.
# If op of cupy.ndarray is called and the `other` is cupy.ndarray, too,
# it is safe to call cupy's op.
Expand Down Expand Up @@ -2538,7 +2541,8 @@ cdef _ndarray_base _array_default(
order = 'F'
else:
order = 'C'
a_cpu = numpy.array(obj, dtype=dtype, copy=False, order=order,
copy = False if NUMPY_1x else None
a_cpu = numpy.array(obj, dtype=dtype, copy=copy, order=order,
ndmin=ndmin)
if a_cpu.dtype.char not in '?bhilqBHILQefdFD':
raise ValueError('Unsupported dtype %s' % a_cpu.dtype)
Expand Down
Loading

0 comments on commit 187f965

Please sign in to comment.