How to effectively do Gradient Accumulation to flax.training.train_state
?
#1989
Unanswered
reshinthadithyan
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have a
train_step
that looks roughly like this,What is the most effective way of adding gradient accumulation to this? Thanks. I had implemented a rough version with
jax.lax.cond
[Based on this] writing an additional accumulation attribute toflax.training.train_state
. Is this the right way to do it?Thanks, much. An example with gradient accumulation would help.
Beta Was this translation helpful? Give feedback.
All reactions