diff --git a/jax/experimental/optimizers.py b/jax/experimental/optimizers.py index d1449329ed18..2b92ae24bb49 100644 --- a/jax/experimental/optimizers.py +++ b/jax/experimental/optimizers.py @@ -89,10 +89,10 @@ # lists (with no further nesting). OptimizerState = namedtuple("OptimizerState", - ["states_flat", "tree_def", "subtree_defs"]) + ["packed_state", "tree_def", "subtree_defs"]) register_pytree_node( OptimizerState, - lambda xs: ((xs.states_flat,), (xs.tree_def, xs.subtree_defs)), + lambda xs: ((xs.packed_state,), (xs.tree_def, xs.subtree_defs)), lambda data, xs: OptimizerState(xs[0], data[0], data[1])) def optimizer(opt_maker):