diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 2d743bf06c6b..2a056d5c94f0 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -578,7 +578,7 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer, out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None) for x in jaxpr_unknown.outvars] new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True) - recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p, + recipe = pe.new_eqn_recipe(trace, in_jaxpr_tracers, out_jaxpr_tracers, remat_p, new_params, jaxpr_unknown.effects, source_info_util.current()) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index e931a6edb9b3..3490de5118e1 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -127,7 +127,7 @@ def debug_callback_jvp_rule(primals, tangents, **params): ad.primitive_jvps[debug_callback_p] = debug_callback_jvp_rule def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any], - effect: DebugEffect): + effect: DebugEffect, partitioned): del flat_args, callback, effect raise ValueError("Transpose doesn't support debugging callbacks.") ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 9366b91f8022..7cbdfff01462 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -885,7 +885,7 @@ def make_zero(aval): out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz] in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz] - jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info) + jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, [], jvp.debug_info) jaxpr, used_consts, _ = pe.dce_jaxpr_consts( jaxpr, [True] * len(jaxpr.outvars), [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index f77db5443a86..6ea16ec8e8ba 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -16,6 +16,7 @@ from collections import namedtuple from collections.abc import Callable, Sequence, Hashable import contextlib +from dataclasses import dataclass from functools import partial import itertools as it import operator as op @@ -42,7 +43,7 @@ mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext) from jax._src.source_info_util import SourceInfo -from jax._src.state.types import AbstractRef, ReadEffect +from jax._src.state.types import AbstractRef, ReadEffect, RefEffect from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_structure, register_static) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, @@ -147,6 +148,10 @@ def get_aval(self) -> AbstractValue: else: return self[0] +@dataclass(frozen=True) +class EffectHandle: + parents : list[Tracer] + recipe : JaxprEqnRecipe class JaxprTrace(Trace['JaxprTracer']): @@ -156,6 +161,8 @@ def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, t self.tag = tag self.parent_trace = parent_trace self.requires_low = False + self.effect_handles : list[EffectHandle] = [] + self.counter = it.count() def to_jaxpr_tracer(self, x): if isinstance(x, JaxprTracer) and x._trace.tag is self.tag: @@ -239,14 +246,19 @@ def default_process_primitive(self, primitive, tracers, params): if primitive.multiple_results: out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None) for aval in out_aval] - eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects, + eqn = new_eqn_recipe(self, tracers, out_tracers, primitive, params, effects, source) + if any(isinstance(e, RefEffect) for e in effects): + self.effect_handles.append(EffectHandle(tracers, eqn)) for t in out_tracers: t.recipe = eqn return out_tracers else: out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None) - out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive, - params, effects, source) + eqn = new_eqn_recipe(self, tracers, [out_tracer], primitive, + params, effects, source) + if any(isinstance(e, RefEffect) for e in effects): + self.effect_handles.append(EffectHandle(tracers, eqn)) + out_tracer.recipe = eqn return out_tracer def process_call(self, primitive, f: lu.WrappedFun, tracers, params): @@ -321,7 +333,7 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params): for a in out_type] name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) - eqn = new_eqn_recipe((*res_tracers, *env_tracers, *unknown_arg_tracers), + eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *unknown_arg_tracers), out_tracers, primitive, staged_params, jaxpr.effects, source) for t in out_tracers: t.recipe = eqn @@ -390,7 +402,7 @@ def const_out_axes_thunk(): for a in out_avals] effs = core.filter_named_axis_effects(jaxpr.effects, {params['axis_name']}) src_info = source_info_util.current() - eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), + eqn = new_eqn_recipe(self, (*const_tracers, *env_tracers, *unknown_arg_tracers), out_tracers, primitive, staged_params, effs, src_info) for t in out_tracers: t.recipe = eqn @@ -425,7 +437,7 @@ def process_custom_transpose(self, prim, call, tracers, **params): for aval in params['out_types']] in_tracers = map(self.instantiate_const, tracers) new_params = dict(params, call=call) - eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params, + eqn = new_eqn_recipe(self, in_tracers, out_tracers, prim, new_params, core.no_effects, source_info_util.current()) for t in out_tracers: t.recipe = eqn return out_tracers @@ -470,7 +482,7 @@ def fwd_jaxpr_thunk(*zeros): out_trees=out_trees, symbolic_zeros=symbolic_zeros ) - eqn = new_eqn_recipe((*res_tracers, *env_tracers, *tracers), + eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *tracers), out_tracers, prim, params, jaxpr.effects, source) for t in out_tracers: t.recipe = eqn return out_tracers @@ -657,7 +669,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace, out_tracers = [trace.instantiate_const(t) if inst else t for inst, t in zip(instantiate, out_tracers)] out_tracers_ = [t for t in out_tracers if not t.is_known()] - jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, debug_info) + jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, trace.effect_handles, debug_info) return out_tracers, jaxpr, out_consts, env # The below variant implements an optimization where residuals which are also @@ -739,7 +751,8 @@ class JaxprEqnRecipe(NamedTuple): source_info: source_info_util.SourceInfo ctx: JaxprEqnContext -def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], +def new_eqn_recipe(trace: JaxprTrace, + in_tracers: Sequence[JaxprTracer], out_tracers: Sequence[JaxprTracer], primitive: Primitive, params: dict[str, Any], @@ -762,7 +775,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], config.threefry_partitionable.value, xla_metadata_lib.current_xla_metadata(), ) - return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers), + return JaxprEqnRecipe(next(trace.counter), tuple(in_tracers), map(ref, out_tracers), out_avals, primitive, params, effects, source_info, ctx) @@ -780,6 +793,7 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom], def tracers_to_jaxpr( in_tracers: Sequence[JaxprTracer], out_tracers: Sequence[JaxprTracer], + effect_handles: Sequence[Any], debug_info: core.DebugInfo, ) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]: """Constructs Jaxpr given tracers for inputs and outputs. @@ -821,7 +835,15 @@ def type_substitute(aval: AbstractValue) -> AbstractValue: processed_eqn_ids = set() eqns: list[core.JaxprEqn] = [] - for t in toposort((*in_tracers, *out_tracers)): + + reachable = toposort + tracers = reachable((*in_tracers, *out_tracers, *effect_handles)) + def sort_key(t): + r = t.recipe + return r.eqn_id if isinstance(r, JaxprEqnRecipe) else -1 + tracers = sorted(tracers, key=sort_key) + + for t in tracers: r = t.recipe if isinstance(r, JaxprEqnRecipe): # TODO broadcast_in_dim can create a new tracer, not present in parents diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 741636c47e31..4e8368341d9f 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -617,7 +617,7 @@ def _cond_partial_eval(trace, *tracers, branches, **params): name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) eqn = pe.new_eqn_recipe( - [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, + trace, [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, core.join_effects(*(j.effects for j in branches_unknown)), source) for t in out_tracers: t.recipe = eqn return util.merge_lists(out_uks, out_consts, out_tracers) diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index fc7ebde4cbea..90b81ae367aa 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -498,7 +498,7 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, assert len(unknown_inputs) == len(res_ref_unknown_outputs) assert len(unknown_inputs) == len(jaxpr_unknown.invars) - 1 - eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs, + eqn = pe.new_eqn_recipe(trace, unknown_inputs, res_ref_unknown_outputs, for_p, dict(jaxpr=jaxpr_unknown, nsteps=nsteps, reverse=reverse, which_linear=which_linear_unknown, diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 7efe3294fdca..83c31928d7cb 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -920,7 +920,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool, name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) assert len(out_tracers) == len(jaxpr_unknown.out_avals) - eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res], + eqn = pe.new_eqn_recipe(trace, [*intensive_res, *unknown_inputs, *extensive_res], out_tracers, scan_p, dict(reverse=reverse, length=length, unroll=unroll, jaxpr=jaxpr_unknown, linear=linear_unknown, diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a49c27d06eee..0e3695ba4506 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -6550,7 +6550,7 @@ def _broadcast_in_dim_partial_eval( out_aval = core.DShapedArray(tuple(shape_), operand.dtype, operand.weak_type) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe( - [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p, + trace, [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p, dict(shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None), core.no_effects, source_info_util.current()) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0c55f3fe30ab..d5286be8e0c9 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2324,18 +2324,8 @@ def _pjit_partial_eval(trace: pe.JaxprTrace, known_ins = tuple(pv.is_known() for pv in in_pvals) unknown_ins = tuple(not k for k in known_ins) - if any(isinstance(e, (RefEffect, core.InternalMutableArrayEffect)) - for e in jaxpr.effects): - known_jaxpr_, unknown_jaxpr_, unknown_outs, _, num_res_val, num_res_ref = \ - pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, unknown_ins, unknown_ins, - False, False, None) - if num_res_ref: raise NotImplementedError - known_jaxpr = pe.ClosedJaxpr(known_jaxpr_, jaxpr.consts) - unknown_jaxpr = pe.ClosedJaxpr(unknown_jaxpr_, jaxpr.consts) - res_avals = unknown_jaxpr.in_avals[:num_res_val] - else: - known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ - pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) + known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \ + pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False) unknown_outs = tuple(unknown_outs) # type: ignore[assignment] known_outs = tuple(not uk for uk in unknown_outs) num_residuals = len(res_avals) @@ -2431,7 +2421,7 @@ def keep_where(l, should_keep): pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) for aval in unknown_out_avals ] - eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers), + eqn = pe.new_eqn_recipe(trace, (*unknown_tracers_in, *residual_tracers), unknown_tracers_out, pjit_p, unknown_params, diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 66df2505100c..4f60a833429a 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -1369,7 +1369,7 @@ def known_out_specs(): out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in out_avals] effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) - eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), + eqn = pe.new_eqn_recipe(trace, (*const_tracers, *env_tracers, *unk_arg_tracers), out_tracers, shard_map_p, unk_params, effs, source_info_util.current()) for t in out_tracers: t.recipe = eqn diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index bc6a20a0a76e..100447f12d18 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -828,7 +828,7 @@ def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, is_initialized=(True,) * len(jaxpr_unknown.invars)) _, eqn_effects = run_state_p.abstract_eval(*[v.aval for v in unknown_inputs], **uk_params) - eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs, + eqn = pe.new_eqn_recipe(trace, unknown_inputs, res_ref_unknown_outputs, run_state_p, uk_params, eqn_effects, source) for t in res_ref_unknown_outputs: t.recipe = eqn diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 865d4f8520f1..0da335e2fac5 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -192,14 +192,18 @@ def f(): x = f() self.assertArraysEqual(x, jnp.zeros(8)) - def test_grad_mutable_array(self): - @jax.jit + @parameterized.parameters([False, True]) + def test_grad_mutable_array(self, jit): + def f(x): x_ = core.mutable_array(x) x_[()] = x_[()] + x_[()] y = core.freeze(x_) return y + if jit: + f = jax.jit(f) + ans = jax.grad(f)(1.) expected = 2.0 self.assertAllClose(ans, expected, check_dtypes=False)