Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace batch_size with global_batch_size. #150

Merged
merged 11 commits into from
Dec 18, 2023
4 changes: 2 additions & 2 deletions open_lm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
)

map_dict_handler = {"handler": log_and_continue} if args.ignore_parse_errors else {}
batch_size = args.batch_size if is_train else args.val_batch_size
batch_size = args.per_gpu_batch_size if is_train else args.per_gpu_val_batch_size

if data_key == "json" or data_key == "json.gz":
pipeline.extend(
Expand Down Expand Up @@ -545,7 +545,7 @@ def get_synthetic_dataset(args, is_train, epoch, tokenizer, data_key, floor):

dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
batch_size=args.per_gpu_batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
Expand Down
21 changes: 18 additions & 3 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,21 @@ def main(args):

# fully initialize distributed device environment
device = init_distributed_device(args)

assert (
args.global_batch_size % args.world_size == 0
), f"Global batch size ({args.global_batch_size}) is not divisible by number of GPUs ({args.world_size}), and thus cannot be respected."

args.per_gpu_batch_size = max(args.global_batch_size // args.world_size, 1)
if args.val_data is not None:
args.per_gpu_val_batch_size = max(args.global_val_batch_size // args.world_size, 1)

if args.hf_model is not None and args.hf_seq_len is None:
raise ValueError("If passing --hf-model, must also pass --hf-seq-len to be used for training/fine-tuning.")

if args.hf_model is not None and args.fsdp and args.hf_fsdp_block is None:
raise ValueError("If passing --hf-model and --fsdp, must also pass --hf-fspd-block.")

if args.fsdp and not args.distributed:
raise ValueError(f"--fsdp can only be specified in distributed mode.")

Expand All @@ -301,7 +316,7 @@ def main(args):
date_str,
f"model_{model_name_safe}",
f"lr_{args.lr}",
f"b_{args.batch_size}",
f"b_{args.per_gpu_batch_size}", # Per gpu to respect old naming convention
]
)

Expand Down Expand Up @@ -607,7 +622,7 @@ def main(args):
scheduler = None
if requires_training:
if args.dataset_manifest is not None:
total_steps = (args.train_num_samples * args.epochs) // (args.batch_size * args.world_size)
total_steps = (args.train_num_samples * args.epochs) // args.global_batch_size
else:
total_steps = (data["train"].dataloader.num_batches) * args.epochs

Expand Down Expand Up @@ -720,7 +735,7 @@ def main(args):

done_training = global_step >= total_steps
steps_done_epoch = global_step - prev_step
samples_seen = samples_seen + steps_done_epoch * args.batch_size * args.world_size
samples_seen = samples_seen + steps_done_epoch * args.global_batch_size

if not success:
logging.info("Training exiting due to NaN value")
Expand Down
13 changes: 6 additions & 7 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def parse_args(args):
help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
)
parser.add_argument("--workers", type=int, default=1, help="Number of dataloader workers per GPU.")
parser.add_argument("--batch-size", type=int, default=64, help="Batch size per GPU.")
parser.add_argument("--global-batch-size", type=int, default=64, help="Global batch size.")
parser.add_argument("--epochs", type=int, default=32, help="Number of epochs to train for.")
parser.add_argument(
"--epochs-cooldown",
Expand Down Expand Up @@ -298,7 +298,7 @@ def parse_args(args):
help="How often to run evaluation with val-data (in epochs). Last epoch validated if val-data provided.",
)
parser.add_argument(
"--val-batch-size",
"--global-val-batch-size",
type=int,
default=None,
help="Batch size to be used with val-data.",
Expand Down Expand Up @@ -380,7 +380,7 @@ def parse_args(args):
"--accum-freq",
type=int,
default=1,
help="Update the model every --acum-freq steps.",
help="Update the model every --accum-freq steps.",
)
# arguments for distributed training
parser.add_argument(
Expand Down Expand Up @@ -581,9 +581,8 @@ def parse_args(args):
assert args.train_data is None, "--train-data must not be specified if --dataset-type='synthetic'"
assert args.dataset_manifest is None, "--dataset-manifest must not be specified if --dataset-type='synthetic'"

if args.val_data is not None and args.val_batch_size is None:
# if not set explicitly make sure that the val batch size is set to the micro batch size

args.val_batch_size = args.batch_size // args.accum_freq
if args.val_data is not None and args.global_val_batch_size is None:
# Make sure that val batch size is set to micro batch size
args.global_val_batch_size = args.global_batch_size // args.accum_freq

return args
6 changes: 3 additions & 3 deletions open_lm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
else:
# split up batch into accum_freq chunks -- if you have --batch-size 8 and --accum-freq 4
# then you only process 2 items at a time. batch-size must be divisible by accume-freq.
assert args.batch_size % args.accum_freq == 0, "Batch size must be divisible by accum_freq"
per_batch = args.batch_size // args.accum_freq
assert args.per_gpu_batch_size % args.accum_freq == 0, "Per-GPU batch size must be divisible by accum_freq"
per_batch = args.per_gpu_batch_size // args.accum_freq

inputs, targets = sample_chunk(texts, args)

Expand Down Expand Up @@ -288,7 +288,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
"samples_per_second": samples_per_second,
"samples_per_second_per_gpu": samples_per_second_per_gpu,
"lr": optimizer.param_groups[0]["lr"],
"tokens": (step + 1) * args.batch_size * args.seq_len * args.world_size,
"tokens": (step + 1) * args.global_batch_size * args.seq_len,
}

if args.log_logit_mean:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _read_reqs(relpath):

setuptools.setup(
name="open_lm",
version="0.0.22",
version="0.0.23",
author=[
"Suchin Gururangan*",
"Mitchell Wortsman*",
Expand Down
8 changes: 5 additions & 3 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, model, **kwargs):
"--wd", "0.033",
"--lr", "3e-3",
"--warmup", "2",
"--batch-size", "8",
"--global-batch-size", "8",
"--accum", "1",
"--name", "test_model_name",
"--logs", "./tests/assets/",
Expand All @@ -43,6 +43,7 @@ def __init__(self, model, **kwargs):
self.vocab_size = 50432
self.seq_len = 300
self.wandb = False
self.per_gpu_batch_size = self.global_batch_size // self.world_size
self.distributed = False

for k, v in kwargs.items():
Expand All @@ -64,7 +65,7 @@ def __init__(self):
self.disable_buffer = True
self.seq_len = 300
self.vocab_size = 50432
self.batch_size = 64
self.global_batch_size = 64
self.world_size = 1
self.rank = 0
self.workers = 2
Expand All @@ -73,14 +74,15 @@ def __init__(self):
self.target_mask_left = None
self.target_mask_individual = None
self.ignore_parse_errors = False
self.per_gpu_batch_size = self.global_batch_size // self.world_size


def create_train_fixtures(model="open_lm_11m", fsdp=False, **kwargs):
# Setup data, optimizer, and other basic settings
args = MockTrainArgs(model, **kwargs)

# only want to look at one batch
args.train_num_samples = args.batch_size
args.train_num_samples = args.global_batch_size

# increase learning rate and remove warmup for maximize change to model weights
args.lr = 1e-3
Expand Down
6 changes: 4 additions & 2 deletions tests/test_dataset_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ def retrieve_dataset(epoch, next_shard, weights, seed, disable_buffer, min_shard
args.train_num_samples = NUM_SAMPLES
args.train_data = train_data_string_per_source
args.workers = 2
args.batch_size = 2
args.global_batch_size = 2
args.seed = seed
args.dataset_resampled = False
args.disable_buffer = disable_buffer
args.vocab_size = _MODEL_CONFIGS[args.model]["vocab_size"]
args.seq_len = _MODEL_CONFIGS[args.model]["seq_len"]
args.world_size = 1
args.rank = 0
args.per_gpu_batch_size = 2
data = get_wds_dataset(args, is_train=True, epoch=epoch, force_num_samples=num_samples_per_source)
dl = data.dataloader

Expand All @@ -58,13 +59,14 @@ def retrieve_dataset_resampled(epoch, next_shard, weights, seed, min_shards_need
args.train_num_samples = NUM_SAMPLES
args.train_data = train_data_string_per_source
args.num_workers = 2
args.batch_size = 2
args.global_batch_size = 2
args.seed = seed
args.dataset_resampled = True
args.vocab_size = _MODEL_CONFIGS[args.model]["vocab_size"]
args.seq_len = _MODEL_CONFIGS[args.model]["seq_len"]
args.world_size = 1
args.rank = 0
args.per_gpu_batch_size = 2
data = get_wds_dataset(args, is_train=True, epoch=epoch)
dl = data.dataloader

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset_no_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def retrieve_dataset_once(
args.train_num_samples = total_seqs
args.train_data = train_data_string_per_source
args.workers = num_workers
args.batch_size = batch_size
args.global_batch_size = args.per_gpu_batch_size = batch_size
args.seed = seed
args.dataset_resampled = False
args.disable_buffer = disable_buffer
Expand Down
4 changes: 2 additions & 2 deletions tests/test_param_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def get_cmdline_config1():
# fmt: off
cmdline = [
"--train-num-samples", str(samples),
"--batch-size", str(batch_size),
"--global-batch-size", str(batch_size),
"--dataset-type", "synthetic",
"--model", "open_lm_test_tiny",
"--epochs", "1",
]
config_dict = {
"train-num-samples": samples,
"batch-size": batch_size,
"global-batch-size": batch_size,
"dataset-type": "synthetic",
"model": "open_lm_test_tiny",
"epochs": 1,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_training_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_train_simple():
# fmt: off
main([
"--train-num-samples", str(num_batches * seq_len),
"--batch-size", str(batch_size),
"--global-batch-size", str(batch_size),
"--dataset-type", "synthetic",
"--model", "open_lm_test_tiny",
"--epochs", "1",
Expand Down
6 changes: 3 additions & 3 deletions tests/test_training_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ def test_token_count(test_case):

download_dl_test_data()
args, model, _, optimizer, scheduler, loss = create_train_fixtures("open_lm_11m")
args.batch_size = batch_size
args.global_batch_size = batch_size
args.per_gpu_batch_size = args.global_batch_size // args.world_size
args.workers = workers
args.train_data = None
args.dataset_manifest = SOURCE_MANIFEST
args.epochs = desired_epochs
args.train_num_samples = desired_sequences_per_epoch

total_samples = desired_sequences_per_epoch * desired_epochs
global_batch_size = args.batch_size * args.world_size
total_steps = total_samples // (global_batch_size)
total_steps = total_samples // (args.global_batch_size)
global_step = 0
next_shard_per_source = [0]
epoch = 0
Expand Down
Loading