Gradient Accumulation & Optax MultiSteps #2008
-
Hello, I am testing "optax.MultiSteps" for gradient accumulation on Colab TPU V2, but every time I use anything above 1, I get an OOM. It seems that it increases the memory requirements equivalently to increase the batch size, which should not be the case. My understanding is that all I need to do is to use it, is to call it after the optimizer and increase the training batch size by multiplying it by the gradient accumulation steps.
The rest of the code should be the same. Is my understanding is correct or there is something else we should take care of while using the MultiSteps function? I have posted this question on Optax repo, but it will be great to hear your feedback. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Thanks. The optax team found the problem. |
Beta Was this translation helpful? Give feedback.
Thanks. The optax team found the problem.