Skip to content

Commit e7dcc7b

Browse files
authored
implement collective all_to_all op (#9442)
1 parent 58592c4 commit e7dcc7b

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

test/pjrt/test_collective_ops_tpu.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,36 @@ def test_all_to_all_single(self, use_dynamo):
336336
expected.sort().values),
337337
f"Got {val}, expected {expected}")
338338

339+
@staticmethod
340+
def _all_to_all():
341+
dist.init_process_group("xla", init_method='xla://')
342+
device = torch_xla.device()
343+
world_size = xr.world_size()
344+
rank = xr.global_ordinal()
345+
346+
input_tensors = list(
347+
torch.full([world_size * 2],
348+
fill_value=rank,
349+
dtype=torch.float,
350+
device=device).chunk(world_size))
351+
output_tensors = list(
352+
torch.empty([world_size * 2], dtype=torch.float,
353+
device=device).chunk(world_size))
354+
dist.all_to_all(output_tensors, input_tensors)
355+
356+
return [t.cpu() for t in output_tensors]
357+
358+
def test_all_to_all(self):
359+
# Input on device i is ([i, i], [i, i], ...). After all_to_all,
360+
# output on every device is ([0, 0], [1, 1], ...).
361+
results = pjrt.run_multiprocess(self._all_to_all)
362+
expected = [
363+
torch.tensor([i, i], dtype=torch.float)
364+
for i in range(tpu.num_expected_global_devices())
365+
]
366+
for _, value in results.items():
367+
torch.testing.assert_close(value, expected)
368+
339369
@staticmethod
340370
def _scatter():
341371
dist.init_process_group("xla", init_method='xla://')

test/test_torch_distributed_xla_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,6 @@ def test_barrier(self):
357357

358358
@parameterized.parameters(
359359
'allreduce_coalesced',
360-
'alltoall',
361360
'gather',
362361
'recv_anysource',
363362
'monitored_barrier',

torch_xla/distributed/xla_backend.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch_xla._internal import rendezvous
77
import logging
88
import os
9-
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, ReduceOptions
9+
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, AllToAllOptions, ReduceOptions
1010

1111

1212
def _create_xla_process_group(prefix_store, rank, size, timeout):
@@ -247,8 +247,24 @@ def reduce(self, tensors: list[torch.Tensor], opts: ReduceOptions):
247247
def allreduce_coalesced(self, *args):
248248
raise NotImplementedError
249249

250-
def alltoall(self, *args):
251-
raise NotImplementedError
250+
# Called by torch.distributed.all_to_all. Call site example:
251+
# https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L4577
252+
# The difference between this and all_to_all_single is that this works
253+
# on a list of tensors while all_to_all_single works on a single tensor
254+
# and splits/concats along dimension 0.
255+
def alltoall(self, output_tensor_list: list[torch.Tensor],
256+
input_tensor_list: list[torch.Tensor], opts: AllToAllOptions):
257+
stacked_inputs = torch.stack(input_tensor_list, dim=0)
258+
split_count = len(input_tensor_list)
259+
stacked_results = xm.all_to_all(
260+
stacked_inputs,
261+
split_dimension=0,
262+
concat_dimension=0,
263+
split_count=split_count)
264+
results = torch.chunk(stacked_results, split_count, dim=0)
265+
for result, output_tensor in zip(results, output_tensor_list):
266+
output_tensor.copy_(result.squeeze(dim=0))
267+
return _ret_work(output_tensor_list)
252268

253269
# handle the nondynamo path when call torch.distributed.all_to_all_single
254270
# call from https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3996

0 commit comments

Comments
 (0)