5
5
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
6
6
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
7
7
import dataclasses
8
- import datetime
9
8
import os
10
9
import pickle
11
10
import socket
14
13
import uuid
15
14
from collections import deque
16
15
from collections .abc import Sequence
16
+ from datetime import timedelta
17
17
from typing import Any , Optional
18
18
19
19
import torch
20
20
from torch .distributed import ProcessGroup , TCPStore
21
21
from torch .distributed .distributed_c10d import (Backend , PrefixStore ,
22
22
_get_default_timeout ,
23
- _unregister_process_group ,
24
- is_nccl_available )
23
+ _unregister_process_group )
25
24
from torch .distributed .rendezvous import rendezvous
26
25
27
26
import vllm .envs as envs
@@ -406,7 +405,7 @@ def create(
406
405
port = port ,
407
406
world_size = world_size ,
408
407
is_master = launch_server ,
409
- timeout = datetime . timedelta (seconds = store_timeout ),
408
+ timeout = timedelta (seconds = store_timeout ),
410
409
use_libuv = False , # for now: github.com/pytorch/pytorch/pull/150215
411
410
master_listen_fd = listen_fd ,
412
411
)
@@ -419,6 +418,43 @@ def create(
419
418
data_expiration_seconds = data_expiration_seconds )
420
419
421
420
421
+ def init_gloo_process_group (backend : Backend , prefix_store : PrefixStore ,
422
+ group_rank : int , group_size : int ,
423
+ timeout : timedelta ) -> ProcessGroup :
424
+ """
425
+ Stateless init ProcessGroup with gloo backend compatible with
426
+ different torch versions.
427
+ """
428
+ if is_torch_equal_or_newer ("2.6" ):
429
+ pg = ProcessGroup (
430
+ prefix_store ,
431
+ group_rank ,
432
+ group_size ,
433
+ )
434
+ else :
435
+ options = ProcessGroup .Options (backend = backend )
436
+ pg = ProcessGroup (
437
+ prefix_store ,
438
+ group_rank ,
439
+ group_size ,
440
+ options ,
441
+ )
442
+ from torch .distributed .distributed_c10d import ProcessGroupGloo
443
+ backend_class = ProcessGroupGloo (prefix_store ,
444
+ group_rank ,
445
+ group_size ,
446
+ timeout = timeout )
447
+ backend_type = ProcessGroup .BackendType .GLOO
448
+ device = torch .device ("cpu" )
449
+ if is_torch_equal_or_newer ("2.6" ):
450
+ # _set_default_backend is supported in torch >= 2.6
451
+ pg ._set_default_backend (backend_type )
452
+ backend_class ._set_sequence_number_for_group ()
453
+
454
+ pg ._register_backend (device , backend_type , backend_class )
455
+ return pg
456
+
457
+
422
458
def stateless_init_torch_distributed_process_group (
423
459
host : str , port : int , rank : int , world_size : int ,
424
460
backend : str ) -> ProcessGroup :
@@ -468,40 +504,19 @@ def stateless_init_torch_distributed_process_group(
468
504
# different systems (e.g. RPC) in case the store is multi-tenant.
469
505
prefix_store = PrefixStore (init_method , store )
470
506
471
- pg : ProcessGroup = ProcessGroup (
472
- prefix_store ,
473
- group_rank ,
474
- group_size ,
475
- )
476
-
477
507
if backend == "gloo" :
478
- from torch .distributed .distributed_c10d import ProcessGroupGloo
479
- backend_class = ProcessGroupGloo (prefix_store ,
480
- group_rank ,
481
- group_size ,
482
- timeout = timeout )
483
- backend_type = ProcessGroup .BackendType .GLOO
484
- device = torch .device ("cpu" )
485
- elif backend == "nccl" :
486
- assert is_nccl_available ()
487
- from torch .distributed .distributed_c10d import ProcessGroupNCCL
488
-
489
- backend_options = ProcessGroupNCCL .Options ()
490
- backend_options ._timeout = timeout
491
-
492
- backend_class = ProcessGroupNCCL (prefix_store , group_rank , group_size ,
493
- backend_options )
494
- backend_type = ProcessGroup .BackendType .NCCL
495
- device = torch .device ("cuda" )
496
- else :
497
- raise RuntimeError (f"Unsupported torch distributed backend: { backend } " )
498
-
499
- pg ._set_default_backend (backend_type )
500
- backend_class ._set_sequence_number_for_group ()
501
-
502
- pg ._register_backend (device , backend_type , backend_class )
503
-
504
- return pg
508
+ return init_gloo_process_group (backend = backend ,
509
+ prefix_store = prefix_store ,
510
+ group_rank = group_rank ,
511
+ group_size = group_size ,
512
+ timeout = timeout )
513
+ from vllm .platforms import current_platform
514
+ return current_platform .stateless_init_device_torch_dist_pg (
515
+ backend = backend ,
516
+ prefix_store = prefix_store ,
517
+ group_rank = group_rank ,
518
+ group_size = group_size ,
519
+ timeout = timeout )
505
520
506
521
507
522
def stateless_destroy_torch_distributed_process_group (
0 commit comments