@@ -418,6 +418,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
418418 (applicable to 2D sharding only)
419419 if set and DMP collection is enabled for 2D sharding,
420420 sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
421+ gradient_accumulation_steps (int): number of steps to accumulate gradients before
422+ performing backward pass and optimizer update. Default is 1 (no accumulation).
421423 """
422424
423425 # The PipelinedForward class that is used in _rewrite_model
@@ -438,6 +440,7 @@ def __init__(
438440 ] = None ,
439441 dmp_collection_sync_interval_batches : Optional [int ] = 1 ,
440442 enqueue_batch_after_forward : bool = False ,
443+ gradient_accumulation_steps : int = 1 ,
441444 ) -> None :
442445 self ._model = model
443446 self ._optimizer = optimizer
@@ -503,6 +506,10 @@ def __init__(
503506 dmp_collection_sync_interval_batches
504507 )
505508
509+ self ._accumulation_steps : int = gradient_accumulation_steps
510+ self ._accumulation_step_count : int = gradient_accumulation_steps - 1
511+ self ._is_first_step : bool = True
512+
506513 if self ._dmp_collection_sync_interval_batches is not None :
507514 logger .info (
508515 f"{ self .__class__ .__name__ } : [Sparse 2D] DMP collection will sync every "
@@ -680,7 +687,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
680687 # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
681688 self ._set_module_context (self .contexts [0 ])
682689
683- if self ._model .training :
690+ # only zero grad at the start of each accumulation
691+ if self ._model .training and (
692+ self ._is_first_step or self ._accumulation_step_count == 0
693+ ):
684694 with record_function ("## zero_grad ##" ):
685695 self ._optimizer .zero_grad ()
686696
@@ -696,35 +706,51 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
696706 # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
697707 self .enqueue_batch (dataloader_iter )
698708
699- # forward
700- with record_function (f"## forward { self .contexts [0 ].index } ##" ):
701- self ._state = PipelineState .CALL_FWD
702- losses , output = self ._model_fwd (self .batches [0 ])
709+ # NOTE: the first step cannot be no_sync when DDP.static_graph = True,
710+ # due to an unfortunate restriction in torch.distributed
711+ no_sync = not self ._is_first_step and (
712+ self ._model .training
713+ and self ._accumulation_step_count + 1 < self ._accumulation_steps
714+ )
715+ with (
716+ self ._model ._dmp_wrapped_module .no_sync () # pyre-ignore[16]
717+ if no_sync
718+ else contextlib .nullcontext ()
719+ ):
720+ # forward
721+ with record_function (f"## forward { self .contexts [0 ].index } ##" ):
722+ self ._state = PipelineState .CALL_FWD
723+ losses , output = self ._model_fwd (self .batches [0 ])
703724
704- if self ._enqueue_batch_after_forward :
705- # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
706- # Start this step after the forward of batch i, so that the H2D copy doesn't compete
707- # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
708- self .enqueue_batch (dataloader_iter )
725+ if self ._enqueue_batch_after_forward :
726+ # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
727+ # Start this step after the forward of batch i, so that the H2D copy doesn't compete
728+ # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
729+ self .enqueue_batch (dataloader_iter )
709730
710- if len (self .batches ) >= 2 :
711- # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
712- self .wait_sparse_data_dist (self .contexts [1 ])
731+ if len (self .batches ) >= 2 :
732+ # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
733+ self .wait_sparse_data_dist (self .contexts [1 ])
713734
714- if self ._model .training :
715735 # backward
716- self ._state = PipelineState .CALL_BWD
717- self ._backward (losses )
718-
719- self .sync_embeddings (
720- self ._model ,
721- self ._dmp_collection_sync_interval_batches ,
722- self .contexts [0 ],
723- )
736+ if self ._model .training :
737+ self ._state = PipelineState .CALL_BWD
738+ self ._backward (losses )
724739
725- # update
726- with record_function (f"## optimizer { self .contexts [0 ].index } ##" ):
727- self ._optimizer .step ()
740+ if no_sync :
741+ self ._accumulation_step_count += 1
742+ else :
743+ self .sync_embeddings (
744+ self ._model ,
745+ self ._dmp_collection_sync_interval_batches ,
746+ self .contexts [0 ],
747+ )
748+ # update
749+ with record_function (f"## optimizer { self .contexts [0 ].index } ##" ):
750+ self ._optimizer .step ()
751+ self ._accumulation_step_count = 0
752+ if self ._is_first_step :
753+ self ._is_first_step = False
728754
729755 self .dequeue_batch ()
730756 return output
0 commit comments