Skip to content

Accumulation of Monte-Carlo gradients within a flax module to avoid OOM error #2301

Discussion options

You must be logged in to vote

I had mistakingly posted this question on google/jax's Q&A originally. I wasn't aware of the "transfer discussion" functionality. Hence I copied the post here and deleted the content of the initial post.

In the meantime, @YouJiacheng had posted and answer in the original post, see jax-ml/jax#11528

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant