-
Notifications
You must be signed in to change notification settings - Fork 115
Distributed Strategy Survey
PyTorch provides a module called torch.distributed
to support communication between processes in a distributed training system.
torch.distributed
module works well with torch.nn
and torch.autograd
module. The modular design is friendly to users to define their own flexible training logic.
PyTorch provides some comminication API in torch.distributed module
.
- API doc: https://pytorch.org/docs/stable/distributed.html
- tutorial doc: https://pytorch.org/tutorials/intermediate/dist_tuto.html
There are two kinds of APIs:
- Point-to-Point Communication
Support both blocking and non-blocking API.
- blocking: send/recv
- non-blocking: isend/irecv
- Collective Communication
Following are some supported collective communication primitives:
- reduce
- broadcast
- all_reduce
- scatter
- gather
- all_gather
- barrier
We could use the low level APIs as basic building blocks to implement our own distributed strageties.
Example 1
Implement an our own ring-allreduce API using point-to-point comminucation APIs: https://pytorch.org/tutorials/intermediate/dist_tuto.html#our-own-ring-allreduce
Example 2
DistributedDataParallel
high level API of PyTorch: https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html
TensorFlow 2.0 provides tf.distributed.Stragety
module, which tries to hide the communication details from the users.
Following are several startegies:
- MirroredStrategy
- TPUStrategy
- MultiWorkerMirroredStrategy
- CentralStorageStrategy
- ParameterServerStrategy
- OneDeviceStrategy
Example
@tf.function
def train_step(dist_inputs):
def step_fn(inputs):
features, labels = inputs
with tf.GradientTape() as tape:
logits = model(features)
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=labels)
loss = tf.reduce_sum(cross_entropy) * (1.0 / global_batch_size)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
return cross_entropy
per_example_losses = mirrored_strategy.experimental_run_v2(
step_fn, args=(dist_inputs,))
mean_loss = mirrored_strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)
return mean_loss
with mirrored_strategy.scope():
for inputs in dist_dataset:
print(train_step(inputs))
From the example, we could see that the communication details are hided.
ElasticDL makes customized optimization for some models(such as an model with big embedding table). Besides, it's a Kubernetes native framework, which addresses on fault-tolerance and elastic scheduling. So, we decide to implement the communication stragety in Python ourselves.