Skip to content

Commit

Permalink
Merge pull request cupy#8950 from grlee77/grelee/restore-cub-histogra…
Browse files Browse the repository at this point in the history
…m-and-bincount

restore CUB histogram and bincount
  • Loading branch information
kmaehashi authored and chainer-ci committed Feb 22, 2025
1 parent f5e36bf commit 54de19a
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 11 deletions.
64 changes: 56 additions & 8 deletions cupy/_statistics/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

import cupy
from cupy import _core
from cupy._core import _accelerator
from cupy.cuda import cub
from cupy.cuda import common
from cupy.cuda import runtime


# rename builtin range for use in functions that take a range argument
Expand Down Expand Up @@ -216,10 +220,43 @@ def histogram(x, bins=10, range=None, weights=None, density=False):

if weights is None:
y = cupy.zeros(bin_edges.size - 1, dtype=cupy.int64)
# TODO(leofang): we temporarily remove CUB histogram support for now,
# see cupy/cupy#7698. When it's ready, revert the commit that checked
# in this comment to restore the support.
_histogram_kernel(x, bin_edges, bin_edges.size, y)
for accelerator in _accelerator.get_routine_accelerators():
# CUB uses int for bin counts
# TODO(leofang): support >= 2^31 elements in x?
if (accelerator == _accelerator.ACCELERATOR_CUB
and x.size <= 0x7fffffff and bin_edges.size <= 0x7fffffff):
# Need to ensure the dtype of bin_edges as it's needed for both
# the CUB call and the correction later
assert isinstance(bin_edges, cupy.ndarray)
if numpy.issubdtype(x.dtype, numpy.integer):
bin_type = float
else:
bin_type = numpy.result_type(bin_edges.dtype, x.dtype)
if (bin_type == numpy.float16 and
not common._is_fp16_supported()):
bin_type = numpy.float32
x = x.astype(bin_type, copy=False)
acc_bin_edge = bin_edges.astype(bin_type, copy=True)
# CUB's upper bin boundary is exclusive for all bins, including
# the last bin, so we must shift it to comply with NumPy
if x.dtype.kind in 'ui':
acc_bin_edge[-1] += 1
elif x.dtype.kind == 'f':
last = acc_bin_edge[-1]
acc_bin_edge[-1] = cupy.nextafter(last, last + 1)
if runtime.is_hip:
y = y.astype(cupy.uint64, copy=False)
out = cub.cub_histogram(x, y, acc_bin_edge)
if out is None:
# fallback to CuPy impl
continue
else:
y = out
if runtime.is_hip:
y = y.astype(cupy.int64, copy=False)
break
else:
_histogram_kernel(x, bin_edges, bin_edges.size, y)
else:
simple_weights = (
cupy.can_cast(weights.dtype, cupy.float64) or
Expand Down Expand Up @@ -519,10 +556,21 @@ def bincount(x, weights=None, minlength=None):

if weights is None:
b = cupy.zeros((size,), dtype=numpy.intp)
# TODO(leofang): we temporarily remove CUB histogram support for now,
# see cupy/cupy#7698. When it's ready, revert the commit that checked
# in this comment to restore the support.
_bincount_kernel(x, b)

for accelerator in _accelerator.get_routine_accelerators():
# CUB uses int for bin counts
# TODO(leofang): support >= 2^31 elements in x?
if (not runtime.is_hip
and accelerator == _accelerator.ACCELERATOR_CUB
and x.size <= 0x7fffffff and size <= 0x7fffffff):
out = cub.cub_histogram(x, b, size+1)
if out is None:
continue
else:
b = out
break
else:
_bincount_kernel(x, b)
else:
b = cupy.zeros((size,), dtype=numpy.float64)
_bincount_with_weight_kernel(x, weights, b)
Expand Down
102 changes: 99 additions & 3 deletions tests/cupy_tests/statistics_tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import cupy
from cupy import testing
from cupy._core import _accelerator


# Note that numpy.bincount does not support uint64 on 64-bit environment
Expand Down Expand Up @@ -328,9 +329,104 @@ def test_bincount_too_small_minlength(self, dtype):
xp.bincount(x, minlength=-1)


# TODO(leofang): we temporarily remove CUB histogram support for now,
# see cupy/cupy#7698. When it's ready, revert the commit that checked
# in this comment to restore the support.
# This class compares CUB results against NumPy's
@unittest.skipUnless(cupy.cuda.cub.available, 'The CUB routine is not enabled')
class TestCubHistogram(unittest.TestCase):

def setUp(self):
self.old_accelerators = _accelerator.get_routine_accelerators()
_accelerator.set_routine_accelerators(['cub'])

def tearDown(self):
_accelerator.set_routine_accelerators(self.old_accelerators)

@testing.for_all_dtypes(no_bool=True, no_complex=True)
@testing.numpy_cupy_array_equal()
def test_histogram(self, xp, dtype):
x = testing.shaped_arange((10,), xp, dtype)

if xp is numpy:
return xp.histogram(x)

# xp is cupy, first ensure we really use CUB
cub_func = 'cupy._statistics.histogram.cub.device_histogram'
with testing.AssertFunctionIsCalled(cub_func):
xp.histogram(x)
# ...then perform the actual computation
return xp.histogram(x)

@testing.for_all_dtypes(no_bool=True, no_complex=True)
@testing.numpy_cupy_array_equal()
def test_histogram_range_float(self, xp, dtype):
a = testing.shaped_arange((10,), xp, dtype)
h, b = xp.histogram(a, testing.shaped_arange((10,), xp, numpy.float64))
assert int(h.sum()) == 10
return h, b

@testing.for_all_dtypes_combination(['dtype_a', 'dtype_b'],
no_bool=True, no_complex=True)
@testing.numpy_cupy_array_equal()
def test_histogram_with_bins(self, xp, dtype_a, dtype_b):
x = testing.shaped_arange((10,), xp, dtype_a)
bins = testing.shaped_arange((4,), xp, dtype_b)

if xp is numpy:
return xp.histogram(x, bins)[0]

# xp is cupy, first ensure we really use CUB
cub_func = 'cupy._statistics.histogram.cub.device_histogram'
with testing.AssertFunctionIsCalled(cub_func):
xp.histogram(x, bins)
# ...then perform the actual computation
return xp.histogram(x, bins)[0]

@testing.for_all_dtypes_combination(['dtype_a', 'dtype_b'],
no_bool=True, no_complex=True)
@testing.numpy_cupy_array_equal()
def test_histogram_with_bins2(self, xp, dtype_a, dtype_b):
x = testing.shaped_arange((10,), xp, dtype_a)
bins = testing.shaped_arange((4,), xp, dtype_b)

if xp is numpy:
return xp.histogram(x, bins)[1]

# xp is cupy, first ensure we really use CUB
cub_func = 'cupy._statistics.histogram.cub.device_histogram'
with testing.AssertFunctionIsCalled(cub_func):
xp.histogram(x, bins)
# ...then perform the actual computation
return xp.histogram(x, bins)[1]

@testing.slow
@testing.numpy_cupy_array_equal()
def test_no_oom(self, xp):
# ensure the workaround for NVIDIA/cub#613 kicks in
amax = 28854312
A = xp.linspace(0, amax, num=amax,
endpoint=True, retstep=False, dtype=xp.int32)
out = xp.histogram(A, bins=amax, range=[0, amax])
return out

@testing.for_int_dtypes('dtype', no_bool=True)
@testing.numpy_cupy_array_equal()
def test_bincount_gh7698(self, xp, dtype):
dtype = xp.dtype(dtype)
max_val = xp.iinfo(dtype).max if dtype.itemsize < 4 else 65536
if dtype == xp.uint64:
pytest.skip("only numpy raises exception on uint64 input")

# https://github.com/cupy/cupy/issues/7698
x = xp.arange(max_val, dtype=dtype)

if xp is numpy:
return xp.bincount(x)

# xp is cupy, first ensure we really use CUB
cub_func = 'cupy._statistics.histogram.cub.device_histogram'
with testing.AssertFunctionIsCalled(cub_func):
xp.bincount(x)
# ...then perform the actual computation
return xp.bincount(x)


@testing.parameterize(*testing.product(
Expand Down

0 comments on commit 54de19a

Please sign in to comment.