Skip to content

Commit 641ef87

Browse files
MengqingCaoamitm02
authored andcommitted
[Platform][Dist] Make torch distributed process group extendable (vllm-project#18763)
Signed-off-by: Mengqing Cao <[email protected]> Signed-off-by: amit <[email protected]>
1 parent 882f8d8 commit 641ef87

File tree

4 files changed

+134
-37
lines changed

4 files changed

+134
-37
lines changed

vllm/distributed/utils.py

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
66
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
77
import dataclasses
8-
import datetime
98
import os
109
import pickle
1110
import socket
@@ -14,14 +13,14 @@
1413
import uuid
1514
from collections import deque
1615
from collections.abc import Sequence
16+
from datetime import timedelta
1717
from typing import Any, Optional
1818

1919
import torch
2020
from torch.distributed import ProcessGroup, TCPStore
2121
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
2222
_get_default_timeout,
23-
_unregister_process_group,
24-
is_nccl_available)
23+
_unregister_process_group)
2524
from torch.distributed.rendezvous import rendezvous
2625

2726
import vllm.envs as envs
@@ -406,7 +405,7 @@ def create(
406405
port=port,
407406
world_size=world_size,
408407
is_master=launch_server,
409-
timeout=datetime.timedelta(seconds=store_timeout),
408+
timeout=timedelta(seconds=store_timeout),
410409
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
411410
master_listen_fd=listen_fd,
412411
)
@@ -419,6 +418,43 @@ def create(
419418
data_expiration_seconds=data_expiration_seconds)
420419

421420

421+
def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore,
422+
group_rank: int, group_size: int,
423+
timeout: timedelta) -> ProcessGroup:
424+
"""
425+
Stateless init ProcessGroup with gloo backend compatible with
426+
different torch versions.
427+
"""
428+
if is_torch_equal_or_newer("2.6"):
429+
pg = ProcessGroup(
430+
prefix_store,
431+
group_rank,
432+
group_size,
433+
)
434+
else:
435+
options = ProcessGroup.Options(backend=backend)
436+
pg = ProcessGroup(
437+
prefix_store,
438+
group_rank,
439+
group_size,
440+
options,
441+
)
442+
from torch.distributed.distributed_c10d import ProcessGroupGloo
443+
backend_class = ProcessGroupGloo(prefix_store,
444+
group_rank,
445+
group_size,
446+
timeout=timeout)
447+
backend_type = ProcessGroup.BackendType.GLOO
448+
device = torch.device("cpu")
449+
if is_torch_equal_or_newer("2.6"):
450+
# _set_default_backend is supported in torch >= 2.6
451+
pg._set_default_backend(backend_type)
452+
backend_class._set_sequence_number_for_group()
453+
454+
pg._register_backend(device, backend_type, backend_class)
455+
return pg
456+
457+
422458
def stateless_init_torch_distributed_process_group(
423459
host: str, port: int, rank: int, world_size: int,
424460
backend: str) -> ProcessGroup:
@@ -468,40 +504,19 @@ def stateless_init_torch_distributed_process_group(
468504
# different systems (e.g. RPC) in case the store is multi-tenant.
469505
prefix_store = PrefixStore(init_method, store)
470506

471-
pg: ProcessGroup = ProcessGroup(
472-
prefix_store,
473-
group_rank,
474-
group_size,
475-
)
476-
477507
if backend == "gloo":
478-
from torch.distributed.distributed_c10d import ProcessGroupGloo
479-
backend_class = ProcessGroupGloo(prefix_store,
480-
group_rank,
481-
group_size,
482-
timeout=timeout)
483-
backend_type = ProcessGroup.BackendType.GLOO
484-
device = torch.device("cpu")
485-
elif backend == "nccl":
486-
assert is_nccl_available()
487-
from torch.distributed.distributed_c10d import ProcessGroupNCCL
488-
489-
backend_options = ProcessGroupNCCL.Options()
490-
backend_options._timeout = timeout
491-
492-
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
493-
backend_options)
494-
backend_type = ProcessGroup.BackendType.NCCL
495-
device = torch.device("cuda")
496-
else:
497-
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
498-
499-
pg._set_default_backend(backend_type)
500-
backend_class._set_sequence_number_for_group()
501-
502-
pg._register_backend(device, backend_type, backend_class)
503-
504-
return pg
508+
return init_gloo_process_group(backend=backend,
509+
prefix_store=prefix_store,
510+
group_rank=group_rank,
511+
group_size=group_size,
512+
timeout=timeout)
513+
from vllm.platforms import current_platform
514+
return current_platform.stateless_init_device_torch_dist_pg(
515+
backend=backend,
516+
prefix_store=prefix_store,
517+
group_rank=group_rank,
518+
group_size=group_size,
519+
timeout=timeout)
505520

506521

507522
def stateless_destroy_torch_distributed_process_group(

vllm/platforms/cuda.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
"""
55

66
import os
7+
from datetime import timedelta
78
from functools import wraps
89
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
910

1011
import torch
12+
from torch.distributed import PrefixStore, ProcessGroup
13+
from torch.distributed.distributed_c10d import is_nccl_available
1114
from typing_extensions import ParamSpec
1215

1316
# import custom ops, trigger op registration
@@ -316,6 +319,36 @@ def use_custom_allreduce(cls) -> bool:
316319
def get_piecewise_backend_cls(cls) -> str:
317320
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
318321

322+
@classmethod
323+
def stateless_init_device_torch_dist_pg(
324+
cls,
325+
backend: str,
326+
prefix_store: PrefixStore,
327+
group_rank: int,
328+
group_size: int,
329+
timeout: timedelta,
330+
) -> ProcessGroup:
331+
assert is_nccl_available()
332+
pg: ProcessGroup = ProcessGroup(
333+
prefix_store,
334+
group_rank,
335+
group_size,
336+
)
337+
from torch.distributed.distributed_c10d import ProcessGroupNCCL
338+
339+
backend_options = ProcessGroupNCCL.Options()
340+
backend_options._timeout = timeout
341+
342+
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
343+
backend_options)
344+
backend_type = ProcessGroup.BackendType.NCCL
345+
device = torch.device("cuda")
346+
pg._set_default_backend(backend_type)
347+
backend_class._set_sequence_number_for_group()
348+
349+
pg._register_backend(device, backend_type, backend_class)
350+
return pg
351+
319352

320353
# NVML utils
321354
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

vllm/platforms/interface.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import os
44
import platform
55
import random
6+
from datetime import timedelta
67
from platform import uname
78
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
89

910
import numpy as np
1011
import torch
12+
from torch.distributed import PrefixStore, ProcessGroup
1113

1214
from vllm.inputs import ProcessorInputs, PromptType
1315
from vllm.logger import init_logger
@@ -486,6 +488,20 @@ def get_piecewise_backend_cls(cls) -> str:
486488
"""
487489
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa
488490

491+
@classmethod
492+
def stateless_init_device_torch_dist_pg(
493+
cls,
494+
backend: str,
495+
prefix_store: PrefixStore,
496+
group_rank: int,
497+
group_size: int,
498+
timeout: timedelta,
499+
) -> ProcessGroup:
500+
"""
501+
Init platform-specific torch distributed process group.
502+
"""
503+
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
504+
489505

490506
class UnspecifiedPlatform(Platform):
491507
_enum = PlatformEnum.UNSPECIFIED

vllm/platforms/rocm.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import os
4+
from datetime import timedelta
45
from functools import cache, lru_cache, wraps
56
from typing import TYPE_CHECKING, Optional
67

78
import torch
9+
from torch.distributed import PrefixStore, ProcessGroup
10+
from torch.distributed.distributed_c10d import is_nccl_available
811

912
import vllm.envs as envs
1013
from vllm.logger import init_logger
@@ -387,3 +390,33 @@ def is_navi(cls) -> bool:
387390
@classmethod
388391
def get_piecewise_backend_cls(cls) -> str:
389392
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
393+
394+
@classmethod
395+
def stateless_init_device_torch_dist_pg(
396+
cls,
397+
backend: str,
398+
prefix_store: PrefixStore,
399+
group_rank: int,
400+
group_size: int,
401+
timeout: timedelta,
402+
) -> ProcessGroup:
403+
assert is_nccl_available()
404+
pg: ProcessGroup = ProcessGroup(
405+
prefix_store,
406+
group_rank,
407+
group_size,
408+
)
409+
from torch.distributed.distributed_c10d import ProcessGroupNCCL
410+
411+
backend_options = ProcessGroupNCCL.Options()
412+
backend_options._timeout = timeout
413+
414+
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
415+
backend_options)
416+
backend_type = ProcessGroup.BackendType.NCCL
417+
device = torch.device("cuda")
418+
pg._set_default_backend(backend_type)
419+
backend_class._set_sequence_number_for_group()
420+
421+
pg._register_backend(device, backend_type, backend_class)
422+
return pg

0 commit comments

Comments
 (0)