-
After saving and then restoring checkpoint I am ending up with numpy arrays instead of jax arrays in model parameters. What am I doing wrong? Small reproducer: import jax
from jax import numpy as jnp
from flax import linen, optim
from flax.linen import compact, Dense
from flax.training import checkpoints
class MyModel(linen.Module):
@compact
def __call__(self, xs):
return Dense(1)(xs)
model = MyModel()
params = model.init(jax.random.PRNGKey(0), jnp.zeros([0, 10]))
optimizer = optim.Adam().create(params)
checkpoints.save_checkpoint('./test', optimizer, 2)
optimizer = checkpoints.restore_checkpoint('./test', optimizer)
print(type(optimizer.target['params']['Dense_0']['kernel'])) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
This happened to me, too. I don't know if I was just doing something wrong, or if this is intentional, but to get my model working again, I just converted the pytree of parameters back to JAX arrays: optimizer = jax.tree_map(jnp.asarray, optimizer) |
Beta Was this translation helpful? Give feedback.
-
Checkpoints will indeed restore to NumPy values. Normally this shouldn't matter because JAX apis accept both NumPy and JAX arrays as input. If you do need want to prefetch the data to a device we cannot automatically guess which device that should be. Typically you would want to use jax.device_put or jax.device_put_replicated depending on whether you want to do single device or data parallel training. |
Beta Was this translation helpful? Give feedback.
Checkpoints will indeed restore to NumPy values. Normally this shouldn't matter because JAX apis accept both NumPy and JAX arrays as input. If you do need want to prefetch the data to a device we cannot automatically guess which device that should be. Typically you would want to use jax.device_put or jax.device_put_replicated depending on whether you want to do single device or data parallel training.