-
Notifications
You must be signed in to change notification settings - Fork 115
Distributed Strategy Survey
PyTorch provides a module called torch.distributed
to support communication between processes in a distributed training system.
torch.distributed
module works well with torch.nn
and torch.autograd
module. The modular design is friendly to users to define their own flexible training logic.
PyTorch provides some communication API in torch.distributed module
.
- API doc: https://pytorch.org/docs/stable/distributed.html
- tutorial doc: https://pytorch.org/tutorials/intermediate/dist_tuto.html
There are two kinds of APIs:
- Point-to-Point Communication
Support both blocking and non-blocking API.
- blocking: send/recv
- non-blocking: isend/irecv
- Collective Communication
Following are some supported collective communication primitives:
- reduce
- broadcast
- all_reduce
- scatter
- gather
- all_gather
- barrier
We could use the low level APIs as basic building blocks to implement our own distributed strategies.
Example 1
Implement our own ring-allreduce API using point-to-point communication APIs: https://pytorch.org/tutorials/intermediate/dist_tuto.html#our-own-ring-allreduce
Example 2
DistributedDataParallel
high level API of PyTorch: https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html
TensorFlow 2.0 provides tf.distributed.Strategy
module, which tries to hide the communication details from the users.
Following are several strategies:
- MirroredStrategy
- TPUStrategy
- MultiWorkerMirroredStrategy
- CentralStorageStrategy
- ParameterServerStrategy
- OneDeviceStrategy
Example
@tf.function
def train_step(dist_inputs):
def step_fn(inputs):
features, labels = inputs
with tf.GradientTape() as tape:
logits = model(features)
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=labels)
loss = tf.reduce_sum(cross_entropy) * (1.0 / global_batch_size)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
return cross_entropy
per_example_losses = mirrored_strategy.experimental_run_v2(
step_fn, args=(dist_inputs,))
mean_loss = mirrored_strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)
return mean_loss
with mirrored_strategy.scope():
for inputs in dist_dataset:
print(train_step(inputs))
From the example, we could see that the communication details are hidden.
ElasticDL makes customized optimization for some models(such as a model with a big embedding table). Besides, it's a Kubernetes native framework, which addresses on fault-tolerance and elastic scheduling. So, we decide to implement the communication strategy in Python ourselves.
In TensorFlow 2.0, the tape.gradient
API will return a list of future objects.
grads = tape.gradient(loss, model.trainable_variables)
Until calling tensor.numpy()
, will we get the value of corresponding gradient tensor. The order of the grads
is the same as model.trainable_variables
. We write the following codes to realize the overlapping of backward computation and optimize communication.
grads = tape.gradient(loss, model.trainable_variables)
grads.reverse()
for grad in grads:
grad_t = grad.numpy()
send_to_ps_in_another_thread(grad_t)
Note
In the allreduce strategy, we usually train a model with GPU. Please note that grad.numpy()
will copy the gradients from GPU to CPU. We have to find another solution to avoid this time-consuming copy.