diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 29af03416a76..44e3f9f936c1 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -159,15 +159,19 @@ def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr: def linearize_jaxpr( jaxpr: core.ClosedJaxpr, - nonzeros: Sequence[bool] + nonzeros: Sequence[bool], + instantiate: bool | Sequence[bool] = False, ) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]: - return _linearize_jaxpr(jaxpr, tuple(nonzeros)) + if type(instantiate) is bool: + instantiate = (instantiate,) * len(jaxpr.out_avals) + return _linearize_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate)) @weakref_lru_cache @source_info_util.reset_name_stack() def _linearize_jaxpr( jaxpr: core.ClosedJaxpr, - nonzeros: tuple[bool, ...] + nonzeros: tuple[bool, ...], + instantiate: tuple[bool, ...], ) -> tuple[core.ClosedJaxpr, int, Sequence[bool], core.ClosedJaxpr]: dbg = jaxpr.jaxpr.debug_info primal_trace = pe.DynamicJaxprTrace(dbg) @@ -188,6 +192,8 @@ def new_arg(trace, primal_aval, nz, source_info): with core.set_current_trace(lin_trace, check_leaks=True): ans = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *tracers) out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans)) + out_tangents = tuple(instantiate_zeros(t) if inst else t + for t, inst in zip(out_tangents, instantiate)) del lin_trace, ans, tracers, new_arg debug_info = jaxpr.jaxpr.debug_info diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 741636c47e31..d7e2f7859184 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -562,6 +562,52 @@ def _cond_jvp(primals, tangents, *, branches, **params): for p, nz in zip(out_primals, out_nz)] return out_primals, out_tangents +def _cond_lin(nzs, *args, branches, **params): + idx_nz, *ops_nz = nzs + assert idx_nz is False + idx, *ops = args + branches_nz = [ad.linearize_jaxpr(b, ops_nz, False)[2] for b in branches] + nzs_out = [any(nz) for nz in zip(*branches_nz)] + num_out = len(nzs_out) + branch_res_avals, branch_primal_jaxprs, branch_tangent_jaxprs = [], [], [] + for branch_jaxpr in branches: + primal_jaxpr, num_res, _, tangent_jaxpr = ad.linearize_jaxpr( + branch_jaxpr, ops_nz, nzs_out) + branch_res_avals.append(primal_jaxpr.out_avals[:num_res]) + # ad.linearize_jaxpr uses a convention where the primal jaxpr returns the + # residuals at the front, and the tangent jaxpr accepts the residuals at + # the back. The logic below is simpler if we swap both of these locations. + primal_jaxpr = pe.move_outvars_to_back( + primal_jaxpr, [True] * num_res + [False] * num_out) + tangent_jaxpr = pe.move_binders_to_front( + tangent_jaxpr, [False] * sum(nzs) + [True] * num_res) + branch_primal_jaxprs.append(primal_jaxpr) + branch_tangent_jaxprs.append(tangent_jaxpr) + + # Since the residuals must have the same avals on each branch, we construct + # the smallest set of residual avals that cover all values needed by all + # branches, then update the jaxprs to only return or use the appropriate + # subset. + all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals) + branch_primal_jaxprs = _join_cond_outputs( + branch_primal_jaxprs, all_res_avals, res_avals_per_branch, num_out) + branch_tangent_jaxprs = _join_cond_pe_staged_jaxpr_inputs( + branch_tangent_jaxprs, all_res_avals, res_avals_per_branch) + + def tangent_fun(res, *tangents): + nz_tangents_in = [t for nz, t in zip(nzs, tangents) if nz] + nz_tangents_out = cond_p.bind(idx, *res, *nz_tangents_in, + branches=branch_tangent_jaxprs, **params) + tangent_avals_out = [v.to_tangent_aval() for v in branches[0].out_avals] + nz_tangents_out_ = iter(nz_tangents_out) + tangents_out = [next(nz_tangents_out_) if nz else ad.Zero(aval) + for (aval, nz) in zip(tangent_avals_out, nzs_out)] + return tangents_out + + ans = cond_p.bind(idx, *ops, branches=branch_primal_jaxprs, **params) + primals_out, residuals = split_list(ans, [num_out]) + return primals_out, nzs_out, residuals, tangent_fun + def _cond_partial_eval(trace, *tracers, branches, **params): in_unknowns = [t.pval[0] is not None for t in tracers] index_uk, *ops_uk = in_unknowns @@ -937,6 +983,7 @@ def _cond_typecheck(bind_time, *in_atoms, branches, **params): cond_p.def_impl(partial(dispatch.apply_primitive, cond_p)) cond_p.def_effectful_abstract_eval(_cond_abstract_eval) ad.primitive_jvps[cond_p] = _cond_jvp +ad.primitive_linearizations[cond_p] = _cond_lin ad.primitive_transposes[cond_p] = _cond_transpose pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval batching.fancy_primitive_batchers[cond_p] = _cond_batching_rule