Skip to content

How to extract intermediate values from "underneath" vmap? #1934

Answered by marcvanzee
jatentaki asked this question in Q&A
Discussion options

You must be logged in to vote

You should also include the intermediates variable collection in the variable_axes in the lifted vmap call. So replace variable_axes={'params': None} with variable_axes={'params': None, 'intermediates': 0}.

We should probably document this in our HOWTO.

@jheek

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jatentaki
Comment options

Answer selected by jatentaki
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants
Converted from issue

This discussion was converted from issue #1930 on February 23, 2022 09:49.