Is there a way to sum the parameters of flax.nn.Module and create a new module? #392
Answered
by
jheek
BoyuanJackChen
asked this question in
General
-
I am training a flax Adam optimizer. The data is divided batches, and I want to update the model with the averaged grad among all batches. I checked that the type of grad is flax.nn.Module, and I wonder how to sum it up with jax. Description of the model to be implementedThe model is a simple one with 8 layers:
The update line looks like this:
Dataset the model could be trained onImage data. Specific points to considerReference implementations in other frameworks |
Beta Was this translation helpful? Give feedback.
Answered by
jheek
Aug 14, 2020
Replies: 1 comment 1 reply
-
Hi, Here's a code snippet that should give the basic idea of how to do this:
|
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
jheek
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
if I understand correctly you have a loop over multiple batches and you want to average the gradients across the batches.
Here's a code snippet that should give the basic idea of how to do this: