Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support "auto" fp16/bf16 for DeepSpeed #195

Merged
merged 2 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions xtuner/configs/deepspeed/deepspeed_zero2.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
"train_micro_batch_size_per_gpu": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 2,
"contiguous_gradients": false,
"allgather_bucket_size": 1e8,
"reduce_bucket_size": 1e8,
"overlap_comm": true,
"reduce_scatter": true
"overlap_comm": true
},
"fp16": {
"enabled": true,
"enabled": "auto",
"initial_scale_power": 16
},
"bf16": {
"enabled": "auto"
}
}
9 changes: 4 additions & 5 deletions xtuner/configs/deepspeed/deepspeed_zero2_offload.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,17 @@
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 2,
"contiguous_gradients": false,
"allgather_bucket_size": 1e8,
"reduce_bucket_size": 1e8,
"overlap_comm": true,
"reduce_scatter": true,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"fp16": {
"enabled": true,
"enabled": "auto",
"initial_scale_power": 16
},
"bf16": {
"enabled": "auto"
}
}
10 changes: 4 additions & 6 deletions xtuner/configs/deepspeed/deepspeed_zero3.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 3,
"contiguous_gradients": false,
"allgather_bucket_size": 3e8,
"reduce_bucket_size": 3e8,
"overlap_comm": true,
"reduce_scatter": true,
"stage3_gather_16bit_weights_on_model_save": true
},
"low_cpu_mem_usage": false,
"fp16": {
"enabled": true,
"enabled": "auto",
"initial_scale_power": 16
},
"bf16": {
"enabled": "auto"
}
}
9 changes: 4 additions & 5 deletions xtuner/configs/deepspeed/deepspeed_zero3_offload.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 3,
"contiguous_gradients": false,
"allgather_bucket_size": 3e8,
"reduce_bucket_size": 3e8,
"overlap_comm": true,
"reduce_scatter": true,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
Expand All @@ -22,7 +18,10 @@
"stage3_gather_16bit_weights_on_model_save": true
},
"fp16": {
"enabled": true,
"enabled": "auto",
"initial_scale_power": 16
},
"bf16": {
"enabled": "auto"
}
}
4 changes: 3 additions & 1 deletion xtuner/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from xtuner.model.modules import dispatch_modules
from xtuner.model.utils import LoadWoInit, find_all_linear_names, traverse_dict
from xtuner.registry import BUILDER, MAP_FUNC
from xtuner.tools.utils import auto_dtype_of_deepspeed_config


def parse_args():
Expand Down Expand Up @@ -222,9 +223,10 @@ def main():
logger='current',
level=logging.WARNING)
grad_clip = mm_max_norm
ds_cfg = auto_dtype_of_deepspeed_config(ds_cfg)
strategy = dict(
type='DeepSpeedStrategy',
config=args.deepspeed,
config=ds_cfg,
gradient_accumulation_steps=grad_accum,
train_micro_batch_size_per_gpu=train_bs,
gradient_clipping=grad_clip)
Expand Down
20 changes: 20 additions & 0 deletions xtuner/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import re

import torch
from transformers import (PreTrainedTokenizerFast, StoppingCriteria,
StoppingCriteriaList)
from transformers.generation.streamers import BaseStreamer
Expand Down Expand Up @@ -138,3 +139,22 @@ def update_stop_criteria(base,
if answer_stop_word is not None:
answer.append(StopWordStoppingCriteria(tokenizer, answer_stop_word))
return command, answer


def auto_dtype_of_deepspeed_config(ds_config):
if ds_config.get('fp16') and not ds_config.get('bf16'):
if ds_config.get('fp16').get('enabled') == 'auto':
ds_config['fp16']['enabled'] = torch.cuda.is_available()
elif not ds_config.get('fp16') and ds_config.get('bf16'):
if ds_config.get('bf16').get('enabled') == 'auto':
ds_config['bf16']['enabled'] = torch.cuda.is_bf16_supported()
elif ds_config.get('fp16') and ds_config.get('bf16'):
if ds_config.get('fp16').get('enabled') == 'auto':
ds_config['fp16']['enabled'] = torch.cuda.is_available()
if ds_config.get('bf16').get('enabled') == 'auto':
ds_config['bf16']['enabled'] = torch.cuda.is_bf16_supported()
if (ds_config['fp16']['enabled'] is True
and ds_config['bf16']['enabled'] is True):
ds_config['fp16']['enabled'] = False
ds_config['bf16']['enabled'] = True
return ds_config
Loading