@@ -420,6 +420,10 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
420420 (applicable to 2D sharding only)
421421 if set and DMP collection is enabled for 2D sharding,
422422 sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
423+ gradient_accumulation_steps (int): number of steps to accumulate gradients before
424+ performing backward pass and optimizer update. Default is 1 (no accumulation).
425+ gradient_accumulation_start_step (Optional[int]): the step to start gradient accumulation.
426+ if not set, gradient accumulation will start from the second step.
423427 """
424428
425429 # The PipelinedForward class that is used in _rewrite_model
@@ -440,6 +444,8 @@ def __init__(
440444 ] = None ,
441445 dmp_collection_sync_interval_batches : Optional [int ] = 1 ,
442446 enqueue_batch_after_forward : bool = False ,
447+ gradient_accumulation_steps : int = 1 ,
448+ gradient_accumulation_start_step : Optional [int ] = None ,
443449 ) -> None :
444450 self ._model = model
445451 self ._optimizer = optimizer
@@ -505,6 +511,14 @@ def __init__(
505511 dmp_collection_sync_interval_batches
506512 )
507513
514+ # NOTE: the first step cannot be no_sync when DDP.static_graph = True, due to an unfortunate restriction in torch.distributed.
515+ # thus, unless explicitly set, we set gradient_accumulation_start_step to 1 if gradient_accumulation_steps > 1.
516+ self ._pre_accumulation_steps : int = gradient_accumulation_start_step or (
517+ 1 if gradient_accumulation_steps > 1 else 0
518+ )
519+ self ._accumulation_steps : int = gradient_accumulation_steps
520+ self ._accumulation_step_count : int = 0
521+
508522 if self ._dmp_collection_sync_interval_batches is not None :
509523 logger .info (
510524 f"{ self .__class__ .__name__ } : [Sparse 2D] DMP collection will sync every "
@@ -682,7 +696,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
682696 # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
683697 self ._set_module_context (self .contexts [0 ])
684698
685- if self ._model .training :
699+ # only zero grad at the start of each accumulation
700+ if self ._model .training and (
701+ self ._pre_accumulation_steps > 0 or self ._accumulation_step_count == 0
702+ ):
686703 with record_function ("## zero_grad ##" ):
687704 self ._optimizer .zero_grad ()
688705
@@ -698,35 +715,50 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
698715 # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
699716 self .enqueue_batch (dataloader_iter )
700717
701- # forward
702- with record_function (f"## forward { self .contexts [0 ].index } ##" ):
703- self ._state = PipelineState .CALL_FWD
704- losses , output = self ._model_fwd (self .batches [0 ])
718+ no_sync = self ._pre_accumulation_steps <= 0 and (
719+ self ._model .training
720+ and self ._accumulation_step_count + 1 < self ._accumulation_steps
721+ )
722+ with (
723+ self ._model ._dmp_wrapped_module .no_sync () # pyre-ignore[16]
724+ if no_sync
725+ else contextlib .nullcontext ()
726+ ):
727+ # forward
728+ with record_function (f"## forward { self .contexts [0 ].index } ##" ):
729+ self ._state = PipelineState .CALL_FWD
730+ losses , output = self ._model_fwd (self .batches [0 ])
705731
706- if self ._enqueue_batch_after_forward :
707- # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
708- # Start this step after the forward of batch i, so that the H2D copy doesn't compete
709- # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
710- self .enqueue_batch (dataloader_iter )
732+ if self ._enqueue_batch_after_forward :
733+ # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
734+ # Start this step after the forward of batch i, so that the H2D copy doesn't compete
735+ # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
736+ self .enqueue_batch (dataloader_iter )
711737
712- if len (self .batches ) >= 2 :
713- # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
714- self .wait_sparse_data_dist (self .contexts [1 ])
738+ if len (self .batches ) >= 2 :
739+ # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
740+ self .wait_sparse_data_dist (self .contexts [1 ])
715741
716- if self ._model .training :
717742 # backward
718- self ._state = PipelineState .CALL_BWD
719- self ._backward (losses )
720-
721- self .sync_embeddings (
722- self ._model ,
723- self ._dmp_collection_sync_interval_batches ,
724- self .contexts [0 ],
725- )
743+ if self ._model .training :
744+ self ._state = PipelineState .CALL_BWD
745+ self ._backward (losses )
726746
727- # update
728- with record_function (f"## optimizer { self .contexts [0 ].index } ##" ):
729- self ._optimizer .step ()
747+ if no_sync :
748+ self ._accumulation_step_count += 1
749+ else :
750+ self .sync_embeddings (
751+ self ._model ,
752+ self ._dmp_collection_sync_interval_batches ,
753+ self .contexts [0 ],
754+ )
755+ # update
756+ with record_function (f"## optimizer { self .contexts [0 ].index } ##" ):
757+ self ._optimizer .step ()
758+ if self ._pre_accumulation_steps > 0 :
759+ self ._pre_accumulation_steps -= 1
760+ else :
761+ self ._accumulation_step_count = 0
730762
731763 self .dequeue_batch ()
732764 return output
0 commit comments