Skip to content

Can we use flax.linen.vjp to calculate the HVP of a loss wrt a network's weights? #1748

Answered by jheek
newalexander asked this question in Q&A
Discussion options

You must be logged in to vote

The linen versions of jax transforms are only required when you want to use them inside a linen Module. In this case you should simply use jax.jvp and jax.grad

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@newalexander
Comment options

@jheek
Comment options

jheek Jan 4, 2022
Maintainer

@newalexander
Comment options

Answer selected by newalexander
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants