Accumulation of Monte-Carlo gradients within a flax module to avoid OOM error #2301
-
I am currently training a model with an encoder-decoder architecture, and which task is to reconstruct a corrupted version (e.g. cropped, masked) version of the input. The model is stochastic, and the maximum likelihood is estimated using Monte-Carlo sampling of the latent representation distribution. A much-simplified and lightweight version of the model and train script is provided below to make things concrete. The problem I face is that, in reality and unlike the given example, the input and activations of my model are of very high dimensionality. Running a training step while drawing a single Monte-Carlo sample (n_samples=1) barely fits into a a100 card with 80 GB of memory. I have investigated how to bring this memory consumption down, and will be wrapping the encoder with jax's gradient checkpointing (rematerialization) decorator to somewhat mitigate memory consumption under reverse mode automatic differentiation. More critical however, is that I need to draw a significant quantity of samples from the latent representation, say 16, to accurately estimate the likelihood objective. This implies ~16X the memory increase compared to the simple sample case, and will inevitably to an OOM error. I would like to circumvent this issue by carrying-out the (16) forward calls to the decoder in sequence (as opposed to parallel) and similarly accumulate the associated gradients during the backward pass. In practice, how can I replace the vmap-based line below from the class Model to carry this out? It seems that this could be carried out using jax's custom_vjp / custom_jvp functionality, e.g. jax-ml/jax#10131. However, I am not sure how to carry this out under the flax framework, in particular within a flax module, and how this ties to model parameter management. In particular, I would need to both accumulate the gradients of the decoder's parameter and that of the decoder's input (i.e. samples z), and complete the back-propagation towards the encoder...
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I had mistakingly posted this question on google/jax's Q&A originally. I wasn't aware of the "transfer discussion" functionality. Hence I copied the post here and deleted the content of the initial post. In the meantime, @YouJiacheng had posted and answer in the original post, see jax-ml/jax#11528 |
Beta Was this translation helpful? Give feedback.
I had mistakingly posted this question on google/jax's Q&A originally. I wasn't aware of the "transfer discussion" functionality. Hence I copied the post here and deleted the content of the initial post.
In the meantime, @YouJiacheng had posted and answer in the original post, see jax-ml/jax#11528