Can we use flax.linen.vjp
to calculate the HVP of a loss wrt a network's weights?
#1748
-
The jax documentation has a nice example of calculating the Hessian-vector product of a function using Is it possible to use import jax
import flax
def hvp(loss, primals, tangents):
return flax.linen.jvp(jax.grad(loss), primals, tangents)[1] where Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
The linen versions of jax transforms are only required when you want to use them inside a linen Module. In this case you should simply use jax.jvp and jax.grad |
Beta Was this translation helpful? Give feedback.
The linen versions of jax transforms are only required when you want to use them inside a linen Module. In this case you should simply use jax.jvp and jax.grad