SPMD training on multiple gpus (How to) #2415
-
Hi, I am trying to understand how to do SPMD (single program multiple data) training on multiple gpus. Since it is a single program then we would replicate the state, and rng in all devices to make sure that every device will have exactly the same program. Now, we can easily calculate metrics across all devices using gathered_loss = jax.lax.pmean(loss, axis_name="batch") Since every device will process different batches, as they train, each device will have a different set of weights/params/state. In this line from the code below. the replicated_state will be different in each device since each device is processing a different batch of the data. How do I aggregate these states from all devices to get final weights that are learned from all batches? replicated_state = replicated_state.apply_gradients(grads=grads) Here is a pseudo-code of the whole process that I described above. #Prepare data/batching
train_dataset = db.as_dataset(split=tfds.split_for_jax_process("train"))
test_dataset = db.as_dataset(split=tfds.split_for_jax_process("test"))
per_device_batch_size = 512
per_host_batch_size = per_device_batch_size * jax.local_device_count()
train_dataset = train_dataset.batch(per_host_batch_size, drop_remainder=False)
test_dataset = test_dataset.batch(per_host_batch_size, drop_remainder=False)
#create a train state
state = train_state.TrainState.create(
apply_fn=model.apply, #for some model
params=params, #from some model
tx=optimizer, # optax chained optimizer.
)
#since we are doing SPMD, replicate the state, rng, and metrics
rng_key = jax.random.PRNGKey(seed)
replicated_state = jax_utils.replicate(state)
replicated_metrics = jax_utils.replicate(
dict(loss=jnp.array(0, jnp.int32), acc=jnp.array(0, jnp.int32))
)
replicated_rng = jax_utils.replicate(rng_key)
#define our train_step and do pmap along the batch axis
@functools.partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(4, 5), backend='gpu')
def train_step(
replicated_state,
replicated_metrics,
replicated_rng,
batch,
num_classes,
train=True,
):
def loss_fn(params, rng_key, images, labels):
labels_onehot = jax.nn.one_hot(labels, num_classes=num_classes)
rng, dropout_apply_rng = random.split(rng_key)
logits = replicated_state.apply_fn(
{"params": params},
images,
train=train,
rngs={"dropout": dropout_apply_rng},
)
loss = optax.softmax_cross_entropy(logits, labels_onehot).mean()
return loss, logits
imgs = batch["image"]
labels = batch["label"]
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
# Get loss, gradients for loss, and other outputs of loss function
(loss, logits), grads = grad_fn(
replicated_state.params, replicated_rng, imgs, labels
)
probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name="batch")
gathered_loss = jax.lax.pmean(loss, axis_name="batch")
accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
# Update parameters and batch statistics
replicated_state = replicated_state.apply_gradients(grads=grads)
replicated_metrics["loss"] = gathered_loss
replicated_metrics["acc"] = accuracy
return replicated_state, replicated_metrics
# do pad_shard_unshard
train_step = jax_utils.pad_shard_unpad(
train_step, static_argnums=(0, 1, 2, 4, 5), static_return=True
)``` |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
We normally make sure the state is synced by averaging the gradients before applying them in the optimizer:
|
Beta Was this translation helpful? Give feedback.
We normally make sure the state is synced by averaging the gradients before applying them in the optimizer: