3131
3232import torch
3333from torch .autograd .profiler import record_function
34+ from torch .nn .parallel import DistributedDataParallel
3435from torchrec .distributed .dist_data import KJTAllToAllTensorsAwaitable
3536from torchrec .distributed .model_parallel import ShardedModule
3637from torchrec .distributed .train_pipeline .pipeline_context import (
@@ -418,6 +419,10 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
418419 (applicable to 2D sharding only)
419420 if set and DMP collection is enabled for 2D sharding,
420421 sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
422+ gradient_accumulation_steps (int): number of steps to accumulate gradients before
423+ performing backward pass and optimizer update. Default is 1 (no accumulation).
424+ should_scale_losses (bool): whether to scale accumulated losses by
425+ gradient_accumulation_steps. Default is False.
421426 """
422427
423428 # The PipelinedForward class that is used in _rewrite_model
@@ -438,6 +443,8 @@ def __init__(
438443 ] = None ,
439444 dmp_collection_sync_interval_batches : Optional [int ] = 1 ,
440445 enqueue_batch_after_forward : bool = False ,
446+ gradient_accumulation_steps : int = 1 ,
447+ should_scale_losses : bool = False ,
441448 ) -> None :
442449 self ._model = model
443450 self ._optimizer = optimizer
@@ -503,6 +510,10 @@ def __init__(
503510 dmp_collection_sync_interval_batches
504511 )
505512
513+ self ._accumulation_steps : int = gradient_accumulation_steps
514+ self ._accumulation_step_count : int = 0
515+ self ._should_scale_losses : bool = should_scale_losses
516+
506517 if self ._dmp_collection_sync_interval_batches is not None :
507518 logger .info (
508519 f"{ self .__class__ .__name__ } : [Sparse 2D] DMP collection will sync every "
@@ -680,7 +691,8 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
680691 # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
681692 self ._set_module_context (self .contexts [0 ])
682693
683- if self ._model .training :
694+ # only zero grad at the start of each accumulation
695+ if self ._model .training and self ._accumulation_step_count == 0 :
684696 with record_function ("## zero_grad ##" ):
685697 self ._optimizer .zero_grad ()
686698
@@ -696,35 +708,50 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
696708 # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
697709 self .enqueue_batch (dataloader_iter )
698710
699- # forward
700- with record_function (f"## forward { self .contexts [0 ].index } ##" ):
701- self ._state = PipelineState .CALL_FWD
702- losses , output = self ._model_fwd (self .batches [0 ])
711+ no_sync = (
712+ self ._model .training
713+ and len (self .batches ) >= 2
714+ and self ._accumulation_step_count + 1 < self ._accumulation_steps
715+ )
716+ with (
717+ self ._model ._dmp_wrapped_module .no_sync () # pyre-ignore[16]
718+ if no_sync
719+ else contextlib .nullcontext ()
720+ ):
721+ # forward
722+ with record_function (f"## forward { self .contexts [0 ].index } ##" ):
723+ self ._state = PipelineState .CALL_FWD
724+ losses , output = self ._model_fwd (self .batches [0 ])
703725
704- if self ._enqueue_batch_after_forward :
705- # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
706- # Start this step after the forward of batch i, so that the H2D copy doesn't compete
707- # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
708- self .enqueue_batch (dataloader_iter )
726+ if self ._enqueue_batch_after_forward :
727+ # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
728+ # Start this step after the forward of batch i, so that the H2D copy doesn't compete
729+ # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
730+ self .enqueue_batch (dataloader_iter )
709731
710- if len (self .batches ) >= 2 :
711- # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
712- self .wait_sparse_data_dist (self .contexts [1 ])
732+ if len (self .batches ) >= 2 :
733+ # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
734+ self .wait_sparse_data_dist (self .contexts [1 ])
713735
714- if self ._model .training :
715736 # backward
716- self ._state = PipelineState .CALL_BWD
717- self ._backward (losses )
718-
719- self .sync_embeddings (
720- self ._model ,
721- self ._dmp_collection_sync_interval_batches ,
722- self .contexts [0 ],
723- )
737+ if self ._model .training :
738+ self ._state = PipelineState .CALL_BWD
739+ if self ._should_scale_losses :
740+ losses = losses / self ._accumulation_steps
741+ self ._backward (losses )
724742
725- # update
726- with record_function (f"## optimizer { self .contexts [0 ].index } ##" ):
727- self ._optimizer .step ()
743+ if no_sync :
744+ self ._accumulation_step_count += 1
745+ else :
746+ self .sync_embeddings (
747+ self ._model ,
748+ self ._dmp_collection_sync_interval_batches ,
749+ self .contexts [0 ],
750+ )
751+ # update
752+ with record_function (f"## optimizer { self .contexts [0 ].index } ##" ):
753+ self ._optimizer .step ()
754+ self ._accumulation_step_count = 0
728755
729756 self .dequeue_batch ()
730757 return output
0 commit comments