How to get per-example gradients in Flax? #2084
Answered
by
marcvanzee
marcvanzee
asked this question in
Q&A
-
How to get per-example gradients in Flax? |
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Apr 27, 2022
Replies: 1 comment 2 replies
-
All layers accept batch dimension free inputs. Example: def loss_fn(param, batch):
return jnp.square(param * batch)
model = 3.
per_example_grad_fn = jax.vmap(jax.grad(loss_fn), in_axes=(None, 0))
per_example_grad_fn(param, jnp.ones((3,))) |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
All layers accept batch dimension free inputs. Example: