Skip to content

Porting PyTorch layer norm to Flax #2197

Answered by marcvanzee
gadgetsam asked this question in Q&A
Discussion options

You must be logged in to vote

Flax uses shape inference, so you do not have to provide the normalized shape as an input to the constructor for LayerNorm. Because you do provide an argument (12), this means epsilon will be set to 12. Also it is a good idea to use a manual seed so you get reproducible runs. Finally I'd recommend using np.testing.assert_allclose since it will give you more information if your outputs don't match (absolute and relative tolerances).

Putting this together gives the following code:

import flax
import torch
import jax.numpy as jnp
import numpy as np

torch.manual_seed(0)

torch_layernorm = torch.nn.LayerNorm(12)
flax_layernorm = flax.linen.LayerNorm()
torch_state_dict = torch_layernorm.state_…

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Comment options

You must be logged in to vote
1 reply
@cgarciae
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants
Converted from issue

This discussion was converted from issue #2195 on June 14, 2022 07:53.