Skip to content

Commit f0d34aa

Browse files
authored
Init buffer with mpi4py.MPI.Comm (#365)
Signed-off-by: Tailing Yuan <[email protected]>
1 parent e3908bf commit f0d34aa

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

deep_ep/buffer.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@ class Buffer:
2929

3030
num_sms: int = 20
3131

32-
def __init__(self, group: dist.ProcessGroup,
32+
def __init__(self, group: Optional[dist.ProcessGroup],
3333
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
3434
low_latency_mode: bool = False, num_qps_per_rank: int = 24,
3535
allow_nvlink_for_low_latency_mode: bool = True,
3636
allow_mnnvl: bool = False,
37-
explicitly_destroy: bool = False) -> None:
37+
explicitly_destroy: bool = False,
38+
comm: Optional["mpi4py.MPI.Comm"] = None) -> None:
3839
"""
3940
Initialize the communication buffer.
4041
@@ -53,28 +54,40 @@ def __init__(self, group: dist.ProcessGroup,
5354
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
5455
otherwise, the resources will be released by the destructor.
5556
Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
57+
comm: the mpi4py.MPI.Comm communicator to use in case the group parameter is absent.
5658
"""
5759
check_nvlink_connections(group)
5860

5961
# Initialize the CPP runtime
60-
self.rank = group.rank()
61-
self.group_size = group.size()
62-
self.group = group
62+
if group is not None:
63+
self.rank = group.rank()
64+
self.group_size = group.size()
65+
66+
def all_gather_object(obj):
67+
object_list = [None] * self.group_size
68+
dist.all_gather_object(object_list, obj, group)
69+
return object_list
70+
elif comm is not None:
71+
self.rank = comm.Get_rank()
72+
self.group_size = comm.Get_size()
73+
74+
def all_gather_object(obj):
75+
return comm.allgather(obj)
76+
else:
77+
raise ValueError("Either 'group' or 'comm' must be provided.")
6378
self.num_nvl_bytes = num_nvl_bytes
6479
self.num_rdma_bytes = num_rdma_bytes
6580
self.low_latency_mode = low_latency_mode
6681
self.explicitly_destroy = explicitly_destroy
6782
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode, explicitly_destroy)
6883

6984
# Synchronize device IDs
70-
device_ids = [None, ] * self.group_size
7185
local_device_id = self.runtime.get_local_device_id()
72-
dist.all_gather_object(device_ids, local_device_id, group)
86+
device_ids = all_gather_object(local_device_id)
7387

7488
# Synchronize IPC handles
75-
ipc_handles = [None, ] * self.group_size
7689
local_ipc_handle = self.runtime.get_local_ipc_handle()
77-
dist.all_gather_object(ipc_handles, local_ipc_handle, group)
90+
ipc_handles = all_gather_object(local_ipc_handle)
7891

7992
# Synchronize NVSHMEM unique IDs
8093
root_unique_id = None
@@ -100,10 +113,9 @@ def __init__(self, group: dist.ProcessGroup,
100113
os.environ['NVSHMEM_DISABLE_MNNVL'] = '1'
101114

102115
# Synchronize using the root ID
103-
nvshmem_unique_ids = [None, ] * self.group_size
104116
if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0):
105117
root_unique_id = self.runtime.get_local_nvshmem_unique_id()
106-
dist.all_gather_object(nvshmem_unique_ids, root_unique_id, group)
118+
nvshmem_unique_ids = all_gather_object(root_unique_id)
107119
root_unique_id = nvshmem_unique_ids[0 if low_latency_mode else self.runtime.get_root_rdma_rank(True)]
108120

109121
# Make CPP runtime available

0 commit comments

Comments
 (0)