16
16
from collections import namedtuple
17
17
from collections .abc import Callable , Sequence , Hashable
18
18
import contextlib
19
+ from dataclasses import dataclass
19
20
from functools import partial
20
21
import itertools as it
21
22
import operator as op
@@ -147,6 +148,10 @@ def get_aval(self) -> AbstractValue:
147
148
else :
148
149
return self [0 ]
149
150
151
+ @dataclass (frozen = True )
152
+ class EffectHandle :
153
+ parents : list [Tracer ]
154
+ recipe : JaxprEqnRecipe
150
155
151
156
class JaxprTrace (Trace ['JaxprTracer' ]):
152
157
@@ -156,6 +161,8 @@ def __init__(self, parent_trace:Trace, name_stack: source_info_util.NameStack, t
156
161
self .tag = tag
157
162
self .parent_trace = parent_trace
158
163
self .requires_low = False
164
+ self .effect_handles : list [EffectHandle ] = []
165
+ self .counter = it .count ()
159
166
160
167
def to_jaxpr_tracer (self , x ):
161
168
if isinstance (x , JaxprTracer ) and x ._trace .tag is self .tag :
@@ -239,14 +246,17 @@ def default_process_primitive(self, primitive, tracers, params):
239
246
if primitive .multiple_results :
240
247
out_tracers = [JaxprTracer (self , PartialVal .unknown (aval ), None )
241
248
for aval in out_aval ]
242
- eqn = new_eqn_recipe (tracers , out_tracers , primitive , params , effects ,
249
+ eqn = new_eqn_recipe (self , tracers , out_tracers , primitive , params , effects ,
243
250
source )
251
+ if effects : self .effect_handles .append (EffectHandle (tracers , eqn ))
244
252
for t in out_tracers : t .recipe = eqn
245
253
return out_tracers
246
254
else :
247
255
out_tracer = JaxprTracer (self , PartialVal .unknown (out_aval ), None )
248
- out_tracer .recipe = new_eqn_recipe (tracers , [out_tracer ], primitive ,
249
- params , effects , source )
256
+ eqn = new_eqn_recipe (self , tracers , [out_tracer ], primitive ,
257
+ params , effects , source )
258
+ if effects : self .effect_handles .append (EffectHandle (tracers , eqn ))
259
+ out_tracer .recipe = eqn
250
260
return out_tracer
251
261
252
262
def process_call (self , primitive , f : lu .WrappedFun , tracers , params ):
@@ -321,7 +331,7 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
321
331
for a in out_type ]
322
332
name_stack = self ._current_truncated_name_stack ()
323
333
source = source_info_util .current ().replace (name_stack = name_stack )
324
- eqn = new_eqn_recipe ((* res_tracers , * env_tracers , * unknown_arg_tracers ),
334
+ eqn = new_eqn_recipe (self , (* res_tracers , * env_tracers , * unknown_arg_tracers ),
325
335
out_tracers , primitive , staged_params , jaxpr .effects ,
326
336
source )
327
337
for t in out_tracers : t .recipe = eqn
@@ -390,7 +400,7 @@ def const_out_axes_thunk():
390
400
for a in out_avals ]
391
401
effs = core .filter_named_axis_effects (jaxpr .effects , {params ['axis_name' ]})
392
402
src_info = source_info_util .current ()
393
- eqn = new_eqn_recipe ((* const_tracers , * env_tracers , * unknown_arg_tracers ),
403
+ eqn = new_eqn_recipe (self , (* const_tracers , * env_tracers , * unknown_arg_tracers ),
394
404
out_tracers , primitive , staged_params , effs , src_info )
395
405
for t in out_tracers : t .recipe = eqn
396
406
@@ -425,7 +435,7 @@ def process_custom_transpose(self, prim, call, tracers, **params):
425
435
for aval in params ['out_types' ]]
426
436
in_tracers = map (self .instantiate_const , tracers )
427
437
new_params = dict (params , call = call )
428
- eqn = new_eqn_recipe (in_tracers , out_tracers , prim , new_params ,
438
+ eqn = new_eqn_recipe (self , in_tracers , out_tracers , prim , new_params ,
429
439
core .no_effects , source_info_util .current ())
430
440
for t in out_tracers : t .recipe = eqn
431
441
return out_tracers
@@ -470,7 +480,7 @@ def fwd_jaxpr_thunk(*zeros):
470
480
out_trees = out_trees ,
471
481
symbolic_zeros = symbolic_zeros
472
482
)
473
- eqn = new_eqn_recipe ((* res_tracers , * env_tracers , * tracers ),
483
+ eqn = new_eqn_recipe (self , (* res_tracers , * env_tracers , * tracers ),
474
484
out_tracers , prim , params , jaxpr .effects , source )
475
485
for t in out_tracers : t .recipe = eqn
476
486
return out_tracers
@@ -657,7 +667,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
657
667
out_tracers = [trace .instantiate_const (t ) if inst else t
658
668
for inst , t in zip (instantiate , out_tracers )]
659
669
out_tracers_ = [t for t in out_tracers if not t .is_known ()]
660
- jaxpr , out_consts , env = tracers_to_jaxpr (in_tracers , out_tracers_ , debug_info )
670
+ jaxpr , out_consts , env = tracers_to_jaxpr (in_tracers , out_tracers_ , trace . effect_handles , debug_info )
661
671
return out_tracers , jaxpr , out_consts , env
662
672
663
673
# The below variant implements an optimization where residuals which are also
@@ -739,7 +749,8 @@ class JaxprEqnRecipe(NamedTuple):
739
749
source_info : source_info_util .SourceInfo
740
750
ctx : JaxprEqnContext
741
751
742
- def new_eqn_recipe (in_tracers : Sequence [JaxprTracer ],
752
+ def new_eqn_recipe (trace : JaxprTrace ,
753
+ in_tracers : Sequence [JaxprTracer ],
743
754
out_tracers : Sequence [JaxprTracer ],
744
755
primitive : Primitive ,
745
756
params : dict [str , Any ],
@@ -762,7 +773,7 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
762
773
config .threefry_partitionable .value ,
763
774
xla_metadata_lib .current_xla_metadata (),
764
775
)
765
- return JaxprEqnRecipe (object ( ), tuple (in_tracers ), map (ref , out_tracers ),
776
+ return JaxprEqnRecipe (next ( trace . counter ), tuple (in_tracers ), map (ref , out_tracers ),
766
777
out_avals , primitive , params , effects , source_info ,
767
778
ctx )
768
779
@@ -780,6 +791,7 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
780
791
def tracers_to_jaxpr (
781
792
in_tracers : Sequence [JaxprTracer ],
782
793
out_tracers : Sequence [JaxprTracer ],
794
+ effect_handles : Sequence [Any ],
783
795
debug_info : core .DebugInfo ,
784
796
) -> tuple [Jaxpr , tuple [Any , ...], tuple [Any , ...]]:
785
797
"""Constructs Jaxpr given tracers for inputs and outputs.
@@ -821,7 +833,15 @@ def type_substitute(aval: AbstractValue) -> AbstractValue:
821
833
822
834
processed_eqn_ids = set ()
823
835
eqns : list [core .JaxprEqn ] = []
824
- for t in toposort ((* in_tracers , * out_tracers )):
836
+
837
+ reachable = toposort
838
+ tracers = reachable ((* in_tracers , * out_tracers , * effect_handles ))
839
+ def sort_key (t ):
840
+ r = t .recipe
841
+ return r .eqn_id if isinstance (r , JaxprEqnRecipe ) else - 1
842
+ tracers = sorted (tracers , key = sort_key )
843
+
844
+ for t in tracers :
825
845
r = t .recipe
826
846
if isinstance (r , JaxprEqnRecipe ):
827
847
# TODO broadcast_in_dim can create a new tracer, not present in parents
0 commit comments