Sharding PRNG Keys Across Devices #2021
-
The flax/flax/training/common_utils.py Lines 29 to 34 in cd5c4d7 I have a somewhat conceptual question concerning this use of different PRNG keys across local devices - I'll rely on several simple examples throughout this discussion to help convey my points. Single Device TrainingFirst, let's assume that we are training a stochastic model on a single device. Given a batch of inputs Multiple Device TrainingSingle PRNG Key Sharded PRNG Keys This leads me on to my overall question: should we not set the same PRNG key for the model over all devices to maintain our overall batch size and equivalence to single device training? For reference, the Hugging Face training script |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
The missing detail here is that even with the same PRNGKey the randomness is not the same for different items in the batch so if you do dropout(rng, batch) it will dropout different features for each item in the batch. By extension, if you want pmap(dropout(rng, batch)) the mini-batches on each device should have different dropout masks. If you want to make sure that the random noise generated in the same way irrespective of global and local batch size you should use vmap. For a single input you can then calculate the dropout_rng as |
Beta Was this translation helpful? Give feedback.
The missing detail here is that even with the same PRNGKey the randomness is not the same for different items in the batch so if you do dropout(rng, batch) it will dropout different features for each item in the batch. By extension, if you want pmap(dropout(rng, batch)) the mini-batches on each device should have different dropout masks.
If you want to make sure that the random noise generated in the same way irrespective of global and local batch size you should use vmap. For a single input you can then…