Flax model too big... OOM memory allocation error #2283
-
Hello, I am trying to train a simple encoder-decoder model (input: Here is my model: latent_dims = 20
class AE(nn.Module):
@nn.compact
def __call__(self, x):
# encoder
x = jnp.reshape(x, -1)
x = nn.Dense(features=50176)(x)
x = nn.relu(x)
x = nn.Dense(features=512)(x)
x = nn.relu(x)
x = nn.Dense(features=latent_dims)(x)
# decoder
x = nn.Dense(features=512)(x)
x = nn.relu(x)
x = nn.Dense(features=50176 * 2)(x)
return x.reshape(2,224,224) when I run this cell: autoencoder = AE()
sample_batch = jnp.ones((1,224,224))
params = autoencoder.init(jax.random.PRNGKey(0), sample_batch) the following error appears:
And so it seems that the model is simply too big (9 GB) to be stored in main memory. I am not sure how I can resolve this issue given the compute resources. Is it possible to reduce model size in Flax? What about memory-mapping the model between disk and main memory if it is too large to fit in? Again, I simply just don't know how to progress from this point. I think the model is relatively simple and many architectures are able to take in images as big as 224 x 224 so I find the fact that the model takes up this much space a little surprising. Any recommendations? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
When initializing the model you should autoencoder = AE()
sample_batch = jnp.ones((1,224,224))
params = jax.jit(autoencoder.init)(jax.random.PRNGKey(0), sample_batch) |
Beta Was this translation helpful? Give feedback.
When initializing the model you should
jit
your init function. If you don't do this, you will run a full forward pass which consumes a lot of memory. Jitting will only initialize the parameters and everything else will be optimized away by XLA's dead code elimination. So you should try this: