Skip to content

Commit

Permalink
Test train save load, with and without fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
jmercat committed Dec 12, 2023
1 parent 73e7b37 commit 7086109
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 54 deletions.
4 changes: 3 additions & 1 deletion open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,11 @@ def save_checkpoint(
or (args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0)
):
for prefix in prefixes:
path = os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt")
print(f"Saving {prefix}{completed_epoch} in {path}...")
torch.save(
prefixes[prefix],
os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt"),
path,
)

if args.delete_previous_checkpoint:
Expand Down
115 changes: 62 additions & 53 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,77 @@
from torch import optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from open_lm.data import get_data
from open_lm.distributed import init_distributed_device
from open_lm.main import random_seed
from open_lm.model import create_model
from open_lm.data import get_data
from open_lm.scheduler import cosine_lr
from tests.utils import download_val_data


class MockTrainArgs:
def __init__(self, model):
def __init__(self, model, **kwargs):
data_path = download_val_data("shard_00000000.tar", "./tests/assets/")

self.model = model # part of model config
self.model_norm = "gain_only_layer_norm"
self.rotary_old = False
self.qk_norm = False
self.model_norm = kwargs.get("model_norm", "gain_only_layer_norm")
self.qk_norm = kwargs.get("qk_norm", False)
self.train_data = [
data_path,
]
self.log_logit_mean = False
self.device = 0
self.precision = "float32"
self.wd = 0.033
self.lr = 3e-3
self.beta1 = 0.9
self.beta2 = 0.95
self.eps = 1e-8
self.warmup = 2
self.skip_scheduler = False
self.accum_freq = 1
self.batch_size = 8
self.grad_clip_norm = 1.0
self.rank = 0
self.local_rank = 0
self.log_every_n_steps = 1e8
self.dataset_type = "webdataset"
self.data_key = "json"
self.ffn_type = "swiglu"
self.train_num_samples = 250000
self.train_data_mix_weights = None
self.train_data_upsampling_factors = None
self.disable_buffer = False
self.seed = 1
self.vocab_size = 50432
self.seq_len = 300
self.workers = 1
self.world_size = 1
self.val_data = None
self.lr_cooldown_end = 3e-5
self.force_min_lr = 0.0
self.scaler = None
self.accum_freq = 1
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.wandb = False
self.fsdp = False
self.fsdp_amp = False
self.positional_embedding_type = "rotary"
self.dist_backend = "nccl"
self.dist_url = "env://"
self.dataset_manifest = None
self.target_mask_left = None
self.target_mask_individual = None
self.ignore_parse_errors = False
self.log_logit_mean = kwargs.get("log_logit_mean", False)
self.device = kwargs.get("device", "cpu")
self.precision = kwargs.get("precision", "float32")
self.wd = kwargs.get("wd", 0.033)
self.lr = kwargs.get("lr", 3e-3)
self.beta1 = kwargs.get("beta1", 0.9)
self.beta2 = kwargs.get("beta2", 0.95)
self.eps = kwargs.get("eps", 1e-8)
self.warmup = kwargs.get("warmup", 2)
self.skip_scheduler = kwargs.get("skip_scheduler", False)
self.accum_freq = kwargs.get("accum_freq", 1)
self.batch_size = kwargs.get("batch_size", 8)
self.grad_clip_norm = kwargs.get("grad_clip_norm", 1.0)
self.rank = kwargs.get("rank", 0)
self.local_rank = kwargs.get("local_rank", 0)
self.log_every_n_steps = kwargs.get("log_every_n_steps", 1e8)
self.save_logs = kwargs.get("save_logs", True)
self.logs = kwargs.get("logs", None)
self.name = kwargs.get("name", "test_model_name")
self.dataset_type = kwargs.get("dataset_type", "webdataset")
self.data_key = kwargs.get("data_key", "json")
self.ffn_type = kwargs.get("ffn_type", "swiglu")
self.train_num_samples = kwargs.get("train_num_samples", 250000)
self.train_data_mix_weights = kwargs.get("train_data_mix_weights", None)
self.train_data_upsampling_factors = kwargs.get("train_data_upsampling_factors", None)
self.disable_buffer = kwargs.get("disable_buffer", False)
self.seed = kwargs.get("seed", 1)
self.vocab_size = kwargs.get("vocab_size", 50432)
self.seq_len = kwargs.get("seq_len", 300)
self.epochs = kwargs.get("epochs", 1)
self.save_frequency = kwargs.get("save_frequency", 1)
self.checkpoint_path = kwargs.get("checkpoint_path", "./tests/assets/checkpoints/")
self.resume = kwargs.get("resume", None)
self.distributed = kwargs.get("distributed", False)
self.delete_previous_checkpoint = kwargs.get("delete_previous_checkpoint", False)
self.workers = kwargs.get("workers", 1)
self.world_size = kwargs.get("world_size", 1)
self.val_data = kwargs.get("val_data", None)
self.lr_cooldown_end = kwargs.get("lr_cooldown_end", 3e-5)
self.force_min_lr = kwargs.get("force_min_lr", 0.0)
self.scaler = kwargs.get("scaler", None)
self.accum_freq = kwargs.get("accum_freq", 1)
self.device = kwargs.get("device", "cuda:0" if torch.cuda.is_available() else "cpu")
self.wandb = kwargs.get("wandb", False)
self.fsdp = kwargs.get("fsdp", False)
self.fsdp_amp = kwargs.get("fsdp_amp", False)
self.positional_embedding_type = kwargs.get("positional_embedding_type", "rotary")
self.dist_backend = kwargs.get("dist_backend", "nccl")
self.dist_url = kwargs.get("dist_url", "env://")
self.dataset_manifest = kwargs.get("dataset_manifest", None)
self.target_mask_left = kwargs.get("target_mask_left", None)
self.target_mask_individual = kwargs.get("target_mask_individual", None)
self.ignore_parse_errors = kwargs.get("ignore_parse_errors", False)


class MockDataArgs(object):
Expand Down Expand Up @@ -92,9 +101,9 @@ def __init__(self):
self.ignore_parse_errors = False


def create_train_fixtures(model="open_lm_11m", fsdp=False):
def create_train_fixtures(model="open_lm_11m", fsdp=False, **kwargs):
# Setup data, optimizer, and other basic settings
args = MockTrainArgs(model)
args = MockTrainArgs(model, **kwargs)

# only want to look at one batch
args.train_num_samples = args.batch_size
Expand All @@ -106,8 +115,8 @@ def create_train_fixtures(model="open_lm_11m", fsdp=False):
# create base models
random_seed()
if fsdp:
with torch.device("meta"):
model = create_model(args)
init_distributed_device(args)
model = create_model(args)
model = FSDP(model)
else:
model = create_model(args)
Expand Down
133 changes: 133 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import pytest
import argparse
import torch
import os

import torch.multiprocessing as mp

from open_lm.utils.transformers.hf_model import OpenLMforCausalLM
from open_lm.utils.transformers.hf_config import OpenLMConfig
from open_lm.model import create_params
from open_lm.train import train_one_epoch
from open_lm.main import save_checkpoint, load_model
from open_lm.losses import CrossEntropyLossWithZLoss
from open_lm.distributed import is_using_distributed

from tests.shared import create_train_fixtures
from tests.utils import download_dl_test_data


@pytest.fixture(scope="module")
def tiny_args():
args = argparse.Namespace(
**{
"model": "open_lm_test_tiny",
"vocab_size": 16,
"sequence_length": 16,
"train_num_samples": 64,
"batch_size": 4,
# Model params that might not be in config:
"model_norm": "default_layer_norm",
"qk_norm": False,
"positional_embedding_type": "rotary",
"ffn_type": "swiglu",
}
)
return args


def test_tiny_save_load(tiny_args, fsdp=False):
"""
This test checks that the model can be saved and loaded without changing the parameters.
"""
scaler = None
epoch = 0
evaluation_metrics = None
global_step = 0
done_training = False
download_dl_test_data()
override_params = dict(
seq_len=tiny_args.sequence_length,
vocab_size=tiny_args.vocab_size,
train_num_samples=tiny_args.train_num_samples,
batch_size=tiny_args.batch_size,
checkpoint_path="./tests/assets/checkpoints/tiny_model/",
device="cpu" if not torch.cuda.is_available() else "cuda",
dataset_type="synthetic",
)

args, model, data, optimizer, scheduler, loss = create_train_fixtures(
tiny_args.model, fsdp=False, **override_params
)
model = model.to(args.device)
args2, model2, data2, optimizer2, scheduler2, loss2 = create_train_fixtures(
tiny_args.model, fsdp=False, **override_params
)
model2 = model2.to(args2.device)
os.makedirs(args.checkpoint_path, exist_ok=True)

# print("Training tiny model")
train_one_epoch(
model,
data,
CrossEntropyLossWithZLoss(),
epoch=epoch,
step=global_step,
optimizer=optimizer,
scaler=scaler,
scheduler=scheduler,
total_steps=args.train_num_samples // args.batch_size,
args=args,
)
epoch += 1
threshold = 1e-6
# Checking that tiny models diverged after traning.
allclose = True
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model2.named_parameters()):
allclose = allclose and torch.allclose(p1, p2, atol=threshold)
assert not allclose

args.distributed = is_using_distributed()
# print("Saving tiny model")
# Saving checkpoints.
save_checkpoint(
args,
model,
optimizer,
scaler,
epoch,
evaluation_metrics,
step=global_step,
is_final_checkpoint=done_training,
next_shard_per_source=None,
samples_seen=None,
)

# Loading saved tiny model
args.resume = "./tests/assets/checkpoints/tiny_model/epoch_1.pt"
load_model(args, model2)

# Checking that loaded tiny model is the same as the original tiny model
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model2.named_parameters()):
assert torch.allclose(p1, p2, atol=threshold)


def _save_load_helper_fsdp(rank, world_size, tiny_args):
# Initialize distributed training
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="tcp://127.0.0.1:29501",
rank=rank,
world_size=world_size,
)
test_tiny_save_load(tiny_args, fsdp=True)
torch.distributed.destroy_process_group()


def test_tiny_save_load_fsdp(tiny_args):
world_size = 1
mp.spawn(_save_load_helper_fsdp, args=(world_size, tiny_args), nprocs=world_size, join=True)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 7086109

Please sign in to comment.