Skip to content

Distributed Strategy Survey

QI JUN edited this page Nov 3, 2019 · 2 revisions

PyTorch

PyTorch provides a module called torch.distributed to support communication between processes in a distributed training system.

torch.distributed module works well with torch.nn and torch.autograd module. The modular design is friendly to users to define their own flexible training logic.

Low Level API

PyTorch provides some comminication API in torch.distributed module.

There are two kinds of APIs:

  1. Point-to-Point Communication

Support both blocking and non-blocking API.

  • blocking: send/recv
  • non-blocking: isend/irecv
  1. Collective Communication

Following are some supported collective communication primitives:

  • reduce
  • broadcast
  • all_reduce
  • scatter
  • gather
  • all_gather
  • barrier

High Level API

We could use the low level APIs as basic building blocks to implement our own distributed strageties.

Example 1

Implement an our own ring-allreduce API using point-to-point comminucation APIs: https://pytorch.org/tutorials/intermediate/dist_tuto.html#our-own-ring-allreduce

Example 2

DistributedDataParallel high level API of PyTorch: https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html

TensorFlow 2.0

TensorFlow 2.0 provides tf.distributed.Stragety module, which tries to hide the communication details from the users.

Following are several startegies:

  • MirroredStrategy
  • TPUStrategy
  • MultiWorkerMirroredStrategy
  • CentralStorageStrategy
  • ParameterServerStrategy
  • OneDeviceStrategy

Example

@tf.function
def train_step(dist_inputs):
  def step_fn(inputs):
    features, labels = inputs

    with tf.GradientTape() as tape:
      logits = model(features)
      cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
          logits=logits, labels=labels)
      loss = tf.reduce_sum(cross_entropy) * (1.0 / global_batch_size)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    return cross_entropy

  per_example_losses = mirrored_strategy.experimental_run_v2(
      step_fn, args=(dist_inputs,))
  mean_loss = mirrored_strategy.reduce(
      tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)
  return mean_loss
  
with mirrored_strategy.scope():
  for inputs in dist_dataset:
    print(train_step(inputs))

From the example, we could see that the communication details are hided.

ElasticDL

ElasticDL makes customized optimization for some models(such as an model with big embedding table). Besides, it's a Kubernetes native framework, which addresses on fault-tolerance and elastic scheduling. So, we decide to implement the communication stragety in Python ourselves.