Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add deepspeed zero-stage-1 #264

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/open_clip/transformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
from collections import OrderedDict
import math
from typing import Callable, Optional, Sequence

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
if os.getenv('ENV_TYPE') == 'deepspeed':
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
else:
from torch.utils.checkpoint import checkpoint

from .utils import to_2tuple

Expand Down
66 changes: 65 additions & 1 deletion src/training/distributed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os

import json
import torch

try:
Expand Down Expand Up @@ -111,3 +111,67 @@ def init_distributed_device(args):
args.device = device
device = torch.device(device)
return device

def create_deepspeed_config(args):
args.deepspeed_config = os.path.join(os.getcwd(), "deepspeed_config.json")
_, _, world_size = world_info_from_env()
with open(args.deepspeed_config, mode="w") as writer:
ds_config = {
"train_batch_size": args.batch_size * world_size,
"train_micro_batch_size_per_gpu": args.batch_size,
"steps_per_print": 1000,
"optimizer": {
"type": "Adam",
"adam_w_mode": True,
"params": {
"bias_correction": True,
"betas": [
args.beta1,
args.beta2
],
"eps": args.eps
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
# "bf16": {
# "enabled": True
# },
"amp": {
"enabled": False,
"opt_level": "O2"
},
"flops_profiler": {
"enabled": True,
"profile_step": -1,
"module_depth": -1,
"top_modules": 1,
"detailed": True,
},
}

if args.grad_clip_norm is not None:
ds_config.update({'gradient_clipping': args.grad_clip_norm})

if args.zero_stage == 1:
ds_config.update(
{
"zero_optimization": {
"stage": 1,
"reduce_bucket_size": 5e8,
# "offload_optimizer": {
# "device": "cpu"
# }
}
}
)
elif args.zero_stage > 1:
raise NotImplementedError()

writer.write(json.dumps(ds_config, indent=2))
122 changes: 82 additions & 40 deletions src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from open_clip import create_model_and_transforms, trace_model, get_tokenizer
from training.data import get_data
from training.distributed import is_master, init_distributed_device, world_info_from_env
from training.distributed import is_master, init_distributed_device, world_info_from_env, create_deepspeed_config
from training.logger import setup_logging
from training.params import parse_args
from training.scheduler import cosine_lr
Expand All @@ -40,7 +40,7 @@ def random_seed(seed=42, rank=0):


def main(args):
args = parse_args(args)
args, ds_init = parse_args(args)

if torch.cuda.is_available():
# This enables tf32 on Ampere GPUs which is only 8% slower than
Expand Down Expand Up @@ -87,6 +87,10 @@ def main(args):
# fully initialize distributed device environment
device = init_distributed_device(args)

# create deepspeed config
if ds_init is not None:
create_deepspeed_config(args)

args.wandb = 'wandb' in args.report_to or 'all' in args.report_to
args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to
if is_master(args):
Expand Down Expand Up @@ -163,11 +167,12 @@ def main(args):
if args.distributed and not args.horovod:
if args.use_bn_sync:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
ddp_args = {}
if args.ddp_static_graph:
# this doesn't exist in older PyTorch, arg only added if enabled
ddp_args['static_graph'] = True
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args)
if not args.enable_deepspeed:
ddp_args = {}
if args.ddp_static_graph:
# this doesn't exist in older PyTorch, arg only added if enabled
ddp_args['static_graph'] = True
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args)

# create optimizer and scaler
optimizer = None
Expand All @@ -183,45 +188,74 @@ def main(args):
gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]

optimizer = optim.AdamW(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
)
if args.horovod:
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
optimizer_params = [
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
]

if args.enable_deepspeed:
scaler = None
model, optimizer, _, _ = ds_init(
args=args,
model=model,
model_parameters=optimizer_params,
dist_init_required=not args.distributed,
)
else:
optimizer = optim.AdamW(
optimizer_params,
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
)
if args.horovod:
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

scaler = GradScaler() if args.precision == "amp" else None
scaler = GradScaler() if args.precision == "amp" else None

# optionally resume from a checkpoint
start_epoch = 0
if args.resume is not None:
if os.path.isfile(args.resume):
checkpoint = torch.load(args.resume, map_location='cpu')
if 'epoch' in checkpoint:
# resuming a train checkpoint w/ epoch and optimizer state
start_epoch = checkpoint["epoch"]
sd = checkpoint["state_dict"]
if not args.distributed and next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items()}
model.load_state_dict(sd)
if optimizer is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
if scaler is not None and 'scaler' in checkpoint:
scaler.load_state_dict(checkpoint['scaler'])
logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})")
if args.enable_deepspeed:
if os.path.exists(args.resume):
import glob
all_checkpoints = glob.glob(os.path.join(args.resume, 'epoch_*'))
latest_ckpt = -1
for ckpt in all_checkpoints:
t = ckpt.split('/')[-1].split('_')[1]
if t.isdigit():
latest_ckpt = max(int(t), latest_ckpt)
if latest_ckpt >= 0:
start_epoch = latest_ckpt
_, client_states = model.load_checkpoint(args.resume, tag='epoch_%d' % latest_ckpt)
logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {latest_ckpt})")
else:
logging.info("=> no checkpoint found at '{}'".format(args.resume))
else:
# loading a bare (model only) checkpoint for fine-tune or evaluation
model.load_state_dict(checkpoint)
logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})")
logging.info("=> '{}' is not existing!".format(args.resume))
else:
logging.info("=> no checkpoint found at '{}'".format(args.resume))
if os.path.isfile(args.resume):
checkpoint = torch.load(args.resume, map_location='cpu')
if 'epoch' in checkpoint:
# resuming a train checkpoint w/ epoch and optimizer state
start_epoch = checkpoint["epoch"]
sd = checkpoint["state_dict"]
if not args.distributed and next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items()}
model.load_state_dict(sd)
if optimizer is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
if scaler is not None and 'scaler' in checkpoint:
scaler.load_state_dict(checkpoint['scaler'])
logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})")
else:
# loading a bare (model only) checkpoint for fine-tune or evaluation
model.load_state_dict(checkpoint)
logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})")
else:
logging.info("=> no checkpoint found at '{}'".format(args.resume))

# initialize datasets
data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model))
Expand Down Expand Up @@ -274,7 +308,15 @@ def main(args):
evaluate(model, data, completed_epoch, args, writer)

# Saving checkpoints.
if args.save_logs:
# is_master(args) can not be here while using deepspped, otherwise ckpts can not be saved
if args.logs and args.logs.lower() != 'none' and args.enable_deepspeed:
deepspeed_checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
if completed_epoch == args.epochs or (
args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
):
client_state = {'epoch': completed_epoch}
model.save_checkpoint(save_dir=deepspeed_checkpoint_path, tag="epoch_%s" % str(completed_epoch), client_state=client_state)
elif args.save_logs:
checkpoint_dict = {
"epoch": completed_epoch,
"name": args.name,
Expand Down
19 changes: 17 additions & 2 deletions src/training/params.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse

import os

def get_default_params(model_name):
# Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
Expand Down Expand Up @@ -317,6 +317,8 @@ def parse_args(args):
default=100,
help="Log every n steps to tensorboard/console/wandb.",
)
parser.add_argument('--enable-deepspeed', action='store_true', default=False)
parser.add_argument('--zero-stage', type=int, default=1, help='stage of ZERO')


args = parser.parse_args(args)
Expand All @@ -327,4 +329,17 @@ def parse_args(args):
if getattr(args, name) is None:
setattr(args, name, val)

return args
if args.enable_deepspeed:
try:
import deepspeed
os.environ['ENV_TYPE'] = "deepspeed"
parser = deepspeed.add_config_arguments(parser)
ds_init = deepspeed.initialize
except:
print("Please 'pip install deepspeed'")
exit(0)
else:
os.environ['ENV_TYPE'] = "pytorch"
ds_init = None

return args, ds_init
9 changes: 8 additions & 1 deletion src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w
texts = texts.to(device=device, non_blocking=True)

data_time_m.update(time.time() - end)
optimizer.zero_grad()
if args.enable_deepspeed:
model.zero_grad()
model.micro_steps = 0
else:
optimizer.zero_grad()

with autocast():
image_features, text_features, logit_scale = model(images, texts)
Expand All @@ -99,6 +103,9 @@ def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_w
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
scaler.step(optimizer)
scaler.update()
elif args.enable_deepspeed:
model.backward(total_loss)
model.step()
else:
total_loss.backward()
if args.grad_clip_norm is not None:
Expand Down