Skip to content

Commit

Permalink
Use argparse to avoid duplicating args in test code
Browse files Browse the repository at this point in the history
  • Loading branch information
achalddave committed Dec 16, 2023
1 parent 093f299 commit 92a7139
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 71 deletions.
84 changes: 27 additions & 57 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from open_lm.distributed import init_distributed_device
from open_lm.main import random_seed
from open_lm.model import create_model
from open_lm.params import parse_args
from open_lm.scheduler import cosine_lr
from tests.utils import download_val_data

Expand All @@ -14,69 +15,38 @@ class MockTrainArgs:
def __init__(self, model, **kwargs):
data_path = download_val_data("shard_00000000.tar", "./tests/assets/")

self.model = model # part of model config
self.model_norm = "gain_only_layer_norm"
self.qk_norm = False
self.train_data = [
data_path,
]
self.log_logit_mean = False
self.device = "cpu"
self.precision = "float32"
self.wd = 0.033
self.lr = 3e-3
self.beta1 = 0.9
self.beta2 = 0.95
self.eps = 1e-8
self.warmup = 2
self.skip_scheduler = False
self.accum_freq = 1
self.batch_size = 8
self.grad_clip_norm = 1.0
# fmt: off
args = parse_args([
"--model", model,
"--model-norm", "gain_only_layer_norm",
"--train-data", data_path,
"--precision", "fp32",
"--wd", "0.033",
"--lr", "3e-3",
"--warmup", "2",
"--batch-size", "8",
"--accum", "1",
"--name", "test_model_name",
"--logs", "./tests/assets/",
"--workers", "1",
"--data-key", "json",
"--seed", "1",
])
# fmt: off
for k, v in vars(args).items():
setattr(self, k, v)

self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.rank = 0
self.local_rank = 0
self.log_every_n_steps = 1e8
self.save_logs = False
self.logs = None
self.name = "test_model_name"
self.dataset_type = "webdataset"
self.data_key = "json"
self.ffn_type = "swiglu"
self.train_num_samples = 250000
self.train_data_mix_weights = None
self.train_data_upsampling_factors = None
self.disable_buffer = False
self.seed = 1
self.world_size = 1
self.vocab_size = 50432
self.seq_len = 300
self.epochs = 1
self.save_frequency = 1
self.checkpoint_path = "./tests/assets/checkpoints/"
self.resume = None
self.distributed = False
self.delete_previous_checkpoint = False
self.workers = 1
self.world_size = 1
self.val_data = None
self.lr_cooldown_end = 3e-5
self.force_min_lr = 0.0
self.scaler = None
self.accum_freq = 1
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.wandb = False
self.fsdp = False
self.fsdp_amp = False
self.positional_embedding_type = "rotary"
self.dist_backend = "nccl"
self.dist_url = "env://"
self.dataset_manifest = None
self.target_mask_left = None
self.target_mask_individual = None
self.ignore_parse_errors = False
self.distributed = False

for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)
setattr(self, k, v)


class MockDataArgs(object):
Expand All @@ -92,7 +62,7 @@ def __init__(self):
self.train_data_upsampling_factors = None
self.train_num_samples = 512
self.disable_buffer = True
self.seq_len = 2048
self.seq_len = 300
self.vocab_size = 50432
self.batch_size = 64
self.world_size = 1
Expand Down
7 changes: 3 additions & 4 deletions tests/test_generate_kv_cache_time.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import time
import pytest

import argparse

from transformers import GPTNeoXTokenizerFast

from open_lm.utils.transformers.hf_model import OpenLMforCausalLM
from open_lm.utils.transformers.hf_config import OpenLMConfig
from open_lm.model import create_params
from tests.shared import MockTrainArgs
from .utils import run_model


Expand All @@ -18,10 +17,10 @@
@pytest.mark.parametrize("max_gen_len", [1024, 1792])
def test_generate_kv_cache(wiki_page, context_len, max_gen_len):
"""Test that the model generates faster with cache than without."""
args = argparse.Namespace(
args = MockTrainArgs(
model="open_lm_160m",
**{
# Generation params:
"model": "open_lm_160m",
"input_text": "random",
"max_gen_len": max_gen_len,
"context_len": context_len,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_generate_load_kv_cache_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os

import pytest
import argparse
import torch

from huggingface_hub import hf_hub_download
Expand All @@ -11,6 +10,7 @@
from open_lm.utils.transformers.hf_model import OpenLMforCausalLM
from open_lm.utils.transformers.hf_config import OpenLMConfig
from open_lm.model import create_params
from tests.shared import MockTrainArgs
from .utils import run_model


Expand All @@ -26,10 +26,10 @@ def args():
model_path = hf_hub_download("mlfoundations/open_lm_1B", filename="open_lm_1b.pt")
shutil.copy2(model_path, "checkpoints/open_lm_1b_old.pt")

args = argparse.Namespace(
args = MockTrainArgs(
model="open_lm_1b_old",
**{
# Generation params:
"model": "open_lm_1b_old",
"input_text": "random",
"max_gen_len": None,
"context_len": None,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_grad_accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _grad_acc_helper_single(test_fsdp, accs=[2, 1], threshold=1e-7):
epoch=0,
step=0,
optimizer=optimizer,
scaler=args.scaler,
scaler=None,
scheduler=scheduler,
total_steps=10,
args=args,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
from open_lm.losses import CrossEntropyLossWithZLoss
from open_lm.distributed import is_using_distributed

from tests.shared import create_train_fixtures
from tests.shared import MockTrainArgs, create_train_fixtures
from tests.utils import download_dl_test_data


@pytest.fixture(scope="module")
def tiny_args():
args = argparse.Namespace(
args = MockTrainArgs(
model="open_lm_test_tiny",
**{
"model": "open_lm_test_tiny",
"vocab_size": 16,
"sequence_length": 16,
"train_num_samples": 64,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_tiny_generate_kv_cache_equal.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest
import argparse

from open_lm.utils.transformers.hf_model import OpenLMforCausalLM
from open_lm.utils.transformers.hf_config import OpenLMConfig
from open_lm.model import create_params
from tests.shared import MockTrainArgs
from tests.utils import run_model, CharacterTokenizer


Expand All @@ -12,10 +12,10 @@
@pytest.mark.slow
@pytest.fixture(scope="module")
def args():
args = argparse.Namespace(
args = MockTrainArgs(
model="open_lm_test_tiny",
**{
# Generation params:
"model": "open_lm_test_tiny",
"input_text": "random",
"max_gen_len": None,
"context_len": None,
Expand Down

0 comments on commit 92a7139

Please sign in to comment.