From 5fadff4e002dfe643ccc60a9bd1cb00ef6b9e7b0 Mon Sep 17 00:00:00 2001 From: Jean Mercat Date: Mon, 11 Dec 2023 12:01:01 -0800 Subject: [PATCH] Test train save load, with and without fsdp --- open_lm/main.py | 4 +- tests/shared.py | 29 ++++++--- tests/test_save_load.py | 134 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 158 insertions(+), 9 deletions(-) create mode 100644 tests/test_save_load.py diff --git a/open_lm/main.py b/open_lm/main.py index b29f7ed5..f0b3ac41 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -215,9 +215,11 @@ def save_checkpoint( or (args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0) ): for prefix in prefixes: + path = os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt") + print(f"Saving {prefix}{completed_epoch} in {path}...") torch.save( prefixes[prefix], - os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt"), + path, ) if args.delete_previous_checkpoint: diff --git a/tests/shared.py b/tests/shared.py index fb7bdbff..01eb000e 100644 --- a/tests/shared.py +++ b/tests/shared.py @@ -2,26 +2,26 @@ from torch import optim from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from open_lm.data import get_data +from open_lm.distributed import init_distributed_device from open_lm.main import random_seed from open_lm.model import create_model -from open_lm.data import get_data from open_lm.scheduler import cosine_lr from tests.utils import download_val_data class MockTrainArgs: - def __init__(self, model): + def __init__(self, model, **kwargs): data_path = download_val_data("shard_00000000.tar", "./tests/assets/") self.model = model # part of model config self.model_norm = "gain_only_layer_norm" - self.rotary_old = False self.qk_norm = False self.train_data = [ data_path, ] self.log_logit_mean = False - self.device = 0 + self.device = "cpu" self.precision = "float32" self.wd = 0.033 self.lr = 3e-3 @@ -36,6 +36,9 @@ def __init__(self, model): self.rank = 0 self.local_rank = 0 self.log_every_n_steps = 1e8 + self.save_logs = False + self.logs = None + self.name = "test_model_name" self.dataset_type = "webdataset" self.data_key = "json" self.ffn_type = "swiglu" @@ -46,6 +49,12 @@ def __init__(self, model): self.seed = 1 self.vocab_size = 50432 self.seq_len = 300 + self.epochs = 1 + self.save_frequency = 1 + self.checkpoint_path = "./tests/assets/checkpoints/" + self.resume = None + self.distributed = False + self.delete_previous_checkpoint = False self.workers = 1 self.world_size = 1 self.val_data = None @@ -65,6 +74,10 @@ def __init__(self, model): self.target_mask_individual = None self.ignore_parse_errors = False + for k, v in kwargs.items(): + if hasattr(self, k): + setattr(self, k, v) + class MockDataArgs(object): def __init__(self): @@ -92,9 +105,9 @@ def __init__(self): self.ignore_parse_errors = False -def create_train_fixtures(model="open_lm_11m", fsdp=False): +def create_train_fixtures(model="open_lm_11m", fsdp=False, **kwargs): # Setup data, optimizer, and other basic settings - args = MockTrainArgs(model) + args = MockTrainArgs(model, **kwargs) # only want to look at one batch args.train_num_samples = args.batch_size @@ -106,8 +119,8 @@ def create_train_fixtures(model="open_lm_11m", fsdp=False): # create base models random_seed() if fsdp: - with torch.device("meta"): - model = create_model(args) + init_distributed_device(args) + model = create_model(args) model = FSDP(model) else: model = create_model(args) diff --git a/tests/test_save_load.py b/tests/test_save_load.py new file mode 100644 index 00000000..850d0666 --- /dev/null +++ b/tests/test_save_load.py @@ -0,0 +1,134 @@ +import pytest +import argparse +import torch +import os + +import torch.multiprocessing as mp + +from open_lm.utils.transformers.hf_model import OpenLMforCausalLM +from open_lm.utils.transformers.hf_config import OpenLMConfig +from open_lm.model import create_params +from open_lm.train import train_one_epoch +from open_lm.main import save_checkpoint, load_model +from open_lm.losses import CrossEntropyLossWithZLoss +from open_lm.distributed import is_using_distributed + +from tests.shared import create_train_fixtures +from tests.utils import download_dl_test_data + + +@pytest.fixture(scope="module") +def tiny_args(): + args = argparse.Namespace( + **{ + "model": "open_lm_test_tiny", + "vocab_size": 16, + "sequence_length": 16, + "train_num_samples": 64, + "batch_size": 4, + # Model params that might not be in config: + "model_norm": "default_layer_norm", + "qk_norm": False, + "positional_embedding_type": "rotary", + "ffn_type": "swiglu", + } + ) + return args + + +def test_tiny_save_load(tiny_args, fsdp=False): + """ + This test checks that the model can be saved and loaded without changing the parameters. + """ + scaler = None + epoch = 0 + evaluation_metrics = None + global_step = 0 + done_training = False + download_dl_test_data() + override_params = dict( + seq_len=tiny_args.sequence_length, + vocab_size=tiny_args.vocab_size, + train_num_samples=tiny_args.train_num_samples, + batch_size=tiny_args.batch_size, + checkpoint_path="./tests/assets/checkpoints/tiny_model/", + device="cpu" if not torch.cuda.is_available() else "cuda", + dataset_type="synthetic", + save_logs=True, + ) + + args, model, data, optimizer, scheduler, loss = create_train_fixtures( + tiny_args.model, fsdp=False, **override_params + ) + model = model.to(args.device) + args2, model2, data2, optimizer2, scheduler2, loss2 = create_train_fixtures( + tiny_args.model, fsdp=False, **override_params + ) + model2 = model2.to(args2.device) + os.makedirs(args.checkpoint_path, exist_ok=True) + + # print("Training tiny model") + train_one_epoch( + model, + data, + CrossEntropyLossWithZLoss(), + epoch=epoch, + step=global_step, + optimizer=optimizer, + scaler=scaler, + scheduler=scheduler, + total_steps=args.train_num_samples // args.batch_size, + args=args, + ) + epoch += 1 + threshold = 1e-6 + # Checking that tiny models diverged after traning. + allclose = True + for (n1, p1), (n2, p2) in zip(model.named_parameters(), model2.named_parameters()): + allclose = allclose and torch.allclose(p1, p2, atol=threshold) + assert not allclose + + args.distributed = is_using_distributed() + # print("Saving tiny model") + # Saving checkpoints. + save_checkpoint( + args, + model, + optimizer, + scaler, + epoch, + evaluation_metrics, + step=global_step, + is_final_checkpoint=done_training, + next_shard_per_source=None, + samples_seen=None, + ) + + # Loading saved tiny model + args.resume = "./tests/assets/checkpoints/tiny_model/epoch_1.pt" + load_model(args, model2) + + # Checking that loaded tiny model is the same as the original tiny model + for (n1, p1), (n2, p2) in zip(model.named_parameters(), model2.named_parameters()): + assert torch.allclose(p1, p2, atol=threshold) + + +def _save_load_helper_fsdp(rank, world_size, tiny_args): + # Initialize distributed training + torch.distributed.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + init_method="tcp://127.0.0.1:29501", + rank=rank, + world_size=world_size, + ) + test_tiny_save_load(tiny_args, fsdp=True) + torch.distributed.destroy_process_group() + + +def test_tiny_save_load_fsdp(tiny_args): + world_size = 1 + mp.spawn(_save_load_helper_fsdp, args=(world_size, tiny_args), nprocs=world_size, join=True) + + +if __name__ == "__main__": + pytest.main([__file__])