Skip to content

Commit f3c83ae

Browse files
Zheng-Yong (Arsa) Angfacebook-github-bot
authored andcommitted
enable gradient accumulation in SDD (#3462)
Summary: Context: gradient accumulation is still not available in TorchRec, especially for the SDD pipeline which is being used by many recommendation models. This diff: implements a UI enabling gradient accumulation for SDD. Differential Revision: D84915986
1 parent 24bb848 commit f3c83ae

File tree

1 file changed

+53
-25
lines changed

1 file changed

+53
-25
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,10 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
418418
(applicable to 2D sharding only)
419419
if set and DMP collection is enabled for 2D sharding,
420420
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
421+
gradient_accumulation_steps (int): number of steps to accumulate gradients before
422+
performing backward pass and optimizer update. Default is 1 (no accumulation).
423+
should_scale_losses (bool): whether to scale accumulated losses by
424+
gradient_accumulation_steps. Default is False.
421425
"""
422426

423427
# The PipelinedForward class that is used in _rewrite_model
@@ -438,6 +442,8 @@ def __init__(
438442
] = None,
439443
dmp_collection_sync_interval_batches: Optional[int] = 1,
440444
enqueue_batch_after_forward: bool = False,
445+
gradient_accumulation_steps: int = 1,
446+
should_scale_losses: bool = False,
441447
) -> None:
442448
self._model = model
443449
self._optimizer = optimizer
@@ -503,6 +509,11 @@ def __init__(
503509
dmp_collection_sync_interval_batches
504510
)
505511

512+
self._accumulation_steps: int = gradient_accumulation_steps
513+
self._accumulation_step_count: int = gradient_accumulation_steps - 1
514+
self._should_scale_losses: bool = should_scale_losses
515+
self._is_first_step: bool = True
516+
506517
if self._dmp_collection_sync_interval_batches is not None:
507518
logger.info(
508519
f"{self.__class__.__name__}: [Sparse 2D] DMP collection will sync every "
@@ -680,7 +691,8 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
680691
# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
681692
self._set_module_context(self.contexts[0])
682693

683-
if self._model.training:
694+
# only zero grad at the start of each accumulation
695+
if self._model.training and self._accumulation_step_count == 0:
684696
with record_function("## zero_grad ##"):
685697
self._optimizer.zero_grad()
686698

@@ -696,35 +708,51 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
696708
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
697709
self.enqueue_batch(dataloader_iter)
698710

699-
# forward
700-
with record_function(f"## forward {self.contexts[0].index} ##"):
701-
self._state = PipelineState.CALL_FWD
702-
losses, output = self._model_fwd(self.batches[0])
711+
no_sync = (
712+
self._model.training
713+
and self._accumulation_step_count + 1 < self._accumulation_steps
714+
) or not self._is_first_step # the first step cannot be no_sync, due to a "feature" in distributed.py
715+
with (
716+
self._model._dmp_wrapped_module.no_sync() # pyre-ignore[16]
717+
if no_sync
718+
else contextlib.nullcontext()
719+
):
720+
# forward
721+
with record_function(f"## forward {self.contexts[0].index} ##"):
722+
self._state = PipelineState.CALL_FWD
723+
losses, output = self._model_fwd(self.batches[0])
703724

704-
if self._enqueue_batch_after_forward:
705-
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
706-
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
707-
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
708-
self.enqueue_batch(dataloader_iter)
725+
if self._enqueue_batch_after_forward:
726+
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
727+
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
728+
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
729+
self.enqueue_batch(dataloader_iter)
709730

710-
if len(self.batches) >= 2:
711-
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
712-
self.wait_sparse_data_dist(self.contexts[1])
731+
if len(self.batches) >= 2:
732+
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
733+
self.wait_sparse_data_dist(self.contexts[1])
713734

714-
if self._model.training:
715735
# backward
716-
self._state = PipelineState.CALL_BWD
717-
self._backward(losses)
718-
719-
self.sync_embeddings(
720-
self._model,
721-
self._dmp_collection_sync_interval_batches,
722-
self.contexts[0],
723-
)
736+
if self._model.training:
737+
self._state = PipelineState.CALL_BWD
738+
if self._should_scale_losses and not self._is_first_step:
739+
losses = losses / self._accumulation_steps
740+
self._backward(losses)
724741

725-
# update
726-
with record_function(f"## optimizer {self.contexts[0].index} ##"):
727-
self._optimizer.step()
742+
if no_sync:
743+
self._accumulation_step_count += 1
744+
else:
745+
self.sync_embeddings(
746+
self._model,
747+
self._dmp_collection_sync_interval_batches,
748+
self.contexts[0],
749+
)
750+
# update
751+
with record_function(f"## optimizer {self.contexts[0].index} ##"):
752+
self._optimizer.step()
753+
self._accumulation_step_count = 0
754+
if self._is_first_step:
755+
self._is_first_step = False
728756

729757
self.dequeue_batch()
730758
return output

0 commit comments

Comments
 (0)