How to resolve pytrees issues when using a Flax optimizer with a Haiku model? #997
-
Original question by @j5b: Does a Flax optim optimizer to work with a Haiku model (and its parameter representation)? I'm running into pytree issues and I'm wondering if there's some incompatibility that I didn't consider. Here's code reproducing the issue: import haiku as hk
import jax
from jax import numpy as jnp
from flax import optim
class SomeModule(hk.Module):
def __call__(self, x):
w = hk.get_parameter("w", [], init=jnp.zeros)
return x + w
def loss_fn(params):
return f.apply(params, None, 1)
f = hk.transform(lambda x: SomeModule()(x))
# Get params from Haiku
params = f.init(None, 1)
# Create Flax optimizer from Haiku parameters.
optimizer = optim.Adam(learning_rate=1.0).create(params)
grad = jax.grad(loss_fn)(optimizer.target)
optimizer.apply_gradient(grad) This gives the following error:
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Answer by @jheek: This is actually a Haiku bug. The FlatMap implementation breaks the assumptions of pytrees. Here is an example that doesn't use Flax at all: double_params = jax.tree_map(lambda x: (x, x), params)
treedef = jax.tree_structure(params)
treedef.flatten_up_to(double_params)
You can flatten the representation manually by using So your code would look like this: ...
def loss(params):
params = treedef.unflatten(optimizer.target)
return f.apply(params, None, 1)
...
params = f.init(None, 1)
flat_params, treedef = jax.tree_flatten(params)
optimizer_def = optim.Adam(learning_rate=1.0)
optimizer=optimizer_def.create(params) |
Beta Was this translation helpful? Give feedback.
-
I recently stumbled upon that same issue. |
Beta Was this translation helpful? Give feedback.
Answer by @jheek:
This is actually a Haiku bug. The FlatMap implementation breaks the assumptions of pytrees.
Here is an example that doesn't use Flax at all:
flatten_up_to
simply doesn't work with the pytree implementation of FlatMap.You can flatten the representation manually by using
flat_params
,treedef = jax.tree_flatten(params)
and then you usetreedef.unflatten(flat_params)
before passing to Haiku.So your code would look like this: