Skip to content

Commit 93af51c

Browse files
committed
Avoid doing DCE of effectful ops and reordering in partial eval.
1 parent 0e690c1 commit 93af51c

File tree

11 files changed

+46
-22
lines changed

11 files changed

+46
-22
lines changed

jax/_src/ad_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer,
578578
out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
579579
for x in jaxpr_unknown.outvars]
580580
new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
581-
recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
581+
recipe = pe.new_eqn_recipe(trace, in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
582582
new_params, jaxpr_unknown.effects,
583583
source_info_util.current())
584584

jax/_src/interpreters/ad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ def make_zero(aval):
885885
out_nz_tracers = [trace.to_jaxpr_tracer(r)
886886
for (r, nz) in zip(out_tangents, out_nzs) if nz]
887887
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
888-
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info)
888+
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, [], jvp.debug_info)
889889
jaxpr, used_consts, _ = pe.dce_jaxpr_consts(
890890
jaxpr, [True] * len(jaxpr.outvars),
891891
[False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars))

jax/_src/interpreters/partial_eval.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from collections import namedtuple
1717
from collections.abc import Callable, Sequence, Hashable
1818
import contextlib
19+
from dataclasses import dataclass
1920
from functools import partial
2021
import itertools as it
2122
import operator as op
@@ -147,6 +148,10 @@ def get_aval(self) -> AbstractValue:
147148
else:
148149
return self[0]
149150

151+
@dataclass(frozen=True)
152+
class EffectHandle:
153+
parents : list[Tracer]
154+
recipe : JaxprEqnRecipe
150155

151156
class JaxprTrace(Trace['JaxprTracer']):
152157

@@ -156,6 +161,8 @@ def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, t
156161
self.tag = tag
157162
self.parent_trace = parent_trace
158163
self.requires_low = False
164+
self.effect_handles : list[EffectHandle] = []
165+
self.counter = it.count()
159166

160167
def to_jaxpr_tracer(self, x):
161168
if isinstance(x, JaxprTracer) and x._trace.tag is self.tag:
@@ -239,14 +246,17 @@ def default_process_primitive(self, primitive, tracers, params):
239246
if primitive.multiple_results:
240247
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
241248
for aval in out_aval]
242-
eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, effects,
249+
eqn = new_eqn_recipe(self, tracers, out_tracers, primitive, params, effects,
243250
source)
251+
if effects: self.effect_handles.append(EffectHandle(tracers, eqn))
244252
for t in out_tracers: t.recipe = eqn
245253
return out_tracers
246254
else:
247255
out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None)
248-
out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive,
249-
params, effects, source)
256+
eqn = new_eqn_recipe(self, tracers, [out_tracer], primitive,
257+
params, effects, source)
258+
if effects: self.effect_handles.append(EffectHandle(tracers, eqn))
259+
out_tracer.recipe = eqn
250260
return out_tracer
251261

252262
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
@@ -321,7 +331,7 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
321331
for a in out_type]
322332
name_stack = self._current_truncated_name_stack()
323333
source = source_info_util.current().replace(name_stack=name_stack)
324-
eqn = new_eqn_recipe((*res_tracers, *env_tracers, *unknown_arg_tracers),
334+
eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *unknown_arg_tracers),
325335
out_tracers, primitive, staged_params, jaxpr.effects,
326336
source)
327337
for t in out_tracers: t.recipe = eqn
@@ -390,7 +400,7 @@ def const_out_axes_thunk():
390400
for a in out_avals]
391401
effs = core.filter_named_axis_effects(jaxpr.effects, {params['axis_name']})
392402
src_info = source_info_util.current()
393-
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
403+
eqn = new_eqn_recipe(self, (*const_tracers, *env_tracers, *unknown_arg_tracers),
394404
out_tracers, primitive, staged_params, effs, src_info)
395405
for t in out_tracers: t.recipe = eqn
396406

@@ -425,7 +435,7 @@ def process_custom_transpose(self, prim, call, tracers, **params):
425435
for aval in params['out_types']]
426436
in_tracers = map(self.instantiate_const, tracers)
427437
new_params = dict(params, call=call)
428-
eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params,
438+
eqn = new_eqn_recipe(self, in_tracers, out_tracers, prim, new_params,
429439
core.no_effects, source_info_util.current())
430440
for t in out_tracers: t.recipe = eqn
431441
return out_tracers
@@ -470,7 +480,7 @@ def fwd_jaxpr_thunk(*zeros):
470480
out_trees=out_trees,
471481
symbolic_zeros=symbolic_zeros
472482
)
473-
eqn = new_eqn_recipe((*res_tracers, *env_tracers, *tracers),
483+
eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *tracers),
474484
out_tracers, prim, params, jaxpr.effects, source)
475485
for t in out_tracers: t.recipe = eqn
476486
return out_tracers
@@ -657,7 +667,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
657667
out_tracers = [trace.instantiate_const(t) if inst else t
658668
for inst, t in zip(instantiate, out_tracers)]
659669
out_tracers_ = [t for t in out_tracers if not t.is_known()]
660-
jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, debug_info)
670+
jaxpr, out_consts, env = tracers_to_jaxpr(in_tracers, out_tracers_, trace.effect_handles, debug_info)
661671
return out_tracers, jaxpr, out_consts, env
662672

663673
# The below variant implements an optimization where residuals which are also
@@ -739,7 +749,8 @@ class JaxprEqnRecipe(NamedTuple):
739749
source_info: source_info_util.SourceInfo
740750
ctx: JaxprEqnContext
741751

742-
def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
752+
def new_eqn_recipe(trace: JaxprTrace,
753+
in_tracers: Sequence[JaxprTracer],
743754
out_tracers: Sequence[JaxprTracer],
744755
primitive: Primitive,
745756
params: dict[str, Any],
@@ -762,7 +773,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
762773
config.threefry_partitionable.value,
763774
xla_metadata_lib.current_xla_metadata(),
764775
)
765-
return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers),
776+
return JaxprEqnRecipe(next(trace.counter), tuple(in_tracers), map(ref, out_tracers),
766777
out_avals, primitive, params, effects, source_info,
767778
ctx)
768779

@@ -780,6 +791,7 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
780791
def tracers_to_jaxpr(
781792
in_tracers: Sequence[JaxprTracer],
782793
out_tracers: Sequence[JaxprTracer],
794+
effect_handles: Sequence[Any],
783795
debug_info: core.DebugInfo,
784796
) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]:
785797
"""Constructs Jaxpr given tracers for inputs and outputs.
@@ -821,7 +833,15 @@ def type_substitute(aval: AbstractValue) -> AbstractValue:
821833

822834
processed_eqn_ids = set()
823835
eqns: list[core.JaxprEqn] = []
824-
for t in toposort((*in_tracers, *out_tracers)):
836+
837+
reachable = toposort
838+
tracers = reachable((*in_tracers, *out_tracers, *effect_handles))
839+
def sort_key(t):
840+
r = t.recipe
841+
return r.eqn_id if isinstance(r, JaxprEqnRecipe) else -1
842+
tracers = sorted(tracers, key=sort_key)
843+
844+
for t in tracers:
825845
r = t.recipe
826846
if isinstance(r, JaxprEqnRecipe):
827847
# TODO broadcast_in_dim can create a new tracer, not present in parents

jax/_src/lax/control_flow/conditionals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def _cond_partial_eval(trace, *tracers, branches, **params):
617617
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
618618
source = source_info_util.current().replace(name_stack=name_stack)
619619
eqn = pe.new_eqn_recipe(
620-
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
620+
trace, [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
621621
core.join_effects(*(j.effects for j in branches_unknown)), source)
622622
for t in out_tracers: t.recipe = eqn
623623
return util.merge_lists(out_uks, out_consts, out_tracers)

jax/_src/lax/control_flow/for_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
498498

499499
assert len(unknown_inputs) == len(res_ref_unknown_outputs)
500500
assert len(unknown_inputs) == len(jaxpr_unknown.invars) - 1
501-
eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs,
501+
eqn = pe.new_eqn_recipe(trace, unknown_inputs, res_ref_unknown_outputs,
502502
for_p, dict(jaxpr=jaxpr_unknown, nsteps=nsteps,
503503
reverse=reverse,
504504
which_linear=which_linear_unknown,

jax/_src/lax/control_flow/loops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
920920
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
921921
source = source_info_util.current().replace(name_stack=name_stack)
922922
assert len(out_tracers) == len(jaxpr_unknown.out_avals)
923-
eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res],
923+
eqn = pe.new_eqn_recipe(trace, [*intensive_res, *unknown_inputs, *extensive_res],
924924
out_tracers, scan_p,
925925
dict(reverse=reverse, length=length, unroll=unroll,
926926
jaxpr=jaxpr_unknown, linear=linear_unknown,

jax/_src/lax/lax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6550,7 +6550,7 @@ def _broadcast_in_dim_partial_eval(
65506550
out_aval = core.DShapedArray(tuple(shape_), operand.dtype, operand.weak_type)
65516551
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
65526552
eqn = pe.new_eqn_recipe(
6553-
[operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p,
6553+
trace, [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p,
65546554
dict(shape=shape, broadcast_dimensions=broadcast_dimensions,
65556555
sharding=None),
65566556
core.no_effects, source_info_util.current())

jax/_src/pjit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2431,7 +2431,7 @@ def keep_where(l, should_keep):
24312431
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
24322432
for aval in unknown_out_avals
24332433
]
2434-
eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers),
2434+
eqn = pe.new_eqn_recipe(trace, (*unknown_tracers_in, *residual_tracers),
24352435
unknown_tracers_out,
24362436
pjit_p,
24372437
unknown_params,

jax/_src/shard_map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1369,7 +1369,7 @@ def known_out_specs():
13691369
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
13701370
for a in out_avals]
13711371
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
1372-
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers),
1372+
eqn = pe.new_eqn_recipe(trace, (*const_tracers, *env_tracers, *unk_arg_tracers),
13731373
out_tracers, shard_map_p, unk_params,
13741374
effs, source_info_util.current())
13751375
for t in out_tracers: t.recipe = eqn

jax/_src/state/discharge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
828828
is_initialized=(True,) * len(jaxpr_unknown.invars))
829829
_, eqn_effects = run_state_p.abstract_eval(*[v.aval for v in unknown_inputs],
830830
**uk_params)
831-
eqn = pe.new_eqn_recipe(unknown_inputs, res_ref_unknown_outputs,
831+
eqn = pe.new_eqn_recipe(trace, unknown_inputs, res_ref_unknown_outputs,
832832
run_state_p, uk_params,
833833
eqn_effects, source)
834834
for t in res_ref_unknown_outputs: t.recipe = eqn

tests/mutable_array_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,18 @@ def f():
192192
x = f()
193193
self.assertArraysEqual(x, jnp.zeros(8))
194194

195-
def test_grad_mutable_array(self):
196-
@jax.jit
195+
@parameterized.parameters([False, True])
196+
def test_grad_mutable_array(self, jit):
197+
197198
def f(x):
198199
x_ = core.mutable_array(x)
199200
x_[()] = x_[()] + x_[()]
200201
y = core.freeze(x_)
201202
return y
202203

204+
if jit:
205+
f = jax.jit(f)
206+
203207
ans = jax.grad(f)(1.)
204208
expected = 2.0
205209
self.assertAllClose(ans, expected, check_dtypes=False)

0 commit comments

Comments
 (0)