Skip to content

Commit 1682257

Browse files
authored
test allreduce failures for diloco (#226)
Summary: - test when allreduce fails but no new nodes join - added another event of type `AllreduceFailure` - This new event required modifying some manager code to inject the failure
1 parent 17c7cb9 commit 1682257

File tree

3 files changed

+170
-4
lines changed

3 files changed

+170
-4
lines changed

torchft/local_sgd_integ_test.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
MyModel,
2929
Runner,
3030
)
31-
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
31+
from torchft.process_group import (
32+
FakeProcessGroupWrapper,
33+
ProcessGroupBabyNCCL,
34+
ProcessGroupGloo,
35+
)
3236

3337
logger: logging.Logger = logging.getLogger(__name__)
3438

@@ -188,9 +192,11 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
188192
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
189193

190194
if device.type == "cuda":
191-
pg = ProcessGroupBabyNCCL()
195+
pg = FakeProcessGroupWrapper(ProcessGroupBabyNCCL())
192196
else:
193-
pg = ProcessGroupGloo(timeout=timedelta(seconds=10))
197+
pg = FakeProcessGroupWrapper(
198+
ProcessGroupGloo(timeout=timedelta(seconds=10))
199+
)
194200
manager = Manager(
195201
pg=pg,
196202
min_replica_size=2,
@@ -210,6 +216,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
210216
# pyre-fixme[6]: Incompatible parameter type
211217
**runner.manager_args,
212218
)
219+
runner.event_injector.set_pg(pg)
213220
stack.callback(manager.shutdown)
214221
# initialize default group for device mesh to work
215222
if not torch.distributed.is_initialized():
@@ -669,3 +676,78 @@ def test_streaming_diloco_upscale(
669676

670677
for event_injector in event_injectors:
671678
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Barrier], 1)
679+
680+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
681+
@skipIf(sys.platform == "darwin", "not reliable on mac")
682+
@parameterized.expand(CONFIG)
683+
def test_streaming_diloco_commit_failure(
684+
self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int
685+
) -> None:
686+
# Skip the test if use_cuda is True and there are not enough GPUs
687+
if use_cuda and torch.cuda.device_count() < 2:
688+
self.skipTest("Not enough GPUs for CUDA test")
689+
690+
lighthouse = LighthouseServer(
691+
bind="[::]:0",
692+
min_replicas=2,
693+
)
694+
num_replicas = 2
695+
futures = []
696+
executors = []
697+
698+
event_injectors = [
699+
EventInjector().fail_allreduce_at(0, 1),
700+
EventInjector().fail_allreduce_at(0, 1),
701+
]
702+
703+
torch.manual_seed(42)
704+
# Initialize the model so we can pass in the state_dict
705+
m: nn.Module = MultiMyModel(2, 3, n_fragments)
706+
707+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
708+
executor = ThreadPoolExecutor(max_workers=1)
709+
executors.append(executor)
710+
runner = Runner(
711+
replica_id=replica_id,
712+
num_replicas=num_replicas,
713+
lighthouse_address=lighthouse.address(),
714+
event_injector=event_injector,
715+
train_loop=diloco_train_loop,
716+
train_loop_args={
717+
"model_state_dict": m.state_dict(),
718+
"n_fragments": n_fragments,
719+
"diloco_args": {
720+
"fragment_sync_delay": fragment_sync_delay,
721+
"sync_every": 4,
722+
},
723+
},
724+
)
725+
futures.append(executor.submit(runner.run_replica))
726+
727+
state_dicts = []
728+
729+
for fut in as_completed(futures):
730+
continue
731+
732+
for fut in futures:
733+
try:
734+
state_dicts.append(fut.result()[0])
735+
except Exception as e:
736+
print(e)
737+
raise
738+
739+
lighthouse.shutdown()
740+
741+
rep0, rep1 = state_dicts
742+
743+
assert_equal_global_state(rep0, rep1)
744+
745+
for step in rep0.keys():
746+
self.assertEqual(
747+
rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"]
748+
)
749+
750+
for event_injector in event_injectors:
751+
self.assertEqual(
752+
event_injector.count[EventInjectorEvent.AllreduceFailure], 1
753+
)

torchft/manager_integ_test.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
from torchft.local_sgd import DiLoCo, LocalSGD
3535
from torchft.manager import Manager
3636
from torchft.optim import OptimizerWrapper
37-
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
37+
from torchft.process_group import (
38+
FakeProcessGroupWrapper,
39+
ProcessGroupBabyNCCL,
40+
ProcessGroupGloo,
41+
)
3842

3943
logger: logging.Logger = logging.getLogger(__name__)
4044

@@ -76,10 +80,13 @@ class InjectedFailure(Exception):
7680

7781

7882
class EventInjectorEvent(Enum):
83+
# Crashes a rank
7984
Failure = auto()
8085
# Used to wait for a rank to reach a certain step before continuing.
8186
# Users need to make sure the size of the barrier is appropriately set.
8287
Barrier = auto()
88+
# Fails the allreduce call made by a rank
89+
AllreduceFailure = auto()
8390

8491

8592
class EventInjectorInfo:
@@ -90,10 +97,15 @@ def __init__(self, event: EventInjectorEvent, data: object) -> None:
9097

9198
class EventInjector:
9299
def __init__(self) -> None:
100+
self._pg: Optional[FakeProcessGroupWrapper] = None
93101
self._lock = threading.Lock()
94102
self._events: Dict[Tuple[int, int], EventInjectorInfo] = {}
95103
self.count: dict[EventInjectorEvent, int] = defaultdict(int)
96104

105+
def set_pg(self, pg: FakeProcessGroupWrapper) -> None:
106+
with self._lock:
107+
self._pg = pg
108+
97109
def fail_at(self, rank: int, step: int) -> "EventInjector":
98110
with self._lock:
99111
assert (rank, step) not in self._events
@@ -102,6 +114,14 @@ def fail_at(self, rank: int, step: int) -> "EventInjector":
102114
)
103115
return self
104116

117+
def fail_allreduce_at(self, rank: int, step: int) -> "EventInjector":
118+
with self._lock:
119+
assert (rank, step) not in self._events
120+
self._events[(rank, step)] = EventInjectorInfo(
121+
EventInjectorEvent.AllreduceFailure, None
122+
)
123+
return self
124+
105125
def barrier_at(
106126
self, rank: int, step: int, barrier: threading.Barrier
107127
) -> "EventInjector":
@@ -124,6 +144,14 @@ def check(self, rank: int, step: int) -> None:
124144
print(f"injecting failure {rank=} {step=}")
125145
raise InjectedFailure(f"injected failure {rank=} {step=}")
126146

147+
if event_info.event == EventInjectorEvent.AllreduceFailure:
148+
print(f"injecting allreduce failure {rank=} {step=}")
149+
assert self._pg is not None
150+
self._pg.report_future_error(
151+
RuntimeError("injected allreduce error")
152+
)
153+
return
154+
127155
if event_info.event == EventInjectorEvent.Barrier:
128156
print(f"waiting for barrier {rank=} {step=}")
129157
cast(threading.Barrier, event_info.data).wait()

torchft/process_group.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,13 @@ def __init__(
378378
self._pg: Optional[BaseProcessGroup] = pg
379379
self._timeout = timeout
380380

381+
def getBackendName(self) -> str:
382+
pg = self._pg
383+
if isinstance(pg, ProcessGroup):
384+
return pg.getBackendName()
385+
386+
raise NotImplementedError("not implemented")
387+
381388
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
382389
pg = self._pg
383390
if isinstance(pg, ProcessGroup):
@@ -1016,6 +1023,55 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
10161023
return _DummyWork(tensors)
10171024

10181025

1026+
class FakeProcessGroupWrapper(ProcessGroupWrapper):
1027+
"""
1028+
This is a wrapper around any ProcessGroup that can be used to inject
1029+
errors into the process group at various points.
1030+
1031+
This is intended to be used for tests so that they can test cases
1032+
in which process group operations error out.
1033+
"""
1034+
1035+
def __init__(self, pg: ProcessGroup) -> None:
1036+
super().__init__(pg=pg)
1037+
1038+
self._future_error: Optional[Exception] = None
1039+
1040+
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
1041+
self._future_error = None
1042+
1043+
super().configure(store_addr, rank, world_size)
1044+
1045+
def report_future_error(self, e: Exception) -> None:
1046+
"""
1047+
Report an error to this process group. This will cause the
1048+
future attached to the next operation to error out.
1049+
1050+
Args:
1051+
e: exception to report
1052+
"""
1053+
self._future_error = e
1054+
1055+
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
1056+
work = super().allreduce(tensors, opts)
1057+
1058+
if self._future_error is None:
1059+
return work
1060+
1061+
fut = work.get_future()
1062+
1063+
def callback(
1064+
fut: torch.futures.Future[List[torch.Tensor]],
1065+
) -> List[torch.Tensor]:
1066+
future_error, self._future_error = self._future_error, None
1067+
assert future_error is not None
1068+
raise future_error
1069+
1070+
fut = fut.then(callback)
1071+
1072+
return work
1073+
1074+
10191075
class _ManagedWork(Work):
10201076
def __init__(self, manager: "Manager", work: Work, default_result: object) -> None:
10211077
super().__init__()

0 commit comments

Comments
 (0)