Skip to content

Is there a way to sum the parameters of flax.nn.Module and create a new module? #392

Answered by jheek
BoyuanJackChen asked this question in General
Discussion options

You must be logged in to vote

Hi,
if I understand correctly you have a loop over multiple batches and you want to average the gradients across the batches.

Here's a code snippet that should give the basic idea of how to do this:

def loss_fn(model, batch):  # Will be used repeatedly in the loop
            rays_o, rays_d, target = batch # or whatever is in the batch
            rgb, depth, acc = render_rays(model, rays_o, rays_d, near=near, far=far,
                                batchify_size=batchify_size, N_samples=N_samples, rand=True)
            loss = jnp.mean(jnp.square(rgb - target))
            return loss, rgb
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

avg_grad = None
num_batches = len(batches)
fo…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@BoyuanJackChen
Comment options

Answer selected by jheek
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants
Converted from issue

This discussion was converted from issue #392 on August 14, 2020 09:29.