Skip to content

Commit 358f33d

Browse files
Zheng-Yong (Arsa) Angfacebook-github-bot
authored andcommitted
enable gradient accumulation in SDD (#3462)
Summary: Context: gradient accumulation is still not available in TorchRec, especially for the SDD pipeline which is being used by many recommendation models. This diff: implements a UI enabling gradient accumulation for SDD. Differential Revision: D84915986
1 parent 2f4b794 commit 358f33d

File tree

1 file changed

+57
-25
lines changed

1 file changed

+57
-25
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,10 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
420420
(applicable to 2D sharding only)
421421
if set and DMP collection is enabled for 2D sharding,
422422
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
423+
gradient_accumulation_steps (int): number of steps to accumulate gradients before
424+
performing backward pass and optimizer update. Default is 1 (no accumulation).
425+
gradient_accumulation_start_step (Optional[int]): the step to start gradient accumulation.
426+
if not set, gradient accumulation will start from the second step.
423427
"""
424428

425429
# The PipelinedForward class that is used in _rewrite_model
@@ -440,6 +444,8 @@ def __init__(
440444
] = None,
441445
dmp_collection_sync_interval_batches: Optional[int] = 1,
442446
enqueue_batch_after_forward: bool = False,
447+
gradient_accumulation_steps: int = 1,
448+
gradient_accumulation_start_step: Optional[int] = None,
443449
) -> None:
444450
self._model = model
445451
self._optimizer = optimizer
@@ -505,6 +511,14 @@ def __init__(
505511
dmp_collection_sync_interval_batches
506512
)
507513

514+
# NOTE: the first step cannot be no_sync when DDP.static_graph = True, due to an unfortunate restriction in torch.distributed.
515+
# thus, unless explicitly set, we set gradient_accumulation_start_step to 1 if gradient_accumulation_steps > 1.
516+
self._pre_accumulation_steps: int = gradient_accumulation_start_step or (
517+
1 if gradient_accumulation_steps > 1 else 0
518+
)
519+
self._accumulation_steps: int = gradient_accumulation_steps
520+
self._accumulation_step_count: int = 0
521+
508522
if self._dmp_collection_sync_interval_batches is not None:
509523
logger.info(
510524
f"{self.__class__.__name__}: [Sparse 2D] DMP collection will sync every "
@@ -682,7 +696,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
682696
# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
683697
self._set_module_context(self.contexts[0])
684698

685-
if self._model.training:
699+
# only zero grad at the start of each accumulation
700+
if self._model.training and (
701+
self._pre_accumulation_steps > 0 or self._accumulation_step_count == 0
702+
):
686703
with record_function("## zero_grad ##"):
687704
self._optimizer.zero_grad()
688705

@@ -698,35 +715,50 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
698715
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
699716
self.enqueue_batch(dataloader_iter)
700717

701-
# forward
702-
with record_function(f"## forward {self.contexts[0].index} ##"):
703-
self._state = PipelineState.CALL_FWD
704-
losses, output = self._model_fwd(self.batches[0])
718+
no_sync = self._pre_accumulation_steps <= 0 and (
719+
self._model.training
720+
and self._accumulation_step_count + 1 < self._accumulation_steps
721+
)
722+
with (
723+
self._model._dmp_wrapped_module.no_sync() # pyre-ignore[16]
724+
if no_sync
725+
else contextlib.nullcontext()
726+
):
727+
# forward
728+
with record_function(f"## forward {self.contexts[0].index} ##"):
729+
self._state = PipelineState.CALL_FWD
730+
losses, output = self._model_fwd(self.batches[0])
705731

706-
if self._enqueue_batch_after_forward:
707-
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
708-
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
709-
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
710-
self.enqueue_batch(dataloader_iter)
732+
if self._enqueue_batch_after_forward:
733+
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
734+
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
735+
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
736+
self.enqueue_batch(dataloader_iter)
711737

712-
if len(self.batches) >= 2:
713-
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
714-
self.wait_sparse_data_dist(self.contexts[1])
738+
if len(self.batches) >= 2:
739+
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
740+
self.wait_sparse_data_dist(self.contexts[1])
715741

716-
if self._model.training:
717742
# backward
718-
self._state = PipelineState.CALL_BWD
719-
self._backward(losses)
720-
721-
self.sync_embeddings(
722-
self._model,
723-
self._dmp_collection_sync_interval_batches,
724-
self.contexts[0],
725-
)
743+
if self._model.training:
744+
self._state = PipelineState.CALL_BWD
745+
self._backward(losses)
726746

727-
# update
728-
with record_function(f"## optimizer {self.contexts[0].index} ##"):
729-
self._optimizer.step()
747+
if no_sync:
748+
self._accumulation_step_count += 1
749+
else:
750+
self.sync_embeddings(
751+
self._model,
752+
self._dmp_collection_sync_interval_batches,
753+
self.contexts[0],
754+
)
755+
# update
756+
with record_function(f"## optimizer {self.contexts[0].index} ##"):
757+
self._optimizer.step()
758+
if self._pre_accumulation_steps > 0:
759+
self._pre_accumulation_steps -= 1
760+
else:
761+
self._accumulation_step_count = 0
730762

731763
self.dequeue_batch()
732764
return output

0 commit comments

Comments
 (0)