Skip to content

wip hang #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,15 @@ def _perform_sync(self) -> None:
pseudogradient = p.data - self.original_parameters[name]
p.grad = pseudogradient

for name, p in self._model.named_parameters():
print(
f"BEFORE average, replica {self._manager._replica_id} [{name} {p.data=} {self.original_parameters[name]=} {p.grad}]"
)
self._average_grads()
for name, p in self._model.named_parameters():
print(
f"AFTER average, replica {self._manager._replica_id} [{name} {p.grad}]"
)
# Restore the parameters back to the previous state
self._restore_parameters()
if self._manager.should_commit():
Expand All @@ -272,7 +280,7 @@ def _average_grads(self) -> None:
for p in self._model.parameters():
# Perform allreduce on the pseudogradients
assert p.grad is not None
work = self._manager.allreduce(p.grad)
work = self._manager.allreduce(p.grad.data)
works.append(work)
# Wait for all allreduce operations to complete
for work in works:
Expand Down
13 changes: 9 additions & 4 deletions torchft/local_sgd_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import re
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import as_completed, ThreadPoolExecutor
from contextlib import ExitStack
from datetime import timedelta
from typing import Any, Dict
Expand All @@ -16,7 +16,11 @@
from torchft.local_sgd import DiLoCo, LocalSGD
from torchft.manager import Manager
from torchft.manager_integ_test import FailureInjector, MyModel, Runner
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
from torchft.process_group import (
ProcessGroupBabyNCCL,
ProcessGroupGloo,
ProcessGroupNCCL,
)

logger: logging.Logger = logging.getLogger(__name__)

Expand All @@ -42,7 +46,7 @@ def state_dict() -> Dict[str, Dict[str, object]]:
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")

if device.type == "cuda":
pg = ProcessGroupBabyNCCL()
pg = ProcessGroupNCCL()
else:
pg = ProcessGroupGloo()
manager = Manager(
Expand Down Expand Up @@ -133,7 +137,8 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")

if device.type == "cuda":
pg = ProcessGroupBabyNCCL()
pg = ProcessGroupNCCL()
# pg = ProcessGroupBabyNCCL()
else:
pg = ProcessGroupGloo()
manager = Manager(
Expand Down
39 changes: 34 additions & 5 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from contextlib import nullcontext
from datetime import timedelta
from enum import Enum
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, TypeVar

import torch
from torch.distributed import ReduceOp, TCPStore
Expand Down Expand Up @@ -186,6 +186,7 @@ def __init__(
self._recovery_stream: Optional["torch.cuda.Stream"] = (
torch.cuda.Stream() if torch.cuda.is_available() else None
)
self._replica_id = replica_id

if rank == 0:
if port is None:
Expand Down Expand Up @@ -272,6 +273,7 @@ def allreduce(self, tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]:
fut.set_result(tensor)
return fut

print(f"{self._replica_id} in allreduce")
self.wait_quorum()

if not self.is_participating():
Expand All @@ -283,17 +285,41 @@ def allreduce(self, tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]:
# it later.
work = self._pg.allreduce([tensor], ReduceOp.SUM)
fut = work.get_future()
# torch.cuda.synchronize()

event = torch.cuda.Event()

# def record_event_callback(
# fut: torch.futures.Future[List[torch.Tensor]],
# ) -> None:
# # We need to record the event after the allreduce completes
# # which is why this is its own callback
# event.record()

# fut = fut.then(record_event_callback)

# schedule grad normalization as a continuation
# on the Future
def callback(
fut: torch.futures.Future[List[torch.Tensor]],
) -> torch.Tensor:
nonlocal tensor
# work.wait()
print(f"{torch.cuda.current_stream()=}")
# torch.cuda.synchronize()
# event.record()
# event.synchronize()

# check for exceptions
fut.value()

value = fut.value()
# print(value)
# val_tensor = value[0]

# if val_tensor is None:
# # edge case for process group baby nccl
# nonlocal tensor
# else:
# tensor = val_tensor
nonlocal tensor
tensor /= self.num_participants()

return tensor
Expand Down Expand Up @@ -436,7 +462,9 @@ def wait_quorum(self) -> None:
assert (
self._quorum_future is not None
), "must call start_quorum before wait_quorum"
self._quorum_future.result()
print(f"{self._replica_id} Called wait_quorum()")
result = self._quorum_future.result()
print(f"{self._replica_id} Finished wait_quorum() {result=}")

def _async_quorum(
self,
Expand Down Expand Up @@ -721,6 +749,7 @@ def num_participants(self) -> int:
self.wait_quorum()

assert self._participating_world_size >= 0, "internal error"
print(f"{self._replica_id} returning {self._participating_world_size}")
return self._participating_world_size

def is_participating(self) -> bool:
Expand Down
31 changes: 21 additions & 10 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack, contextmanager
from concurrent.futures import as_completed, ThreadPoolExecutor
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass, field
from datetime import timedelta
from typing import Any, Dict, Generator, List, Optional, Protocol, Set, Tuple, TypeVar
Expand All @@ -21,7 +21,7 @@
from torchft.local_sgd import DiLoCo, LocalSGD
from torchft.manager import Manager
from torchft.optim import OptimizerWrapper
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
from torchft.process_group import ProcessGroupGloo, ProcessGroupNCCL

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -118,6 +118,7 @@ def _replica_main(self) -> List[object]:
else:
device = torch.device("cpu")

# try on a separate stream
futures.append(
executor.submit(
self.train_loop,
Expand Down Expand Up @@ -401,7 +402,7 @@ def test_quorum_timeout(self) -> None:
@parameterized.expand(
[
(True,), # Test with CUDA
(False,), # Test without CUDA (CPU)
# (False,), # Test without CUDA (CPU)
]
)
def test_manager_allreduce(self, use_cuda: bool) -> None:
Expand Down Expand Up @@ -457,7 +458,7 @@ def all_reduce_callback(
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")

if device.type == "cuda":
pg = ProcessGroupBabyNCCL()
pg = ProcessGroupNCCL()
else:
pg = ProcessGroupGloo()
manager = Manager(
Expand All @@ -480,9 +481,19 @@ def all_reduce_callback(
)
stack.callback(lambda: manager.shutdown(wait=False))

manager.start_quorum()
t1 = torch.ones((1, 3), device=device)
fut = manager.allreduce(t1)
fut.wait()
return t1
with torch.cuda.stream(torch.cuda.Stream()):
results = []
manager.start_quorum()
futs = []
for _ in range(1):
t1 = torch.randn((1, 3), device=device)
fut = manager.allreduce(t1)
futs.append(fut)
results.append(t1)
# Wait for all allreduce operations to complete
for fut in futs:
fut.wait()
results_tensor = torch.stack(results)
total_sum = results_tensor.sum(dim=0)
return total_sum
return None
8 changes: 4 additions & 4 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@
from datetime import timedelta
from multiprocessing.connection import Connection
from typing import (
TYPE_CHECKING,
Any,
Callable,
cast,
Dict,
Generator,
List,
Optional,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
cast,
)

import torch
Expand All @@ -44,14 +44,14 @@
# pyre-fixme[21]: no attribute ProcessGroupGloo
from torch.distributed import (
DeviceMesh,
get_rank,
init_device_mesh,
PrefixStore,
ProcessGroup as BaseProcessGroup,
ProcessGroupGloo as BaseProcessGroupGloo,
ProcessGroupNCCL as BaseProcessGroupNCCL,
Store,
TCPStore,
get_rank,
init_device_mesh,
)
from torch.distributed.distributed_c10d import (
AllgatherOptions,
Expand Down
Loading