From 54467bea274060ba0808a0586eebcc1082fe460c Mon Sep 17 00:00:00 2001 From: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com> Date: Thu, 2 Nov 2023 15:47:08 +0800 Subject: [PATCH] [Feature] Support "auto" fp16/bf16 for DeepSpeed (#195) * support deepspeed auto dtype * update ds configs --- xtuner/configs/deepspeed/deepspeed_zero2.json | 12 +++++------ .../deepspeed/deepspeed_zero2_offload.json | 9 ++++----- xtuner/configs/deepspeed/deepspeed_zero3.json | 10 ++++------ .../deepspeed/deepspeed_zero3_offload.json | 9 ++++----- xtuner/tools/train.py | 4 +++- xtuner/tools/utils.py | 20 +++++++++++++++++++ 6 files changed, 41 insertions(+), 23 deletions(-) diff --git a/xtuner/configs/deepspeed/deepspeed_zero2.json b/xtuner/configs/deepspeed/deepspeed_zero2.json index b9f1ff37a..cf1fa0add 100644 --- a/xtuner/configs/deepspeed/deepspeed_zero2.json +++ b/xtuner/configs/deepspeed/deepspeed_zero2.json @@ -3,16 +3,16 @@ "train_micro_batch_size_per_gpu": "auto", "gradient_clipping": "auto", "zero_allow_untested_optimizer": true, + "zero_force_ds_cpu_optimizer": false, "zero_optimization": { "stage": 2, - "contiguous_gradients": false, - "allgather_bucket_size": 1e8, - "reduce_bucket_size": 1e8, - "overlap_comm": true, - "reduce_scatter": true + "overlap_comm": true }, "fp16": { - "enabled": true, + "enabled": "auto", "initial_scale_power": 16 + }, + "bf16": { + "enabled": "auto" } } diff --git a/xtuner/configs/deepspeed/deepspeed_zero2_offload.json b/xtuner/configs/deepspeed/deepspeed_zero2_offload.json index e46d0ef66..7f3c0671c 100644 --- a/xtuner/configs/deepspeed/deepspeed_zero2_offload.json +++ b/xtuner/configs/deepspeed/deepspeed_zero2_offload.json @@ -6,18 +6,17 @@ "zero_force_ds_cpu_optimizer": false, "zero_optimization": { "stage": 2, - "contiguous_gradients": false, - "allgather_bucket_size": 1e8, - "reduce_bucket_size": 1e8, "overlap_comm": true, - "reduce_scatter": true, "offload_optimizer": { "device": "cpu", "pin_memory": true } }, "fp16": { - "enabled": true, + "enabled": "auto", "initial_scale_power": 16 + }, + "bf16": { + "enabled": "auto" } } diff --git a/xtuner/configs/deepspeed/deepspeed_zero3.json b/xtuner/configs/deepspeed/deepspeed_zero3.json index c7a0e802b..1a2c666df 100644 --- a/xtuner/configs/deepspeed/deepspeed_zero3.json +++ b/xtuner/configs/deepspeed/deepspeed_zero3.json @@ -6,16 +6,14 @@ "zero_force_ds_cpu_optimizer": false, "zero_optimization": { "stage": 3, - "contiguous_gradients": false, - "allgather_bucket_size": 3e8, - "reduce_bucket_size": 3e8, "overlap_comm": true, - "reduce_scatter": true, "stage3_gather_16bit_weights_on_model_save": true }, - "low_cpu_mem_usage": false, "fp16": { - "enabled": true, + "enabled": "auto", "initial_scale_power": 16 + }, + "bf16": { + "enabled": "auto" } } diff --git a/xtuner/configs/deepspeed/deepspeed_zero3_offload.json b/xtuner/configs/deepspeed/deepspeed_zero3_offload.json index 64f41ce7a..3f3b9506b 100644 --- a/xtuner/configs/deepspeed/deepspeed_zero3_offload.json +++ b/xtuner/configs/deepspeed/deepspeed_zero3_offload.json @@ -6,11 +6,7 @@ "zero_force_ds_cpu_optimizer": false, "zero_optimization": { "stage": 3, - "contiguous_gradients": false, - "allgather_bucket_size": 3e8, - "reduce_bucket_size": 3e8, "overlap_comm": true, - "reduce_scatter": true, "offload_optimizer": { "device": "cpu", "pin_memory": true @@ -22,7 +18,10 @@ "stage3_gather_16bit_weights_on_model_save": true }, "fp16": { - "enabled": true, + "enabled": "auto", "initial_scale_power": 16 + }, + "bf16": { + "enabled": "auto" } } diff --git a/xtuner/tools/train.py b/xtuner/tools/train.py index a9b6eac98..c9449a41b 100644 --- a/xtuner/tools/train.py +++ b/xtuner/tools/train.py @@ -19,6 +19,7 @@ from xtuner.model.modules import dispatch_modules from xtuner.model.utils import LoadWoInit, find_all_linear_names, traverse_dict from xtuner.registry import BUILDER, MAP_FUNC +from xtuner.tools.utils import auto_dtype_of_deepspeed_config def parse_args(): @@ -222,9 +223,10 @@ def main(): logger='current', level=logging.WARNING) grad_clip = mm_max_norm + ds_cfg = auto_dtype_of_deepspeed_config(ds_cfg) strategy = dict( type='DeepSpeedStrategy', - config=args.deepspeed, + config=ds_cfg, gradient_accumulation_steps=grad_accum, train_micro_batch_size_per_gpu=train_bs, gradient_clipping=grad_clip) diff --git a/xtuner/tools/utils.py b/xtuner/tools/utils.py index 7d070024d..46e76a2b7 100644 --- a/xtuner/tools/utils.py +++ b/xtuner/tools/utils.py @@ -2,6 +2,7 @@ import copy import re +import torch from transformers import (PreTrainedTokenizerFast, StoppingCriteria, StoppingCriteriaList) from transformers.generation.streamers import BaseStreamer @@ -138,3 +139,22 @@ def update_stop_criteria(base, if answer_stop_word is not None: answer.append(StopWordStoppingCriteria(tokenizer, answer_stop_word)) return command, answer + + +def auto_dtype_of_deepspeed_config(ds_config): + if ds_config.get('fp16') and not ds_config.get('bf16'): + if ds_config.get('fp16').get('enabled') == 'auto': + ds_config['fp16']['enabled'] = torch.cuda.is_available() + elif not ds_config.get('fp16') and ds_config.get('bf16'): + if ds_config.get('bf16').get('enabled') == 'auto': + ds_config['bf16']['enabled'] = torch.cuda.is_bf16_supported() + elif ds_config.get('fp16') and ds_config.get('bf16'): + if ds_config.get('fp16').get('enabled') == 'auto': + ds_config['fp16']['enabled'] = torch.cuda.is_available() + if ds_config.get('bf16').get('enabled') == 'auto': + ds_config['bf16']['enabled'] = torch.cuda.is_bf16_supported() + if (ds_config['fp16']['enabled'] is True + and ds_config['bf16']['enabled'] is True): + ds_config['fp16']['enabled'] = False + ds_config['bf16']['enabled'] = True + return ds_config