Skip to content

Commit 5f863c6

Browse files
committed
ParallelProcessGroup
1 parent b84c5a6 commit 5f863c6

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed

torchft/process_group.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,112 @@ def reduce_scatter_tensor_coalesced(
611611
)
612612

613613

614+
class _ParallelWork(Work):
615+
def __init__(self, works: List[Work]) -> None:
616+
super().__init__()
617+
self._works = works
618+
619+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
620+
for work in self._works:
621+
if timeout is not None:
622+
work.wait(timeout=timeout)
623+
else:
624+
work.wait()
625+
return True
626+
627+
def get_future(self) -> torch.futures.Future[object]:
628+
futures = [work.get_future() for work in self._works]
629+
return torch.futures.collect_all(futures)
630+
631+
632+
class ParallelProcessGroup(ProcessGroupWrapper):
633+
def __init__(
634+
self,
635+
base: ProcessGroupWrapper,
636+
timeout: timedelta = timedelta(seconds=60),
637+
count: int = 10,
638+
) -> None:
639+
super().__init__(timeout=timeout)
640+
641+
self._base = base
642+
self._count = count
643+
self._pgs = []
644+
645+
self._create_pg = base._create_pg
646+
647+
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
648+
# abort if already initialized
649+
self.abort()
650+
651+
for i in range(self._count):
652+
store = create_store_client(
653+
f"{store_addr}/parallel{i}", timeout=self._timeout
654+
)
655+
656+
self._pgs.append(self._create_pg(store, rank, world_size))
657+
658+
self._pg = self._pgs[0]
659+
660+
def getBackendName(self) -> str:
661+
return f"{self._base.getBackendName()}-parallel"
662+
663+
def _split_tensors(self, tensors: List[torch.Tensor]) -> List[List[torch.Tensor]]:
664+
if not isinstance(tensors, (list, tuple)):
665+
tensors = [tensors]
666+
667+
tensor_lists = [[] for _ in range(self._count)]
668+
for t in tensors:
669+
chunks = torch.tensor_split(t.view(-1), self._count, dim=0)
670+
for i, chunk in enumerate(chunks):
671+
tensor_lists[i].append(chunk)
672+
673+
return tensor_lists
674+
675+
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
676+
tensor_lists = self._split_tensors(tensors)
677+
678+
with self._run_context():
679+
works = []
680+
for i in range(self._count):
681+
works.append(
682+
self._pgs[i].allreduce(tensor_lists[i], self._opts_hook(opts))
683+
)
684+
685+
return self._wrap_work(_ParallelWork(works), opts)
686+
687+
def reduce(self, tensors: List[torch.Tensor], dst: int, opts: object) -> Work:
688+
tensor_lists = self._split_tensors(tensors)
689+
690+
with self._run_context():
691+
works = []
692+
for i in range(self._count):
693+
works.append(
694+
self._pgs[i].reduce(tensor_lists[i], dst, self._opts_hook(opts))
695+
)
696+
697+
return self._wrap_work(_ParallelWork(works), opts)
698+
699+
def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
700+
tensor_lists = self._split_tensors(tensors)
701+
702+
with self._run_context():
703+
works = []
704+
for i in range(self._count):
705+
works.append(self._pgs[i].send(tensor_lists[i], dst_rank, tag))
706+
707+
return self._wrap_work(_ParallelWork(works), None)
708+
709+
def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
710+
tensor_lists = self._split_tensors(tensors)
711+
712+
with self._run_context():
713+
works = []
714+
for i in range(self._count):
715+
works.append(self._pgs[i].recv(tensor_lists[i], src_rank, tag))
716+
717+
return self._wrap_work(_ParallelWork(works), None)
718+
719+
614720
class _WorkCUDATimeout(Work):
615721
def __init__(self, pg: ProcessGroup, work: Work, timeout: timedelta) -> None:
616722
super().__init__()

torchft/process_group_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from torchft.process_group import (
4141
ErrorSwallowingProcessGroupWrapper,
4242
ManagedProcessGroup,
43+
ParallelProcessGroup,
4344
ProcessGroup,
4445
ProcessGroupBabyGloo,
4546
ProcessGroupBabyNCCL,
@@ -690,6 +691,29 @@ def test_baby_gloo_apis(self) -> None:
690691
with self.assertRaisesRegex(OSError, "handle is closed"):
691692
a.allreduce([t], AllreduceOptions()).wait()
692693

694+
def test_parallel_gloo_apis(self) -> None:
695+
dummy_init_pg()
696+
697+
store = TCPStore(
698+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
699+
)
700+
701+
store_addr = f"localhost:{store.port}/prefix"
702+
703+
a = ParallelProcessGroup(
704+
base=ProcessGroupGloo(),
705+
count=4,
706+
)
707+
a.configure(store_addr, 0, 1)
708+
a.register("test_parallel_gloo_apis")
709+
710+
_test_pg(
711+
a,
712+
skip=("reduce_scatter_tensor_coalesced"),
713+
)
714+
715+
a.unregister()
716+
693717
# pyre-fixme[56]: Pyre was not able to infer the type of argument
694718
@skipUnless(torch.cuda.is_available(), "needs CUDA")
695719
def test_baby_nccl_apis(self) -> None:

0 commit comments

Comments
 (0)