diff --git a/open_lm/main.py b/open_lm/main.py index b29f7ed5..f0b3ac41 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -215,9 +215,11 @@ def save_checkpoint( or (args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0) ): for prefix in prefixes: + path = os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt") + print(f"Saving {prefix}{completed_epoch} in {path}...") torch.save( prefixes[prefix], - os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt"), + path, ) if args.delete_previous_checkpoint: diff --git a/tests/shared.py b/tests/shared.py index fb7bdbff..e4b4621b 100644 --- a/tests/shared.py +++ b/tests/shared.py @@ -2,68 +2,77 @@ from torch import optim from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from open_lm.data import get_data +from open_lm.distributed import init_distributed_device from open_lm.main import random_seed from open_lm.model import create_model -from open_lm.data import get_data from open_lm.scheduler import cosine_lr from tests.utils import download_val_data class MockTrainArgs: - def __init__(self, model): + def __init__(self, model, **kwargs): data_path = download_val_data("shard_00000000.tar", "./tests/assets/") self.model = model # part of model config - self.model_norm = "gain_only_layer_norm" - self.rotary_old = False - self.qk_norm = False + self.model_norm = kwargs.get("model_norm", "gain_only_layer_norm") + self.qk_norm = kwargs.get("qk_norm", False) self.train_data = [ data_path, ] - self.log_logit_mean = False - self.device = 0 - self.precision = "float32" - self.wd = 0.033 - self.lr = 3e-3 - self.beta1 = 0.9 - self.beta2 = 0.95 - self.eps = 1e-8 - self.warmup = 2 - self.skip_scheduler = False - self.accum_freq = 1 - self.batch_size = 8 - self.grad_clip_norm = 1.0 - self.rank = 0 - self.local_rank = 0 - self.log_every_n_steps = 1e8 - self.dataset_type = "webdataset" - self.data_key = "json" - self.ffn_type = "swiglu" - self.train_num_samples = 250000 - self.train_data_mix_weights = None - self.train_data_upsampling_factors = None - self.disable_buffer = False - self.seed = 1 - self.vocab_size = 50432 - self.seq_len = 300 - self.workers = 1 - self.world_size = 1 - self.val_data = None - self.lr_cooldown_end = 3e-5 - self.force_min_lr = 0.0 - self.scaler = None - self.accum_freq = 1 - self.device = "cuda:0" if torch.cuda.is_available() else "cpu" - self.wandb = False - self.fsdp = False - self.fsdp_amp = False - self.positional_embedding_type = "rotary" - self.dist_backend = "nccl" - self.dist_url = "env://" - self.dataset_manifest = None - self.target_mask_left = None - self.target_mask_individual = None - self.ignore_parse_errors = False + self.log_logit_mean = kwargs.get("log_logit_mean", False) + self.device = kwargs.get("device", "cpu") + self.precision = kwargs.get("precision", "float32") + self.wd = kwargs.get("wd", 0.033) + self.lr = kwargs.get("lr", 3e-3) + self.beta1 = kwargs.get("beta1", 0.9) + self.beta2 = kwargs.get("beta2", 0.95) + self.eps = kwargs.get("eps", 1e-8) + self.warmup = kwargs.get("warmup", 2) + self.skip_scheduler = kwargs.get("skip_scheduler", False) + self.accum_freq = kwargs.get("accum_freq", 1) + self.batch_size = kwargs.get("batch_size", 8) + self.grad_clip_norm = kwargs.get("grad_clip_norm", 1.0) + self.rank = kwargs.get("rank", 0) + self.local_rank = kwargs.get("local_rank", 0) + self.log_every_n_steps = kwargs.get("log_every_n_steps", 1e8) + self.save_logs = kwargs.get("save_logs", True) + self.logs = kwargs.get("logs", None) + self.name = kwargs.get("name", "test_model_name") + self.dataset_type = kwargs.get("dataset_type", "webdataset") + self.data_key = kwargs.get("data_key", "json") + self.ffn_type = kwargs.get("ffn_type", "swiglu") + self.train_num_samples = kwargs.get("train_num_samples", 250000) + self.train_data_mix_weights = kwargs.get("train_data_mix_weights", None) + self.train_data_upsampling_factors = kwargs.get("train_data_upsampling_factors", None) + self.disable_buffer = kwargs.get("disable_buffer", False) + self.seed = kwargs.get("seed", 1) + self.vocab_size = kwargs.get("vocab_size", 50432) + self.seq_len = kwargs.get("seq_len", 300) + self.epochs = kwargs.get("epochs", 1) + self.save_frequency = kwargs.get("save_frequency", 1) + self.checkpoint_path = kwargs.get("checkpoint_path", "./tests/assets/checkpoints/") + self.resume = kwargs.get("resume", None) + self.distributed = kwargs.get("distributed", False) + self.delete_previous_checkpoint = kwargs.get("delete_previous_checkpoint", False) + self.workers = kwargs.get("workers", 1) + self.world_size = kwargs.get("world_size", 1) + self.val_data = kwargs.get("val_data", None) + self.lr_cooldown_end = kwargs.get("lr_cooldown_end", 3e-5) + self.force_min_lr = kwargs.get("force_min_lr", 0.0) + self.scaler = kwargs.get("scaler", None) + self.accum_freq = kwargs.get("accum_freq", 1) + self.device = kwargs.get("device", "cuda:0" if torch.cuda.is_available() else "cpu") + self.wandb = kwargs.get("wandb", False) + self.fsdp = kwargs.get("fsdp", False) + self.fsdp_amp = kwargs.get("fsdp_amp", False) + self.positional_embedding_type = kwargs.get("positional_embedding_type", "rotary") + self.dist_backend = kwargs.get("dist_backend", "nccl") + self.dist_url = kwargs.get("dist_url", "env://") + self.dataset_manifest = kwargs.get("dataset_manifest", None) + self.target_mask_left = kwargs.get("target_mask_left", None) + self.target_mask_individual = kwargs.get("target_mask_individual", None) + self.ignore_parse_errors = kwargs.get("ignore_parse_errors", False) class MockDataArgs(object): @@ -92,9 +101,9 @@ def __init__(self): self.ignore_parse_errors = False -def create_train_fixtures(model="open_lm_11m", fsdp=False): +def create_train_fixtures(model="open_lm_11m", fsdp=False, **kwargs): # Setup data, optimizer, and other basic settings - args = MockTrainArgs(model) + args = MockTrainArgs(model, **kwargs) # only want to look at one batch args.train_num_samples = args.batch_size @@ -106,8 +115,8 @@ def create_train_fixtures(model="open_lm_11m", fsdp=False): # create base models random_seed() if fsdp: - with torch.device("meta"): - model = create_model(args) + init_distributed_device(args) + model = create_model(args) model = FSDP(model) else: model = create_model(args) diff --git a/tests/test_save_load.py b/tests/test_save_load.py new file mode 100644 index 00000000..ade85213 --- /dev/null +++ b/tests/test_save_load.py @@ -0,0 +1,133 @@ +import pytest +import argparse +import torch +import os + +import torch.multiprocessing as mp + +from open_lm.utils.transformers.hf_model import OpenLMforCausalLM +from open_lm.utils.transformers.hf_config import OpenLMConfig +from open_lm.model import create_params +from open_lm.train import train_one_epoch +from open_lm.main import save_checkpoint, load_model +from open_lm.losses import CrossEntropyLossWithZLoss +from open_lm.distributed import is_using_distributed + +from tests.shared import create_train_fixtures +from tests.utils import download_dl_test_data + + +@pytest.fixture(scope="module") +def tiny_args(): + args = argparse.Namespace( + **{ + "model": "open_lm_test_tiny", + "vocab_size": 16, + "sequence_length": 16, + "train_num_samples": 64, + "batch_size": 4, + # Model params that might not be in config: + "model_norm": "default_layer_norm", + "qk_norm": False, + "positional_embedding_type": "rotary", + "ffn_type": "swiglu", + } + ) + return args + + +def test_tiny_save_load(tiny_args, fsdp=False): + """ + This test checks that the model can be saved and loaded without changing the parameters. + """ + scaler = None + epoch = 0 + evaluation_metrics = None + global_step = 0 + done_training = False + download_dl_test_data() + override_params = dict( + seq_len=tiny_args.sequence_length, + vocab_size=tiny_args.vocab_size, + train_num_samples=tiny_args.train_num_samples, + batch_size=tiny_args.batch_size, + checkpoint_path="./tests/assets/checkpoints/tiny_model/", + device="cpu" if not torch.cuda.is_available() else "cuda", + dataset_type="synthetic", + ) + + args, model, data, optimizer, scheduler, loss = create_train_fixtures( + tiny_args.model, fsdp=False, **override_params + ) + model = model.to(args.device) + args2, model2, data2, optimizer2, scheduler2, loss2 = create_train_fixtures( + tiny_args.model, fsdp=False, **override_params + ) + model2 = model2.to(args2.device) + os.makedirs(args.checkpoint_path, exist_ok=True) + + # print("Training tiny model") + train_one_epoch( + model, + data, + CrossEntropyLossWithZLoss(), + epoch=epoch, + step=global_step, + optimizer=optimizer, + scaler=scaler, + scheduler=scheduler, + total_steps=args.train_num_samples // args.batch_size, + args=args, + ) + epoch += 1 + threshold = 1e-6 + # Checking that tiny models diverged after traning. + allclose = True + for (n1, p1), (n2, p2) in zip(model.named_parameters(), model2.named_parameters()): + allclose = allclose and torch.allclose(p1, p2, atol=threshold) + assert not allclose + + args.distributed = is_using_distributed() + # print("Saving tiny model") + # Saving checkpoints. + save_checkpoint( + args, + model, + optimizer, + scaler, + epoch, + evaluation_metrics, + step=global_step, + is_final_checkpoint=done_training, + next_shard_per_source=None, + samples_seen=None, + ) + + # Loading saved tiny model + args.resume = "./tests/assets/checkpoints/tiny_model/epoch_1.pt" + load_model(args, model2) + + # Checking that loaded tiny model is the same as the original tiny model + for (n1, p1), (n2, p2) in zip(model.named_parameters(), model2.named_parameters()): + assert torch.allclose(p1, p2, atol=threshold) + + +def _save_load_helper_fsdp(rank, world_size, tiny_args): + # Initialize distributed training + torch.distributed.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + init_method="tcp://127.0.0.1:29501", + rank=rank, + world_size=world_size, + ) + test_tiny_save_load(tiny_args, fsdp=True) + torch.distributed.destroy_process_group() + + +def test_tiny_save_load_fsdp(tiny_args): + world_size = 1 + mp.spawn(_save_load_helper_fsdp, args=(world_size, tiny_args), nprocs=world_size, join=True) + + +if __name__ == "__main__": + pytest.main([__file__])