Replies: 2 comments 2 replies
-
Hi Simon, are you wrapping your backward pass in |
Beta Was this translation helpful? Give feedback.
1 reply
-
Hey @SimonPre, consider using https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.remat.html over your |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Dear Community,
i am currently working on the replication of a GNN model for Knowledge Graph completion using Jax and Flax.
The model involves the computation of new node representations in an iterative fashion using linen.Dense layers. This works fine during the forward pass, but during the backwards pass the memory consumption of my model is substantial. It seems like a copy of all the inputs to my dense layer is saved to memory for each iteration. Is there a way to reduce this memory consumption during the backwards pass?
Best regards and thanks in advance,
Simon
`
`
Beta Was this translation helpful? Give feedback.
All reactions