Skip to content

Commit

Permalink
reuse commonly-typed input/output values when joining partially evalu…
Browse files Browse the repository at this point in the history
…ated conditional branches
  • Loading branch information
froystig committed Jun 12, 2020
1 parent dd040de commit ddea95e
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 43 deletions.
110 changes: 67 additions & 43 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,14 +847,19 @@ def _cond_partial_eval(trace, *tracers, branches, linear):

num_outs = len(branches_2[0].out_avals)

branches_1 = _join_cond_outputs(branches_1, branch_res_avals, num_outs)
branches_2 = _join_cond_pe_staged_jaxpr_inputs(branches_2, branch_res_avals)
all_res_avals, res_avals_per_branch = _merge_branch_residuals(
branch_res_avals)

branches_1 = _join_cond_outputs(
branches_1, all_res_avals, res_avals_per_branch, num_outs)
branches_2 = _join_cond_pe_staged_jaxpr_inputs(
branches_2, all_res_avals, res_avals_per_branch)

# TODO(frostig,mattjj): reinstate this assertion once pe.partial_eval_jaxpr
# raises to shaped avals
# for j in branches_1[1:]:
# assert j.out_avals == branches_1[0].out_avals
num_res = sum(_map(len, branch_res_avals))
num_res = len(all_res_avals)

_, in_consts = unzip2([t.pval for t in tracers])
out_consts_res = cond_p.bind(*in_consts, branches=branches_1, linear=linear)
Expand Down Expand Up @@ -888,65 +893,84 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
# staged jaxpr that accepts those residuals as its first few inputs. The
# residual-producing branches are staged as jaxprs and bound right away in a
# conditional. The residual-consuming jaxprs are assembled together in a jaxpr
# conditional. The following two helper functions are used to ensure that both
# collections of jaxprs (those evaluated and those staged) are valid for joint
# use under their respective conditionals.

# Because every branch might produce different residuals, the branches' output
# signatures might not match. But we need branch signatures to match in order to
# bind them in a conditional. This function "joins" the residual outputs of the
# branches by concatenation. Each augmented branch returns zero-filled values in
# the place of all other branches' residuals.
def _join_cond_outputs(jaxprs, res_avals_per_jaxpr, num_non_res_outputs):
def augment_jaxpr(i, jaxpr):
res_avals_prefix = util.concatenate(res_avals_per_jaxpr[:i])
res_avals_suffix = util.concatenate(res_avals_per_jaxpr[i+1:])

# conditional. The following helper functions ensure that both collections of
# jaxprs (those evaluated and those staged) are valid for joint use under their
# respective conditionals.
#
# In particular, the residuals derived from each original branch may have
# distinct types. Because the branches of conditionals must have identical type
# signatures, we join residuals together across branches into a common format.

# In order to set up a type signature that all branches can conform to, it would
# suffice to concatenate all branches' residuals. But concatenation can result
# in redundant inputs and outputs, and might lead to memory allocation that
# scales unnecessarily with the branch count. This function finds common
# residual types across branches for reuse, so as to avoid redundant
# allocation. It returns a list L of types (avals) representing the collection
# of residuals merged according to type, and, for each branch, a lookup table to
# match its residuals to their positions/types in L. Example input/output:
#
# [x], [y], [x, x] -> [x, y, x], [[0], [1], [0, 2]]
# [x], [x], [x, x] -> [x, x], [[0], [0], [0, 1]]
# [y, x, x], [x, z, y], [z, x] -> [y, x, x, z], [[0, 1, 2], [1, 3, 0], [3, 1]]
def _merge_branch_residuals(branch_res_avals):
def enumerate_equal(xs):
counts = {v: itertools.count() for v in set(xs)}
return [(x, next(counts[x])) for x in xs]
branch_res_tagged_avals = _map(enumerate_equal, branch_res_avals)
all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals))
indices = {v: i for i, v in enumerate(all_tagged_avals)}
branch_indices = [
[indices[aval] for aval in avals] for avals in branch_res_tagged_avals]
all_avals = [x for x, _ in all_tagged_avals]
return all_avals, branch_indices

# This function augments branch outputs to agree with the merged residual
# format: each branch is made to return zero-filled values in the places of
# residual outputs that it does not populate.
def _join_cond_outputs(jaxprs, all_res_avals, res_aval_indices_per_jaxpr,
num_non_res_outputs):
def augment_jaxpr(jaxpr, res_indices):
@lu.wrap_init
def f_aug(*args):
outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
zeros_prefix = _map(ad_util.zeros_like_aval, res_avals_prefix)
zeros_suffix = _map(ad_util.zeros_like_aval, res_avals_suffix)
return outs + zeros_prefix + residuals + zeros_suffix
aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals)
aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
return outs + list(aug_residuals)

return _make_typed_jaxpr(f_aug, jaxpr.in_avals)

return tuple(augment_jaxpr(i, jaxpr) for i, jaxpr in enumerate(jaxprs))
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))

# To use these staged jaxprs as the branches of another conditional, we need for
# their (input) signatures to match. This function "joins" the staged jaxprs:
# for each one, it makes another that accepts *all* residuals, but still only
# uses those that it needs (dropping the rest).
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, res_avals_per_jaxpr):
# This function augments branch inputs to agree with the merged residual format:
# each branch is made to accept all residuals, even though it will ignore those
# that it does not read.
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
res_aval_indices_per_jaxpr):
newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_')
unused_res_vars = tuple(
tuple(newvar(aval) for aval in res_avals)
for res_avals in res_avals_per_jaxpr)

def pad_jaxpr_res_avals(i, jaxpr):
res_vars_prefix = util.concatenate(unused_res_vars[:i])
res_vars_suffix = util.concatenate(unused_res_vars[i+1:])
res_avals_prefix = util.concatenate(res_avals_per_jaxpr[:i])
res_avals_suffix = util.concatenate(res_avals_per_jaxpr[i+1:])

res_avals = res_avals_per_jaxpr[i]
num_res = len(res_avals)
res_vars = jaxpr.jaxpr.invars[:num_res]
all_res_vars = _map(newvar, all_res_avals)

def augment_jaxpr(jaxpr, res_indices):
num_res = len(res_indices)
res_vars = jaxpr.jaxpr.invars[:num_res]
non_res_vars = jaxpr.jaxpr.invars[num_res:]
non_res_avals = jaxpr.in_avals[num_res:]

aug_invars = res_vars_prefix + res_vars + res_vars_suffix + non_res_vars
aug_avals = res_avals_prefix + res_avals + res_avals_suffix + non_res_avals

aug_res_vars = list(util.subvals(all_res_vars, zip(res_indices, res_vars)))
aug_invars = aug_res_vars + non_res_vars
aug_avals = all_res_avals + non_res_avals
jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns)
jaxpr_aug = core.TypedJaxpr(jaxpr_aug, jaxpr.literals, aug_avals,
jaxpr.out_avals)
return jaxpr_aug

return tuple(pad_jaxpr_res_avals(i, jaxpr) for i, jaxpr in enumerate(jaxprs))
return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))

def _ordered_unique(xs):
d = collections.OrderedDict((x, None) for x in xs)
return list(d.keys())

def _transpose_cond_jaxpr(jaxpr, num_res):
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
Expand Down
45 changes: 45 additions & 0 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,51 @@ def cfun(x):
self.assertEqual(fun(2), cfun(2))
self.assertEqual(fun(3), cfun(3))

def testSwitchResidualsMerge(self):
def get_conds(fun):
jaxpr = api.make_jaxpr(api.grad(fun))(0., 0)
return [eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == 'cond']

def branch_invars_len(cond_eqn):
lens = [len(jaxpr.jaxpr.invars) for jaxpr in cond_eqn.params['branches']]
assert len(set(lens)) == 1
return lens[0]

def branch_outvars_len(cond_eqn):
lens = [len(jaxpr.jaxpr.outvars) for jaxpr in cond_eqn.params['branches']]
assert len(set(lens)) == 1
return lens[0]

branches1 = [
lambda x: jnp.sin(x),
lambda x: jnp.cos(x)] # branch residuals overlap, should be reused
branches2 = branches1 + [
lambda x: jnp.sinh(x)] # another overlapping residual, expect reuse
branches3 = branches2 + [
lambda x: jnp.sin(x) + jnp.cos(x)] # requires one more residual slot
def fun1(x, i):
return lax.switch(i + 1, branches1, x)
def fun2(x, i):
return lax.switch(i + 1, branches2, x)
def fun3(x, i):
return lax.switch(i + 1, branches3, x)

fwd1, bwd1 = get_conds(fun1)
fwd2, bwd2 = get_conds(fun2)
fwd3, bwd3 = get_conds(fun3)

fwd1_num_out = branch_outvars_len(fwd1)
fwd2_num_out = branch_outvars_len(fwd2)
fwd3_num_out = branch_outvars_len(fwd3)
assert fwd1_num_out == fwd2_num_out
assert fwd3_num_out == fwd2_num_out + 1

bwd1_num_in = branch_invars_len(bwd1)
bwd2_num_in = branch_invars_len(bwd2)
bwd3_num_in = branch_invars_len(bwd3)
assert bwd1_num_in == bwd2_num_in
assert bwd3_num_in == bwd2_num_in + 1

def testOneBranchSwitch(self):
branch = lambda x: -x
f = lambda i, x: lax.switch(i, [branch], x)
Expand Down

0 comments on commit ddea95e

Please sign in to comment.