Why does Flax serialization convert tuples in pytree to dicts, and how can I prevent this? #2043
Answered
by
marcvanzee
marcvanzee
asked this question in
Q&A
-
(Original question by jmgilmer@) I'm often getting hit with bugs in my code due to flax serialization converting tuples in pytrees to dictionaries with keys
Specifically, I am calling this function: def append_pytree_leaves(full_pytree, to_append):
"""Appends all leaves in the to_append pytree to the full_pytree.
We assume full_pytree and to_append have the same structure. The leaves of
full_pytree will have shape (num_previous_appends, *to_append_leaf_shape).
For example if full_pytree = {'a': np.ones((2, 10))} and
to_append = {'a': np.ones(10)}. Then append_pytree(full_pytree, to_append)
returns {'a':, np.ones(3, 10)}. If full_pytree is None, then in the above
example returns {'a': np.ones(1, 10)}.
Args:
full_pytree: pytree of all previously appended pytrees.
to_append: pytree with same structure of leaves to be appended to
full_pytree.
Returns:
A pytree where the leaves of to_append have been concatenate onto the
leaves of full_pytree.
"""
if not full_pytree:
return jax.tree_map(lambda x: np.expand_dims(x, axis=0), to_append)
return jax.tree_multimap(lambda x, y: array_append(y, x), to_append,
full_pytree) The running state is a pytree of the same shape as what's output from capture_intermediates, but the leaves are a concatenated array. This works until I restore the saved pytree b/c of the tuple conversion that occurs |
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Apr 11, 2022
Replies: 1 comment
-
Answer by @levskaya:
def check_tuple(x):
return isinstance(x, tuple) and len(x) == 1
def maybe_remove_tuple(x):
return x[0] if check_tuple(x) else x
def remove_leaf_tuples(tree):
return jax.tree_map(maybe_remove_tuple,
tree,
is_leaf=check_tuple) |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Answer by @levskaya:
capture_intermediates
returns a tuple in general is that a layer can be called multiple times. So there might be multiple intermediates to report.