Porting PyTorch layer norm to Flax #2197
-
I'm trying to port the layer norm module from PyTorch to Flax. I transformed the state dict from PyTorch to flax yet the layer norm is still not producing the same results. I tried this on OSX 12.4 on M1, Ubuntu 20.04 and google colab and in all of them the outputs aren't equal. I created a minimal example:
And created a colab link: https://colab.research.google.com/drive/1wTIbWbM9LBjzlKC14aegLjWHrkoXdapF#scrollTo=hv0tS0BJ2_EE Thanks in advance for the help. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Flax uses shape inference, so you do not have to provide the normalized shape as an input to the constructor for Putting this together gives the following code: import flax
import torch
import jax.numpy as jnp
import numpy as np
torch.manual_seed(0)
torch_layernorm = torch.nn.LayerNorm(12)
flax_layernorm = flax.linen.LayerNorm()
torch_state_dict = torch_layernorm.state_dict()
torch_state_dict["scale"] = jnp.array(np.array(torch_state_dict.pop("weight")))
torch_state_dict["bias"] = jnp.array(np.array(torch_state_dict.pop("bias")))
x = torch.randn((8, 12))
x_flax = jnp.array(np.array(x))
torch_out = torch_layernorm(x)
flax_out = flax_layernorm.apply(variables={"params": torch_state_dict}, x=x_flax)
np.testing.assert_allclose(torch_out.detach().numpy(), flax_out, rtol=1e-5) |
Beta Was this translation helpful? Give feedback.
-
Hi, I've been trying to do the same (porting pretrained ViT based PyTorch models to Flax) and facing a similar issue. I noticed the following (below code gives error) and also this issue. Is there a way to enforce PyTorch style LayerNorm computation? import flax
import torch
import jax.numpy as jnp
import numpy as np
torch.manual_seed(0)
torch_layernorm = torch.nn.LayerNorm(768)
flax_layernorm = flax.linen.LayerNorm(use_fast_variance=False)
torch_state_dict = torch_layernorm.state_dict()
torch_state_dict["scale"] = jnp.array(np.array(torch_state_dict.pop("weight")))
torch_state_dict["bias"] = jnp.array(np.array(torch_state_dict.pop("bias")))
x = torch.randn((1, 197, 768))
x_flax = jnp.array(np.array(x))
torch_out = torch_layernorm(x)
flax_out = flax_layernorm.apply(variables={"params": torch_state_dict}, x=x_flax)
np.testing.assert_almost_equal(torch_out.detach().numpy(), flax_out, decimal=5) While the error here is low, it seems to add up and give larger errors for later outputs. Thanks a lot in advance for any help! |
Beta Was this translation helpful? Give feedback.
Flax uses shape inference, so you do not have to provide the normalized shape as an input to the constructor for
LayerNorm
. Because you do provide an argument (12
), this meansepsilon
will be set to 12. Also it is a good idea to use a manual seed so you get reproducible runs. Finally I'd recommend usingnp.testing.assert_allclose
since it will give you more information if your outputs don't match (absolute and relative tolerances).Putting this together gives the following code: