Recommended Approach for Gradient Accumulation #2030
Unanswered
sanchit-gandhi
asked this question in
Q&A
Replies: 1 comment 1 reply
-
Hi @sanchit-gandhi, right now we don't have any simple examples for doing gradient accumulation yet. @melissatan is planning to look into creating a HOWTO explaining how to do this simple setting. In the meantime, have you seen Discussion #2008? Someone seems to be running into the same problem as you when using Optax Multistep, and they documented their solution here: google-deepmind/optax#320. Please let us know if this helps resolving issue! |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm working on a training script for a Speech model in Flax, and was wondering if I could get an opinion from the community on what the best way is of implementing gradient accumulation in JAX/Flax. To my understanding, there are two viable options:
The simpler of the two approaches, I initially trialled using Optax MultiSteps. Keeping the per-device batch size fixed, I was not able to increase the number of gradient accumulation steps to be any greater than 1 (equivalent to no gradient accumulation!). Thus, I implemented a version of gradient accumulation by hand (see here). Once again, I was not able to increase the number of gradient accumulation steps to be any greater than 1 keeping the per-device batch size fixed. Are there any caveats with using Optax MultiSteps in regards to memory that people have experienced before? As for the custom approach, is there a 'standard' way of going about doing this? Many thanks for all your help!
Beta Was this translation helpful? Give feedback.
All reactions