Skip to content

Flax model too big... OOM memory allocation error #2283

Answered by marcvanzee
rosikand asked this question in Q&A
Discussion options

You must be logged in to vote

When initializing the model you should jit your init function. If you don't do this, you will run a full forward pass which consumes a lot of memory. Jitting will only initialize the parameters and everything else will be optimized away by XLA's dead code elimination. So you should try this:

autoencoder = AE()
sample_batch = jnp.ones((1,224,224))
params = jax.jit(autoencoder.init)(jax.random.PRNGKey(0), sample_batch)

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@rosikand
Comment options

@marcvanzee
Comment options

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants