From 832431db9eea95f2212370b47bdd0e6e891fcd0a Mon Sep 17 00:00:00 2001 From: Ayush Shridhar Date: Thu, 11 Jun 2020 02:57:35 +1000 Subject: [PATCH] Add np.triu_indices_from function (#3346) --- docs/jax.numpy.rst | 2 ++ jax/numpy/__init__.py | 4 ++-- jax/numpy/lax_numpy.py | 11 +++++++++ tests/lax_numpy_test.py | 52 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 2 deletions(-) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 758d80d8405b..27117ba540ec 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -268,8 +268,10 @@ Not every function in NumPy is implemented; contributions are welcome! tri tril tril_indices + tril_indices_from triu triu_indices + triu_indices_from true_divide trunc unique diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index bdf7aea83cfa..cedd5bddbc47 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -54,8 +54,8 @@ save, savez, searchsorted, select, set_printoptions, shape, sign, signbit, signedinteger, sin, sinc, single, sinh, size, sometrue, sort, split, sqrt, square, squeeze, stack, std, subtract, sum, swapaxes, take, take_along_axis, - tan, tanh, tensordot, tile, trace, trapz, transpose, tri, tril, tril_indices, - triu, triu_indices, true_divide, trunc, uint16, uint32, uint64, uint8, unique, + tan, tanh, tensordot, tile, trace, trapz, transpose, tri, tril, tril_indices, tril_indices_from, + triu, triu_indices, triu_indices_from, true_divide, trunc, uint16, uint32, uint64, uint8, unique, unpackbits, unravel_index, unsignedinteger, vander, var, vdot, vsplit, vstack, where, zeros, zeros_like) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 5eb81f69b946..fa49f7e4dece 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -2542,6 +2542,17 @@ def wrapper(*args, **kwargs): triu_indices = _wrap_indices_function(np.triu_indices) mask_indices = _wrap_indices_function(np.mask_indices) + +@_wraps(np.triu_indices_from) +def triu_indices_from(arr, k=0): + return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1]) + + +@_wraps(np.tril_indices_from) +def tril_indices_from(arr, k=0): + return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1]) + + @_wraps(np.diag_indices) def diag_indices(n, ndim=2): if n < 0: diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 69d6d68e4b4f..1cabe46f9869 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1481,6 +1481,58 @@ def testTriLU(self, dtype, shape, op, k, rng_factory): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "n={}_k={}_m={}".format(n, k, m), + "n": n, "k": k, "m": m} + for n in range(1, 5) + for k in [-1, 0, 1] + for m in range(1, 5))) + def testTrilIndices(self, n, k, m): + np_fun = lambda n, k, m: np.tril_indices(n, k=k, m=m) + jnp_fun = lambda n, k, m: jnp.tril_indices(n, k=k, m=m) + args_maker = lambda: [n, k, m] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "n={}_k={}_m={}".format(n, k, m), + "n": n, "k": k, "m": m} + for n in range(1, 5) + for k in [-1, 0, 1] + for m in range(1, 5))) + def testTriuIndices(self, n, k, m): + np_fun = lambda n, k, m: np.triu_indices(n, k=k, m=m) + jnp_fun = lambda n, k, m: jnp.triu_indices(n, k=k, m=m) + args_maker = lambda: [n, k, m] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_k={}".format( + jtu.format_shape_dtype_string(shape, dtype), k), + "dtype": dtype, "shape": shape, "k": k, "rng_factory": jtu.rand_default} + for dtype in default_dtypes + for shape in [(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)] + for k in [-1, 0, 1])) + def testTriuIndicesFrom(self, shape, dtype, k, rng_factory): + rng = rng_factory(self.rng()) + np_fun = lambda arr, k: np.triu_indices_from(arr, k=k) + jnp_fun = lambda arr, k: jnp.triu_indices_from(arr, k=k) + args_maker = lambda: [rng(shape, dtype), k] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_k={}".format( + jtu.format_shape_dtype_string(shape, dtype), k), + "dtype": dtype, "shape": shape, "k": k, "rng_factory": jtu.rand_default} + for dtype in default_dtypes + for shape in [(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)] + for k in [-1, 0, 1])) + def testTrilIndicesFrom(self, shape, dtype, k, rng_factory): + rng = rng_factory(self.rng()) + np_fun = lambda arr, k: np.tril_indices_from(arr, k=k) + jnp_fun = lambda arr, k: jnp.tril_indices_from(arr, k=k) + args_maker = lambda: [rng(shape, dtype), k] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_ndim={}_n={}".format(ndim, n), "ndim": ndim, "n": n}