Skip to content

Commit

Permalink
Test checkpoint saving and loading, with and without fsdp (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmercat authored Dec 14, 2023
1 parent e016855 commit 813d501
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 9 deletions.
4 changes: 3 additions & 1 deletion open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,11 @@ def save_checkpoint(
or (args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0)
):
for prefix in prefixes:
path = os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt")
print(f"Saving {prefix}{completed_epoch} in {path}...")
torch.save(
prefixes[prefix],
os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt"),
path,
)

if args.delete_previous_checkpoint:
Expand Down
29 changes: 21 additions & 8 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
from torch import optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from open_lm.data import get_data
from open_lm.distributed import init_distributed_device
from open_lm.main import random_seed
from open_lm.model import create_model
from open_lm.data import get_data
from open_lm.scheduler import cosine_lr
from tests.utils import download_val_data


class MockTrainArgs:
def __init__(self, model):
def __init__(self, model, **kwargs):
data_path = download_val_data("shard_00000000.tar", "./tests/assets/")

self.model = model # part of model config
self.model_norm = "gain_only_layer_norm"
self.rotary_old = False
self.qk_norm = False
self.train_data = [
data_path,
]
self.log_logit_mean = False
self.device = 0
self.device = "cpu"
self.precision = "float32"
self.wd = 0.033
self.lr = 3e-3
Expand All @@ -36,6 +36,9 @@ def __init__(self, model):
self.rank = 0
self.local_rank = 0
self.log_every_n_steps = 1e8
self.save_logs = False
self.logs = None
self.name = "test_model_name"
self.dataset_type = "webdataset"
self.data_key = "json"
self.ffn_type = "swiglu"
Expand All @@ -46,6 +49,12 @@ def __init__(self, model):
self.seed = 1
self.vocab_size = 50432
self.seq_len = 300
self.epochs = 1
self.save_frequency = 1
self.checkpoint_path = "./tests/assets/checkpoints/"
self.resume = None
self.distributed = False
self.delete_previous_checkpoint = False
self.workers = 1
self.world_size = 1
self.val_data = None
Expand All @@ -65,6 +74,10 @@ def __init__(self, model):
self.target_mask_individual = None
self.ignore_parse_errors = False

for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)


class MockDataArgs(object):
def __init__(self):
Expand Down Expand Up @@ -92,9 +105,9 @@ def __init__(self):
self.ignore_parse_errors = False


def create_train_fixtures(model="open_lm_11m", fsdp=False):
def create_train_fixtures(model="open_lm_11m", fsdp=False, **kwargs):
# Setup data, optimizer, and other basic settings
args = MockTrainArgs(model)
args = MockTrainArgs(model, **kwargs)

# only want to look at one batch
args.train_num_samples = args.batch_size
Expand All @@ -106,8 +119,8 @@ def create_train_fixtures(model="open_lm_11m", fsdp=False):
# create base models
random_seed()
if fsdp:
with torch.device("meta"):
model = create_model(args)
init_distributed_device(args)
model = create_model(args)
model = FSDP(model)
else:
model = create_model(args)
Expand Down
134 changes: 134 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import pytest
import argparse
import torch
import os

import torch.multiprocessing as mp

from open_lm.utils.transformers.hf_model import OpenLMforCausalLM
from open_lm.utils.transformers.hf_config import OpenLMConfig
from open_lm.model import create_params
from open_lm.train import train_one_epoch
from open_lm.main import save_checkpoint, load_model
from open_lm.losses import CrossEntropyLossWithZLoss
from open_lm.distributed import is_using_distributed

from tests.shared import create_train_fixtures
from tests.utils import download_dl_test_data


@pytest.fixture(scope="module")
def tiny_args():
args = argparse.Namespace(
**{
"model": "open_lm_test_tiny",
"vocab_size": 16,
"sequence_length": 16,
"train_num_samples": 64,
"batch_size": 4,
# Model params that might not be in config:
"model_norm": "default_layer_norm",
"qk_norm": False,
"positional_embedding_type": "rotary",
"ffn_type": "swiglu",
}
)
return args


def test_tiny_save_load(tiny_args, fsdp=False):
"""
This test checks that the model can be saved and loaded without changing the parameters.
"""
scaler = None
epoch = 0
evaluation_metrics = None
global_step = 0
done_training = False
download_dl_test_data()
override_params = dict(
seq_len=tiny_args.sequence_length,
vocab_size=tiny_args.vocab_size,
train_num_samples=tiny_args.train_num_samples,
batch_size=tiny_args.batch_size,
checkpoint_path="./tests/assets/checkpoints/tiny_model/",
device="cpu" if not torch.cuda.is_available() else "cuda",
dataset_type="synthetic",
save_logs=True,
)

args, model, data, optimizer, scheduler, loss = create_train_fixtures(
tiny_args.model, fsdp=False, **override_params
)
model = model.to(args.device)
args2, model2, data2, optimizer2, scheduler2, loss2 = create_train_fixtures(
tiny_args.model, fsdp=False, **override_params
)
model2 = model2.to(args2.device)
os.makedirs(args.checkpoint_path, exist_ok=True)

# print("Training tiny model")
train_one_epoch(
model,
data,
CrossEntropyLossWithZLoss(),
epoch=epoch,
step=global_step,
optimizer=optimizer,
scaler=scaler,
scheduler=scheduler,
total_steps=args.train_num_samples // args.batch_size,
args=args,
)
epoch += 1
threshold = 1e-6
# Checking that tiny models diverged after traning.
allclose = True
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model2.named_parameters()):
allclose = allclose and torch.allclose(p1, p2, atol=threshold)
assert not allclose

args.distributed = is_using_distributed()
# print("Saving tiny model")
# Saving checkpoints.
save_checkpoint(
args,
model,
optimizer,
scaler,
epoch,
evaluation_metrics,
step=global_step,
is_final_checkpoint=done_training,
next_shard_per_source=None,
samples_seen=None,
)

# Loading saved tiny model
args.resume = "./tests/assets/checkpoints/tiny_model/epoch_1.pt"
load_model(args, model2)

# Checking that loaded tiny model is the same as the original tiny model
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model2.named_parameters()):
assert torch.allclose(p1, p2, atol=threshold)


def _save_load_helper_fsdp(rank, world_size, tiny_args):
# Initialize distributed training
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="tcp://127.0.0.1:29501",
rank=rank,
world_size=world_size,
)
test_tiny_save_load(tiny_args, fsdp=True)
torch.distributed.destroy_process_group()


def test_tiny_save_load_fsdp(tiny_args):
world_size = 1
mp.spawn(_save_load_helper_fsdp, args=(world_size, tiny_args), nprocs=world_size, join=True)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 813d501

Please sign in to comment.