Skip to content

Commit c4f0e72

Browse files
authored
Introduce custom all-reduce that communicates in FP8 precision (#201)
* Introduce custom all-reduce that communicates in FP8 precision and reduces in FP32 precision * Address linter and test failures * Address linter and test failures * Address linter and test failures * Address linter and test failures * Address linter and test failures * Address pytest seg fault * Address pytest seg fault * Address review feedback
1 parent b84c5a6 commit c4f0e72

File tree

6 files changed

+404
-69
lines changed

6 files changed

+404
-69
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Issues = "https://github.com/pytorch-labs/torchft/issues"
2222

2323
[project.optional-dependencies]
2424
dev = [
25-
"pytest",
25+
"pytest==8.3.4",
2626
"pytest-timeout",
2727
"black",
2828
"pyre-check",

torchft/_test_utils.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
10+
def any_nan(ts: list[torch.Tensor]) -> bool:
11+
"""
12+
Check if any tensor in the list contains NaN values.
13+
14+
Args:
15+
ts: List of tensors to check for NaN values
16+
17+
Returns:
18+
True if any tensor contains NaN values, False otherwise
19+
"""
20+
for t in ts:
21+
if torch.isnan(t).any():
22+
return True
23+
return False
24+
25+
26+
def combine_views(
27+
views: list[list[tuple[int, ...]]],
28+
combinations: list[list[tuple[int, ...]]],
29+
tmp: list[tuple[int, ...]],
30+
i: int,
31+
) -> None:
32+
"""
33+
Recursively generate all possible combinations of views from a list of
34+
lists of views.
35+
36+
This function uses backtracking to generate all possible combinations by
37+
selecting each list in the input. The results are stored in the
38+
combinations list.
39+
40+
Args:
41+
views: A list of lists, where each inner list contains possible view
42+
shapes (tuples)
43+
combinations: Output list where all combinations will be stored
44+
tmp: Temporary list to build the current combination
45+
i: Current index in the views list being processed
46+
47+
Returns:
48+
None. Results are stored in the combinations list passed as
49+
an argument.
50+
"""
51+
if i == len(views):
52+
combinations.append(tmp.copy())
53+
return
54+
55+
for j in range(len(views[i])):
56+
tmp.append(views[i][j])
57+
combine_views(views, combinations, tmp, i + 1)
58+
tmp.pop()
59+
60+
61+
def gen_views(inp: torch.Tensor) -> list[tuple[int, ...]]:
62+
"""
63+
Generate all possible 2D views (shapes) for a tensor with a given number
64+
of elements.
65+
66+
This function finds all pairs of integers (m, n) such that m * n equals the
67+
total number of elements in the input tensor. These pairs represent possible
68+
2D shapes that the tensor can be reshaped into.
69+
70+
Args:
71+
inp: Input tensor
72+
73+
Returns:
74+
A list of tuples, where each tuple (m, n) represents a possible 2D shape
75+
such that m * n equals the total number of elements in the input tensor
76+
"""
77+
size = inp.numel()
78+
79+
views = []
80+
for m in range(1 if size % 2 == 0 else 2, size):
81+
if size % m == 0:
82+
views.append((m, size // m))
83+
84+
return views
85+
86+
87+
def gen_splits(inp: torch.Tensor, split_size: int) -> list[list[tuple[int, ...]]]:
88+
"""
89+
Split a tensor into chunks and generate all possible combinations of views.
90+
91+
This function first splits the input tensor into chunks of the specified size,
92+
then generates all possible 2D views for each chunk, and finally computes all
93+
possible combinations of these views across all chunks.
94+
95+
Args:
96+
inp: Input tensor to be split
97+
split_size: Size of each chunk
98+
99+
Returns:
100+
A list of lists, where each inner list contains a combination of view
101+
shapes, one for each chunk of the input tensor
102+
"""
103+
views = []
104+
105+
for split in torch.split(inp, split_size):
106+
views.append(gen_views(split))
107+
108+
combinations = []
109+
combine_views(views, combinations, [], 0)
110+
111+
return combinations

torchft/collectives.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
# pyre-ignore[21]: Could not find a module corresponding to import `triton`
10+
import triton
11+
from torch import cuda
12+
from torch.distributed.distributed_c10d import (
13+
AllgatherOptions,
14+
AllreduceOptions,
15+
AllToAllOptions,
16+
ReduceOp,
17+
)
18+
from torch.futures import Future
19+
20+
from torchft.process_group import ProcessGroup
21+
from torchft.quantization import (
22+
fused_dequantize_from_fp8,
23+
fused_quantize_into_fp8,
24+
fused_reduce_fp8,
25+
)
26+
27+
28+
def _to_alltoall_options(opts: AllreduceOptions) -> AllToAllOptions:
29+
alltoall_opts = AllToAllOptions()
30+
alltoall_opts.timeout = opts.timeout
31+
return alltoall_opts
32+
33+
34+
def _to_allgather_options(opts: AllreduceOptions) -> AllgatherOptions:
35+
allgather_opts = AllgatherOptions()
36+
allgather_opts.timeout = opts.timeout
37+
return allgather_opts
38+
39+
40+
def allreduce_quantized(
41+
tensors: list[torch.Tensor],
42+
opts: AllreduceOptions | ReduceOp,
43+
process_group: ProcessGroup,
44+
sync_stream: cuda.Stream | None = None,
45+
) -> Future[None]:
46+
"""
47+
Performs a quantized all-reduce operation on a list of tensors.
48+
49+
This function implements an optimized all-reduce that reduces communication
50+
overhead by quantizing tensors to FP8 format before sending them over the
51+
network. The algorithm works as follows:
52+
53+
1. Quantize input tensors to FP8 format
54+
2. Distribute chunks of quantized tensors to all ranks using all-to-all
55+
3. Reduce chunks locally in higher precision after dequantization
56+
4. Collect reduced chunks from all ranks using all-gather
57+
5. Dequantize the result back to the original precision
58+
59+
This implementation only supports the AVG reduce operation.
60+
61+
Args:
62+
tensors: List of tensors to be reduced. All tensors must be on the same
63+
CUDA device and have the same dtype.
64+
opts: Options for the all-reduce operation. Can be either an
65+
AllreduceOptions object or a ReduceOp enum. If a ReduceOp is
66+
provided, it must be ReduceOp.AVG.
67+
process_group: The process group to perform the all-reduce on.
68+
sync_stream: Optional CUDA stream to use for synchronization. If None,
69+
a new stream will be created.
70+
71+
Returns:
72+
A Future that can be used to wait for the operation to complete and
73+
clean up intermediate buffers.
74+
75+
Raises:
76+
NotImplementedError: If the reduce operation is not ReduceOp.AVG.
77+
"""
78+
if isinstance(opts, ReduceOp):
79+
allreduce_opts = AllreduceOptions()
80+
allreduce_opts.reduceOp = opts
81+
else:
82+
allreduce_opts = opts
83+
84+
# Check if the reduceOp is AVG, as only AVG is supported
85+
if allreduce_opts.reduceOp != ReduceOp.AVG:
86+
raise NotImplementedError(
87+
f"ReduceOp {allreduce_opts.reduceOp} is not supported "
88+
f"for quantized allreduce, only AVG is supported"
89+
)
90+
91+
rank = process_group.rank()
92+
world_size = process_group.size()
93+
94+
if sync_stream is None:
95+
sync_stream = cuda.Stream()
96+
97+
assert sync_stream is not None
98+
# Ensure that all operations are completed on the current stream
99+
# before proceeding with all-reduce
100+
sync_stream.wait_stream(cuda.current_stream())
101+
with cuda.stream(sync_stream):
102+
# Quantize tensoers and compute their scales, all inlined in the
103+
# output tensor.
104+
quantized_tensors = fused_quantize_into_fp8(tensors, world_size)
105+
106+
# Allocate output tensor where all-reduce results will be stored
107+
quantized_tensors_out = torch.zeros_like(quantized_tensors)
108+
# Collect chunks and their scales from other ranks
109+
process_group.alltoall_base(
110+
quantized_tensors_out.view(world_size, -1),
111+
quantized_tensors.view(world_size, -1),
112+
[],
113+
[],
114+
_to_alltoall_options(allreduce_opts),
115+
).wait()
116+
117+
# Reduce chunks locally in higher precision after dequantization.
118+
# The output is again quantized.
119+
fused_reduce_fp8(
120+
tensors,
121+
quantized_tensors_out,
122+
world_size,
123+
rank,
124+
)
125+
126+
# Collect reduced chunks from other ranks.
127+
process_group.allgather_into_tensor_coalesced(
128+
[quantized_tensors.view(world_size, -1)],
129+
[torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]],
130+
_to_allgather_options(allreduce_opts),
131+
).wait()
132+
133+
# Dequantize and copy to output buffer.
134+
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
135+
136+
class QuantizedAllReduceFuture(Future[None]):
137+
def __init__(
138+
self,
139+
sync_stream: cuda.Stream,
140+
quantized_tensors: torch.Tensor,
141+
quantized_tensors_out: torch.Tensor,
142+
) -> None:
143+
super().__init__()
144+
self._sync_stream = sync_stream
145+
self._quantized_tensors = quantized_tensors
146+
self._quantized_tensors_out = quantized_tensors_out
147+
148+
def wait(self) -> None:
149+
# Wait for the synchronization to complete.
150+
cuda.current_stream().wait_stream(self._sync_stream)
151+
# Clean up intermediate buffers.
152+
del self._quantized_tensors_out
153+
del self._quantized_tensors
154+
155+
# pyre-ignore[29]
156+
return QuantizedAllReduceFuture(
157+
sync_stream, quantized_tensors, quantized_tensors_out
158+
)

torchft/collectives_test.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from typing import Callable
9+
from unittest import TestCase, skipUnless
10+
11+
import torch
12+
from parameterized import parameterized
13+
from torch import cuda
14+
from torch.distributed import AllreduceOptions, ReduceOp
15+
from torch.distributed.distributed_c10d import ReduceOp
16+
17+
from torchft import _test_utils
18+
from torchft.process_group import ProcessGroup
19+
from torchft.process_group_test import MultiPgBaseTest
20+
21+
try:
22+
# pyre-ignore[21]: Could not find a module corresponding to import `triton`
23+
import triton
24+
except ImportError:
25+
pass
26+
else:
27+
from torchft.collectives import allreduce_quantized
28+
29+
@skipUnless(
30+
torch.cuda.is_available() and torch.cuda.device_count() >= 2,
31+
"2 CUDA devices are required for this test",
32+
)
33+
class QuantizedAllReduceTest(MultiPgBaseTest):
34+
BACKEND = "nccl"
35+
WORLD_SIZE = 2
36+
37+
def _run_parallel_collectives(
38+
self, collective: Callable[[ProcessGroup, int, str], None]
39+
) -> None:
40+
futures = []
41+
for rank in range(self.WORLD_SIZE):
42+
pg = self.pg_pool[rank]
43+
device = f"cuda:{rank}"
44+
fut = self.executor.submit(collective, pg, rank, device)
45+
futures.append(fut)
46+
47+
self._collect(futures)
48+
49+
def _run_collective(
50+
self,
51+
pg: ProcessGroup,
52+
rank: int,
53+
device: str,
54+
tensors_num: int,
55+
tensor_size: int,
56+
multiplier: float,
57+
tolerance: float,
58+
) -> None:
59+
cuda.set_device(device)
60+
inp = (
61+
torch.rand(
62+
tensors_num * tensor_size,
63+
dtype=torch.float32,
64+
device=device,
65+
)
66+
* multiplier
67+
)
68+
for split in _test_utils.gen_splits(inp, tensor_size):
69+
actual = inp.clone()
70+
expected = inp.clone()
71+
tensors = [
72+
i.view(*s)
73+
for s, i in zip(
74+
split,
75+
torch.split(actual, tensor_size),
76+
)
77+
]
78+
79+
fut = allreduce_quantized(tensors, ReduceOp.AVG, pg)
80+
fut.wait()
81+
82+
work = pg.allreduce([expected], ReduceOp.AVG)
83+
work.get_future().wait()
84+
85+
diff = torch.abs((expected - actual).div(expected))
86+
mean_diff = diff.mean().item()
87+
88+
if mean_diff > tolerance:
89+
raise AssertionError(f"Results not within tolerance {tolerance}")
90+
91+
END_TO_END_CONFIGS: list[tuple[int, float]] = [
92+
(ts, m)
93+
for ts in [128, 512, 1024, 2048, 4096]
94+
for m in [1.0, 10.0, 100.0, 1000.0]
95+
]
96+
97+
@parameterized.expand(END_TO_END_CONFIGS)
98+
def test_collective(self, tensor_size: int, multiplier: float) -> None:
99+
self._run_parallel_collectives(
100+
lambda pg, rank, device: self._run_collective(
101+
pg,
102+
rank,
103+
device,
104+
3,
105+
tensor_size,
106+
multiplier,
107+
3.0,
108+
)
109+
)

0 commit comments

Comments
 (0)