Skip to content

Commit 8ad5118

Browse files
author
jax authors
committed
Merge pull request jax-ml#5762 from hawkinsp:padbatch
PiperOrigin-RevId: 358070219
2 parents 347481b + 0dd1b55 commit 8ad5118

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

docs/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
1616
* Added {func}`jax.experimental.jax2tf.call_tf` to call TensorFlow functions
1717
from JAX ([#5627](https://github.com/google/jax/pull/5627)
1818
and [README](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)).
19+
* Extended the batching rule for `lax.pad` to support batching of the padding values.
1920
* Bug fixes:
2021
* {func}`jax.numpy.take` properly handles negative indices
2122
([#5768](https://github.com/google/jax/pull/5768))

jax/_src/lax/lax.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -3536,13 +3536,22 @@ def t_op():
35363536
def _pad_batch_rule(batched_args, batch_dims, *, padding_config):
35373537
operand, padding_value = batched_args
35383538
operand_bdim, padding_value_bdim = batch_dims
3539+
if operand_bdim is None:
3540+
operand_bdim = 0
3541+
operand = broadcast(operand, (padding_value.shape[padding_value_bdim],))
3542+
3543+
padding_config = list(padding_config)
3544+
padding_config.insert(operand_bdim, (0, 0, 0))
35393545
if padding_value_bdim is None:
3540-
assert operand_bdim is not None
3541-
padding_config = list(padding_config)
3542-
padding_config.insert(operand_bdim, (0, 0, 0))
35433546
return pad(operand, padding_value, padding_config), operand_bdim
3544-
else:
3545-
raise NotImplementedError # loop and stack
3547+
3548+
assert padding_value_bdim == 0, padding_value_bdim
3549+
3550+
x = pad(operand, _zero(operand), padding_config)
3551+
mask = pad(full_like(operand, True, np.bool_), False, padding_config)
3552+
broadcasted_padding = broadcast_in_dim(padding_value, x.shape,
3553+
(operand_bdim,))
3554+
return select(mask, x, broadcasted_padding), operand_bdim
35463555

35473556
def _pad_translation_rule(c, operand, padding_value, *, padding_config):
35483557
return xops.Pad(operand, padding_value,

tests/lax_vmap_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -369,13 +369,13 @@ def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims):
369369
.format(jtu.format_shape_dtype_string(shape, dtype), pads, bdims),
370370
"shape": shape, "dtype": dtype, "pads": pads, "bdims": bdims}
371371
for shape in [(2, 3)]
372-
for bdims in all_bdims(shape)
372+
for bdims in all_bdims(shape, ())
373373
for dtype in default_dtypes
374374
for pads in [[(1, 2, 1), (0, 1, 0)]]))
375375
def testPad(self, shape, dtype, pads, bdims):
376376
rng = jtu.rand_small(self.rng())
377-
fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
378-
self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)
377+
fun = lambda operand, padding: lax.pad(operand, padding, pads)
378+
self._CheckBatching(fun, 5, bdims, (shape, ()), (dtype, dtype), rng)
379379

380380
@parameterized.named_parameters(jtu.cases_from_list(
381381
{"testcase_name": "_predshape={}_argshapes={}_bdims={}".format(

0 commit comments

Comments
 (0)