How to move train_state object to CPU? #1783
-
Building on FLAX' mnist example the This allocates memory on the GPU. I tried moving this memory buffer to CPU by calling:
and also by calling: The former didn't release the memory on the GPU and the later gave |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The easiest way to do this is to simply create the train_state on CPU to begin with. In this case by using Btw JAX is currently missing a way to control device placement without using jit see jax-ml/jax#8879 |
Beta Was this translation helpful? Give feedback.
The easiest way to do this is to simply create the train_state on CPU to begin with. In this case by using
model = jax.jit(train.create_train_state, static_argnums=(1,), backend="cpu")(init_rng, config)
Btw JAX is currently missing a way to control device placement without using jit see jax-ml/jax#8879