-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Test train save load, with and without fsdp
- Loading branch information
Showing
3 changed files
with
130 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import pytest | ||
import argparse | ||
import torch | ||
import os | ||
|
||
import torch.multiprocessing as mp | ||
|
||
from open_lm.utils.transformers.hf_model import OpenLMforCausalLM | ||
from open_lm.utils.transformers.hf_config import OpenLMConfig | ||
from open_lm.model import create_params | ||
from open_lm.train import train_one_epoch | ||
from open_lm.main import save_checkpoint, load_model | ||
from open_lm.losses import CrossEntropyLossWithZLoss | ||
from open_lm.distributed import is_using_distributed | ||
|
||
from tests.shared import create_train_fixtures | ||
from tests.utils import download_dl_test_data | ||
|
||
@pytest.fixture(scope="module") | ||
def tiny_args(): | ||
args = argparse.Namespace( | ||
**{ | ||
"model": "open_lm_test_tiny", | ||
# Model params that might not be in config: | ||
"model_norm": "default_layer_norm", | ||
"qk_norm": False, | ||
"positional_embedding_type": "rotary", | ||
"ffn_type": "swiglu", | ||
} | ||
) | ||
return args | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def tiny_open_lm(tiny_args): | ||
tiny_open_lm = OpenLMforCausalLM(OpenLMConfig(create_params(tiny_args))) | ||
return tiny_open_lm | ||
|
||
|
||
def test_tiny_save_load(tiny_open_lm, tiny_args, fsdp=False): | ||
""" | ||
This test checks that the model can be saved and loaded without changing the parameters. | ||
""" | ||
scaler = None | ||
epoch = 0 | ||
evaluation_metrics = None | ||
global_step = 0 | ||
done_training = False | ||
download_dl_test_data() | ||
args, model, data, optimizer, scheduler, loss = create_train_fixtures(tiny_args.model, fsdp=False) | ||
args2, model2, data2, optimizer2, scheduler2, loss2 = create_train_fixtures(tiny_args.model, fsdp=False) | ||
args.vocab_size = tiny_open_lm.config.vocab_size | ||
args.seq_len = tiny_open_lm.config.seq_len | ||
args.train_num_samples = 16 | ||
args.batch_size = 4 | ||
args.checkpoint_path = "./tests/assets/checkpoints/tiny_model/" | ||
os.makedirs(args.checkpoint_path, exist_ok=True) | ||
|
||
print("Training tiny model") | ||
train_one_epoch(tiny_open_lm, data, CrossEntropyLossWithZLoss(), epoch=epoch, step=global_step, optimizer=optimizer, scaler=scaler, scheduler=scheduler, total_steps=-1, args=args) | ||
epoch += 1 | ||
args.distributed = is_using_distributed() | ||
|
||
print("Saving tiny model") | ||
# Saving checkpoints. | ||
save_checkpoint( | ||
args, | ||
model, | ||
optimizer, | ||
scaler, | ||
epoch, | ||
evaluation_metrics, | ||
step=global_step, | ||
is_final_checkpoint=done_training, | ||
next_shard_per_source=None, | ||
samples_seen=None, | ||
) | ||
|
||
args.resume = "./tests/assets/checkpoints/tiny_model/epoch_1.pt" | ||
|
||
print("Loading saved tiny model") | ||
load_model(args, model2) | ||
threshold = 1e-3 | ||
|
||
print("Checking that loaded tiny model is the same as the original tiny model") | ||
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model2.named_parameters()): | ||
assert torch.allclose(p1, p2, atol=threshold) | ||
|
||
def _save_load_helper_fsdp(rank, world_size, tiny_open_lm, tiny_args): | ||
# Initialize distributed training | ||
torch.distributed.init_process_group( | ||
backend="nccl" if torch.cuda.is_available() else "gloo", | ||
init_method="tcp://127.0.0.1:29501", | ||
rank=rank, | ||
world_size=world_size, | ||
) | ||
test_tiny_save_load(tiny_open_lm, tiny_args, fsdp=True) | ||
torch.distributed.destroy_process_group() | ||
|
||
|
||
@pytest.mark.gpu | ||
def test_tiny_save_load_fsdp(tiny_open_lm, tiny_args): | ||
world_size = 1 | ||
mp.spawn( | ||
_save_load_helper_fsdp, | ||
args=(world_size, tiny_open_lm, tiny_args), | ||
nprocs=world_size, | ||
join=True | ||
) | ||
|
||
# def test_s3_load( | ||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |