We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents f37f235 + 5f1f29e commit 9191843Copy full SHA for 9191843
jax/interpreters/partial_eval.py
@@ -446,8 +446,8 @@ def tracers_to_jaxpr(in_tracers, out_tracers):
446
def getvar(t):
447
var = t_to_var.get(id(t))
448
if var is None:
449
- var = newvar(partial_val_aval(*t.pval))
450
- t_to_var[id(t)] = var
+ aval = t.pval[0] if t.pval[0] is not None else abstract_unit
+ var = t_to_var[id(t)] = newvar(aval)
451
return var
452
sorted_tracers = toposort(out_tracers)
453
invars = map(getvar, in_tracers)
@@ -458,8 +458,7 @@ def getvar(t):
458
def getconstvar(c):
459
var = const_to_var.get(id(c))
460
461
- var = newvar(get_aval(c))
462
- const_to_var[id(c)] = var
+ var = const_to_var[id(c)] = newvar(get_aval(c))
463
464
processed_eqn_ids = set()
465
for t in sorted_tracers:
0 commit comments