Skip to content

Numpy arrays in model parameters after restoring checkpoint file #1199

Answered by jheek
PgLoLo asked this question in Q&A
Discussion options

You must be logged in to vote

Checkpoints will indeed restore to NumPy values. Normally this shouldn't matter because JAX apis accept both NumPy and JAX arrays as input. If you do need want to prefetch the data to a device we cannot automatically guess which device that should be. Typically you would want to use jax.device_put or jax.device_put_replicated depending on whether you want to do single device or data parallel training.

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants