Skip to content

Commit b7161ca

Browse files
Zheng-Yong (Arsa) Angfacebook-github-bot
authored andcommitted
enable gradient accumulation in SDD (#3462)
Summary: Context: gradient accumulation is still not available in TorchRec, especially for the SDD pipeline which is being used by many recommendation models. This diff: implements a UI enabling gradient accumulation for SDD. Differential Revision: D84915986
1 parent e6e5e8f commit b7161ca

File tree

1 file changed

+51
-25
lines changed

1 file changed

+51
-25
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
418418
(applicable to 2D sharding only)
419419
if set and DMP collection is enabled for 2D sharding,
420420
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
421+
gradient_accumulation_steps (int): number of steps to accumulate gradients before
422+
performing backward pass and optimizer update. Default is 1 (no accumulation).
421423
"""
422424

423425
# The PipelinedForward class that is used in _rewrite_model
@@ -438,6 +440,7 @@ def __init__(
438440
] = None,
439441
dmp_collection_sync_interval_batches: Optional[int] = 1,
440442
enqueue_batch_after_forward: bool = False,
443+
gradient_accumulation_steps: int = 1,
441444
) -> None:
442445
self._model = model
443446
self._optimizer = optimizer
@@ -503,6 +506,10 @@ def __init__(
503506
dmp_collection_sync_interval_batches
504507
)
505508

509+
self._accumulation_steps: int = gradient_accumulation_steps
510+
self._accumulation_step_count: int = gradient_accumulation_steps - 1
511+
self._is_first_step: bool = True
512+
506513
if self._dmp_collection_sync_interval_batches is not None:
507514
logger.info(
508515
f"{self.__class__.__name__}: [Sparse 2D] DMP collection will sync every "
@@ -680,7 +687,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
680687
# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
681688
self._set_module_context(self.contexts[0])
682689

683-
if self._model.training:
690+
# only zero grad at the start of each accumulation
691+
if self._model.training and (
692+
self._is_first_step or self._accumulation_step_count == 0
693+
):
684694
with record_function("## zero_grad ##"):
685695
self._optimizer.zero_grad()
686696

@@ -696,35 +706,51 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
696706
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
697707
self.enqueue_batch(dataloader_iter)
698708

699-
# forward
700-
with record_function(f"## forward {self.contexts[0].index} ##"):
701-
self._state = PipelineState.CALL_FWD
702-
losses, output = self._model_fwd(self.batches[0])
709+
# NOTE: the first step cannot be no_sync when DDP.static_graph = True,
710+
# due to an unfortunate restriction in torch.distributed
711+
no_sync = not self._is_first_step and (
712+
self._model.training
713+
and self._accumulation_step_count + 1 < self._accumulation_steps
714+
)
715+
with (
716+
self._model._dmp_wrapped_module.no_sync() # pyre-ignore[16]
717+
if no_sync
718+
else contextlib.nullcontext()
719+
):
720+
# forward
721+
with record_function(f"## forward {self.contexts[0].index} ##"):
722+
self._state = PipelineState.CALL_FWD
723+
losses, output = self._model_fwd(self.batches[0])
703724

704-
if self._enqueue_batch_after_forward:
705-
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
706-
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
707-
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
708-
self.enqueue_batch(dataloader_iter)
725+
if self._enqueue_batch_after_forward:
726+
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
727+
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
728+
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
729+
self.enqueue_batch(dataloader_iter)
709730

710-
if len(self.batches) >= 2:
711-
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
712-
self.wait_sparse_data_dist(self.contexts[1])
731+
if len(self.batches) >= 2:
732+
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
733+
self.wait_sparse_data_dist(self.contexts[1])
713734

714-
if self._model.training:
715735
# backward
716-
self._state = PipelineState.CALL_BWD
717-
self._backward(losses)
718-
719-
self.sync_embeddings(
720-
self._model,
721-
self._dmp_collection_sync_interval_batches,
722-
self.contexts[0],
723-
)
736+
if self._model.training:
737+
self._state = PipelineState.CALL_BWD
738+
self._backward(losses)
724739

725-
# update
726-
with record_function(f"## optimizer {self.contexts[0].index} ##"):
727-
self._optimizer.step()
740+
if no_sync:
741+
self._accumulation_step_count += 1
742+
else:
743+
self.sync_embeddings(
744+
self._model,
745+
self._dmp_collection_sync_interval_batches,
746+
self.contexts[0],
747+
)
748+
# update
749+
with record_function(f"## optimizer {self.contexts[0].index} ##"):
750+
self._optimizer.step()
751+
self._accumulation_step_count = 0
752+
if self._is_first_step:
753+
self._is_first_step = False
728754

729755
self.dequeue_batch()
730756
return output

0 commit comments

Comments
 (0)