Skip to content

Commit 5dad728

Browse files
committed
add grad scale for optim_manager
1 parent 6670f5c commit 5dad728

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

.github/workflows/build.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ on:
66
branches:
77
- 'dev'
88
- 'main'
9+
push:
10+
branches:
11+
- 'dev'
912

1013
jobs:
1114
build-archive-wheel:
1215

1316
uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main
14-
secrets:
15-
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
16-
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
17+
secrets: inherit
1718

1819
publish:
1920
needs: build-archive-wheel

bmtrain/optim/optim_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self,
5252
loss_scale_steps : int = 1024,
5353
min_loss_scale = 1,
5454
max_loss_scale = float("inf"),
55+
grad_scale : Optional[int] = None,
5556
):
5657
if loss_scale is not None:
5758
self.loss_scale = loss_scale
@@ -64,6 +65,9 @@ def __init__(self,
6465
self.loss_scale_steps = loss_scale_steps
6566
self.min_loss_scale = min_loss_scale
6667
self.max_loss_scale = max_loss_scale
68+
if grad_scale is None:
69+
grad_scale = config['zero_size']
70+
self.grad_scale = grad_scale
6771

6872
self.optimizers = []
6973
self.lr_schedulers = []
@@ -85,7 +89,7 @@ def add_optimizer(
8589

8690
def scale_loss(self, loss : torch.Tensor) -> torch.Tensor:
8791

88-
return loss * (self.loss_scale / (config['world_size']//(config['tp_size']*config['pipe_size']))) # loss scale
92+
return loss * ( self.loss_scale / self.grad_scale ) # loss scale
8993

9094
def backward(self, loss : torch.Tensor):
9195
"""

bmtrain/synchronize.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from . import distributed, nccl
33
from .global_var import config
44
import warnings
5+
from typing import Optional
56

67
def synchronize():
78
"""
@@ -24,14 +25,17 @@ def wait_loader():
2425
config['calc_stream'].record_event(config['load_event'])
2526

2627

27-
def sum_loss(loss : torch.Tensor):
28+
def sum_loss(loss : torch.Tensor, comm: Optional[nccl.Communicator] = None):
2829
"""
2930
Sum the loss across all workers.
3031
3132
This is a helper function to reduce the loss across all workers.
3233
"""
34+
if comm is None:
35+
comm = config['comm']
3336
warnings.warn("bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.", DeprecationWarning)
34-
return distributed.all_reduce(loss, "sum") / config['world_size']
37+
38+
return distributed.all_reduce(loss, "avg", comm)
3539

3640
def gather_result(result: torch.Tensor):
3741
warnings.warn("bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", DeprecationWarning)

0 commit comments

Comments
 (0)