Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test train save load, with and without fsdp #148

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,11 @@ def save_checkpoint(
or (args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0)
):
for prefix in prefixes:
path = os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt")
print(f"Saving {prefix}{completed_epoch} in {path}...")
torch.save(
prefixes[prefix],
os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt"),
path,
)

if args.delete_previous_checkpoint:
Expand Down
29 changes: 21 additions & 8 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
from torch import optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from open_lm.data import get_data
from open_lm.distributed import init_distributed_device
from open_lm.main import random_seed
from open_lm.model import create_model
from open_lm.data import get_data
from open_lm.scheduler import cosine_lr
from tests.utils import download_val_data


class MockTrainArgs:
def __init__(self, model):
def __init__(self, model, **kwargs):
data_path = download_val_data("shard_00000000.tar", "./tests/assets/")

self.model = model # part of model config
self.model_norm = "gain_only_layer_norm"
self.rotary_old = False
self.qk_norm = False
self.train_data = [
data_path,
]
self.log_logit_mean = False
self.device = 0
self.device = "cpu"
self.precision = "float32"
self.wd = 0.033
self.lr = 3e-3
Expand All @@ -36,6 +36,9 @@ def __init__(self, model):
self.rank = 0
self.local_rank = 0
self.log_every_n_steps = 1e8
self.save_logs = False
self.logs = None
self.name = "test_model_name"
self.dataset_type = "webdataset"
self.data_key = "json"
self.ffn_type = "swiglu"
Expand All @@ -46,6 +49,12 @@ def __init__(self, model):
self.seed = 1
self.vocab_size = 50432
self.seq_len = 300
self.epochs = 1
self.save_frequency = 1
self.checkpoint_path = "./tests/assets/checkpoints/"
self.resume = None
self.distributed = False
self.delete_previous_checkpoint = False
self.workers = 1
self.world_size = 1
self.val_data = None
Expand All @@ -65,6 +74,10 @@ def __init__(self, model):
self.target_mask_individual = None
self.ignore_parse_errors = False

for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)


class MockDataArgs(object):
def __init__(self):
Expand Down Expand Up @@ -92,9 +105,9 @@ def __init__(self):
self.ignore_parse_errors = False


def create_train_fixtures(model="open_lm_11m", fsdp=False):
def create_train_fixtures(model="open_lm_11m", fsdp=False, **kwargs):
# Setup data, optimizer, and other basic settings
args = MockTrainArgs(model)
args = MockTrainArgs(model, **kwargs)

# only want to look at one batch
args.train_num_samples = args.batch_size
Expand All @@ -106,8 +119,8 @@ def create_train_fixtures(model="open_lm_11m", fsdp=False):
# create base models
random_seed()
if fsdp:
with torch.device("meta"):
model = create_model(args)
init_distributed_device(args)
model = create_model(args)
model = FSDP(model)
else:
model = create_model(args)
Expand Down
134 changes: 134 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import pytest
import argparse
import torch
import os

import torch.multiprocessing as mp

from open_lm.utils.transformers.hf_model import OpenLMforCausalLM
from open_lm.utils.transformers.hf_config import OpenLMConfig
from open_lm.model import create_params
from open_lm.train import train_one_epoch
from open_lm.main import save_checkpoint, load_model
from open_lm.losses import CrossEntropyLossWithZLoss
from open_lm.distributed import is_using_distributed

from tests.shared import create_train_fixtures
from tests.utils import download_dl_test_data


@pytest.fixture(scope="module")
def tiny_args():
args = argparse.Namespace(
**{
"model": "open_lm_test_tiny",
"vocab_size": 16,
"sequence_length": 16,
"train_num_samples": 64,
"batch_size": 4,
# Model params that might not be in config:
"model_norm": "default_layer_norm",
"qk_norm": False,
"positional_embedding_type": "rotary",
"ffn_type": "swiglu",
}
)
return args


def test_tiny_save_load(tiny_args, fsdp=False):
"""
This test checks that the model can be saved and loaded without changing the parameters.
"""
scaler = None
epoch = 0
evaluation_metrics = None
global_step = 0
done_training = False
download_dl_test_data()
override_params = dict(
seq_len=tiny_args.sequence_length,
vocab_size=tiny_args.vocab_size,
train_num_samples=tiny_args.train_num_samples,
batch_size=tiny_args.batch_size,
checkpoint_path="./tests/assets/checkpoints/tiny_model/",
device="cpu" if not torch.cuda.is_available() else "cuda",
dataset_type="synthetic",
save_logs=True,
)

args, model, data, optimizer, scheduler, loss = create_train_fixtures(
tiny_args.model, fsdp=False, **override_params
)
model = model.to(args.device)
args2, model2, data2, optimizer2, scheduler2, loss2 = create_train_fixtures(
tiny_args.model, fsdp=False, **override_params
)
model2 = model2.to(args2.device)
os.makedirs(args.checkpoint_path, exist_ok=True)

# print("Training tiny model")
train_one_epoch(
model,
data,
CrossEntropyLossWithZLoss(),
epoch=epoch,
step=global_step,
optimizer=optimizer,
scaler=scaler,
scheduler=scheduler,
total_steps=args.train_num_samples // args.batch_size,
args=args,
)
epoch += 1
threshold = 1e-6
# Checking that tiny models diverged after traning.
allclose = True
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model2.named_parameters()):
allclose = allclose and torch.allclose(p1, p2, atol=threshold)
assert not allclose

args.distributed = is_using_distributed()
# print("Saving tiny model")
# Saving checkpoints.
save_checkpoint(
args,
model,
optimizer,
scaler,
epoch,
evaluation_metrics,
step=global_step,
is_final_checkpoint=done_training,
next_shard_per_source=None,
samples_seen=None,
)

# Loading saved tiny model
args.resume = "./tests/assets/checkpoints/tiny_model/epoch_1.pt"
load_model(args, model2)

# Checking that loaded tiny model is the same as the original tiny model
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model2.named_parameters()):
assert torch.allclose(p1, p2, atol=threshold)


def _save_load_helper_fsdp(rank, world_size, tiny_args):
# Initialize distributed training
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="tcp://127.0.0.1:29501",
rank=rank,
world_size=world_size,
)
test_tiny_save_load(tiny_args, fsdp=True)
torch.distributed.destroy_process_group()


def test_tiny_save_load_fsdp(tiny_args):
world_size = 1
mp.spawn(_save_load_helper_fsdp, args=(world_size, tiny_args), nprocs=world_size, join=True)


if __name__ == "__main__":
pytest.main([__file__])
Loading