diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 063307bae2df..7a6aaa25282b 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -1020,25 +1020,44 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array: else: return transpose_p.bind(operand, permutation=permutation) -def reduce(operand: Array, init_value: Array, computation: Callable, - dimensions: Sequence[int]) -> Array: +def reduce(operand: Union[Array, Sequence[Array]], + init_value: Union[Array, Sequence[Array]], + computation: Callable, + dimensions: Sequence[int]) -> Union[Array, Tuple[Array, ...]]: """Wraps XLA's `Reduce `_ operator. """ - monoid_reducer = _get_monoid_reducer(computation, init_value) - if monoid_reducer: - return monoid_reducer(operand, dimensions) + return_tuple = isinstance(operand, Sequence) + if not return_tuple: + operand = (operand,) + if not isinstance(init_value, Sequence): + init_value = (init_value,) + if len(operand) == 0: + raise TypeError("reduce requires at least one operand") + if len(operand) != len(init_value): + raise TypeError("reduce: length of operands tuple must match length of init_values tuple; got " + f"len(operand)={len(operand)}, len(init_value)={len(init_value)}.") + + monoid_reducer = _get_monoid_reducer(computation, init_value[0]) + if len(operand) == 1 and monoid_reducer: + out = (monoid_reducer(operand[0], dimensions),) else: - jaxpr, consts = _reduction_jaxpr(computation, _abstractify(init_value)) - return reduce_p.bind(operand, init_value, computation=computation, - jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions)) + jaxpr, consts = _reduction_jaxpr(computation, *(_abstractify(v) for v in init_value)) + # TODO(mattjj): handle consts correctly + # TODO(mattjj): don't pass computation + out = reduce_p.bind(*operand, *init_value, computation=computation, + jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions)) + return tuple(out) if return_tuple else out[0] @cache() -def _reduction_jaxpr(computation, aval): - pval = pe.PartialVal.unknown(aval) - comp = lu.wrap_init(lambda x, y: (computation(x, y),)) - jaxpr, _, consts = pe.trace_to_jaxpr(comp, (pval, pval), instantiate=False) +def _reduction_jaxpr(computation, *avals): + pvals = tuple(pe.PartialVal.unknown(aval) for aval in avals) + if len(pvals) == 1: + comp = lu.wrap_init(lambda x, y: (computation(x, y),)) + else: + comp = lu.wrap_init(computation) + jaxpr, _, consts = pe.trace_to_jaxpr(comp, 2 * pvals, instantiate=True) return jaxpr, consts def _get_monoid_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]: @@ -4078,26 +4097,50 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, partial(_scatter_batching_rule, scatter)) -def _reduce_shape_rule(operand, init_value, *, computation, jaxpr, consts, - dimensions): - return tuple(onp.delete(operand.shape, dimensions)) - -def _reduce_translation_rule(c, operand, init_value, *, computation, jaxpr, - consts, dimensions): - xla_computation = _reduction_computation(c, jaxpr, consts, init_value) - return xops.Reduce(c, [operand], [init_value], xla_computation, dimensions) +def _reduce_abstract_eval(*args, dimensions, **kwargs): + N = len(args) // 2 + operands, init_values = args[:N], args[N:] + if len(operands) != len(init_values): + raise TypeError("Expected number of operands to equal number of init_values; " + f"got {len(operands)} and {len(init_values)}") + if any(operand.shape != operands[0].shape for operand in operands[1:]): + shapes = " ".join(str(operand.shape) for operand in operands) + raise TypeError(f"Arguments to reduce must have equal shapes, got: {shapes}") + shape = tuple(onp.delete(operands[0].shape, dimensions)) + return tuple( + ShapedArray(shape, dtype=dtypes.canonicalize_dtype(operand.dtype)) + for operand in operands + ) + +def _reduce_translation_rule(c, *args, computation, jaxpr, consts, dimensions): + N = len(args) // 2 + operands, init_values = args[:N], args[N:] + assert len(operands) == len(init_values) + shapes = [c.get_shape(v) for v in init_values] + axis_env = xla.AxisEnv(1) # no parallel primitives inside reductions + subc = xla_bridge.make_computation_builder("variadic_reduction_computation") + assert len(consts) == 0, "Reduction computations cannot have constants" + args = [xb.parameter(subc, 2 * i + j, shape) + for i, shape in enumerate(shapes) for j in range(2)] + out = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args) + xla_computation = subc.build(xops.Tuple(subc, out)) + return xops.Reduce(c, operands, init_values, xla_computation, dimensions) def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr, consts, dimensions): + if len(batched_args) != 2: + # TODO(jakevdp): implement this after generalizing reduce implementation. + raise NotImplementedError("reduce batch rule for more than one array.") operand, init_value = batched_args operand_bdim, init_value_bdim = batch_dims - if init_value_bdim is None: - assert operand_bdim is not None - new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions] - new_operand_bdim = operand_bdim - int(onp.sum(onp.less(dimensions, operand_bdim))) - return reduce(operand, init_value, computation, new_dimensions), new_operand_bdim - else: - raise NotImplementedError # loop and stack + if init_value_bdim is not None: + # TODO(jakevdp): implement this via loop and stack. + raise NotImplementedError("batched reduce with different init_val per batch") + assert operand_bdim is not None + new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions] + new_operand_bdim = operand_bdim - int(onp.sum(onp.less(dimensions, operand_bdim))) + out = reduce(operand, init_value, computation, new_dimensions) + return (out,), (new_operand_bdim,) def _reduction_computation(c, jaxpr, consts, init_value): shape = c.get_shape(init_value) @@ -4122,8 +4165,12 @@ def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes, bind = prim.bind if input_shape is None else partial(prim.bind, input_shape=padded_shape) return bind(masked_val, axes=axes) -reduce_p = standard_primitive(_reduce_shape_rule, _input_dtype, 'reduce', - _reduce_translation_rule) + +reduce_p = Primitive('reduce') +reduce_p.multiple_results = True +reduce_p.def_impl(partial(xla.apply_primitive, reduce_p)) +reduce_p.def_abstract_eval(_reduce_abstract_eval) +xla.translations[reduce_p] = _reduce_translation_rule batching.primitive_batchers[reduce_p] = _reduce_batch_rule diff --git a/tests/lax_test.py b/tests/lax_test.py index e60fd8f4bcdb..f45e3981e105 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -32,6 +32,7 @@ from jax import test_util as jtu from jax import lax_reference from jax.test_util import check_grads +from jax.lax.lax import _get_min_identity, _get_max_identity from jax.lib import xla_client import jax.util @@ -1221,6 +1222,29 @@ def testTransposeAgainstNumpy(self, shape, dtype, perm, rng_factory): numpy_op = lambda x: lax_reference.transpose(x, perm) self._CheckAgainstNumpy(op, numpy_op, args_maker) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_reducedims={}" + .format(jtu.format_shape_dtype_string(shape, dtype), dims), + "shape": shape, "dtype": dtype, "dims": dims} + for dtype in default_dtypes + for shape, dims in [ + [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)], + [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)] + ])) + def testMultiReduce(self, shape, dtype, dims): + rng = jtu.rand_default(self.rng()) + op = lambda a1, b1, a2, b2: (lax.min(a1, a2), lax.max(b1, b2)) + + np_fun = lambda a, b: (a.min(axis=dims), b.max(axis=dims)) + def jnp_fun(a, b): + # device_put here to ensure dtype below is correct. + a, b = jax.device_put(a), jax.device_put(b) + init_val = (_get_min_identity(a.dtype), _get_max_identity(b.dtype)) + return lax.reduce((a, b), init_val, op, dims) + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)] + self._CompileAndCheck(jnp_fun, args_maker) + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,