@@ -418,6 +418,10 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
418418 (applicable to 2D sharding only)
419419 if set and DMP collection is enabled for 2D sharding,
420420 sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
421+ gradient_accumulation_steps (int): number of steps to accumulate gradients before
422+ performing backward pass and optimizer update. Default is 1 (no accumulation).
423+ should_scale_losses (bool): whether to scale accumulated losses by
424+ gradient_accumulation_steps. Default is False.
421425 """
422426
423427 # The PipelinedForward class that is used in _rewrite_model
@@ -438,6 +442,8 @@ def __init__(
438442 ] = None ,
439443 dmp_collection_sync_interval_batches : Optional [int ] = 1 ,
440444 enqueue_batch_after_forward : bool = False ,
445+ gradient_accumulation_steps : int = 1 ,
446+ should_scale_losses : bool = False ,
441447 ) -> None :
442448 self ._model = model
443449 self ._optimizer = optimizer
@@ -503,6 +509,11 @@ def __init__(
503509 dmp_collection_sync_interval_batches
504510 )
505511
512+ self ._accumulation_steps : int = gradient_accumulation_steps
513+ self ._accumulation_step_count : int = gradient_accumulation_steps - 1
514+ self ._should_scale_losses : bool = should_scale_losses
515+ self ._is_first_step : bool = True
516+
506517 if self ._dmp_collection_sync_interval_batches is not None :
507518 logger .info (
508519 f"{ self .__class__ .__name__ } : [Sparse 2D] DMP collection will sync every "
@@ -680,7 +691,8 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
680691 # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
681692 self ._set_module_context (self .contexts [0 ])
682693
683- if self ._model .training :
694+ # only zero grad at the start of each accumulation
695+ if self ._model .training and self ._accumulation_step_count == 0 :
684696 with record_function ("## zero_grad ##" ):
685697 self ._optimizer .zero_grad ()
686698
@@ -696,35 +708,51 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
696708 # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
697709 self .enqueue_batch (dataloader_iter )
698710
699- # forward
700- with record_function (f"## forward { self .contexts [0 ].index } ##" ):
701- self ._state = PipelineState .CALL_FWD
702- losses , output = self ._model_fwd (self .batches [0 ])
711+ no_sync = (
712+ self ._model .training
713+ and self ._accumulation_step_count + 1 < self ._accumulation_steps
714+ ) or not self ._is_first_step # the first step cannot be no_sync, due to a "feature" in distributed.py
715+ with (
716+ self ._model ._dmp_wrapped_module .no_sync () # pyre-ignore[16]
717+ if no_sync
718+ else contextlib .nullcontext ()
719+ ):
720+ # forward
721+ with record_function (f"## forward { self .contexts [0 ].index } ##" ):
722+ self ._state = PipelineState .CALL_FWD
723+ losses , output = self ._model_fwd (self .batches [0 ])
703724
704- if self ._enqueue_batch_after_forward :
705- # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
706- # Start this step after the forward of batch i, so that the H2D copy doesn't compete
707- # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
708- self .enqueue_batch (dataloader_iter )
725+ if self ._enqueue_batch_after_forward :
726+ # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
727+ # Start this step after the forward of batch i, so that the H2D copy doesn't compete
728+ # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
729+ self .enqueue_batch (dataloader_iter )
709730
710- if len (self .batches ) >= 2 :
711- # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
712- self .wait_sparse_data_dist (self .contexts [1 ])
731+ if len (self .batches ) >= 2 :
732+ # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
733+ self .wait_sparse_data_dist (self .contexts [1 ])
713734
714- if self ._model .training :
715735 # backward
716- self ._state = PipelineState .CALL_BWD
717- self ._backward (losses )
718-
719- self .sync_embeddings (
720- self ._model ,
721- self ._dmp_collection_sync_interval_batches ,
722- self .contexts [0 ],
723- )
736+ if self ._model .training :
737+ self ._state = PipelineState .CALL_BWD
738+ if self ._should_scale_losses and not self ._is_first_step :
739+ losses = losses / self ._accumulation_steps
740+ self ._backward (losses )
724741
725- # update
726- with record_function (f"## optimizer { self .contexts [0 ].index } ##" ):
727- self ._optimizer .step ()
742+ if no_sync :
743+ self ._accumulation_step_count += 1
744+ else :
745+ self .sync_embeddings (
746+ self ._model ,
747+ self ._dmp_collection_sync_interval_batches ,
748+ self .contexts [0 ],
749+ )
750+ # update
751+ with record_function (f"## optimizer { self .contexts [0 ].index } ##" ):
752+ self ._optimizer .step ()
753+ self ._accumulation_step_count = 0
754+ if self ._is_first_step :
755+ self ._is_first_step = False
728756
729757 self .dequeue_batch ()
730758 return output
0 commit comments