Skip to content

Commit

Permalink
Implement np.bincount (jax-ml#2986)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored May 7, 2020
1 parent d679ccd commit 9f04d98
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Not every function in NumPy is implemented; contributions are welcome!
atleast_2d
atleast_3d
bartlett
bincount
bitwise_and
bitwise_not
bitwise_or
Expand Down
2 changes: 1 addition & 1 deletion jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
alltrue, amax, amin, angle, any, append, arange, arccos, arccosh, arcsin,
arcsinh, arctan, arctan2, arctanh, argmax, argmin, argsort, around, array,
array_equal, array_repr, array_str, asarray, atleast_1d, atleast_2d,
atleast_3d, average, bartlett, bfloat16, bitwise_and, bitwise_not,
atleast_3d, average, bartlett, bfloat16, bincount, bitwise_and, bitwise_not,
bitwise_or, bitwise_xor, blackman, block, bool_, broadcast_arrays,
broadcast_to, can_cast, cbrt, cdouble, ceil, character, clip, column_stack,
complex128, complex64, complex_, complexfloating, concatenate, conj,
Expand Down
25 changes: 25 additions & 0 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,31 @@ def select(condlist, choicelist, default=0):
return output


@_wraps(onp.bincount, lax_description="""\
Jax adds the optional `length` parameter which specifies the output length, and
defaults to ``x.max() + 1``. It must be specified for bincount to be compilable.
Values larger than the specified length will be discarded.
Additionally, while ``np.bincount`` raises an error if the input array contains
negative values, ``jax.numpy.bincount`` treats negative values as zero.
""")
def bincount(x, weights=None, minlength=0, *, length=None):
if not issubdtype(_dtype(x), integer):
msg = f"x argument to bincount must have an integer type; got {x.dtype}"
raise TypeError(msg)
if length is None:
length = max(x) + 1
length = _max(length, minlength)
if ndim(x) != 1:
raise ValueError("only 1-dimensional input supported.")
if weights is None:
weights = array(1, dtype=int32)
else:
if shape(x) != shape(weights):
raise ValueError("shape of weights must match shape of x.")
return ops.index_add(zeros((length,), _dtype(weights)), ops.index[clip(x, 0)], weights)


def broadcast_arrays(*args):
"""Like Numpy's broadcast_arrays but doesn't return views."""
shapes = [shape(arg) for arg in args]
Expand Down
46 changes: 46 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2525,6 +2525,52 @@ def testAtLeastNdLiterals(self, pytype, dtype, op):
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
{
"testcase_name": "_shape={}_dtype={}_weights={}_minlength={}_length={}".format(
shape, dtype, weights, minlength, length
),
"shape": shape,
"dtype": dtype,
"weights": weights,
"minlength": minlength,
"length": length,
"rng_factory": rng_factory}
for shape in [(5,), (10,)]
for dtype in int_dtypes
for weights in [True, False]
for minlength in [0, 20]
for length in [None, 10]
for rng_factory in [jtu.rand_positive]
))
def testBincount(self, shape, dtype, weights, minlength, length, rng_factory):
rng = rng_factory(self.rng())
args_maker = lambda: (rng(shape, dtype), (rng(shape, 'float32') if weights else None))

onp_fun = partial(onp.bincount, minlength=minlength)
jnp_fun = partial(jnp.bincount, minlength=minlength, length=length)

if length is not None:
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
if length is None:
self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=False)

def testBincountNegative(self):
# Test that jnp.bincount ignores negative values.
x_rng = jtu.rand_int(self.rng(), -100, 100)
w_rng = jtu.rand_uniform(self.rng())
shape = (1000,)
x = x_rng(shape, 'int32')
w = w_rng(shape, 'float32')

xn = onp.array(x)
xn[xn < 0] = 0
wn = onp.array(w)
onp_result = onp.bincount(xn[xn >= 0], wn[xn >= 0])
jnp_result = jnp.bincount(x, w)
self.assertAllClose(onp_result, jnp_result, check_dtypes=False)


@parameterized.named_parameters(*jtu.cases_from_list(
{"testcase_name": "_case={}".format(i),
"input": input}
Expand Down

0 comments on commit 9f04d98

Please sign in to comment.