Skip to content

Sharding PRNG Keys Across Devices #2021

Answered by jheek
sanchit-gandhi asked this question in Q&A
Discussion options

You must be logged in to vote

This leads me on to my overall question: should we not set the same PRNG key for the model over all devices to maintain our overall batch size and equivalence to single device training?

The missing detail here is that even with the same PRNGKey the randomness is not the same for different items in the batch so if you do dropout(rng, batch) it will dropout different features for each item in the batch. By extension, if you want pmap(dropout(rng, batch)) the mini-batches on each device should have different dropout masks.

If you want to make sure that the random noise generated in the same way irrespective of global and local batch size you should use vmap. For a single input you can then…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@sanchit-gandhi
Comment options

@jheek
Comment options

jheek Apr 7, 2022
Maintainer

@sanchit-gandhi
Comment options

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants