Skip to content

Commit 0f4da0c

Browse files
Merge pull request #28955 from jax-ml:prevent-partial-eval-dce-effects
PiperOrigin-RevId: 763886950
2 parents 1d10a48 + 859e120 commit 0f4da0c

File tree

12 files changed

+52
-36
lines changed

12 files changed

+52
-36
lines changed

jax/_src/ad_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer,
578578
out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
579579
for x in jaxpr_unknown.outvars]
580580
new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
581-
recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
581+
recipe = pe.new_eqn_recipe(trace, in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
582582
new_params, jaxpr_unknown.effects,
583583
source_info_util.current())
584584

jax/_src/debugging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def debug_callback_jvp_rule(primals, tangents, **params):
127127
ad.primitive_jvps[debug_callback_p] = debug_callback_jvp_rule
128128

129129
def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any],
130-
effect: DebugEffect):
130+
effect: DebugEffect, partitioned):
131131
del flat_args, callback, effect
132132
raise ValueError("Transpose doesn't support debugging callbacks.")
133133
ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule

jax/_src/interpreters/ad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ def make_zero(aval):
885885
out_nz_tracers = [trace.to_jaxpr_tracer(r)
886886
for (r, nz) in zip(out_tangents, out_nzs) if nz]
887887
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
888-
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info)
888+
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, [], jvp.debug_info)
889889
jaxpr, used_consts, _ = pe.dce_jaxpr_consts(
890890
jaxpr, [True] * len(jaxpr.outvars),
891891
[False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars))

jax/_src/interpreters/partial_eval.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from collections import namedtuple
1717
from collections.abc import Callable, Sequence, Hashable
1818
import contextlib
19+
from dataclasses import dataclass
1920
from functools import partial
2021
import itertools as it
2122
import operator as op
@@ -42,7 +43,7 @@
4243
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
4344
InputType, OutputType, get_referent, JaxprEqnContext)
4445
from jax._src.source_info_util import SourceInfo
45-
from jax._src.state.types import AbstractRef, ReadEffect
46+
from jax._src.state.types import AbstractRef, ReadEffect, RefEffect
4647
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten,
4748
tree_structure, register_static)
4849
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
@@ -147,6 +148,10 @@ def get_aval(self) -> AbstractValue:
147148
else:
148149
return self[0]
149150

151+
@dataclass(frozen=True)
152+
class EffectHandle:
153+
parents : list[Tracer]
154+
recipe : JaxprEqnRecipe
150155

151156
class JaxprTrace(Trace['JaxprTracer']):
152157

@@ -156,6 +161,8 @@ def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, t
156161
self.tag = tag
157162
self.parent_trace = parent_trace
158163
self.requires_low = False
164+
self.effect_handles : list[EffectHandle] = []
165+
self.counter = it.count()
159166

160167
def to_jaxpr_tracer(self, x):
161168
if isinstance(x, JaxprTracer) and x._trace.tag is self.tag:
@@ -239,14 +246,19 @@ def default_process_primitive(self, primitive, tracers, params):
239246
if primitive.multiple_results:
240247
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
241248
for aval in out_aval]
242-
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects,
249+
eqn = new_eqn_recipe(self, tracers, out_tracers, primitive, params, effects,
243250
source)
251+
if any(isinstance(e, RefEffect) for e in effects):
252+
self.effect_handles.append(EffectHandle(tracers, eqn))
244253
for t in out_tracers: t.recipe = eqn
245254
return out_tracers
246255
else:
247256
out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
248-
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
249-
params, effects, source)
257+
eqn = new_eqn_recipe(self, tracers, [out_tracer], primitive,
258+
params, effects, source)
259+
if any(isinstance(e, RefEffect) for e in effects):
260+
self.effect_handles.append(EffectHandle(tracers, eqn))
261+
out_tracer.recipe = eqn
250262
return out_tracer
251263

252264
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
@@ -321,7 +333,7 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
321333
for a in out_type]
322334
name_stack = self._current_truncated_name_stack()
323335
source = source_info_util.current().replace(name_stack=name_stack)
324-
eqn = new_eqn_recipe((*res_tracers, *env_tracers, *unknown_arg_tracers),
336+
eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *unknown_arg_tracers),
325337
out_tracers, primitive, staged_params, jaxpr.effects,
326338
source)
327339
for t in out_tracers: t.recipe = eqn
@@ -390,7 +402,7 @@ def const_out_axes_thunk():
390402
for a in out_avals]
391403
effs = core.filter_named_axis_effects(jaxpr.effects, {params['axis_name']})
392404
src_info = source_info_util.current()
393-
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
405+
eqn = new_eqn_recipe(self, (*const_tracers, *env_tracers, *unknown_arg_tracers),
394406
out_tracers, primitive, staged_params, effs, src_info)
395407
for t in out_tracers: t.recipe = eqn
396408

@@ -425,7 +437,7 @@ def process_custom_transpose(self, prim, call, tracers, **params):
425437
for aval in params['out_types']]
426438
in_tracers = map(self.instantiate_const, tracers)
427439
new_params = dict(params, call=call)
428-
eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params,
440+
eqn = new_eqn_recipe(self, in_tracers, out_tracers, prim, new_params,
429441
core.no_effects, source_info_util.current())
430442
for t in out_tracers: t.recipe = eqn
431443
return out_tracers
@@ -470,7 +482,7 @@ def fwd_jaxpr_thunk(*zeros):
470482
out_trees=out_trees,
471483
symbolic_zeros=symbolic_zeros
472484
)
473-
eqn = new_eqn_recipe((*res_tracers, *env_tracers, *tracers),
485+
eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *tracers),
474486
out_tracers, prim, params, jaxpr.effects, source)
475487
for t in out_tracers: t.recipe = eqn
476488
return out_tracers
@@ -657,7 +669,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
657669
out_tracers = [trace.instantiate_const(t) if inst else t
658670
for inst, t in zip(instantiate, out_tracers)]
659671
out_tracers_ = [t for t in out_tracers if not t.is_known()]
660-
jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, debug_info)
672+
jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, trace.effect_handles, debug_info)
661673
return out_tracers, jaxpr, out_consts, env
662674

663675
# The below variant implements an optimization where residuals which are also
@@ -739,7 +751,8 @@ class JaxprEqnRecipe(NamedTuple):
739751
source_info: source_info_util.SourceInfo
740752
ctx: JaxprEqnContext
741753

742-
def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
754+
def new_eqn_recipe(trace: JaxprTrace,
755+
in_tracers: Sequence[JaxprTracer],
743756
out_tracers: Sequence[JaxprTracer],
744757
primitive: Primitive,
745758
params: dict[str, Any],
@@ -762,7 +775,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
762775
config.threefry_partitionable.value,
763776
xla_metadata_lib.current_xla_metadata(),
764777
)
765-
return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers),
778+
return JaxprEqnRecipe(next(trace.counter), tuple(in_tracers), map(ref, out_tracers),
766779
out_avals, primitive, params, effects, source_info,
767780
ctx)
768781

@@ -780,6 +793,7 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
780793
def tracers_to_jaxpr(
781794
in_tracers: Sequence[JaxprTracer],
782795
out_tracers: Sequence[JaxprTracer],
796+
effect_handles: Sequence[Any],
783797
debug_info: core.DebugInfo,
784798
) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]:
785799
"""Constructs Jaxpr given tracers for inputs and outputs.
@@ -821,7 +835,15 @@ def type_substitute(aval: AbstractValue) -> AbstractValue:
821835

822836
processed_eqn_ids = set()
823837
eqns: list[core.JaxprEqn] = []
824-
for t in toposort((*in_tracers, *out_tracers)):
838+
839+
reachable = toposort
840+
tracers = reachable((*in_tracers, *out_tracers, *effect_handles))
841+
def sort_key(t):
842+
r = t.recipe
843+
return r.eqn_id if isinstance(r, JaxprEqnRecipe) else -1
844+
tracers = sorted(tracers, key=sort_key)
845+
846+
for t in tracers:
825847
r = t.recipe
826848
if isinstance(r, JaxprEqnRecipe):
827849
# TODO broadcast_in_dim can create a new tracer, not present in parents

jax/_src/lax/control_flow/conditionals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def _cond_partial_eval(trace, *tracers, branches, **params):
617617
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
618618
source = source_info_util.current().replace(name_stack=name_stack)
619619
eqn = pe.new_eqn_recipe(
620-
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
620+
trace, [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
621621
core.join_effects(*(j.effects for j in branches_unknown)), source)
622622
for t in out_tracers: t.recipe = eqn
623623
return util.merge_lists(out_uks, out_consts, out_tracers)

jax/_src/lax/control_flow/for_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
498498

499499
assert len(unknown_inputs) == len(res_ref_unknown_outputs)
500500
assert len(unknown_inputs) == len(jaxpr_unknown.invars) - 1
501-
eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs,
501+
eqn = pe.new_eqn_recipe(trace, unknown_inputs, res_ref_unknown_outputs,
502502
for_p, dict(jaxpr=jaxpr_unknown, nsteps=nsteps,
503503
reverse=reverse,
504504
which_linear=which_linear_unknown,

jax/_src/lax/control_flow/loops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
920920
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
921921
source = source_info_util.current().replace(name_stack=name_stack)
922922
assert len(out_tracers) == len(jaxpr_unknown.out_avals)
923-
eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res],
923+
eqn = pe.new_eqn_recipe(trace, [*intensive_res, *unknown_inputs, *extensive_res],
924924
out_tracers, scan_p,
925925
dict(reverse=reverse, length=length, unroll=unroll,
926926
jaxpr=jaxpr_unknown, linear=linear_unknown,

jax/_src/lax/lax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6550,7 +6550,7 @@ def _broadcast_in_dim_partial_eval(
65506550
out_aval = core.DShapedArray(tuple(shape_), operand.dtype, operand.weak_type)
65516551
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
65526552
eqn = pe.new_eqn_recipe(
6553-
[operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p,
6553+
trace, [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p,
65546554
dict(shape=shape, broadcast_dimensions=broadcast_dimensions,
65556555
sharding=None),
65566556
core.no_effects, source_info_util.current())

jax/_src/pjit.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2324,18 +2324,8 @@ def _pjit_partial_eval(trace: pe.JaxprTrace,
23242324

23252325
known_ins = tuple(pv.is_known() for pv in in_pvals)
23262326
unknown_ins = tuple(not k for k in known_ins)
2327-
if any(isinstance(e, (RefEffect, core.InternalMutableArrayEffect))
2328-
for e in jaxpr.effects):
2329-
known_jaxpr_, unknown_jaxpr_, unknown_outs, _, num_res_val, num_res_ref = \
2330-
pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, unknown_ins, unknown_ins,
2331-
False, False, None)
2332-
if num_res_ref: raise NotImplementedError
2333-
known_jaxpr = pe.ClosedJaxpr(known_jaxpr_, jaxpr.consts)
2334-
unknown_jaxpr = pe.ClosedJaxpr(unknown_jaxpr_, jaxpr.consts)
2335-
res_avals = unknown_jaxpr.in_avals[:num_res_val]
2336-
else:
2337-
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
2338-
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
2327+
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
2328+
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
23392329
unknown_outs = tuple(unknown_outs) # type: ignore[assignment]
23402330
known_outs = tuple(not uk for uk in unknown_outs)
23412331
num_residuals = len(res_avals)
@@ -2431,7 +2421,7 @@ def keep_where(l, should_keep):
24312421
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
24322422
for aval in unknown_out_avals
24332423
]
2434-
eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers),
2424+
eqn = pe.new_eqn_recipe(trace, (*unknown_tracers_in, *residual_tracers),
24352425
unknown_tracers_out,
24362426
pjit_p,
24372427
unknown_params,

jax/_src/shard_map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1392,7 +1392,7 @@ def known_out_specs():
13921392
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
13931393
for a in out_avals]
13941394
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
1395-
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers),
1395+
eqn = pe.new_eqn_recipe(trace, (*const_tracers, *env_tracers, *unk_arg_tracers),
13961396
out_tracers, shard_map_p, unk_params,
13971397
effs, source_info_util.current())
13981398
for t in out_tracers: t.recipe = eqn

jax/_src/state/discharge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
828828
is_initialized=(True,) * len(jaxpr_unknown.invars))
829829
_, eqn_effects = run_state_p.abstract_eval(*[v.aval for v in unknown_inputs],
830830
**uk_params)
831-
eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs,
831+
eqn = pe.new_eqn_recipe(trace, unknown_inputs, res_ref_unknown_outputs,
832832
run_state_p, uk_params,
833833
eqn_effects, source)
834834
for t in res_ref_unknown_outputs: t.recipe = eqn

tests/mutable_array_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,18 @@ def f():
192192
x = f()
193193
self.assertArraysEqual(x, jnp.zeros(8))
194194

195-
def test_grad_mutable_array(self):
196-
@jax.jit
195+
@parameterized.parameters([False, True])
196+
def test_grad_mutable_array(self, jit):
197+
197198
def f(x):
198199
x_ = core.mutable_array(x)
199200
x_[()] = x_[()] + x_[()]
200201
y = core.freeze(x_)
201202
return y
202203

204+
if jit:
205+
f = jax.jit(f)
206+
203207
ans = jax.grad(f)(1.)
204208
expected = 2.0
205209
self.assertAllClose(ans, expected, check_dtypes=False)

0 commit comments

Comments
 (0)