@@ -59,7 +59,8 @@ def worker_fn():
59
59
device = get_world_group ().device )
60
60
tensor = torch .ones (16 , 1024 , 1024 ,
61
61
dtype = torch .float32 ).cuda (pynccl_comm .rank )
62
- tensor = pynccl_comm .all_reduce (tensor )
62
+ with pynccl_comm .change_state (enable = True ):
63
+ tensor = pynccl_comm .all_reduce (tensor )
63
64
torch .cuda .synchronize ()
64
65
assert torch .all (tensor == pynccl_comm .world_size ).cpu ().item ()
65
66
@@ -80,16 +81,17 @@ def multiple_allreduce_worker_fn():
80
81
group = groups [0 ] if torch .distributed .get_rank () in [0 , 1 ] else groups [1 ]
81
82
pynccl_comm = PyNcclCommunicator (group = group , device = device )
82
83
tensor = torch .ones (16 , 1024 , 1024 , dtype = torch .float32 , device = device )
83
- # two groups can communicate independently
84
- if torch .distributed .get_rank () in [0 , 1 ]:
85
- tensor = pynccl_comm .all_reduce (tensor )
86
- tensor = pynccl_comm .all_reduce (tensor )
87
- torch .cuda .synchronize ()
88
- assert torch .all (tensor == 4 ).cpu ().item ()
89
- else :
90
- tensor = pynccl_comm .all_reduce (tensor )
91
- torch .cuda .synchronize ()
92
- assert torch .all (tensor == 2 ).cpu ().item ()
84
+ with pynccl_comm .change_state (enable = True ):
85
+ # two groups can communicate independently
86
+ if torch .distributed .get_rank () in [0 , 1 ]:
87
+ tensor = pynccl_comm .all_reduce (tensor )
88
+ tensor = pynccl_comm .all_reduce (tensor )
89
+ torch .cuda .synchronize ()
90
+ assert torch .all (tensor == 4 ).cpu ().item ()
91
+ else :
92
+ tensor = pynccl_comm .all_reduce (tensor )
93
+ torch .cuda .synchronize ()
94
+ assert torch .all (tensor == 2 ).cpu ().item ()
93
95
94
96
95
97
@pytest .mark .skipif (torch .cuda .device_count () < 4 ,
@@ -135,7 +137,9 @@ def worker_fn_with_cudagraph():
135
137
# run something in the default stream to initialize torch engine
136
138
a = torch .ones ((4 , 4 ), device = f'cuda:{ pynccl_comm .rank } ' )
137
139
torch .cuda .synchronize ()
138
- with torch .cuda .graph (graph ):
140
+ with torch .cuda .graph (
141
+ graph , stream = pynccl_comm .stream ), pynccl_comm .change_state (
142
+ enable = True ):
139
143
a_out = pynccl_comm .all_reduce (a )
140
144
torch .cuda .synchronize ()
141
145
graph .replay ()
@@ -164,7 +168,8 @@ def all_gather_worker_fn():
164
168
for r in range (world_size )
165
169
]).to (device )
166
170
167
- pynccl_comm .all_gather (result , tensor )
171
+ with pynccl_comm .change_state (enable = True ):
172
+ pynccl_comm .all_gather (result , tensor )
168
173
torch .cuda .synchronize ()
169
174
torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-8 )
170
175
@@ -201,7 +206,8 @@ def reduce_scatter_worker_fn():
201
206
expected = sum (tensor [rank * scattered_size :(rank + 1 ) * scattered_size ]
202
207
for tensor in all_tensors ).to (device )
203
208
204
- pynccl_comm .reduce_scatter (result , tensor )
209
+ with pynccl_comm .change_state (enable = True ):
210
+ pynccl_comm .reduce_scatter (result , tensor )
205
211
torch .cuda .synchronize ()
206
212
torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-8 )
207
213
@@ -228,13 +234,15 @@ def send_recv_worker_fn():
228
234
else :
229
235
tensor = torch .empty (16 , 1024 , 1024 ,
230
236
dtype = torch .float32 ).cuda (pynccl_comm .rank )
231
-
232
- if pynccl_comm .rank == 0 :
233
- pynccl_comm .send (tensor ,
234
- dst = (pynccl_comm .rank + 1 ) % pynccl_comm .world_size )
235
- else :
236
- pynccl_comm .recv (tensor ,
237
- src = (pynccl_comm .rank - 1 ) % pynccl_comm .world_size )
237
+ with pynccl_comm .change_state (enable = True ):
238
+ if pynccl_comm .rank == 0 :
239
+ pynccl_comm .send (tensor ,
240
+ dst = (pynccl_comm .rank + 1 ) %
241
+ pynccl_comm .world_size )
242
+ else :
243
+ pynccl_comm .recv (tensor ,
244
+ src = (pynccl_comm .rank - 1 ) %
245
+ pynccl_comm .world_size )
238
246
torch .cuda .synchronize ()
239
247
assert torch .all (tensor == 1 ).cpu ().item ()
240
248
@@ -265,12 +273,15 @@ def multiple_send_recv_worker_fn():
265
273
1024 ,
266
274
dtype = torch .float32 ,
267
275
device = device )
268
- if torch .distributed .get_rank () in [0 , 1 ]:
269
- pynccl_comm .send (tensor ,
270
- dst = (pynccl_comm .rank + 1 ) % pynccl_comm .world_size )
271
- else :
272
- pynccl_comm .recv (tensor ,
273
- src = (pynccl_comm .rank - 1 ) % pynccl_comm .world_size )
276
+ with pynccl_comm .change_state (enable = True ):
277
+ if torch .distributed .get_rank () in [0 , 1 ]:
278
+ pynccl_comm .send (tensor ,
279
+ dst = (pynccl_comm .rank + 1 ) %
280
+ pynccl_comm .world_size )
281
+ else :
282
+ pynccl_comm .recv (tensor ,
283
+ src = (pynccl_comm .rank - 1 ) %
284
+ pynccl_comm .world_size )
274
285
torch .cuda .synchronize ()
275
286
if torch .distributed .get_rank () in [0 , 2 ]:
276
287
assert torch .all (tensor == 1 ).cpu ().item ()
0 commit comments