Skip to content

Commit c040f0e

Browse files
authored
Revert nccl changes (#351)
* Revert "[distributed] remove pynccl's redundant change_state (vllm-project#11749)" This reverts commit 9e764e7. * Revert "[distributed] remove pynccl's redundant stream (vllm-project#11744)" This reverts commit 635b897.
1 parent 88e020d commit c040f0e

File tree

3 files changed

+82
-36
lines changed

3 files changed

+82
-36
lines changed

tests/distributed/test_pynccl.py

+38-27
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def worker_fn():
5959
device=get_world_group().device)
6060
tensor = torch.ones(16, 1024, 1024,
6161
dtype=torch.float32).cuda(pynccl_comm.rank)
62-
tensor = pynccl_comm.all_reduce(tensor)
62+
with pynccl_comm.change_state(enable=True):
63+
tensor = pynccl_comm.all_reduce(tensor)
6364
torch.cuda.synchronize()
6465
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
6566

@@ -80,16 +81,17 @@ def multiple_allreduce_worker_fn():
8081
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
8182
pynccl_comm = PyNcclCommunicator(group=group, device=device)
8283
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
83-
# two groups can communicate independently
84-
if torch.distributed.get_rank() in [0, 1]:
85-
tensor = pynccl_comm.all_reduce(tensor)
86-
tensor = pynccl_comm.all_reduce(tensor)
87-
torch.cuda.synchronize()
88-
assert torch.all(tensor == 4).cpu().item()
89-
else:
90-
tensor = pynccl_comm.all_reduce(tensor)
91-
torch.cuda.synchronize()
92-
assert torch.all(tensor == 2).cpu().item()
84+
with pynccl_comm.change_state(enable=True):
85+
# two groups can communicate independently
86+
if torch.distributed.get_rank() in [0, 1]:
87+
tensor = pynccl_comm.all_reduce(tensor)
88+
tensor = pynccl_comm.all_reduce(tensor)
89+
torch.cuda.synchronize()
90+
assert torch.all(tensor == 4).cpu().item()
91+
else:
92+
tensor = pynccl_comm.all_reduce(tensor)
93+
torch.cuda.synchronize()
94+
assert torch.all(tensor == 2).cpu().item()
9395

9496

9597
@pytest.mark.skipif(torch.cuda.device_count() < 4,
@@ -135,7 +137,9 @@ def worker_fn_with_cudagraph():
135137
# run something in the default stream to initialize torch engine
136138
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
137139
torch.cuda.synchronize()
138-
with torch.cuda.graph(graph):
140+
with torch.cuda.graph(
141+
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
142+
enable=True):
139143
a_out = pynccl_comm.all_reduce(a)
140144
torch.cuda.synchronize()
141145
graph.replay()
@@ -164,7 +168,8 @@ def all_gather_worker_fn():
164168
for r in range(world_size)
165169
]).to(device)
166170

167-
pynccl_comm.all_gather(result, tensor)
171+
with pynccl_comm.change_state(enable=True):
172+
pynccl_comm.all_gather(result, tensor)
168173
torch.cuda.synchronize()
169174
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
170175

@@ -201,7 +206,8 @@ def reduce_scatter_worker_fn():
201206
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
202207
for tensor in all_tensors).to(device)
203208

204-
pynccl_comm.reduce_scatter(result, tensor)
209+
with pynccl_comm.change_state(enable=True):
210+
pynccl_comm.reduce_scatter(result, tensor)
205211
torch.cuda.synchronize()
206212
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
207213

@@ -228,13 +234,15 @@ def send_recv_worker_fn():
228234
else:
229235
tensor = torch.empty(16, 1024, 1024,
230236
dtype=torch.float32).cuda(pynccl_comm.rank)
231-
232-
if pynccl_comm.rank == 0:
233-
pynccl_comm.send(tensor,
234-
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
235-
else:
236-
pynccl_comm.recv(tensor,
237-
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
237+
with pynccl_comm.change_state(enable=True):
238+
if pynccl_comm.rank == 0:
239+
pynccl_comm.send(tensor,
240+
dst=(pynccl_comm.rank + 1) %
241+
pynccl_comm.world_size)
242+
else:
243+
pynccl_comm.recv(tensor,
244+
src=(pynccl_comm.rank - 1) %
245+
pynccl_comm.world_size)
238246
torch.cuda.synchronize()
239247
assert torch.all(tensor == 1).cpu().item()
240248

@@ -265,12 +273,15 @@ def multiple_send_recv_worker_fn():
265273
1024,
266274
dtype=torch.float32,
267275
device=device)
268-
if torch.distributed.get_rank() in [0, 1]:
269-
pynccl_comm.send(tensor,
270-
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
271-
else:
272-
pynccl_comm.recv(tensor,
273-
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
276+
with pynccl_comm.change_state(enable=True):
277+
if torch.distributed.get_rank() in [0, 1]:
278+
pynccl_comm.send(tensor,
279+
dst=(pynccl_comm.rank + 1) %
280+
pynccl_comm.world_size)
281+
else:
282+
pynccl_comm.recv(tensor,
283+
src=(pynccl_comm.rank - 1) %
284+
pynccl_comm.world_size)
274285
torch.cuda.synchronize()
275286
if torch.distributed.get_rank() in [0, 2]:
276287
assert torch.all(tensor == 1).cpu().item()

vllm/distributed/device_communicators/pynccl.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import contextmanager
12
from typing import Optional, Union
23

34
# ===================== import region =====================
@@ -50,6 +51,7 @@ def __init__(
5051
if self.world_size == 1:
5152
self.available = False
5253
self.disabled = True
54+
self.stream = None
5355
return
5456
try:
5557
self.nccl = NCCLLibrary(library_path)
@@ -58,6 +60,7 @@ def __init__(
5860
# e.g. in a non-GPU environment
5961
self.available = False
6062
self.disabled = True
63+
self.stream = None
6164
return
6265

6366
self.available = True
@@ -95,12 +98,12 @@ def __init__(
9598
with torch.cuda.device(device):
9699
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
97100
self.world_size, self.unique_id, self.rank)
101+
self.stream = torch.cuda.Stream()
98102

99-
stream = torch.cuda.current_stream()
100103
# A small all_reduce for warmup.
101104
data = torch.zeros(1, device=device)
102105
self.all_reduce(data)
103-
stream.synchronize()
106+
self.stream.synchronize()
104107
del data
105108

106109
def all_reduce(self,
@@ -119,7 +122,7 @@ def all_reduce(self,
119122
out_tensor = torch.empty_like(in_tensor)
120123

121124
if stream is None:
122-
stream = torch.cuda.current_stream()
125+
stream = self.stream
123126
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
124127
buffer_type(out_tensor.data_ptr()),
125128
in_tensor.numel(),
@@ -141,7 +144,7 @@ def all_gather(self,
141144
f"this nccl communicator is created to work on {self.device}, "
142145
f"but the input tensor is on {input_tensor.device}")
143146
if stream is None:
144-
stream = torch.cuda.current_stream()
147+
stream = self.stream
145148
self.nccl.ncclAllGather(
146149
buffer_type(input_tensor.data_ptr()),
147150
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
@@ -162,7 +165,7 @@ def reduce_scatter(self,
162165
f"this nccl communicator is created to work on {self.device}, "
163166
f"but the input tensor is on {input_tensor.device}")
164167
if stream is None:
165-
stream = torch.cuda.current_stream()
168+
stream = self.stream
166169
self.nccl.ncclReduceScatter(
167170
buffer_type(input_tensor.data_ptr()),
168171
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
@@ -177,7 +180,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None):
177180
f"this nccl communicator is created to work on {self.device}, "
178181
f"but the input tensor is on {tensor.device}")
179182
if stream is None:
180-
stream = torch.cuda.current_stream()
183+
stream = self.stream
181184
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
182185
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
183186
self.comm, cudaStream_t(stream.cuda_stream))
@@ -189,7 +192,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
189192
f"this nccl communicator is created to work on {self.device}, "
190193
f"but the input tensor is on {tensor.device}")
191194
if stream is None:
192-
stream = torch.cuda.current_stream()
195+
stream = self.stream
193196
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
194197
ncclDataTypeEnum.from_torch(tensor.dtype), src,
195198
self.comm, cudaStream_t(stream.cuda_stream))
@@ -201,7 +204,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
201204
f"this nccl communicator is created to work on {self.device}, "
202205
f"but the input tensor is on {tensor.device}")
203206
if stream is None:
204-
stream = torch.cuda.current_stream()
207+
stream = self.stream
205208
if src == self.rank:
206209
sendbuff = buffer_type(tensor.data_ptr())
207210
# NCCL requires the sender also to have a receive buffer
@@ -212,3 +215,27 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
212215
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
213216
ncclDataTypeEnum.from_torch(tensor.dtype), src,
214217
self.comm, cudaStream_t(stream.cuda_stream))
218+
219+
@contextmanager
220+
def change_state(self,
221+
enable: Optional[bool] = None,
222+
stream: Optional[torch.cuda.Stream] = None):
223+
"""
224+
A context manager to change the state of the communicator.
225+
"""
226+
if enable is None:
227+
# guess a default value when not specified
228+
enable = self.available
229+
230+
if stream is None:
231+
stream = self.stream
232+
233+
old_disable = self.disabled
234+
old_stream = self.stream
235+
236+
self.stream = stream
237+
self.disabled = not enable
238+
yield
239+
240+
self.disabled = old_disable
241+
self.stream = old_stream

vllm/distributed/parallel_state.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,15 @@ def graph_capture(
305305
stream.wait_stream(curr_stream)
306306

307307
with torch.cuda.stream(stream), maybe_ca_context:
308-
yield graph_capture_context
308+
pynccl_comm = self.pynccl_comm
309+
maybe_pynccl_context: Any
310+
if not pynccl_comm:
311+
maybe_pynccl_context = nullcontext()
312+
else:
313+
maybe_pynccl_context = pynccl_comm.change_state(
314+
stream=torch.cuda.current_stream())
315+
with maybe_pynccl_context:
316+
yield graph_capture_context
309317

310318
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
311319
"""

0 commit comments

Comments
 (0)