2020#include " tensorrt_llm/common/cudaDriverWrapper.h"
2121#include " tensorrt_llm/common/cudaUtils.h"
2222#include " tensorrt_llm/common/logger.h"
23- # include " tensorrt_llm/runtime/utils/mpiUtils.h "
23+
2424#include < cstddef>
2525#include < cstdint>
2626#include < cuda_runtime_api.h>
@@ -38,7 +38,7 @@ inline size_t roundUp(size_t val, size_t gran)
3838} // namespace
3939
4040McastDeviceMemory::McastDeviceMemory (
41- size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink)
41+ size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink)
4242 : mIsMNNvlink (mnNvlink)
4343 , mDeviceIdx (deviceIdx)
4444 , mGroupSize (groupSize)
@@ -48,6 +48,7 @@ McastDeviceMemory::McastDeviceMemory(
4848 , mAllocationSize (0 )
4949 , mMcPtr (0 )
5050 , mMcHandle (0 )
51+ , mGroupComm (tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank ))
5152{
5253
5354 TLLM_CUDA_CHECK (cudaSetDevice (mDeviceIdx ));
@@ -62,9 +63,12 @@ McastDeviceMemory::McastDeviceMemory(
6263 // From pytorch implementation for alignment
6364 constexpr size_t kSignalPadAlignment = 16UL ;
6465 mSignalPadOffset = roundUp (mBufSize , kSignalPadAlignment );
66+ int const world_rank{tensorrt_llm::mpi::MpiComm::session ().getRank ()};
67+
6568 TLLM_LOG_DEBUG (
66- " [McastDeviceMemory] Rank: %u, Group size: %u, isMultiNode: %d, device_idx: %d, Signal pad offset: %zu" ,
67- mGroupRank , mGroupSize , mIsMNNvlink , mDeviceIdx , mSignalPadOffset );
69+ " [McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, GroupSplitColor: %u, isMultiNode: %d, "
70+ " device_idx: %d, Signal pad offset: %zu" ,
71+ world_rank, mGroupRank , mGroupSize , splitColor, mIsMNNvlink , mDeviceIdx , mSignalPadOffset );
6872
6973 if (mIsMNNvlink )
7074 {
@@ -127,9 +131,6 @@ McastDeviceMemory::~McastDeviceMemory()
127131
128132void McastDeviceMemory::allocMnMcastMem (size_t bufSize)
129133{
130-
131- auto const & mpi_comm = tensorrt_llm::mpi::MpiComm::session ();
132-
133134 CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC;
134135 CUmemAllocationProp prop = {};
135136 prop.requestedHandleTypes = handle_type;
@@ -156,7 +157,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
156157 // All gather
157158 cudaMallocHost (&exphndl, mGroupSize * sizeof (CUmemFabricHandle));
158159 memcpy (exphndl + mGroupRank * sizeof (CUmemFabricHandle), &myhndl, sizeof (CUmemFabricHandle));
159- mpi_comm .allgather (
160+ mGroupComm .allgather (
160161 exphndl + mGroupRank * sizeof (CUmemFabricHandle), exphndl, sizeof (CUmemFabricHandle), mpi::MpiType::kCHAR );
161162 cudaDeviceSynchronize ();
162163
@@ -175,7 +176,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
175176 TLLM_CU_CHECK (cuMemExportToShareableHandle ((void *) fabric_handle, mMcHandle , CU_MEM_HANDLE_TYPE_FABRIC, 0 ));
176177 }
177178 // Broadcast
178- mpi_comm .bcast (fabric_handle, sizeof (CUmemFabricHandle), mpi::MpiType::kCHAR , 0 );
179+ mGroupComm .bcast (fabric_handle, sizeof (CUmemFabricHandle), mpi::MpiType::kCHAR , 0 );
179180 cudaDeviceSynchronize ();
180181 if (mGroupRank != 0 )
181182 {
@@ -210,12 +211,9 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
210211
211212void McastDeviceMemory::allocNvlsMcastMem (size_t bufSize)
212213{
213- // Create a std::set to include all ranks in range (0, group_size)
214- std::set<int > ranks;
215- for (uint32_t i = 0 ; i < mGroupSize ; ++i)
216- {
217- ranks.insert (i);
218- }
214+ // Get the world ranks for ranks in this group
215+ auto ranks_ = tensorrt_llm::mpi::getWorldRanks (mGroupComm );
216+ std::set<int > ranks (ranks_.begin (), ranks_.end ());
219217 // Reuse existing implementation
220218 mNvlsHandle = tensorrt_llm::runtime::ipcNvlsAllocate (bufSize, ranks);
221219 mMcHandle = mNvlsHandle ->mc_handle ;
0 commit comments