Skip to content

Add a linearization rule for cond_p. #28919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,19 @@ def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr:

def linearize_jaxpr(
jaxpr: core.ClosedJaxpr,
nonzeros: Sequence[bool]
nonzeros: Sequence[bool],
instantiate: bool | Sequence[bool] = False,
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]:
return _linearize_jaxpr(jaxpr, tuple(nonzeros))
if type(instantiate) is bool:
instantiate = (instantiate,) * len(jaxpr.out_avals)
return _linearize_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate))

@weakref_lru_cache
@source_info_util.reset_name_stack()
def _linearize_jaxpr(
jaxpr: core.ClosedJaxpr,
nonzeros: tuple[bool, ...]
nonzeros: tuple[bool, ...],
instantiate: tuple[bool, ...],
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]:
dbg = jaxpr.jaxpr.debug_info
primal_trace = pe.DynamicJaxprTrace(dbg)
Expand All @@ -188,6 +192,8 @@ def new_arg(trace, primal_aval, nz, source_info):
with core.set_current_trace(lin_trace, check_leaks=True):
ans = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *tracers)
out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans))
out_tangents = tuple(instantiate_zeros(t) if inst else t
for t, inst in zip(out_tangents, instantiate))
del lin_trace, ans, tracers, new_arg

debug_info = jaxpr.jaxpr.debug_info
Expand Down
47 changes: 47 additions & 0 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,52 @@ def _cond_jvp(primals, tangents, *, branches, **params):
for p, nz in zip(out_primals, out_nz)]
return out_primals, out_tangents

def _cond_lin(nzs, *args, branches, **params):
idx_nz, *ops_nz = nzs
assert idx_nz is False
idx, *ops = args
branches_nz = [ad.linearize_jaxpr(b, ops_nz, False)[2] for b in branches]
nzs_out = [any(nz) for nz in zip(*branches_nz)]
num_out = len(nzs_out)
branch_res_avals, branch_primal_jaxprs, branch_tangent_jaxprs = [], [], []
for branch_jaxpr in branches:
primal_jaxpr, num_res, _, tangent_jaxpr = ad.linearize_jaxpr(
branch_jaxpr, ops_nz, nzs_out)
branch_res_avals.append(primal_jaxpr.out_avals[:num_res])
# ad.linearize_jaxpr uses a convention where the primal jaxpr returns the
# residuals at the front, and the tangent jaxpr accepts the residuals at
# the back. The logic below is simpler if we swap both of these locations.
primal_jaxpr = pe.move_outvars_to_back(
primal_jaxpr, [True] * num_res + [False] * num_out)
tangent_jaxpr = pe.move_binders_to_front(
tangent_jaxpr, [False] * sum(nzs) + [True] * num_res)
branch_primal_jaxprs.append(primal_jaxpr)
branch_tangent_jaxprs.append(tangent_jaxpr)

# Since the residuals must have the same avals on each branch, we construct
# the smallest set of residual avals that cover all values needed by all
# branches, then update the jaxprs to only return or use the appropriate
# subset.
all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals)
branch_primal_jaxprs = _join_cond_outputs(
branch_primal_jaxprs, all_res_avals, res_avals_per_branch, num_out)
branch_tangent_jaxprs = _join_cond_pe_staged_jaxpr_inputs(
branch_tangent_jaxprs, all_res_avals, res_avals_per_branch)

def tangent_fun(res, *tangents):
nz_tangents_in = [t for nz, t in zip(nzs, tangents) if nz]
nz_tangents_out = cond_p.bind(idx, *res, *nz_tangents_in,
branches=branch_tangent_jaxprs, **params)
tangent_avals_out = [v.to_tangent_aval() for v in branches[0].out_avals]
nz_tangents_out_ = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_) if nz else ad.Zero(aval)
for (aval, nz) in zip(tangent_avals_out, nzs_out)]
return tangents_out

ans = cond_p.bind(idx, *ops, branches=branch_primal_jaxprs, **params)
primals_out, residuals = split_list(ans, [num_out])
return primals_out, nzs_out, residuals, tangent_fun

def _cond_partial_eval(trace, *tracers, branches, **params):
in_unknowns = [t.pval[0] is not None for t in tracers]
index_uk, *ops_uk = in_unknowns
Expand Down Expand Up @@ -937,6 +983,7 @@ def _cond_typecheck(bind_time, *in_atoms, branches, **params):
cond_p.def_impl(partial(dispatch.apply_primitive, cond_p))
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
ad.primitive_jvps[cond_p] = _cond_jvp
ad.primitive_linearizations[cond_p] = _cond_lin
ad.primitive_transposes[cond_p] = _cond_transpose
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
batching.fancy_primitive_batchers[cond_p] = _cond_batching_rule
Expand Down
Loading