Skip to content

Commit a4c8aeb

Browse files
Zheng-Yong (Arsa) Angfacebook-github-bot
authored andcommitted
enable gradient accumulation in SDD
Summary: Context: gradient accumulation is still not available in TorchRec, especially for the SDD pipeline which is being used by many recommendation models. This diff: implements a UI enabling gradient accumulation for SDD. Differential Revision: D84915986
1 parent 24bb848 commit a4c8aeb

File tree

1 file changed

+52
-25
lines changed

1 file changed

+52
-25
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import torch
3333
from torch.autograd.profiler import record_function
34+
from torch.nn.parallel import DistributedDataParallel
3435
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
3536
from torchrec.distributed.model_parallel import ShardedModule
3637
from torchrec.distributed.train_pipeline.pipeline_context import (
@@ -418,6 +419,10 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
418419
(applicable to 2D sharding only)
419420
if set and DMP collection is enabled for 2D sharding,
420421
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
422+
gradient_accumulation_steps (int): number of steps to accumulate gradients before
423+
performing backward pass and optimizer update. Default is 1 (no accumulation).
424+
should_scale_losses (bool): whether to scale accumulated losses by
425+
gradient_accumulation_steps. Default is False.
421426
"""
422427

423428
# The PipelinedForward class that is used in _rewrite_model
@@ -438,6 +443,8 @@ def __init__(
438443
] = None,
439444
dmp_collection_sync_interval_batches: Optional[int] = 1,
440445
enqueue_batch_after_forward: bool = False,
446+
gradient_accumulation_steps: int = 1,
447+
should_scale_losses: bool = False,
441448
) -> None:
442449
self._model = model
443450
self._optimizer = optimizer
@@ -503,6 +510,10 @@ def __init__(
503510
dmp_collection_sync_interval_batches
504511
)
505512

513+
self._accumulation_steps: int = gradient_accumulation_steps
514+
self._accumulation_step_count: int = 0
515+
self._should_scale_losses: bool = should_scale_losses
516+
506517
if self._dmp_collection_sync_interval_batches is not None:
507518
logger.info(
508519
f"{self.__class__.__name__}: [Sparse 2D] DMP collection will sync every "
@@ -680,7 +691,8 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
680691
# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
681692
self._set_module_context(self.contexts[0])
682693

683-
if self._model.training:
694+
# only zero grad at the start of each accumulation
695+
if self._model.training and self._accumulation_step_count == 0:
684696
with record_function("## zero_grad ##"):
685697
self._optimizer.zero_grad()
686698

@@ -696,35 +708,50 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
696708
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
697709
self.enqueue_batch(dataloader_iter)
698710

699-
# forward
700-
with record_function(f"## forward {self.contexts[0].index} ##"):
701-
self._state = PipelineState.CALL_FWD
702-
losses, output = self._model_fwd(self.batches[0])
711+
no_sync = (
712+
self._model.training
713+
and len(self.batches) >= 2
714+
and self._accumulation_step_count + 1 < self._accumulation_steps
715+
)
716+
with (
717+
self._model._dmp_wrapped_module.no_sync() # pyre-ignore[16]
718+
if no_sync
719+
else contextlib.nullcontext()
720+
):
721+
# forward
722+
with record_function(f"## forward {self.contexts[0].index} ##"):
723+
self._state = PipelineState.CALL_FWD
724+
losses, output = self._model_fwd(self.batches[0])
703725

704-
if self._enqueue_batch_after_forward:
705-
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
706-
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
707-
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
708-
self.enqueue_batch(dataloader_iter)
726+
if self._enqueue_batch_after_forward:
727+
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
728+
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
729+
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
730+
self.enqueue_batch(dataloader_iter)
709731

710-
if len(self.batches) >= 2:
711-
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
712-
self.wait_sparse_data_dist(self.contexts[1])
732+
if len(self.batches) >= 2:
733+
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
734+
self.wait_sparse_data_dist(self.contexts[1])
713735

714-
if self._model.training:
715736
# backward
716-
self._state = PipelineState.CALL_BWD
717-
self._backward(losses)
718-
719-
self.sync_embeddings(
720-
self._model,
721-
self._dmp_collection_sync_interval_batches,
722-
self.contexts[0],
723-
)
737+
if self._model.training:
738+
self._state = PipelineState.CALL_BWD
739+
if self._should_scale_losses:
740+
losses = losses / self._accumulation_steps
741+
self._backward(losses)
724742

725-
# update
726-
with record_function(f"## optimizer {self.contexts[0].index} ##"):
727-
self._optimizer.step()
743+
if no_sync:
744+
self._accumulation_step_count += 1
745+
else:
746+
self.sync_embeddings(
747+
self._model,
748+
self._dmp_collection_sync_interval_batches,
749+
self.contexts[0],
750+
)
751+
# update
752+
with record_function(f"## optimizer {self.contexts[0].index} ##"):
753+
self._optimizer.step()
754+
self._accumulation_step_count = 0
728755

729756
self.dequeue_batch()
730757
return output

0 commit comments

Comments
 (0)