Skip to content

Commit 5bbb449

Browse files
author
jax authors
committed
Merge pull request jax-ml#5777 from google:issue5776
PiperOrigin-RevId: 358211915
2 parents 8ad5118 + 5a97eab commit 5bbb449

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

jax/api_util.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ def apply_flat_fun_nokwargs(fun, io_tree, py_args):
7575
@lu.transformation_with_aux
7676
def flatten_fun_nokwargs2(in_tree, *args_flat):
7777
py_args = tree_unflatten(in_tree, args_flat)
78-
ans, aux = yield py_args, {}
78+
pair = yield py_args, {}
79+
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
80+
raise TypeError("expected function with aux output to return a two-element "
81+
f"tuple, but got type {type(pair)} with value {repr(pair)}")
82+
ans, aux = pair
7983
ans_flat, ans_tree = tree_flatten(ans)
8084
aux_flat, aux_tree = tree_flatten(aux)
8185
yield (ans_flat, aux_flat), (ans_tree, aux_tree)

tests/api_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,16 @@ def test_grad_and_aux_basic(self):
876876
self.assertAllClose(g, grad(lambda x: x**3)(3.))
877877
self.assertAllClose(aux, [9.], check_dtypes=False)
878878

879+
def test_grad_and_aux_error(self):
880+
with self.assertRaisesRegex(TypeError, "two-element tuple"):
881+
grad(lambda x: (1, 2, 3), has_aux=True)(1.)
882+
883+
with self.assertRaisesRegex(TypeError, "two-element tuple"):
884+
grad(lambda x: x, has_aux=True)(1.)
885+
886+
with self.assertRaisesRegex(TypeError, "two-element tuple"):
887+
grad(lambda x: (x,), has_aux=True)(1.)
888+
879889
def test_grad_and_aux_nested(self):
880890
def f(x):
881891
g, aux = grad(lambda x: (x**3, [x**3]), has_aux=True)(x)
@@ -2319,6 +2329,7 @@ def __jax_array__(self):
23192329
self.assertEqual(f(x), f(a))
23202330

23212331

2332+
23222333
class RematTest(jtu.JaxTestCase):
23232334

23242335
def test_remat_basic(self):

0 commit comments

Comments
 (0)