Skip to content

SPMD training on multiple gpus (How to) #2415

Answered by jheek
zjoukhadar asked this question in Q&A
Discussion options

You must be logged in to vote

We normally make sure the state is synced by averaging the gradients before applying them in the optimizer:

 (loss, logits), grads = grad_fn(
        replicated_state.params, replicated_rng, imgs, labels
    )
grads = jax.lax.pmean(grads, "batch")

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@zjoukhadar
Comment options

Answer selected by zjoukhadar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants