Skip to content

Commit

Permalink
Initial implementation of variadic lax.reduce() (jax-ml#3342)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp authored Jun 9, 2020
1 parent d3ccf0a commit 99401c5
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 29 deletions.
105 changes: 76 additions & 29 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,25 +1020,44 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array:
else:
return transpose_p.bind(operand, permutation=permutation)

def reduce(operand: Array, init_value: Array, computation: Callable,
dimensions: Sequence[int]) -> Array:
def reduce(operand: Union[Array, Sequence[Array]],
init_value: Union[Array, Sequence[Array]],
computation: Callable,
dimensions: Sequence[int]) -> Union[Array, Tuple[Array, ...]]:
"""Wraps XLA's `Reduce
<https://www.tensorflow.org/xla/operation_semantics#reduce>`_
operator.
"""
monoid_reducer = _get_monoid_reducer(computation, init_value)
if monoid_reducer:
return monoid_reducer(operand, dimensions)
return_tuple = isinstance(operand, Sequence)
if not return_tuple:
operand = (operand,)
if not isinstance(init_value, Sequence):
init_value = (init_value,)
if len(operand) == 0:
raise TypeError("reduce requires at least one operand")
if len(operand) != len(init_value):
raise TypeError("reduce: length of operands tuple must match length of init_values tuple; got "
f"len(operand)={len(operand)}, len(init_value)={len(init_value)}.")

monoid_reducer = _get_monoid_reducer(computation, init_value[0])
if len(operand) == 1 and monoid_reducer:
out = (monoid_reducer(operand[0], dimensions),)
else:
jaxpr, consts = _reduction_jaxpr(computation, _abstractify(init_value))
return reduce_p.bind(operand, init_value, computation=computation,
jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions))
jaxpr, consts = _reduction_jaxpr(computation, *(_abstractify(v) for v in init_value))
# TODO(mattjj): handle consts correctly
# TODO(mattjj): don't pass computation
out = reduce_p.bind(*operand, *init_value, computation=computation,
jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions))
return tuple(out) if return_tuple else out[0]

@cache()
def _reduction_jaxpr(computation, aval):
pval = pe.PartialVal.unknown(aval)
comp = lu.wrap_init(lambda x, y: (computation(x, y),))
jaxpr, _, consts = pe.trace_to_jaxpr(comp, (pval, pval), instantiate=False)
def _reduction_jaxpr(computation, *avals):
pvals = tuple(pe.PartialVal.unknown(aval) for aval in avals)
if len(pvals) == 1:
comp = lu.wrap_init(lambda x, y: (computation(x, y),))
else:
comp = lu.wrap_init(computation)
jaxpr, _, consts = pe.trace_to_jaxpr(comp, 2 * pvals, instantiate=True)
return jaxpr, consts

def _get_monoid_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]:
Expand Down Expand Up @@ -4078,26 +4097,50 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
partial(_scatter_batching_rule, scatter))


def _reduce_shape_rule(operand, init_value, *, computation, jaxpr, consts,
dimensions):
return tuple(onp.delete(operand.shape, dimensions))

def _reduce_translation_rule(c, operand, init_value, *, computation, jaxpr,
consts, dimensions):
xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
return xops.Reduce(c, [operand], [init_value], xla_computation, dimensions)
def _reduce_abstract_eval(*args, dimensions, **kwargs):
N = len(args) // 2
operands, init_values = args[:N], args[N:]
if len(operands) != len(init_values):
raise TypeError("Expected number of operands to equal number of init_values; "
f"got {len(operands)} and {len(init_values)}")
if any(operand.shape != operands[0].shape for operand in operands[1:]):
shapes = " ".join(str(operand.shape) for operand in operands)
raise TypeError(f"Arguments to reduce must have equal shapes, got: {shapes}")
shape = tuple(onp.delete(operands[0].shape, dimensions))
return tuple(
ShapedArray(shape, dtype=dtypes.canonicalize_dtype(operand.dtype))
for operand in operands
)

def _reduce_translation_rule(c, *args, computation, jaxpr, consts, dimensions):
N = len(args) // 2
operands, init_values = args[:N], args[N:]
assert len(operands) == len(init_values)
shapes = [c.get_shape(v) for v in init_values]
axis_env = xla.AxisEnv(1) # no parallel primitives inside reductions
subc = xla_bridge.make_computation_builder("variadic_reduction_computation")
assert len(consts) == 0, "Reduction computations cannot have constants"
args = [xb.parameter(subc, 2 * i + j, shape)
for i, shape in enumerate(shapes) for j in range(2)]
out = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args)
xla_computation = subc.build(xops.Tuple(subc, out))
return xops.Reduce(c, operands, init_values, xla_computation, dimensions)

def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr, consts,
dimensions):
if len(batched_args) != 2:
# TODO(jakevdp): implement this after generalizing reduce implementation.
raise NotImplementedError("reduce batch rule for more than one array.")
operand, init_value = batched_args
operand_bdim, init_value_bdim = batch_dims
if init_value_bdim is None:
assert operand_bdim is not None
new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions]
new_operand_bdim = operand_bdim - int(onp.sum(onp.less(dimensions, operand_bdim)))
return reduce(operand, init_value, computation, new_dimensions), new_operand_bdim
else:
raise NotImplementedError # loop and stack
if init_value_bdim is not None:
# TODO(jakevdp): implement this via loop and stack.
raise NotImplementedError("batched reduce with different init_val per batch")
assert operand_bdim is not None
new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions]
new_operand_bdim = operand_bdim - int(onp.sum(onp.less(dimensions, operand_bdim)))
out = reduce(operand, init_value, computation, new_dimensions)
return (out,), (new_operand_bdim,)

def _reduction_computation(c, jaxpr, consts, init_value):
shape = c.get_shape(init_value)
Expand All @@ -4122,8 +4165,12 @@ def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes,
bind = prim.bind if input_shape is None else partial(prim.bind, input_shape=padded_shape)
return bind(masked_val, axes=axes)

reduce_p = standard_primitive(_reduce_shape_rule, _input_dtype, 'reduce',
_reduce_translation_rule)

reduce_p = Primitive('reduce')
reduce_p.multiple_results = True
reduce_p.def_impl(partial(xla.apply_primitive, reduce_p))
reduce_p.def_abstract_eval(_reduce_abstract_eval)
xla.translations[reduce_p] = _reduce_translation_rule
batching.primitive_batchers[reduce_p] = _reduce_batch_rule


Expand Down
24 changes: 24 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from jax import test_util as jtu
from jax import lax_reference
from jax.test_util import check_grads
from jax.lax.lax import _get_min_identity, _get_max_identity
from jax.lib import xla_client
import jax.util

Expand Down Expand Up @@ -1221,6 +1222,29 @@ def testTransposeAgainstNumpy(self, shape, dtype, perm, rng_factory):
numpy_op = lambda x: lax_reference.transpose(x, perm)
self._CheckAgainstNumpy(op, numpy_op, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_inshape={}_reducedims={}"
.format(jtu.format_shape_dtype_string(shape, dtype), dims),
"shape": shape, "dtype": dtype, "dims": dims}
for dtype in default_dtypes
for shape, dims in [
[(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
]))
def testMultiReduce(self, shape, dtype, dims):
rng = jtu.rand_default(self.rng())
op = lambda a1, b1, a2, b2: (lax.min(a1, a2), lax.max(b1, b2))

np_fun = lambda a, b: (a.min(axis=dims), b.max(axis=dims))
def jnp_fun(a, b):
# device_put here to ensure dtype below is correct.
a, b = jax.device_put(a), jax.device_put(b)
init_val = (_get_min_identity(a.dtype), _get_max_identity(b.dtype))
return lax.reduce((a, b), init_val, op, dims)
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)]
self._CompileAndCheck(jnp_fun, args_maker)
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}"
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
Expand Down

0 comments on commit 99401c5

Please sign in to comment.