|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import torch |
| 8 | + |
| 9 | +# pyre-ignore[21]: Could not find a module corresponding to import `triton` |
| 10 | +import triton |
| 11 | +from torch import cuda |
| 12 | +from torch.distributed.distributed_c10d import ( |
| 13 | + AllgatherOptions, |
| 14 | + AllreduceOptions, |
| 15 | + AllToAllOptions, |
| 16 | + ReduceOp, |
| 17 | +) |
| 18 | +from torch.futures import Future |
| 19 | + |
| 20 | +from torchft.process_group import ProcessGroup |
| 21 | +from torchft.quantization import ( |
| 22 | + fused_dequantize_from_fp8, |
| 23 | + fused_quantize_into_fp8, |
| 24 | + fused_reduce_fp8, |
| 25 | +) |
| 26 | + |
| 27 | + |
| 28 | +def _to_alltoall_options(opts: AllreduceOptions) -> AllToAllOptions: |
| 29 | + alltoall_opts = AllToAllOptions() |
| 30 | + alltoall_opts.timeout = opts.timeout |
| 31 | + return alltoall_opts |
| 32 | + |
| 33 | + |
| 34 | +def _to_allgather_options(opts: AllreduceOptions) -> AllgatherOptions: |
| 35 | + allgather_opts = AllgatherOptions() |
| 36 | + allgather_opts.timeout = opts.timeout |
| 37 | + return allgather_opts |
| 38 | + |
| 39 | + |
| 40 | +def allreduce_quantized( |
| 41 | + tensors: list[torch.Tensor], |
| 42 | + opts: AllreduceOptions | ReduceOp, |
| 43 | + process_group: ProcessGroup, |
| 44 | + sync_stream: cuda.Stream | None = None, |
| 45 | +) -> Future[None]: |
| 46 | + """ |
| 47 | + Performs a quantized all-reduce operation on a list of tensors. |
| 48 | +
|
| 49 | + This function implements an optimized all-reduce that reduces communication |
| 50 | + overhead by quantizing tensors to FP8 format before sending them over the |
| 51 | + network. The algorithm works as follows: |
| 52 | +
|
| 53 | + 1. Quantize input tensors to FP8 format |
| 54 | + 2. Distribute chunks of quantized tensors to all ranks using all-to-all |
| 55 | + 3. Reduce chunks locally in higher precision after dequantization |
| 56 | + 4. Collect reduced chunks from all ranks using all-gather |
| 57 | + 5. Dequantize the result back to the original precision |
| 58 | +
|
| 59 | + This implementation only supports the AVG reduce operation. |
| 60 | +
|
| 61 | + Args: |
| 62 | + tensors: List of tensors to be reduced. All tensors must be on the same |
| 63 | + CUDA device and have the same dtype. |
| 64 | + opts: Options for the all-reduce operation. Can be either an |
| 65 | + AllreduceOptions object or a ReduceOp enum. If a ReduceOp is |
| 66 | + provided, it must be ReduceOp.AVG. |
| 67 | + process_group: The process group to perform the all-reduce on. |
| 68 | + sync_stream: Optional CUDA stream to use for synchronization. If None, |
| 69 | + a new stream will be created. |
| 70 | +
|
| 71 | + Returns: |
| 72 | + A Future that can be used to wait for the operation to complete and |
| 73 | + clean up intermediate buffers. |
| 74 | +
|
| 75 | + Raises: |
| 76 | + NotImplementedError: If the reduce operation is not ReduceOp.AVG. |
| 77 | + """ |
| 78 | + if isinstance(opts, ReduceOp): |
| 79 | + allreduce_opts = AllreduceOptions() |
| 80 | + allreduce_opts.reduceOp = opts |
| 81 | + else: |
| 82 | + allreduce_opts = opts |
| 83 | + |
| 84 | + # Check if the reduceOp is AVG, as only AVG is supported |
| 85 | + if allreduce_opts.reduceOp != ReduceOp.AVG: |
| 86 | + raise NotImplementedError( |
| 87 | + f"ReduceOp {allreduce_opts.reduceOp} is not supported " |
| 88 | + f"for quantized allreduce, only AVG is supported" |
| 89 | + ) |
| 90 | + |
| 91 | + rank = process_group.rank() |
| 92 | + world_size = process_group.size() |
| 93 | + |
| 94 | + if sync_stream is None: |
| 95 | + sync_stream = cuda.Stream() |
| 96 | + |
| 97 | + assert sync_stream is not None |
| 98 | + # Ensure that all operations are completed on the current stream |
| 99 | + # before proceeding with all-reduce |
| 100 | + sync_stream.wait_stream(cuda.current_stream()) |
| 101 | + with cuda.stream(sync_stream): |
| 102 | + # Quantize tensoers and compute their scales, all inlined in the |
| 103 | + # output tensor. |
| 104 | + quantized_tensors = fused_quantize_into_fp8(tensors, world_size) |
| 105 | + |
| 106 | + # Allocate output tensor where all-reduce results will be stored |
| 107 | + quantized_tensors_out = torch.zeros_like(quantized_tensors) |
| 108 | + # Collect chunks and their scales from other ranks |
| 109 | + process_group.alltoall_base( |
| 110 | + quantized_tensors_out.view(world_size, -1), |
| 111 | + quantized_tensors.view(world_size, -1), |
| 112 | + [], |
| 113 | + [], |
| 114 | + _to_alltoall_options(allreduce_opts), |
| 115 | + ).wait() |
| 116 | + |
| 117 | + # Reduce chunks locally in higher precision after dequantization. |
| 118 | + # The output is again quantized. |
| 119 | + fused_reduce_fp8( |
| 120 | + tensors, |
| 121 | + quantized_tensors_out, |
| 122 | + world_size, |
| 123 | + rank, |
| 124 | + ) |
| 125 | + |
| 126 | + # Collect reduced chunks from other ranks. |
| 127 | + process_group.allgather_into_tensor_coalesced( |
| 128 | + [quantized_tensors.view(world_size, -1)], |
| 129 | + [torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]], |
| 130 | + _to_allgather_options(allreduce_opts), |
| 131 | + ).wait() |
| 132 | + |
| 133 | + # Dequantize and copy to output buffer. |
| 134 | + fused_dequantize_from_fp8(tensors, quantized_tensors, world_size) |
| 135 | + |
| 136 | + class QuantizedAllReduceFuture(Future[None]): |
| 137 | + def __init__( |
| 138 | + self, |
| 139 | + sync_stream: cuda.Stream, |
| 140 | + quantized_tensors: torch.Tensor, |
| 141 | + quantized_tensors_out: torch.Tensor, |
| 142 | + ) -> None: |
| 143 | + super().__init__() |
| 144 | + self._sync_stream = sync_stream |
| 145 | + self._quantized_tensors = quantized_tensors |
| 146 | + self._quantized_tensors_out = quantized_tensors_out |
| 147 | + |
| 148 | + def wait(self) -> None: |
| 149 | + # Wait for the synchronization to complete. |
| 150 | + cuda.current_stream().wait_stream(self._sync_stream) |
| 151 | + # Clean up intermediate buffers. |
| 152 | + del self._quantized_tensors_out |
| 153 | + del self._quantized_tensors |
| 154 | + |
| 155 | + # pyre-ignore[29] |
| 156 | + return QuantizedAllReduceFuture( |
| 157 | + sync_stream, quantized_tensors, quantized_tensors_out |
| 158 | + ) |
0 commit comments