3131
3232import torch
3333from torch .autograd .profiler import record_function
34+ from torch .nn .parallel import DistributedDataParallel
3435from torchrec .distributed .dist_data import KJTAllToAllTensorsAwaitable
3536from torchrec .distributed .model_parallel import ShardedModule
3637from torchrec .distributed .train_pipeline .pipeline_context import (
@@ -418,6 +419,10 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
418419 (applicable to 2D sharding only)
419420 if set and DMP collection is enabled for 2D sharding,
420421 sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
422+ gradient_accumulation_steps (int): number of steps to accumulate gradients before
423+ performing backward pass and optimizer update. Default is 1 (no accumulation).
424+ should_scale_losses (bool): whether to scale accumulated losses by
425+ gradient_accumulation_steps. Default is False.
421426 """
422427
423428 # The PipelinedForward class that is used in _rewrite_model
@@ -438,6 +443,8 @@ def __init__(
438443 ] = None ,
439444 dmp_collection_sync_interval_batches : Optional [int ] = 1 ,
440445 enqueue_batch_after_forward : bool = False ,
446+ gradient_accumulation_steps : int = 1 ,
447+ should_scale_losses : bool = False ,
441448 ) -> None :
442449 self ._model = model
443450 self ._optimizer = optimizer
@@ -503,6 +510,10 @@ def __init__(
503510 dmp_collection_sync_interval_batches
504511 )
505512
513+ self ._accumulation_steps : int = gradient_accumulation_steps
514+ self ._accumulation_step_count : int = 0
515+ self ._should_scale_losses : bool = should_scale_losses
516+
506517 if self ._dmp_collection_sync_interval_batches is not None :
507518 logger .info (
508519 f"{ self .__class__ .__name__ } : [Sparse 2D] DMP collection will sync every "
@@ -680,7 +691,8 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
680691 # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
681692 self ._set_module_context (self .contexts [0 ])
682693
683- if self ._model .training :
694+ # only zero grad at the start of each accumulation
695+ if self ._model .training and self ._accumulation_step_count == 0 :
684696 with record_function ("## zero_grad ##" ):
685697 self ._optimizer .zero_grad ()
686698
@@ -696,35 +708,49 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
696708 # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
697709 self .enqueue_batch (dataloader_iter )
698710
699- # forward
700- with record_function (f"## forward { self .contexts [0 ].index } ##" ):
701- self ._state = PipelineState .CALL_FWD
702- losses , output = self ._model_fwd (self .batches [0 ])
711+ no_sync = (
712+ self ._model .training
713+ and self ._accumulation_step_count + 1 < self ._accumulation_steps
714+ )
715+ with (
716+ self ._model ._dmp_wrapped_module .no_sync () # pyre-ignore[16]
717+ if no_sync
718+ else contextlib .nullcontext ()
719+ ):
720+ # forward
721+ with record_function (f"## forward { self .contexts [0 ].index } ##" ):
722+ self ._state = PipelineState .CALL_FWD
723+ losses , output = self ._model_fwd (self .batches [0 ])
703724
704- if self ._enqueue_batch_after_forward :
705- # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
706- # Start this step after the forward of batch i, so that the H2D copy doesn't compete
707- # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
708- self .enqueue_batch (dataloader_iter )
725+ if self ._enqueue_batch_after_forward :
726+ # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
727+ # Start this step after the forward of batch i, so that the H2D copy doesn't compete
728+ # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
729+ self .enqueue_batch (dataloader_iter )
709730
710- if len (self .batches ) >= 2 :
711- # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
712- self .wait_sparse_data_dist (self .contexts [1 ])
731+ if len (self .batches ) >= 2 :
732+ # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
733+ self .wait_sparse_data_dist (self .contexts [1 ])
713734
714- if self ._model .training :
715735 # backward
716- self ._state = PipelineState .CALL_BWD
717- self ._backward (losses )
718-
719- self .sync_embeddings (
720- self ._model ,
721- self ._dmp_collection_sync_interval_batches ,
722- self .contexts [0 ],
723- )
736+ if self ._model .training :
737+ self ._state = PipelineState .CALL_BWD
738+ if self ._should_scale_losses :
739+ losses = losses / self ._accumulation_steps
740+ self ._backward (losses )
724741
725- # update
726- with record_function (f"## optimizer { self .contexts [0 ].index } ##" ):
727- self ._optimizer .step ()
742+ if no_sync :
743+ self ._accumulation_step_count += 1
744+ else :
745+ self .sync_embeddings (
746+ self ._model ,
747+ self ._dmp_collection_sync_interval_batches ,
748+ self .contexts [0 ],
749+ )
750+ # update
751+ with record_function (f"## optimizer { self .contexts [0 ].index } ##" ):
752+ self ._optimizer .step ()
753+ self ._accumulation_step_count = 0
728754
729755 self .dequeue_batch ()
730756 return output
0 commit comments