Efficiently backpropagate subset of minibatch #2262
-
Hi, I have a batch size of 2048 and based on the results from the forward pass I only want to backpropagate half of the minibatch (I have 1024 indices). However, if I just return the mean of the losses at the indices and then let jax calculate the gradient it's just as slow as if I would backpropagate the whole minibatch. It's actually faster to do an additional forward pass with the selected images and then take the gradient. Do you have any ideas how to make this more efficient or is this just an inherent limitation of the jax compiler? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I think that you can split your batch and use |
Beta Was this translation helpful? Give feedback.
I think that you can split your batch and use
jax.lax.stop_gradient
on the half of the batch you want to exclude