Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support FSDP #358

Draft
wants to merge 37 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
67a0e13
resolve conflicts
mehdidc Jan 31, 2023
173cba4
show before fsdp memory usage
mehdidc Jan 6, 2023
a45acae
add ddp again
mehdidc Jan 6, 2023
fa80396
resolve conflicts
mehdidc Jan 31, 2023
9f967b7
resolve conflicts
mehdidc Jan 31, 2023
1832e13
resolve conflicts
mehdidc Jan 31, 2023
9d5369e
minor
mehdidc Jan 7, 2023
08016d0
fix logit scale and eval issues on FSDP
mehdidc Jan 7, 2023
8820831
support cpu offload
mehdidc Jan 7, 2023
188bc9c
wrap residual blocks with FSDP
mehdidc Jan 7, 2023
2782ab1
add forward trick to CustomCLIP
mehdidc Jan 8, 2023
afd8ef3
test_training_clip_with_jit test error
mehdidc Jan 31, 2023
6627268
select layers to wrap in FSDP and grad checkpointing
mehdidc Jan 31, 2023
fd42631
support unlocking
mehdidc Feb 4, 2023
4f65c85
fix hang after epoch finish
mehdidc Feb 18, 2023
3bada34
use `use_orig_params=True` (thanks to @nkflash) to use original param…
mehdidc Feb 19, 2023
f495986
fix distill
mehdidc Mar 7, 2023
397b8fc
fix FSDP optim state save/load so that we save the full optim state d…
mehdidc Mar 13, 2023
f2c72f8
offload to cpu when saving checkpoint to avoid OOM
mehdidc Mar 14, 2023
a69c0a7
- use the new ModuleWrapPolicy instead of transformer_auto_wrap_polic…
mehdidc May 17, 2023
62980cb
use ShardedGradScaler for fsdp, thanks to @nkflash
mehdidc May 17, 2023
9e47140
- FSDP printouts: use logging info.
mehdidc May 17, 2023
a8d644b
parametrize FSDP mixed precision
mehdidc May 17, 2023
16013c4
use a boolean param args.fsdp to match current args.horovod instead o…
mehdidc May 17, 2023
7735cac
replace last args.distributed_engine mention in the code
mehdidc May 17, 2023
f4165f7
fsdp log on rank zero only
mehdidc May 17, 2023
3aa42f4
minor
mehdidc May 17, 2023
5e167b2
minor
mehdidc May 17, 2023
5704ada
rank0 only and offload to cpu both true as recommended
mehdidc May 18, 2023
ffcf226
cli parameters description
mehdidc May 18, 2023
d3ab217
support CoCa models
mehdidc May 22, 2023
86799c2
fix optimizer resuming in FSDP and remove param/buffer precision
mehdidc May 24, 2023
0859c84
use original_model instead of model
mehdidc Nov 3, 2023
0a98da2
delete old import
mehdidc Nov 3, 2023
acd5af7
remove old zero shot classifier builder
mehdidc Nov 3, 2023
67bfcaa
fix again zero-shot eval
mehdidc Nov 3, 2023
4206d56
support sharded checkpointing for FSDP to handle large models, following
mehdidc Nov 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix hang after epoch finish
mehdidc committed Nov 3, 2023
commit 4f65c85ece4e5d93569e8126e8122e38c9039f7d
56 changes: 25 additions & 31 deletions src/training/main.py
Original file line number Diff line number Diff line change
@@ -41,6 +41,7 @@
from training.file_utils import pt_load, check_exists, start_sync_process, remote_sync



LATEST_CHECKPOINT_NAME = "epoch_latest.pt"


@@ -70,7 +71,6 @@ def get_latest_checkpoint(path: str, remote : bool):
return checkpoints[-1]
return None


def main(args):
args = parse_args(args)

@@ -323,6 +323,7 @@ def main(args):
if re.match(layer, name):
layers.add(module.__class__)
print("Wrapped layers", layers)

wrapper_kwargs = dict(
mixed_precision=mp,
limit_all_gathers=args.fsdp_limit_allgathers,
@@ -340,9 +341,6 @@ def main(args):
#model.visual = FSDP(model.visual, device_id=device)
#model.text_projection = FSDP(model.text_projection) ???
#model.ln_final = FSDP(model.ln_final, device_id=device)
model = FSDP(model, **wrapper_kwargs)
print(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}")
print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB")
if args.lock_image:
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
model.lock_image_tower(
@@ -352,6 +350,9 @@ def main(args):
model.lock_text_tower(
unlocked_layers=args.lock_text_unlocked_layers,
freeze_layer_norm=args.lock_text_freeze_layer_norm)
model = FSDP(model, **wrapper_kwargs)
print(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}")
print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB")
if args.grad_checkpointing:
#https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/
layers_grad_checkpoint = set()
@@ -374,8 +375,6 @@ def main(args):
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)


else:
print("--distrubted_engine should be either 'ddp or 'fsdp'")
sys.exit(1)
@@ -409,7 +408,6 @@ def main(args):
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

scaler = GradScaler() if args.precision == "amp" else None

# optionally resume from a checkpoint
start_epoch = 0
if args.resume is not None:
@@ -430,7 +428,6 @@ def main(args):
# loading a bare (model only) checkpoint for fine-tune or evaluation
model.load_state_dict(checkpoint)
logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})")

# initialize datasets
tokenizer = get_tokenizer(args.model)
data = get_data(
@@ -460,9 +457,8 @@ def main(args):
logging.error(
f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.')
exit(1)

# determine if this worker should save logs and checkpoints. only do so if it is rank == 0
args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args)
args.save_logs = args.logs and args.logs.lower() != 'none' and (is_master(args) or args.distributed_engine == 'fsdp')
writer = None
if args.save_logs and args.tensorboard:
assert tensorboard is not None, "Please install tensorboard."
@@ -507,11 +503,9 @@ def main(args):
return

loss = create_loss(args)

for epoch in range(start_epoch, args.epochs):
if is_master(args):
logging.info(f'Start epoch {epoch}')

train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer)
completed_epoch = epoch + 1

@@ -528,25 +522,25 @@ def main(args):
}
if scaler is not None:
checkpoint_dict["scaler"] = scaler.state_dict()

if completed_epoch == args.epochs or (
args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
):
torch.save(
checkpoint_dict,
os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
)
if args.delete_previous_checkpoint:
previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt")
if os.path.exists(previous_checkpoint):
os.remove(previous_checkpoint)

if args.save_most_recent:
# try not to corrupt the latest checkpoint if save fails
tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt")
latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME)
torch.save(checkpoint_dict, tmp_save_path)
os.replace(tmp_save_path, latest_save_path)
if is_master(args):
if completed_epoch == args.epochs or (
args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
):
torch.save(
checkpoint_dict,
os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
)
if args.delete_previous_checkpoint:
previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt")
if os.path.exists(previous_checkpoint):
os.remove(previous_checkpoint)

if args.save_most_recent:
# try not to corrupt the latest checkpoint if save fails
tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt")
latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME)
torch.save(checkpoint_dict, tmp_save_path)
os.replace(tmp_save_path, latest_save_path)

if args.wandb and is_master(args):
wandb.finish()