Skip to content

Commit 9d6e87a

Browse files
authored
[None][fix] Cherry-Pick MNNVLAllreduce Fixes into release/1.1.0rc2 branch (#7487)
Signed-off-by: Shiyu Li <[email protected]>
1 parent 7776793 commit 9d6e87a

File tree

7 files changed

+52
-40
lines changed

7 files changed

+52
-40
lines changed

cpp/tensorrt_llm/nanobind/runtime/bindings.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,9 @@ void initBindings(nb::module_& m)
340340
"Reset the current virtual memory allocator and stop allocating virtual memory for CUDA allocations");
341341

342342
nb::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
343-
.def(nb::init<size_t, uint32_t, uint32_t, at::Device, bool>())
343+
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), nb::arg("buf_size"),
344+
nb::arg("group_size"), nb::arg("group_rank"), nb::arg("split_color"), nb::arg("device_idx"),
345+
nb::arg("mn_nvlink"))
344346
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer)
345347
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer);
346348

cpp/tensorrt_llm/pybind/runtime/bindings.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,9 @@ void initBindings(pybind11::module_& m)
434434
"Reset the current virtual memory allocator and stop allocating virtual memory for CUDA allocations");
435435

436436
py::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
437-
.def(py::init<size_t, uint32_t, uint32_t, at::Device, bool>())
437+
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), py::arg("buf_size"),
438+
py::arg("group_size"), py::arg("group_rank"), py::arg("split_color"), py::arg("device_idx"),
439+
py::arg("mn_nvlink"))
438440
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer)
439441
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer);
440442

cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include "tensorrt_llm/common/cudaDriverWrapper.h"
2121
#include "tensorrt_llm/common/cudaUtils.h"
2222
#include "tensorrt_llm/common/logger.h"
23-
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
23+
2424
#include <cstddef>
2525
#include <cstdint>
2626
#include <cuda_runtime_api.h>
@@ -38,7 +38,7 @@ inline size_t roundUp(size_t val, size_t gran)
3838
} // namespace
3939

4040
McastDeviceMemory::McastDeviceMemory(
41-
size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink)
41+
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink)
4242
: mIsMNNvlink(mnNvlink)
4343
, mDeviceIdx(deviceIdx)
4444
, mGroupSize(groupSize)
@@ -48,6 +48,7 @@ McastDeviceMemory::McastDeviceMemory(
4848
, mAllocationSize(0)
4949
, mMcPtr(0)
5050
, mMcHandle(0)
51+
, mGroupComm(tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank))
5152
{
5253

5354
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceIdx));
@@ -62,9 +63,12 @@ McastDeviceMemory::McastDeviceMemory(
6263
// From pytorch implementation for alignment
6364
constexpr size_t kSignalPadAlignment = 16UL;
6465
mSignalPadOffset = roundUp(mBufSize, kSignalPadAlignment);
66+
int const world_rank{tensorrt_llm::mpi::MpiComm::session().getRank()};
67+
6568
TLLM_LOG_DEBUG(
66-
"[McastDeviceMemory] Rank: %u, Group size: %u, isMultiNode: %d, device_idx: %d, Signal pad offset: %zu",
67-
mGroupRank, mGroupSize, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);
69+
"[McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, GroupSplitColor: %u, isMultiNode: %d, "
70+
"device_idx: %d, Signal pad offset: %zu",
71+
world_rank, mGroupRank, mGroupSize, splitColor, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);
6872

6973
if (mIsMNNvlink)
7074
{
@@ -127,9 +131,6 @@ McastDeviceMemory::~McastDeviceMemory()
127131

128132
void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
129133
{
130-
131-
auto const& mpi_comm = tensorrt_llm::mpi::MpiComm::session();
132-
133134
CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC;
134135
CUmemAllocationProp prop = {};
135136
prop.requestedHandleTypes = handle_type;
@@ -156,7 +157,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
156157
// All gather
157158
cudaMallocHost(&exphndl, mGroupSize * sizeof(CUmemFabricHandle));
158159
memcpy(exphndl + mGroupRank * sizeof(CUmemFabricHandle), &myhndl, sizeof(CUmemFabricHandle));
159-
mpi_comm.allgather(
160+
mGroupComm.allgather(
160161
exphndl + mGroupRank * sizeof(CUmemFabricHandle), exphndl, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR);
161162
cudaDeviceSynchronize();
162163

@@ -175,7 +176,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
175176
TLLM_CU_CHECK(cuMemExportToShareableHandle((void*) fabric_handle, mMcHandle, CU_MEM_HANDLE_TYPE_FABRIC, 0));
176177
}
177178
// Broadcast
178-
mpi_comm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0);
179+
mGroupComm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0);
179180
cudaDeviceSynchronize();
180181
if (mGroupRank != 0)
181182
{
@@ -210,12 +211,9 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
210211

211212
void McastDeviceMemory::allocNvlsMcastMem(size_t bufSize)
212213
{
213-
// Create a std::set to include all ranks in range (0, group_size)
214-
std::set<int> ranks;
215-
for (uint32_t i = 0; i < mGroupSize; ++i)
216-
{
217-
ranks.insert(i);
218-
}
214+
// Get the world ranks for ranks in this group
215+
auto ranks_ = tensorrt_llm::mpi::getWorldRanks(mGroupComm);
216+
std::set<int> ranks(ranks_.begin(), ranks_.end());
219217
// Reuse existing implementation
220218
mNvlsHandle = tensorrt_llm::runtime::ipcNvlsAllocate(bufSize, ranks);
221219
mMcHandle = mNvlsHandle->mc_handle;

cpp/tensorrt_llm/runtime/mcastDeviceMemory.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "tensorrt_llm/common/mcastDevMemUtils.h"
1919
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
20+
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
2021
#include <cstddef>
2122
#include <cstdint>
2223
#include <cuda.h>
@@ -42,7 +43,8 @@ class McastDeviceMemory
4243
McastDeviceMemory(McastDeviceMemory const&) = delete;
4344
McastDeviceMemory& operator=(McastDeviceMemory const&) = delete;
4445

45-
McastDeviceMemory(size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink);
46+
McastDeviceMemory(
47+
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink);
4648

4749
// We don't register the pointer in these two functions since we don't expect any python-level code would call
4850
// to obtain the raw pointers.
@@ -98,6 +100,8 @@ class McastDeviceMemory
98100
CUmemGenericAllocationHandle mMcHandle;
99101
std::vector<CUmemGenericAllocationHandle> mUcHandles;
100102

103+
tensorrt_llm::mpi::MpiComm mGroupComm; //!< The MPI communicator for the group
104+
101105
// Host array of pointers
102106
std::vector<CUdeviceptr> mUcPtrs;
103107
std::vector<CUdeviceptr> mSignalPads;

cpp/tensorrt_llm/runtime/mcastGPUBuffer.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ class McastGPUBuffer
3434
//! \param bufSize The total size of the buffer in bytes.
3535
//! \param groupSize The number of ranks in the communication group.
3636
//! \param groupRank The rank of the local process within the group.
37+
//! \param splitColor The color of the split for topology split.
3738
//! \param device The CUDA device for buffer allocation.
3839
//! \param mnNvlink Flag indicating if multi-node NVLink is used.
39-
McastGPUBuffer(size_t bufSize, uint32_t groupSize, uint32_t groupRank, at::Device device, bool mnNvlink)
40-
: mMcastDeviceMemory(bufSize, groupSize, groupRank, device.index(), mnNvlink)
40+
McastGPUBuffer(
41+
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, uint32_t deviceIdx, bool mnNvlink)
42+
: mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, deviceIdx, mnNvlink)
4143
, mBufSize(bufSize)
42-
, mLocalDevice(device)
44+
, mLocalDevice(at::Device(at::DeviceType::CUDA, deviceIdx))
4345
{
4446
}
4547

@@ -49,7 +51,7 @@ class McastGPUBuffer
4951
//! \param dtype The data type of the tensor elements.
5052
//! \param storageOffset The offset in elements from the start of the buffer.
5153
//! \return An ATen tensor wrapping the unicast buffer section.
52-
at::Tensor getUCBuffer(uint32_t rank, c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
54+
at::Tensor getUCBuffer(uint32_t rank, std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
5355
{
5456
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
5557
size_t const elementSize = c10::elementSize(dtype);
@@ -59,15 +61,18 @@ class McastGPUBuffer
5961
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getUnicastPtr(rank)) + storageOffset * elementSize;
6062

6163
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
62-
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
64+
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
65+
.options(options)
66+
.target_device(mLocalDevice)
67+
.make_tensor();
6368
}
6469

6570
//! \brief Returns a PyTorch tensor view of the multicast buffer portion.
6671
//! \param sizes The desired shape (dimensions) of the tensor.
6772
//! \param dtype The data type of the tensor elements.
6873
//! \param storageOffset The offset in elements from the start of the buffer.
6974
//! \return An ATen tensor wrapping the multicast buffer section.
70-
at::Tensor getMCBuffer(c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
75+
at::Tensor getMCBuffer(std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
7176
{
7277
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
7378
size_t const elementSize = c10::elementSize(dtype);
@@ -77,7 +82,10 @@ class McastGPUBuffer
7782
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getMulticastPtr()) + storageOffset * elementSize;
7883

7984
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
80-
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
85+
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
86+
.options(options)
87+
.target_device(mLocalDevice)
88+
.make_tensor();
8189
}
8290

8391
private:

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import math
32
import os
43
import platform
@@ -8,7 +7,7 @@
87
import torch
98
from torch import nn
109

11-
from tensorrt_llm._utils import mpi_barrier
10+
from tensorrt_llm._utils import mpi_comm
1211
from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer
1312
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
1413
AllReduceStrategy, MoEAllReduceParams)
@@ -17,7 +16,6 @@
1716
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
1817

1918
_thread_local = threading.local()
20-
logger = logging.getLogger(__name__)
2119

2220

2321
def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
@@ -55,11 +53,15 @@ def allocate_low_presicion_allreduce_workspace(mapping: Mapping) -> None:
5553
def get_allreduce_mnnvl_workspace(
5654
mapping: Mapping, dtype: torch.dtype
5755
) -> Tuple[McastGPUBuffer, torch.Tensor, torch.Tensor, int]:
56+
5857
if not hasattr(_thread_local,
5958
f'allreduce_mnnvl_workspaces_{mapping.pp_rank}'):
6059
setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}',
6160
{})
62-
61+
# Support topology split
62+
comm = mpi_comm().Split(
63+
int(mapping.pp_rank * mapping.cp_size + mapping.cp_rank),
64+
mapping.tp_rank)
6365
force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1"
6466

6567
allreduce_mnnvl_workspaces = getattr(
@@ -77,7 +79,9 @@ def get_allreduce_mnnvl_workspace(
7779
buffer_size_in_bytes,
7880
mapping.tp_size,
7981
mapping.tp_rank,
80-
torch.device("cuda", mapping.local_rank),
82+
# Split the communicator according to the topology
83+
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
84+
mapping.local_rank,
8185
True, # mnNvlink
8286
)
8387

@@ -87,7 +91,7 @@ def get_allreduce_mnnvl_workspace(
8791
buffer.fill_(-0.0)
8892
# CPU barrier since we assume this should not be called in cuda graph
8993
torch.cuda.synchronize()
90-
mpi_barrier()
94+
comm.Barrier()
9195

9296
# This is a buffer to maintain the state of this allreduce Op
9397
# Should have the same lifetime with self._buffer
@@ -458,12 +462,7 @@ def __init__(self,
458462
# Initialize MNNVL AllReduce if needed
459463
if self.strategy in (AllReduceStrategy.AUTO,
460464
AllReduceStrategy.MNNVL):
461-
if self.mapping.tp_size != self.mapping.world_size:
462-
logger.debug(
463-
f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} "
464-
f"!= world_size:{self.mapping.world_size}")
465-
self.mnnvl_allreduce = None
466-
elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
465+
if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
467466
try:
468467
self.mnnvl_allreduce = MNNVLAllReduce(
469468
self.mapping, dtype) if dtype else None

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -761,12 +761,11 @@ def _compute_mlp_tp_size(self, intermediate_size: int,
761761
self.mapping.tp_size,
762762
)
763763

764-
if tp > self.mapping.gpus_per_node and not self.allreduce.is_mnnvl(
765-
):
764+
if tp > self.mapping.gpus_per_node:
766765
mlp_tp_size = math.gcd(
767766
tp,
768767
self.mapping.gpus_per_node,
769-
) # Avoid costly inter-node TP when MNNVL is not supported
768+
) # Avoid costly inter-node TP
770769
else:
771770
mlp_tp_size = tp
772771
return mlp_tp_size

0 commit comments

Comments
 (0)