Skip to content

Commit

Permalink
Test train save load, with and without fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
jmercat committed Dec 11, 2023
1 parent 158124d commit e50f2ec
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 3 deletions.
4 changes: 3 additions & 1 deletion open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,11 @@ def save_checkpoint(
or (args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0)
):
for prefix in prefixes:
path = os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt")
print(f"Saving {prefix} {completed_epoch} in {path}...")
torch.save(
prefixes[prefix],
os.path.join(args.checkpoint_path, f"{prefix}{completed_epoch}.pt"),
path,
)

if args.delete_previous_checkpoint:
Expand Down
15 changes: 13 additions & 2 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from torch import optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from open_lm.data import get_data
from open_lm.distributed import init_distributed_device
from open_lm.main import random_seed
from open_lm.model import create_model
from open_lm.data import get_data
from open_lm.scheduler import cosine_lr
from tests.utils import download_val_data

Expand Down Expand Up @@ -36,6 +37,9 @@ def __init__(self, model):
self.rank = 0
self.local_rank = 0
self.log_every_n_steps = 1e8
self.save_logs = True
self.logs = None
self.name = "test_model_name"
self.dataset_type = "webdataset"
self.data_key = "json"
self.ffn_type = "swiglu"
Expand All @@ -46,6 +50,12 @@ def __init__(self, model):
self.seed = 1
self.vocab_size = 50432
self.seq_len = 300
self.epochs = 1
self.save_frequency = 1
self.checkpoint_path = "./tests/assets/checkpoints/"
self.resume = None
self.distributed = False
self.delete_previous_checkpoint = False
self.workers = 1
self.world_size = 1
self.val_data = None
Expand Down Expand Up @@ -106,9 +116,10 @@ def create_train_fixtures(model="open_lm_11m", fsdp=False):
# create base models
random_seed()
if fsdp:
init_distributed_device(args)
with torch.device("meta"):
model = create_model(args)
model = FSDP(model)
model = FSDP(model)
else:
model = create_model(args)
model.reset_parameters()
Expand Down
114 changes: 114 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import pytest
import argparse
import torch
import os

import torch.multiprocessing as mp

from open_lm.utils.transformers.hf_model import OpenLMforCausalLM
from open_lm.utils.transformers.hf_config import OpenLMConfig
from open_lm.model import create_params
from open_lm.train import train_one_epoch
from open_lm.main import save_checkpoint, load_model
from open_lm.losses import CrossEntropyLossWithZLoss
from open_lm.distributed import is_using_distributed

from tests.shared import create_train_fixtures
from tests.utils import download_dl_test_data

@pytest.fixture(scope="module")
def tiny_args():
args = argparse.Namespace(
**{
"model": "open_lm_test_tiny",
# Model params that might not be in config:
"model_norm": "default_layer_norm",
"qk_norm": False,
"positional_embedding_type": "rotary",
"ffn_type": "swiglu",
}
)
return args


@pytest.fixture(scope="module")
def tiny_open_lm(tiny_args):
tiny_open_lm = OpenLMforCausalLM(OpenLMConfig(create_params(tiny_args)))
return tiny_open_lm


def test_tiny_save_load(tiny_open_lm, tiny_args, fsdp=False):
"""
This test checks that the model can be saved and loaded without changing the parameters.
"""
scaler = None
epoch = 0
evaluation_metrics = None
global_step = 0
done_training = False
download_dl_test_data()
args, model, data, optimizer, scheduler, loss = create_train_fixtures(tiny_args.model, fsdp=False)
args2, model2, data2, optimizer2, scheduler2, loss2 = create_train_fixtures(tiny_args.model, fsdp=False)
args.vocab_size = tiny_open_lm.config.vocab_size
args.seq_len = tiny_open_lm.config.seq_len
args.train_num_samples = 16
args.batch_size = 4
args.checkpoint_path = "./tests/assets/checkpoints/tiny_model/"
os.makedirs(args.checkpoint_path, exist_ok=True)

print("Training tiny model")
train_one_epoch(tiny_open_lm, data, CrossEntropyLossWithZLoss(), epoch=epoch, step=global_step, optimizer=optimizer, scaler=scaler, scheduler=scheduler, total_steps=-1, args=args)
epoch += 1
args.distributed = is_using_distributed()

print("Saving tiny model")
# Saving checkpoints.
save_checkpoint(
args,
model,
optimizer,
scaler,
epoch,
evaluation_metrics,
step=global_step,
is_final_checkpoint=done_training,
next_shard_per_source=None,
samples_seen=None,
)

args.resume = "./tests/assets/checkpoints/tiny_model/epoch_1.pt"

print("Loading saved tiny model")
load_model(args, model2)
threshold = 1e-3

print("Checking that loaded tiny model is the same as the original tiny model")
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model2.named_parameters()):
assert torch.allclose(p1, p2, atol=threshold)

def _save_load_helper_fsdp(rank, world_size, tiny_open_lm, tiny_args):
# Initialize distributed training
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="tcp://127.0.0.1:29501",
rank=rank,
world_size=world_size,
)
test_tiny_save_load(tiny_open_lm, tiny_args, fsdp=True)
torch.distributed.destroy_process_group()


@pytest.mark.gpu
def test_tiny_save_load_fsdp(tiny_open_lm, tiny_args):
world_size = 1
mp.spawn(
_save_load_helper_fsdp,
args=(world_size, tiny_open_lm, tiny_args),
nprocs=world_size,
join=True
)

# def test_s3_load(

if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit e50f2ec

Please sign in to comment.