Skip to content

Commit 9191843

Browse files
authored
Merge pull request jax-ml#2654 from google/pfix
fix jaxpr invar avals
2 parents f37f235 + 5f1f29e commit 9191843

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

jax/interpreters/partial_eval.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,8 @@ def tracers_to_jaxpr(in_tracers, out_tracers):
446446
def getvar(t):
447447
var = t_to_var.get(id(t))
448448
if var is None:
449-
var = newvar(partial_val_aval(*t.pval))
450-
t_to_var[id(t)] = var
449+
aval = t.pval[0] if t.pval[0] is not None else abstract_unit
450+
var = t_to_var[id(t)] = newvar(aval)
451451
return var
452452
sorted_tracers = toposort(out_tracers)
453453
invars = map(getvar, in_tracers)
@@ -458,8 +458,7 @@ def getvar(t):
458458
def getconstvar(c):
459459
var = const_to_var.get(id(c))
460460
if var is None:
461-
var = newvar(get_aval(c))
462-
const_to_var[id(c)] = var
461+
var = const_to_var[id(c)] = newvar(get_aval(c))
463462
return var
464463
processed_eqn_ids = set()
465464
for t in sorted_tracers:

0 commit comments

Comments
 (0)