-
Hi everyone, I'm reading the Flax document about converting Pytorch models to Flax and trying to test some code myself. But I observed inconsistency between Flax Dense and Pytorch linear. There is clearly a big gap, which confuses me because they are mathematically equivalent, and the numerical difference should not be this big. Are there any details I miss here? I'd like to understand this issue. Any input will be appreciated! AssertionError:
Arrays are not almost equal to 7 decimals
Mismatched elements: 1024 / 1024 (100%)
Max absolute difference: 0.00105295
Max relative difference: 0.39149174
x: array([[-0.5895869, -1.6013966, -0.3044982, ..., 0.691458 , -0.2171268,
-0.3255446],
[ 1.3270106, 1.0085726, 0.5121558, ..., 0.4522566, 2.0366738,
-1.3362962]], dtype=float32)
y: array([[-0.5898346, -1.6013664, -0.3044372, ..., 0.6914209, -0.2169722,
-0.3249587],
[ 1.3268726, 1.0090327, 0.5124962, ..., 0.4526754, 2.0368207,
-1.3363303]], dtype=float32) The following is the minimal code to reproduce it. import torch
import jax
import jax.numpy as jnp
import numpy as onp
from jax import random
import flax.linen as nn
B, Cin, Cout = 2, 768, 512
key1, key2 = random.split(random.PRNGKey(0))
x = onp.random.normal(size=(B, Cin)).astype(onp.float32) # input
dense = nn.Dense(Cout, name='dense') # Flax module
params = dense.init(key1, jnp.ones((B, Cin)))
torch_dense = torch.nn.Linear(Cin, Cout) # Pytorch module
state_dict = jax.device_get(params)
with torch.no_grad():
torch_dense.weight.copy_(torch.from_numpy(state_dict['params']['kernel'].T))
torch_dense.bias.copy_(torch.from_numpy(state_dict['params']['bias']))
jax_x = jnp.asarray(x)
torch_x = torch.from_numpy(x)
dense_out = dense.apply(params, x)
with torch.no_grad():
torch_out = torch_dense(torch_x).numpy()
onp.testing.assert_almost_equal(jax.device_get(dense_out), torch_out) Environment:
|
Beta Was this translation helpful? Give feedback.
Answered by
chiamp
Oct 10, 2023
Replies: 1 comment
-
related issue: #3128 |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
devzhk
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
related issue: #3128