Replies: 1 comment 1 reply
-
You are now rematerializing the entire LSTM loop. For remat to reduce memory you should rematerialize smaller chunks like for example each cell execution, e.g.:
|
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi Flax community,
I'd like to use
nn.remat
to save memory when using LSTM to process very long sequences. However, I found existing tutorial fornn.remat
is on deep feedforward neural networks, not recurrent ones. I triednn.remat
in LSTM, but I did not find any change in memory, but only increase in time. Please see the demo code below. I triedseq2seq = True
andseq2seq = False
, but neither of them works. Could you help me with that? Or should I usenn.remat_scan
?Beta Was this translation helpful? Give feedback.
All reactions