16
16
from collections import namedtuple
17
17
from collections .abc import Callable , Sequence , Hashable
18
18
import contextlib
19
+ from dataclasses import dataclass
19
20
from functools import partial
20
21
import itertools as it
21
22
import operator as op
42
43
mapped_aval , unmapped_aval , DBIdx , InDBIdx , OutDBIdx ,
43
44
InputType , OutputType , get_referent , JaxprEqnContext )
44
45
from jax ._src .source_info_util import SourceInfo
45
- from jax ._src .state .types import AbstractRef , ReadEffect
46
+ from jax ._src .state .types import AbstractRef , ReadEffect , RefEffect
46
47
from jax ._src .tree_util import (PyTreeDef , treedef_tuple , tree_flatten ,
47
48
tree_structure , register_static )
48
49
from jax ._src .util import (unzip2 , safe_zip , safe_map , toposort , split_list ,
@@ -147,6 +148,10 @@ def get_aval(self) -> AbstractValue:
147
148
else :
148
149
return self [0 ]
149
150
151
+ @dataclass (frozen = True )
152
+ class EffectHandle :
153
+ parents : list [Tracer ]
154
+ recipe : JaxprEqnRecipe
150
155
151
156
class JaxprTrace (Trace ['JaxprTracer' ]):
152
157
@@ -156,6 +161,8 @@ def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, t
156
161
self .tag = tag
157
162
self .parent_trace = parent_trace
158
163
self .requires_low = False
164
+ self .effect_handles : list [EffectHandle ] = []
165
+ self .counter = it .count ()
159
166
160
167
def to_jaxpr_tracer (self , x ):
161
168
if isinstance (x , JaxprTracer ) and x ._trace .tag is self .tag :
@@ -239,14 +246,19 @@ def default_process_primitive(self, primitive, tracers, params):
239
246
if primitive .multiple_results :
240
247
out_tracers = [JaxprTracer (self , PartialVal .unknown (aval ), None )
241
248
for aval in out_aval ]
242
- eqn = new_eqn_recipe (tracers , out_tracers , primitive , params , effects ,
249
+ eqn = new_eqn_recipe (self , tracers , out_tracers , primitive , params , effects ,
243
250
source )
251
+ if any (isinstance (e , RefEffect ) for e in effects ):
252
+ self .effect_handles .append (EffectHandle (tracers , eqn ))
244
253
for t in out_tracers : t .recipe = eqn
245
254
return out_tracers
246
255
else :
247
256
out_tracer = JaxprTracer (self , PartialVal .unknown (out_aval ), None )
248
- out_tracer .recipe = new_eqn_recipe (tracers , [out_tracer ], primitive ,
249
- params , effects , source )
257
+ eqn = new_eqn_recipe (self , tracers , [out_tracer ], primitive ,
258
+ params , effects , source )
259
+ if any (isinstance (e , RefEffect ) for e in effects ):
260
+ self .effect_handles .append (EffectHandle (tracers , eqn ))
261
+ out_tracer .recipe = eqn
250
262
return out_tracer
251
263
252
264
def process_call (self , primitive , f : lu .WrappedFun , tracers , params ):
@@ -321,7 +333,7 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
321
333
for a in out_type ]
322
334
name_stack = self ._current_truncated_name_stack ()
323
335
source = source_info_util .current ().replace (name_stack = name_stack )
324
- eqn = new_eqn_recipe ((* res_tracers , * env_tracers , * unknown_arg_tracers ),
336
+ eqn = new_eqn_recipe (self , (* res_tracers , * env_tracers , * unknown_arg_tracers ),
325
337
out_tracers , primitive , staged_params , jaxpr .effects ,
326
338
source )
327
339
for t in out_tracers : t .recipe = eqn
@@ -390,7 +402,7 @@ def const_out_axes_thunk():
390
402
for a in out_avals ]
391
403
effs = core .filter_named_axis_effects (jaxpr .effects , {params ['axis_name' ]})
392
404
src_info = source_info_util .current ()
393
- eqn = new_eqn_recipe ((* const_tracers , * env_tracers , * unknown_arg_tracers ),
405
+ eqn = new_eqn_recipe (self , (* const_tracers , * env_tracers , * unknown_arg_tracers ),
394
406
out_tracers , primitive , staged_params , effs , src_info )
395
407
for t in out_tracers : t .recipe = eqn
396
408
@@ -425,7 +437,7 @@ def process_custom_transpose(self, prim, call, tracers, **params):
425
437
for aval in params ['out_types' ]]
426
438
in_tracers = map (self .instantiate_const , tracers )
427
439
new_params = dict (params , call = call )
428
- eqn = new_eqn_recipe (in_tracers , out_tracers , prim , new_params ,
440
+ eqn = new_eqn_recipe (self , in_tracers , out_tracers , prim , new_params ,
429
441
core .no_effects , source_info_util .current ())
430
442
for t in out_tracers : t .recipe = eqn
431
443
return out_tracers
@@ -470,7 +482,7 @@ def fwd_jaxpr_thunk(*zeros):
470
482
out_trees = out_trees ,
471
483
symbolic_zeros = symbolic_zeros
472
484
)
473
- eqn = new_eqn_recipe ((* res_tracers , * env_tracers , * tracers ),
485
+ eqn = new_eqn_recipe (self , (* res_tracers , * env_tracers , * tracers ),
474
486
out_tracers , prim , params , jaxpr .effects , source )
475
487
for t in out_tracers : t .recipe = eqn
476
488
return out_tracers
@@ -657,7 +669,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
657
669
out_tracers = [trace .instantiate_const (t ) if inst else t
658
670
for inst , t in zip (instantiate , out_tracers )]
659
671
out_tracers_ = [t for t in out_tracers if not t .is_known ()]
660
- jaxpr , out_consts , env = tracers_to_jaxpr (in_tracers , out_tracers_ , debug_info )
672
+ jaxpr , out_consts , env = tracers_to_jaxpr (in_tracers , out_tracers_ , trace . effect_handles , debug_info )
661
673
return out_tracers , jaxpr , out_consts , env
662
674
663
675
# The below variant implements an optimization where residuals which are also
@@ -739,7 +751,8 @@ class JaxprEqnRecipe(NamedTuple):
739
751
source_info : source_info_util .SourceInfo
740
752
ctx : JaxprEqnContext
741
753
742
- def new_eqn_recipe (in_tracers : Sequence [JaxprTracer ],
754
+ def new_eqn_recipe (trace : JaxprTrace ,
755
+ in_tracers : Sequence [JaxprTracer ],
743
756
out_tracers : Sequence [JaxprTracer ],
744
757
primitive : Primitive ,
745
758
params : dict [str , Any ],
@@ -762,7 +775,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
762
775
config .threefry_partitionable .value ,
763
776
xla_metadata_lib .current_xla_metadata (),
764
777
)
765
- return JaxprEqnRecipe (object ( ), tuple (in_tracers ), map (ref , out_tracers ),
778
+ return JaxprEqnRecipe (next ( trace . counter ), tuple (in_tracers ), map (ref , out_tracers ),
766
779
out_avals , primitive , params , effects , source_info ,
767
780
ctx )
768
781
@@ -780,6 +793,7 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
780
793
def tracers_to_jaxpr (
781
794
in_tracers : Sequence [JaxprTracer ],
782
795
out_tracers : Sequence [JaxprTracer ],
796
+ effect_handles : Sequence [Any ],
783
797
debug_info : core .DebugInfo ,
784
798
) -> tuple [Jaxpr , tuple [Any , ...], tuple [Any , ...]]:
785
799
"""Constructs Jaxpr given tracers for inputs and outputs.
@@ -821,7 +835,15 @@ def type_substitute(aval: AbstractValue) -> AbstractValue:
821
835
822
836
processed_eqn_ids = set ()
823
837
eqns : list [core .JaxprEqn ] = []
824
- for t in toposort ((* in_tracers , * out_tracers )):
838
+
839
+ reachable = toposort
840
+ tracers = reachable ((* in_tracers , * out_tracers , * effect_handles ))
841
+ def sort_key (t ):
842
+ r = t .recipe
843
+ return r .eqn_id if isinstance (r , JaxprEqnRecipe ) else - 1
844
+ tracers = sorted (tracers , key = sort_key )
845
+
846
+ for t in tracers :
825
847
r = t .recipe
826
848
if isinstance (r , JaxprEqnRecipe ):
827
849
# TODO broadcast_in_dim can create a new tracer, not present in parents
0 commit comments