Skip to content

How to resolve pytrees issues when using a Flax optimizer with a Haiku model? #997

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

Answer by @jheek:

This is actually a Haiku bug. The FlatMap implementation breaks the assumptions of pytrees.

Here is an example that doesn't use Flax at all:

double_params = jax.tree_map(lambda x: (x, x), params)
treedef = jax.tree_structure(params)
treedef.flatten_up_to(double_params)

flatten_up_to simply doesn't work with the pytree implementation of FlatMap.

You can flatten the representation manually by using flat_params, treedef = jax.tree_flatten(params) and then you use treedef.unflatten(flat_params) before passing to Haiku.

So your code would look like this:

...
def loss(params):
  params = treedef.unflatten(optimizer.target)
  return f.apply(params, None, 1)

...

params = f.init(

Replies: 2 comments 2 replies

Comment options

marcvanzee
Feb 5, 2021
Maintainer Author

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Comment options

You must be logged in to vote
2 replies
@jheek
Comment options

jheek Mar 8, 2021
Maintainer

@jheek
Comment options

jheek Mar 8, 2021
Maintainer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants