@@ -1042,8 +1042,7 @@ def _backward(_loss):
10421042 loss = _forward ()
10431043 _backward (loss )
10441044
1045- if not args .distributed :
1046- losses_m .update (loss .item () * accum_steps , input .size (0 ))
1045+ losses_m .update (loss .item () * accum_steps , input .size (0 ))
10471046 update_sample_count += input .size (0 )
10481047
10491048 if not need_update :
@@ -1068,16 +1067,18 @@ def _backward(_loss):
10681067 lrl = [param_group ['lr' ] for param_group in optimizer .param_groups ]
10691068 lr = sum (lrl ) / len (lrl )
10701069
1070+ loss_avg , loss_now = losses_m .avg , losses_m .val
10711071 if args .distributed :
1072- reduced_loss = utils .reduce_tensor (loss .data , args .world_size )
1073- losses_m .update (reduced_loss .item () * accum_steps , input .size (0 ))
1072+ # synchronize current step and avg loss, each process keeps its own running avg
1073+ loss_avg = utils .reduce_tensor (loss .new ([loss_avg ]), args .world_size ).item ()
1074+ loss_now = utils .reduce_tensor (loss .new ([loss_now ]), args .world_size ).item ()
10741075 update_sample_count *= args .world_size
10751076
10761077 if utils .is_primary (args ):
10771078 _logger .info (
10781079 f'Train: { epoch } [{ update_idx :>4d} /{ updates_per_epoch } '
10791080 f'({ 100. * (update_idx + 1 ) / updates_per_epoch :>3.0f} %)] '
1080- f'Loss: { losses_m . val :#.3g} ({ losses_m . avg :#.3g} ) '
1081+ f'Loss: { loss_now :#.3g} ({ loss_avg :#.3g} ) '
10811082 f'Time: { update_time_m .val :.3f} s, { update_sample_count / update_time_m .val :>7.2f} /s '
10821083 f'({ update_time_m .avg :.3f} s, { update_sample_count / update_time_m .avg :>7.2f} /s) '
10831084 f'LR: { lr :.3e} '
@@ -1106,7 +1107,12 @@ def _backward(_loss):
11061107 if hasattr (optimizer , 'sync_lookahead' ):
11071108 optimizer .sync_lookahead ()
11081109
1109- return OrderedDict ([('loss' , losses_m .avg )])
1110+ loss_avg = losses_m .avg
1111+ if args .distributed :
1112+ # synchronize avg loss, each process keeps its own running avg
1113+ loss_avg = torch .tensor ([loss_avg ], device = device , dtype = torch .float32 )
1114+ loss_avg = utils .reduce_tensor (loss_avg , args .world_size ).item ()
1115+ return OrderedDict ([('loss' , loss_avg )])
11101116
11111117
11121118def validate (
0 commit comments