Skip to content

How to move train_state object to CPU? #1783

Answered by jheek
mattiasmar asked this question in Q&A
Discussion options

You must be logged in to vote

The easiest way to do this is to simply create the train_state on CPU to begin with. In this case by using model = jax.jit(train.create_train_state, static_argnums=(1,), backend="cpu")(init_rng, config)

Btw JAX is currently missing a way to control device placement without using jit see jax-ml/jax#8879

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@mattiasmar
Comment options

Answer selected by mattiasmar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants