Skip to content

Why does Flax serialization convert tuples in pytree to dicts, and how can I prevent this? #2043

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

Answer by @levskaya:

  1. Serialization normalizes everything to nested dicts instead of tuples (such that there's always a unique notion of a "path" of names to a particular leaf object.)
  2. The reason capture_intermediates returns a tuple in general is that a layer can be called multiple times. So there might be multiple intermediates to report.
  3. This manual code should resolve the problem:
def check_tuple(x):
  return isinstance(x, tuple) and len(x) == 1

def maybe_remove_tuple(x):
  return x[0] if check_tuple(x) else x

def remove_leaf_tuples(tree):
  return jax.tree_map(maybe_remove_tuple,
                      tree,
                      is_leaf=check_tuple)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant