From 9f04d9817b20f3dc600555c4f57bf7a39863bb1c Mon Sep 17 00:00:00 2001 From: Jake Vanderplas Date: Thu, 7 May 2020 13:17:43 -0700 Subject: [PATCH] Implement np.bincount (#2986) --- docs/jax.numpy.rst | 1 + jax/numpy/__init__.py | 2 +- jax/numpy/lax_numpy.py | 25 ++++++++++++++++++++++ tests/lax_numpy_test.py | 46 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 73 insertions(+), 1 deletion(-) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 737d9986fdd0..5b7d660628c2 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -63,6 +63,7 @@ Not every function in NumPy is implemented; contributions are welcome! atleast_2d atleast_3d bartlett + bincount bitwise_and bitwise_not bitwise_or diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 2a54a904fee7..7b814ee4ce21 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -22,7 +22,7 @@ alltrue, amax, amin, angle, any, append, arange, arccos, arccosh, arcsin, arcsinh, arctan, arctan2, arctanh, argmax, argmin, argsort, around, array, array_equal, array_repr, array_str, asarray, atleast_1d, atleast_2d, - atleast_3d, average, bartlett, bfloat16, bitwise_and, bitwise_not, + atleast_3d, average, bartlett, bfloat16, bincount, bitwise_and, bitwise_not, bitwise_or, bitwise_xor, blackman, block, bool_, broadcast_arrays, broadcast_to, can_cast, cbrt, cdouble, ceil, character, clip, column_stack, complex128, complex64, complex_, complexfloating, concatenate, conj, diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 5b804f7ba6d8..5744681d7a7b 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -1242,6 +1242,31 @@ def select(condlist, choicelist, default=0): return output +@_wraps(onp.bincount, lax_description="""\ +Jax adds the optional `length` parameter which specifies the output length, and +defaults to ``x.max() + 1``. It must be specified for bincount to be compilable. +Values larger than the specified length will be discarded. + +Additionally, while ``np.bincount`` raises an error if the input array contains +negative values, ``jax.numpy.bincount`` treats negative values as zero. +""") +def bincount(x, weights=None, minlength=0, *, length=None): + if not issubdtype(_dtype(x), integer): + msg = f"x argument to bincount must have an integer type; got {x.dtype}" + raise TypeError(msg) + if length is None: + length = max(x) + 1 + length = _max(length, minlength) + if ndim(x) != 1: + raise ValueError("only 1-dimensional input supported.") + if weights is None: + weights = array(1, dtype=int32) + else: + if shape(x) != shape(weights): + raise ValueError("shape of weights must match shape of x.") + return ops.index_add(zeros((length,), _dtype(weights)), ops.index[clip(x, 0)], weights) + + def broadcast_arrays(*args): """Like Numpy's broadcast_arrays but doesn't return views.""" shapes = [shape(arg) for arg in args] diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 7eb2a0c506fe..627fc78bc5e7 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2525,6 +2525,52 @@ def testAtLeastNdLiterals(self, pytype, dtype, op): self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + @parameterized.named_parameters(jtu.cases_from_list( + { + "testcase_name": "_shape={}_dtype={}_weights={}_minlength={}_length={}".format( + shape, dtype, weights, minlength, length + ), + "shape": shape, + "dtype": dtype, + "weights": weights, + "minlength": minlength, + "length": length, + "rng_factory": rng_factory} + for shape in [(5,), (10,)] + for dtype in int_dtypes + for weights in [True, False] + for minlength in [0, 20] + for length in [None, 10] + for rng_factory in [jtu.rand_positive] + )) + def testBincount(self, shape, dtype, weights, minlength, length, rng_factory): + rng = rng_factory(self.rng()) + args_maker = lambda: (rng(shape, dtype), (rng(shape, 'float32') if weights else None)) + + onp_fun = partial(onp.bincount, minlength=minlength) + jnp_fun = partial(jnp.bincount, minlength=minlength, length=length) + + if length is not None: + self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) + if length is None: + self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=False) + + def testBincountNegative(self): + # Test that jnp.bincount ignores negative values. + x_rng = jtu.rand_int(self.rng(), -100, 100) + w_rng = jtu.rand_uniform(self.rng()) + shape = (1000,) + x = x_rng(shape, 'int32') + w = w_rng(shape, 'float32') + + xn = onp.array(x) + xn[xn < 0] = 0 + wn = onp.array(w) + onp_result = onp.bincount(xn[xn >= 0], wn[xn >= 0]) + jnp_result = jnp.bincount(x, w) + self.assertAllClose(onp_result, jnp_result, check_dtypes=False) + + @parameterized.named_parameters(*jtu.cases_from_list( {"testcase_name": "_case={}".format(i), "input": input}