From 987b3e7adfa68b9d2ad73c553bde583ae5e77333 Mon Sep 17 00:00:00 2001 From: Sun Quan Date: Mon, 28 Nov 2022 11:35:13 +0800 Subject: [PATCH] add deepspeed zero-stage-1 Signed-off-by: Sun Quan --- src/open_clip/transformer.py | 6 +- src/training/distributed.py | 66 ++++++++++++++++++- src/training/main.py | 122 +++++++++++++++++++++++------------ src/training/params.py | 19 +++++- src/training/train.py | 9 ++- 5 files changed, 177 insertions(+), 45 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index a36fa5f5d..22762a9e4 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -1,3 +1,4 @@ +import os from collections import OrderedDict import math from typing import Callable, Optional, Sequence @@ -5,7 +6,10 @@ import torch from torch import nn from torch.nn import functional as F -from torch.utils.checkpoint import checkpoint +if os.getenv('ENV_TYPE') == 'deepspeed': + from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint +else: + from torch.utils.checkpoint import checkpoint from .utils import to_2tuple diff --git a/src/training/distributed.py b/src/training/distributed.py index 27bdd9b11..bbf118206 100644 --- a/src/training/distributed.py +++ b/src/training/distributed.py @@ -1,5 +1,5 @@ import os - +import json import torch try: @@ -111,3 +111,67 @@ def init_distributed_device(args): args.device = device device = torch.device(device) return device + +def create_deepspeed_config(args): + args.deepspeed_config = os.path.join(os.getcwd(), "deepspeed_config.json") + _, _, world_size = world_info_from_env() + with open(args.deepspeed_config, mode="w") as writer: + ds_config = { + "train_batch_size": args.batch_size * world_size, + "train_micro_batch_size_per_gpu": args.batch_size, + "steps_per_print": 1000, + "optimizer": { + "type": "Adam", + "adam_w_mode": True, + "params": { + "bias_correction": True, + "betas": [ + args.beta1, + args.beta2 + ], + "eps": args.eps + } + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + # "bf16": { + # "enabled": True + # }, + "amp": { + "enabled": False, + "opt_level": "O2" + }, + "flops_profiler": { + "enabled": True, + "profile_step": -1, + "module_depth": -1, + "top_modules": 1, + "detailed": True, + }, + } + + if args.grad_clip_norm is not None: + ds_config.update({'gradient_clipping': args.grad_clip_norm}) + + if args.zero_stage == 1: + ds_config.update( + { + "zero_optimization": { + "stage": 1, + "reduce_bucket_size": 5e8, + # "offload_optimizer": { + # "device": "cpu" + # } + } + } + ) + elif args.zero_stage > 1: + raise NotImplementedError() + + writer.write(json.dumps(ds_config, indent=2)) \ No newline at end of file diff --git a/src/training/main.py b/src/training/main.py index b0e56de27..91891c109 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -26,7 +26,7 @@ from open_clip import create_model_and_transforms, trace_model, get_tokenizer from training.data import get_data -from training.distributed import is_master, init_distributed_device, world_info_from_env +from training.distributed import is_master, init_distributed_device, world_info_from_env, create_deepspeed_config from training.logger import setup_logging from training.params import parse_args from training.scheduler import cosine_lr @@ -40,7 +40,7 @@ def random_seed(seed=42, rank=0): def main(args): - args = parse_args(args) + args, ds_init = parse_args(args) if torch.cuda.is_available(): # This enables tf32 on Ampere GPUs which is only 8% slower than @@ -87,6 +87,10 @@ def main(args): # fully initialize distributed device environment device = init_distributed_device(args) + # create deepspeed config + if ds_init is not None: + create_deepspeed_config(args) + args.wandb = 'wandb' in args.report_to or 'all' in args.report_to args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to if is_master(args): @@ -163,11 +167,12 @@ def main(args): if args.distributed and not args.horovod: if args.use_bn_sync: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - ddp_args = {} - if args.ddp_static_graph: - # this doesn't exist in older PyTorch, arg only added if enabled - ddp_args['static_graph'] = True - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) + if not args.enable_deepspeed: + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args['static_graph'] = True + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) # create optimizer and scaler optimizer = None @@ -183,45 +188,74 @@ def main(args): gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] - optimizer = optim.AdamW( - [ - {"params": gain_or_bias_params, "weight_decay": 0.}, - {"params": rest_params, "weight_decay": args.wd}, - ], - lr=args.lr, - betas=(args.beta1, args.beta2), - eps=args.eps, - ) - if args.horovod: - optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) - hvd.broadcast_parameters(model.state_dict(), root_rank=0) - hvd.broadcast_optimizer_state(optimizer, root_rank=0) + optimizer_params = [ + {"params": gain_or_bias_params, "weight_decay": 0.}, + {"params": rest_params, "weight_decay": args.wd}, + ] + + if args.enable_deepspeed: + scaler = None + model, optimizer, _, _ = ds_init( + args=args, + model=model, + model_parameters=optimizer_params, + dist_init_required=not args.distributed, + ) + else: + optimizer = optim.AdamW( + optimizer_params, + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + ) + if args.horovod: + optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) - scaler = GradScaler() if args.precision == "amp" else None + scaler = GradScaler() if args.precision == "amp" else None # optionally resume from a checkpoint start_epoch = 0 if args.resume is not None: - if os.path.isfile(args.resume): - checkpoint = torch.load(args.resume, map_location='cpu') - if 'epoch' in checkpoint: - # resuming a train checkpoint w/ epoch and optimizer state - start_epoch = checkpoint["epoch"] - sd = checkpoint["state_dict"] - if not args.distributed and next(iter(sd.items()))[0].startswith('module'): - sd = {k[len('module.'):]: v for k, v in sd.items()} - model.load_state_dict(sd) - if optimizer is not None: - optimizer.load_state_dict(checkpoint["optimizer"]) - if scaler is not None and 'scaler' in checkpoint: - scaler.load_state_dict(checkpoint['scaler']) - logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") + if args.enable_deepspeed: + if os.path.exists(args.resume): + import glob + all_checkpoints = glob.glob(os.path.join(args.resume, 'epoch_*')) + latest_ckpt = -1 + for ckpt in all_checkpoints: + t = ckpt.split('/')[-1].split('_')[1] + if t.isdigit(): + latest_ckpt = max(int(t), latest_ckpt) + if latest_ckpt >= 0: + start_epoch = latest_ckpt + _, client_states = model.load_checkpoint(args.resume, tag='epoch_%d' % latest_ckpt) + logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {latest_ckpt})") + else: + logging.info("=> no checkpoint found at '{}'".format(args.resume)) else: - # loading a bare (model only) checkpoint for fine-tune or evaluation - model.load_state_dict(checkpoint) - logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") + logging.info("=> '{}' is not existing!".format(args.resume)) else: - logging.info("=> no checkpoint found at '{}'".format(args.resume)) + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume, map_location='cpu') + if 'epoch' in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith('module'): + sd = {k[len('module.'):]: v for k, v in sd.items()} + model.load_state_dict(sd) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and 'scaler' in checkpoint: + scaler.load_state_dict(checkpoint['scaler']) + logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") + else: + logging.info("=> no checkpoint found at '{}'".format(args.resume)) # initialize datasets data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) @@ -274,7 +308,15 @@ def main(args): evaluate(model, data, completed_epoch, args, writer) # Saving checkpoints. - if args.save_logs: + # is_master(args) can not be here while using deepspped, otherwise ckpts can not be saved + if args.logs and args.logs.lower() != 'none' and args.enable_deepspeed: + deepspeed_checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + client_state = {'epoch': completed_epoch} + model.save_checkpoint(save_dir=deepspeed_checkpoint_path, tag="epoch_%s" % str(completed_epoch), client_state=client_state) + elif args.save_logs: checkpoint_dict = { "epoch": completed_epoch, "name": args.name, diff --git a/src/training/params.py b/src/training/params.py index da16cfd3c..d5019ae57 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -1,5 +1,5 @@ import argparse - +import os def get_default_params(model_name): # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) @@ -317,6 +317,8 @@ def parse_args(args): default=100, help="Log every n steps to tensorboard/console/wandb.", ) + parser.add_argument('--enable-deepspeed', action='store_true', default=False) + parser.add_argument('--zero-stage', type=int, default=1, help='stage of ZERO') args = parser.parse_args(args) @@ -327,4 +329,17 @@ def parse_args(args): if getattr(args, name) is None: setattr(args, name, val) - return args + if args.enable_deepspeed: + try: + import deepspeed + os.environ['ENV_TYPE'] = "deepspeed" + parser = deepspeed.add_config_arguments(parser) + ds_init = deepspeed.initialize + except: + print("Please 'pip install deepspeed'") + exit(0) + else: + os.environ['ENV_TYPE'] = "pytorch" + ds_init = None + + return args, ds_init diff --git a/src/training/train.py b/src/training/train.py index c3b953a4a..981437c3f 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -78,7 +78,11 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w texts = texts.to(device=device, non_blocking=True) data_time_m.update(time.time() - end) - optimizer.zero_grad() + if args.enable_deepspeed: + model.zero_grad() + model.micro_steps = 0 + else: + optimizer.zero_grad() with autocast(): image_features, text_features, logit_scale = model(images, texts) @@ -99,6 +103,9 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) scaler.step(optimizer) scaler.update() + elif args.enable_deepspeed: + model.backward(total_loss) + model.step() else: total_loss.backward() if args.grad_clip_norm is not None: