File tree Expand file tree Collapse file tree 3 files changed +15
-6
lines changed
Expand file tree Collapse file tree 3 files changed +15
-6
lines changed Original file line number Diff line number Diff line change 66 branches :
77 - ' dev'
88 - ' main'
9+ push :
10+ branches :
11+ - ' dev'
912
1013jobs :
1114 build-archive-wheel :
1215
1316 uses : OpenBMB/BMTrain/.github/workflows/build_whl.yml@main
14- secrets :
15- DOCKERHUB_TOKEN : ${{ secrets.DOCKERHUB_TOKEN }}
16- DOCKERHUB_USERNAME : ${{ secrets.DOCKERHUB_USERNAME }}
17+ secrets : inherit
1718
1819 publish :
1920 needs : build-archive-wheel
Original file line number Diff line number Diff line change @@ -52,6 +52,7 @@ def __init__(self,
5252 loss_scale_steps : int = 1024 ,
5353 min_loss_scale = 1 ,
5454 max_loss_scale = float ("inf" ),
55+ grad_scale : Optional [int ] = None ,
5556 ):
5657 if loss_scale is not None :
5758 self .loss_scale = loss_scale
@@ -64,6 +65,9 @@ def __init__(self,
6465 self .loss_scale_steps = loss_scale_steps
6566 self .min_loss_scale = min_loss_scale
6667 self .max_loss_scale = max_loss_scale
68+ if grad_scale is None :
69+ grad_scale = config ['zero_size' ]
70+ self .grad_scale = grad_scale
6771
6872 self .optimizers = []
6973 self .lr_schedulers = []
@@ -85,7 +89,7 @@ def add_optimizer(
8589
8690 def scale_loss (self , loss : torch .Tensor ) -> torch .Tensor :
8791
88- return loss * (self .loss_scale / ( config [ 'world_size' ] // ( config [ 'tp_size' ] * config [ 'pipe_size' ])) ) # loss scale
92+ return loss * ( self .loss_scale / self . grad_scale ) # loss scale
8993
9094 def backward (self , loss : torch .Tensor ):
9195 """
Original file line number Diff line number Diff line change 22from . import distributed , nccl
33from .global_var import config
44import warnings
5+ from typing import Optional
56
67def synchronize ():
78 """
@@ -24,14 +25,17 @@ def wait_loader():
2425 config ['calc_stream' ].record_event (config ['load_event' ])
2526
2627
27- def sum_loss (loss : torch .Tensor ):
28+ def sum_loss (loss : torch .Tensor , comm : Optional [ nccl . Communicator ] = None ):
2829 """
2930 Sum the loss across all workers.
3031
3132 This is a helper function to reduce the loss across all workers.
3233 """
34+ if comm is None :
35+ comm = config ['comm' ]
3336 warnings .warn ("bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead." , DeprecationWarning )
34- return distributed .all_reduce (loss , "sum" ) / config ['world_size' ]
37+
38+ return distributed .all_reduce (loss , "avg" , comm )
3539
3640def gather_result (result : torch .Tensor ):
3741 warnings .warn ("bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead." , DeprecationWarning )
You can’t perform that action at this time.
0 commit comments