diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 37602f7..cca86a9 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -255,7 +255,15 @@ def _perform_sync(self) -> None: pseudogradient = p.data - self.original_parameters[name] p.grad = pseudogradient + for name, p in self._model.named_parameters(): + print( + f"BEFORE average, replica {self._manager._replica_id} [{name} {p.data=} {self.original_parameters[name]=} {p.grad}]" + ) self._average_grads() + for name, p in self._model.named_parameters(): + print( + f"AFTER average, replica {self._manager._replica_id} [{name} {p.grad}]" + ) # Restore the parameters back to the previous state self._restore_parameters() if self._manager.should_commit(): @@ -272,7 +280,7 @@ def _average_grads(self) -> None: for p in self._model.parameters(): # Perform allreduce on the pseudogradients assert p.grad is not None - work = self._manager.allreduce(p.grad) + work = self._manager.allreduce(p.grad.data) works.append(work) # Wait for all allreduce operations to complete for work in works: diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index 917e853..33e9a23 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -2,7 +2,7 @@ import logging import re import traceback -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed, ThreadPoolExecutor from contextlib import ExitStack from datetime import timedelta from typing import Any, Dict @@ -16,7 +16,11 @@ from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager from torchft.manager_integ_test import FailureInjector, MyModel, Runner -from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo +from torchft.process_group import ( + ProcessGroupBabyNCCL, + ProcessGroupGloo, + ProcessGroupNCCL, +) logger: logging.Logger = logging.getLogger(__name__) @@ -42,7 +46,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") if device.type == "cuda": - pg = ProcessGroupBabyNCCL() + pg = ProcessGroupNCCL() else: pg = ProcessGroupGloo() manager = Manager( @@ -133,7 +137,8 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") if device.type == "cuda": - pg = ProcessGroupBabyNCCL() + pg = ProcessGroupNCCL() + # pg = ProcessGroupBabyNCCL() else: pg = ProcessGroupGloo() manager = Manager( diff --git a/torchft/manager.py b/torchft/manager.py index 85af623..9335bb8 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -34,7 +34,7 @@ from contextlib import nullcontext from datetime import timedelta from enum import Enum -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast +from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, TypeVar import torch from torch.distributed import ReduceOp, TCPStore @@ -186,6 +186,7 @@ def __init__( self._recovery_stream: Optional["torch.cuda.Stream"] = ( torch.cuda.Stream() if torch.cuda.is_available() else None ) + self._replica_id = replica_id if rank == 0: if port is None: @@ -272,6 +273,7 @@ def allreduce(self, tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]: fut.set_result(tensor) return fut + print(f"{self._replica_id} in allreduce") self.wait_quorum() if not self.is_participating(): @@ -283,17 +285,41 @@ def allreduce(self, tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]: # it later. work = self._pg.allreduce([tensor], ReduceOp.SUM) fut = work.get_future() + # torch.cuda.synchronize() + + event = torch.cuda.Event() + + # def record_event_callback( + # fut: torch.futures.Future[List[torch.Tensor]], + # ) -> None: + # # We need to record the event after the allreduce completes + # # which is why this is its own callback + # event.record() + + # fut = fut.then(record_event_callback) # schedule grad normalization as a continuation # on the Future def callback( fut: torch.futures.Future[List[torch.Tensor]], ) -> torch.Tensor: - nonlocal tensor + # work.wait() + print(f"{torch.cuda.current_stream()=}") + # torch.cuda.synchronize() + # event.record() + # event.synchronize() # check for exceptions - fut.value() - + value = fut.value() + # print(value) + # val_tensor = value[0] + + # if val_tensor is None: + # # edge case for process group baby nccl + # nonlocal tensor + # else: + # tensor = val_tensor + nonlocal tensor tensor /= self.num_participants() return tensor @@ -436,7 +462,9 @@ def wait_quorum(self) -> None: assert ( self._quorum_future is not None ), "must call start_quorum before wait_quorum" - self._quorum_future.result() + print(f"{self._replica_id} Called wait_quorum()") + result = self._quorum_future.result() + print(f"{self._replica_id} Finished wait_quorum() {result=}") def _async_quorum( self, @@ -721,6 +749,7 @@ def num_participants(self) -> int: self.wait_quorum() assert self._participating_world_size >= 0, "internal error" + print(f"{self._replica_id} returning {self._participating_world_size}") return self._participating_world_size def is_participating(self) -> bool: diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index d591d0d..2c7c871 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -3,8 +3,8 @@ import threading import time import traceback -from concurrent.futures import ThreadPoolExecutor, as_completed -from contextlib import ExitStack, contextmanager +from concurrent.futures import as_completed, ThreadPoolExecutor +from contextlib import contextmanager, ExitStack from dataclasses import dataclass, field from datetime import timedelta from typing import Any, Dict, Generator, List, Optional, Protocol, Set, Tuple, TypeVar @@ -21,7 +21,7 @@ from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager from torchft.optim import OptimizerWrapper -from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo +from torchft.process_group import ProcessGroupGloo, ProcessGroupNCCL logger: logging.Logger = logging.getLogger(__name__) @@ -118,6 +118,7 @@ def _replica_main(self) -> List[object]: else: device = torch.device("cpu") + # try on a separate stream futures.append( executor.submit( self.train_loop, @@ -401,7 +402,7 @@ def test_quorum_timeout(self) -> None: @parameterized.expand( [ (True,), # Test with CUDA - (False,), # Test without CUDA (CPU) + # (False,), # Test without CUDA (CPU) ] ) def test_manager_allreduce(self, use_cuda: bool) -> None: @@ -457,7 +458,7 @@ def all_reduce_callback( print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") if device.type == "cuda": - pg = ProcessGroupBabyNCCL() + pg = ProcessGroupNCCL() else: pg = ProcessGroupGloo() manager = Manager( @@ -480,9 +481,19 @@ def all_reduce_callback( ) stack.callback(lambda: manager.shutdown(wait=False)) - manager.start_quorum() - t1 = torch.ones((1, 3), device=device) - fut = manager.allreduce(t1) - fut.wait() - return t1 + with torch.cuda.stream(torch.cuda.Stream()): + results = [] + manager.start_quorum() + futs = [] + for _ in range(1): + t1 = torch.randn((1, 3), device=device) + fut = manager.allreduce(t1) + futs.append(fut) + results.append(t1) + # Wait for all allreduce operations to complete + for fut in futs: + fut.wait() + results_tensor = torch.stack(results) + total_sum = results_tensor.sum(dim=0) + return total_sum return None diff --git a/torchft/process_group.py b/torchft/process_group.py index 2b13593..dac79ad 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -23,17 +23,17 @@ from datetime import timedelta from multiprocessing.connection import Connection from typing import ( - TYPE_CHECKING, Any, Callable, + cast, Dict, Generator, List, Optional, Tuple, + TYPE_CHECKING, TypeVar, Union, - cast, ) import torch @@ -44,14 +44,14 @@ # pyre-fixme[21]: no attribute ProcessGroupGloo from torch.distributed import ( DeviceMesh, + get_rank, + init_device_mesh, PrefixStore, ProcessGroup as BaseProcessGroup, ProcessGroupGloo as BaseProcessGroupGloo, ProcessGroupNCCL as BaseProcessGroupNCCL, Store, TCPStore, - get_rank, - init_device_mesh, ) from torch.distributed.distributed_c10d import ( AllgatherOptions,