@@ -611,6 +611,112 @@ def reduce_scatter_tensor_coalesced(
611
611
)
612
612
613
613
614
+ class _ParallelWork (Work ):
615
+ def __init__ (self , works : List [Work ]) -> None :
616
+ super ().__init__ ()
617
+ self ._works = works
618
+
619
+ def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
620
+ for work in self ._works :
621
+ if timeout is not None :
622
+ work .wait (timeout = timeout )
623
+ else :
624
+ work .wait ()
625
+ return True
626
+
627
+ def get_future (self ) -> torch .futures .Future [object ]:
628
+ futures = [work .get_future () for work in self ._works ]
629
+ return torch .futures .collect_all (futures )
630
+
631
+
632
+ class ParallelProcessGroup (ProcessGroupWrapper ):
633
+ def __init__ (
634
+ self ,
635
+ base : ProcessGroupWrapper ,
636
+ timeout : timedelta = timedelta (seconds = 60 ),
637
+ count : int = 10 ,
638
+ ) -> None :
639
+ super ().__init__ (timeout = timeout )
640
+
641
+ self ._base = base
642
+ self ._count = count
643
+ self ._pgs = []
644
+
645
+ self ._create_pg = base ._create_pg
646
+
647
+ def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
648
+ # abort if already initialized
649
+ self .abort ()
650
+
651
+ for i in range (self ._count ):
652
+ store = create_store_client (
653
+ f"{ store_addr } /parallel{ i } " , timeout = self ._timeout
654
+ )
655
+
656
+ self ._pgs .append (self ._create_pg (store , rank , world_size ))
657
+
658
+ self ._pg = self ._pgs [0 ]
659
+
660
+ def getBackendName (self ) -> str :
661
+ return f"{ self ._base .getBackendName ()} -parallel"
662
+
663
+ def _split_tensors (self , tensors : List [torch .Tensor ]) -> List [List [torch .Tensor ]]:
664
+ if not isinstance (tensors , (list , tuple )):
665
+ tensors = [tensors ]
666
+
667
+ tensor_lists = [[] for _ in range (self ._count )]
668
+ for t in tensors :
669
+ chunks = torch .tensor_split (t .view (- 1 ), self ._count , dim = 0 )
670
+ for i , chunk in enumerate (chunks ):
671
+ tensor_lists [i ].append (chunk )
672
+
673
+ return tensor_lists
674
+
675
+ def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
676
+ tensor_lists = self ._split_tensors (tensors )
677
+
678
+ with self ._run_context ():
679
+ works = []
680
+ for i in range (self ._count ):
681
+ works .append (
682
+ self ._pgs [i ].allreduce (tensor_lists [i ], self ._opts_hook (opts ))
683
+ )
684
+
685
+ return self ._wrap_work (_ParallelWork (works ), opts )
686
+
687
+ def reduce (self , tensors : List [torch .Tensor ], dst : int , opts : object ) -> Work :
688
+ tensor_lists = self ._split_tensors (tensors )
689
+
690
+ with self ._run_context ():
691
+ works = []
692
+ for i in range (self ._count ):
693
+ works .append (
694
+ self ._pgs [i ].reduce (tensor_lists [i ], dst , self ._opts_hook (opts ))
695
+ )
696
+
697
+ return self ._wrap_work (_ParallelWork (works ), opts )
698
+
699
+ def send (self , tensors : List [torch .Tensor ], dst_rank : int , tag : int ) -> Work :
700
+ tensor_lists = self ._split_tensors (tensors )
701
+
702
+ with self ._run_context ():
703
+ works = []
704
+ for i in range (self ._count ):
705
+ works .append (self ._pgs [i ].send (tensor_lists [i ], dst_rank , tag ))
706
+
707
+ return self ._wrap_work (_ParallelWork (works ), None )
708
+
709
+ def recv (self , tensors : List [torch .Tensor ], src_rank : int , tag : int ) -> Work :
710
+ tensor_lists = self ._split_tensors (tensors )
711
+
712
+ with self ._run_context ():
713
+ works = []
714
+ for i in range (self ._count ):
715
+ works .append (self ._pgs [i ].recv (tensor_lists [i ], src_rank , tag ))
716
+
717
+ return self ._wrap_work (_ParallelWork (works ), None )
718
+
719
+
614
720
class _WorkCUDATimeout (Work ):
615
721
def __init__ (self , pg : ProcessGroup , work : Work , timeout : timedelta ) -> None :
616
722
super ().__init__ ()
0 commit comments