Replies: 1 comment 1 reply
-
Hey @hr0nix, keep an eye on jax-ml/jax#15783. One idea would be to split + partition the |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello!
I'm trying to figure out what are the best practices to train a model with stochastic forward pass (say, transformer decoder with dropout) using pjit.
My train step that I compile using pjit looks something like this:
For now I just shard everything along the batch dimension only to implement data parallelism.
If I just compile the code as is, everything works, but is very slow, because XLA compiler inserts collective primitives for every dropout, as mentioned at https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
If I set
before compiling the code, collective primitives are no longer present. However for some reason XLA decides to allocate large i32 arrays for (my guess) storing RNG keys for every dropped out activation. It increases memory consumption significantly, and I can no longer fit the batch of the same size in the memory of a single GPU. This is especially surprising because transformer blocks are rematerialized, however as far as I understand XLA intends to preserve these buffers anyway.
OOM message that led me to this conclusion (
128 x 1024 x 768
isbatch x seq_len x embedding_dim
):I have several questions:
Beta Was this translation helpful? Give feedback.
All reactions