Skip to content

Commit

Permalink
Add np.triu_indices_from function (jax-ml#3346)
Browse files Browse the repository at this point in the history
  • Loading branch information
ayush-1506 authored Jun 10, 2020
1 parent b3c348c commit 832431d
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,10 @@ Not every function in NumPy is implemented; contributions are welcome!
tri
tril
tril_indices
tril_indices_from
triu
triu_indices
triu_indices_from
true_divide
trunc
unique
Expand Down
4 changes: 2 additions & 2 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
save, savez, searchsorted, select, set_printoptions, shape, sign, signbit,
signedinteger, sin, sinc, single, sinh, size, sometrue, sort, split, sqrt,
square, squeeze, stack, std, subtract, sum, swapaxes, take, take_along_axis,
tan, tanh, tensordot, tile, trace, trapz, transpose, tri, tril, tril_indices,
triu, triu_indices, true_divide, trunc, uint16, uint32, uint64, uint8, unique,
tan, tanh, tensordot, tile, trace, trapz, transpose, tri, tril, tril_indices, tril_indices_from,
triu, triu_indices, triu_indices_from, true_divide, trunc, uint16, uint32, uint64, uint8, unique,
unpackbits, unravel_index, unsignedinteger, vander, var, vdot, vsplit,
vstack, where, zeros, zeros_like)

Expand Down
11 changes: 11 additions & 0 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,6 +2542,17 @@ def wrapper(*args, **kwargs):
triu_indices = _wrap_indices_function(np.triu_indices)
mask_indices = _wrap_indices_function(np.mask_indices)


@_wraps(np.triu_indices_from)
def triu_indices_from(arr, k=0):
return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1])


@_wraps(np.tril_indices_from)
def tril_indices_from(arr, k=0):
return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1])


@_wraps(np.diag_indices)
def diag_indices(n, ndim=2):
if n < 0:
Expand Down
52 changes: 52 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,58 @@ def testTriLU(self, dtype, shape, op, k, rng_factory):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "n={}_k={}_m={}".format(n, k, m),
"n": n, "k": k, "m": m}
for n in range(1, 5)
for k in [-1, 0, 1]
for m in range(1, 5)))
def testTrilIndices(self, n, k, m):
np_fun = lambda n, k, m: np.tril_indices(n, k=k, m=m)
jnp_fun = lambda n, k, m: jnp.tril_indices(n, k=k, m=m)
args_maker = lambda: [n, k, m]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "n={}_k={}_m={}".format(n, k, m),
"n": n, "k": k, "m": m}
for n in range(1, 5)
for k in [-1, 0, 1]
for m in range(1, 5)))
def testTriuIndices(self, n, k, m):
np_fun = lambda n, k, m: np.triu_indices(n, k=k, m=m)
jnp_fun = lambda n, k, m: jnp.triu_indices(n, k=k, m=m)
args_maker = lambda: [n, k, m]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k),
"dtype": dtype, "shape": shape, "k": k, "rng_factory": jtu.rand_default}
for dtype in default_dtypes
for shape in [(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)]
for k in [-1, 0, 1]))
def testTriuIndicesFrom(self, shape, dtype, k, rng_factory):
rng = rng_factory(self.rng())
np_fun = lambda arr, k: np.triu_indices_from(arr, k=k)
jnp_fun = lambda arr, k: jnp.triu_indices_from(arr, k=k)
args_maker = lambda: [rng(shape, dtype), k]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_k={}".format(
jtu.format_shape_dtype_string(shape, dtype), k),
"dtype": dtype, "shape": shape, "k": k, "rng_factory": jtu.rand_default}
for dtype in default_dtypes
for shape in [(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)]
for k in [-1, 0, 1]))
def testTrilIndicesFrom(self, shape, dtype, k, rng_factory):
rng = rng_factory(self.rng())
np_fun = lambda arr, k: np.tril_indices_from(arr, k=k)
jnp_fun = lambda arr, k: jnp.tril_indices_from(arr, k=k)
args_maker = lambda: [rng(shape, dtype), k]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_ndim={}_n={}".format(ndim, n),
"ndim": ndim, "n": n}
Expand Down

0 comments on commit 832431d

Please sign in to comment.