Skip to content

Commit 6f2d22f

Browse files
author
Sam Schoenholz
committed
Tiny change to enable vmap with dimension numbers.
1 parent 20f167d commit 6f2d22f

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

jax/lax/lax.py

-1
Original file line numberDiff line numberDiff line change
@@ -2526,7 +2526,6 @@ def _reshape_batch_rule(batched_args, batch_dims, new_sizes, dimensions, **unuse
25262526
bdim, = batch_dims
25272527
operand = batching.moveaxis(operand, bdim, 0)
25282528
if dimensions is not None:
2529-
raise NotImplementedError # TODO(mattjj): handle reshape w/ dimensions
25302529
dimensions = (0,) + tuple(onp.add(1, dimensions))
25312530
return reshape(operand, operand.shape[:1] + new_sizes, dimensions), 0
25322531

tests/lax_test.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -2538,20 +2538,25 @@ def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims, rng):
25382538
self._CheckBatching(op, 5, bdims, (inshape,), dtype, rng)
25392539

25402540
@parameterized.named_parameters(jtu.cases_from_list(
2541-
{"testcase_name": "_inshape={}_outshape={}_bdims={}".format(
2541+
{"testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format(
25422542
jtu.format_shape_dtype_string(arg_shape, dtype),
25432543
jtu.format_shape_dtype_string(out_shape, dtype),
2544-
bdims),
2544+
dimensions, bdims),
25452545
"arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
2546-
"bdims": bdims, "rng": rng}
2546+
"dimensions": dimensions, "bdims": bdims, "rng": rng}
25472547
for dtype in default_dtypes
2548-
for arg_shape, out_shape in [
2549-
[(3, 4), (12,)], [(2, 1, 4), (8,)], [(2, 2, 4), (2, 8)]
2548+
for arg_shape, dimensions, out_shape in [
2549+
[(3, 4), None, (12,)],
2550+
[(2, 1, 4), None, (8,)],
2551+
[(2, 2, 4), None, (2, 8)],
2552+
[(2, 2, 4), (0, 1, 2), (2, 8)],
2553+
[(2, 2, 4), (1, 0, 2), (8, 2)],
2554+
[(2, 2, 4), (2, 1, 0), (4, 2, 2)]
25502555
]
25512556
for bdims in all_bdims(arg_shape)
25522557
for rng in [jtu.rand_default()]))
2553-
def testReshape(self, arg_shape, out_shape, dtype, bdims, rng):
2554-
op = lambda x: lax.reshape(x, out_shape)
2558+
def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims, rng):
2559+
op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions)
25552560
self._CheckBatching(op, 10, bdims, (arg_shape,), dtype, rng)
25562561

25572562
@parameterized.named_parameters(jtu.cases_from_list(

0 commit comments

Comments
 (0)