@@ -29,12 +29,13 @@ class Buffer:
2929
3030 num_sms : int = 20
3131
32- def __init__ (self , group : dist .ProcessGroup ,
32+ def __init__ (self , group : Optional [ dist .ProcessGroup ] ,
3333 num_nvl_bytes : int = 0 , num_rdma_bytes : int = 0 ,
3434 low_latency_mode : bool = False , num_qps_per_rank : int = 24 ,
3535 allow_nvlink_for_low_latency_mode : bool = True ,
3636 allow_mnnvl : bool = False ,
37- explicitly_destroy : bool = False ) -> None :
37+ explicitly_destroy : bool = False ,
38+ comm : Optional ["mpi4py.MPI.Comm" ] = None ) -> None :
3839 """
3940 Initialize the communication buffer.
4041
@@ -53,28 +54,40 @@ def __init__(self, group: dist.ProcessGroup,
5354 explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
5455 otherwise, the resources will be released by the destructor.
5556 Note: Releasing resources in the destructor may cause Python's exception handling process to hang.
57+ comm: the mpi4py.MPI.Comm communicator to use in case the group parameter is absent.
5658 """
5759 check_nvlink_connections (group )
5860
5961 # Initialize the CPP runtime
60- self .rank = group .rank ()
61- self .group_size = group .size ()
62- self .group = group
62+ if group is not None :
63+ self .rank = group .rank ()
64+ self .group_size = group .size ()
65+
66+ def all_gather_object (obj ):
67+ object_list = [None ] * self .group_size
68+ dist .all_gather_object (object_list , obj , group )
69+ return object_list
70+ elif comm is not None :
71+ self .rank = comm .Get_rank ()
72+ self .group_size = comm .Get_size ()
73+
74+ def all_gather_object (obj ):
75+ return comm .allgather (obj )
76+ else :
77+ raise ValueError ("Either 'group' or 'comm' must be provided." )
6378 self .num_nvl_bytes = num_nvl_bytes
6479 self .num_rdma_bytes = num_rdma_bytes
6580 self .low_latency_mode = low_latency_mode
6681 self .explicitly_destroy = explicitly_destroy
6782 self .runtime = deep_ep_cpp .Buffer (self .rank , self .group_size , num_nvl_bytes , num_rdma_bytes , low_latency_mode , explicitly_destroy )
6883
6984 # Synchronize device IDs
70- device_ids = [None , ] * self .group_size
7185 local_device_id = self .runtime .get_local_device_id ()
72- dist . all_gather_object ( device_ids , local_device_id , group )
86+ device_ids = all_gather_object ( local_device_id )
7387
7488 # Synchronize IPC handles
75- ipc_handles = [None , ] * self .group_size
7689 local_ipc_handle = self .runtime .get_local_ipc_handle ()
77- dist . all_gather_object ( ipc_handles , local_ipc_handle , group )
90+ ipc_handles = all_gather_object ( local_ipc_handle )
7891
7992 # Synchronize NVSHMEM unique IDs
8093 root_unique_id = None
@@ -100,10 +113,9 @@ def __init__(self, group: dist.ProcessGroup,
100113 os .environ ['NVSHMEM_DISABLE_MNNVL' ] = '1'
101114
102115 # Synchronize using the root ID
103- nvshmem_unique_ids = [None , ] * self .group_size
104116 if (low_latency_mode and self .rank == 0 ) or (not low_latency_mode and self .runtime .get_rdma_rank () == 0 ):
105117 root_unique_id = self .runtime .get_local_nvshmem_unique_id ()
106- dist . all_gather_object ( nvshmem_unique_ids , root_unique_id , group )
118+ nvshmem_unique_ids = all_gather_object ( root_unique_id )
107119 root_unique_id = nvshmem_unique_ids [0 if low_latency_mode else self .runtime .get_root_rdma_rank (True )]
108120
109121 # Make CPP runtime available
0 commit comments