Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace batch_size with global_batch_size. #150

Merged
merged 11 commits into from
Dec 18, 2023
4 changes: 2 additions & 2 deletions open_lm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
)

map_dict_handler = {"handler": log_and_continue} if args.ignore_parse_errors else {}
batch_size = args.batch_size if is_train else args.val_batch_size
batch_size = args.per_gpu_batch_size if is_train else args.val_batch_size

if data_key == "json" or data_key == "json.gz":
pipeline.extend(
Expand Down Expand Up @@ -545,7 +545,7 @@ def get_synthetic_dataset(args, is_train, epoch, tokenizer, data_key, floor):

dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
batch_size=args.per_gpu_batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
Expand Down
12 changes: 9 additions & 3 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ def main(args):
# fully initialize distributed device environment
device = init_distributed_device(args)

assert (
args.global_batch_size % args.world_size == 0
), "Global batch size is not divisible by number of GPUs, and thus cannot be respected."

args.per_gpu_batch_size = args.global_batch_size // args.world_size

if args.hf_model is not None and args.hf_seq_len is None:
raise ValueError("If passing --hf-model, must also pass --hf-seq-len to be used for training/fine-tuning.")

Expand Down Expand Up @@ -270,7 +276,7 @@ def main(args):
date_str,
f"model_{model_name_safe}",
f"lr_{args.lr}",
f"b_{args.batch_size}",
f"b_{args.per_gpu_batch_size}", # Per gpu to respect old naming convention
]
)

Expand Down Expand Up @@ -587,7 +593,7 @@ def main(args):
scheduler = None
if requires_training:
if args.dataset_manifest is not None:
total_steps = (args.train_num_samples * args.epochs) // (args.batch_size * args.world_size)
total_steps = (args.train_num_samples * args.epochs) // args.global_batch_size
else:
total_steps = (data["train"].dataloader.num_batches) * args.epochs

Expand Down Expand Up @@ -706,7 +712,7 @@ def main(args):

done_training = global_step >= total_steps
steps_done_epoch = global_step - prev_step
samples_seen = samples_seen + steps_done_epoch * args.batch_size * args.world_size
samples_seen = samples_seen + steps_done_epoch * args.global_batch_size

if not success:
logging.info("Training exiting due to NaN value")
Expand Down
6 changes: 3 additions & 3 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def parse_args(args):
help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
)
parser.add_argument("--workers", type=int, default=1, help="Number of dataloader workers per GPU.")
parser.add_argument("--batch-size", type=int, default=64, help="Batch size per GPU.")
parser.add_argument("--global-batch-size", type=int, default=64, help="Global batch size.")
parser.add_argument("--epochs", type=int, default=32, help="Number of epochs to train for.")
parser.add_argument(
"--epochs-cooldown",
Expand Down Expand Up @@ -576,7 +576,7 @@ def parse_args(args):

if args.val_data is not None and args.val_batch_size is None:
# if not set explicitly make sure that the val batch size is set to the micro batch size

args.val_batch_size = args.batch_size // args.accum_freq
# TODO: is this correct with global batch size?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part I'm not 100% sure about, I think it is assumed that we are running eval with only 1 GPU.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would do

args.per_gpu_batch_size_val = (args.global_batch_size // world_size // args.accum_freq)

and change references to args.val_batch_size to args.per_gpu_batch_size_val (I know, it's a mouthful).

Other than that this LGTM!

args.val_batch_size = args.global_batch_size // args.accum_freq

return args
6 changes: 3 additions & 3 deletions open_lm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
else:
# split up batch into accum_freq chunks -- if you have --batch-size 8 and --accum-freq 4
# then you only process 2 items at a time. batch-size must be divisible by accume-freq.
assert args.batch_size % args.accum_freq == 0, "Batch size must be divisible by accum_freq"
per_batch = args.batch_size // args.accum_freq
assert args.per_gpu_batch_size % args.accum_freq == 0, "Per-GPU batch size must be divisible by accum_freq"
per_batch = args.per_gpu_batch_size // args.accum_freq

inputs, targets = sample_chunk(texts, args)

Expand Down Expand Up @@ -287,7 +287,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
"samples_per_second": samples_per_second,
"samples_per_second_per_gpu": samples_per_second_per_gpu,
"lr": optimizer.param_groups[0]["lr"],
"tokens": (step + 1) * args.batch_size * args.seq_len * args.world_size,
"tokens": (step + 1) * args.global_batch_size * args.seq_len,
}

if args.log_logit_mean:
Expand Down
7 changes: 5 additions & 2 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, model):
self.warmup = 2
self.skip_scheduler = False
self.accum_freq = 1
self.batch_size = 8
self.global_batch_size = 8
self.grad_clip_norm = 1.0
self.rank = 0
self.local_rank = 0
Expand Down Expand Up @@ -64,6 +64,8 @@ def __init__(self, model):
self.target_mask_left = None
self.target_mask_individual = None
self.ignore_parse_errors = False
self.per_gpu_batch_size = self.global_batch_size // self.world_size



class MockDataArgs(object):
Expand All @@ -81,7 +83,7 @@ def __init__(self):
self.disable_buffer = True
self.seq_len = 2048
self.vocab_size = 50432
self.batch_size = 64
self.global_batch_size = 64
self.world_size = 1
self.rank = 0
self.workers = 2
Expand All @@ -90,6 +92,7 @@ def __init__(self):
self.target_mask_left = None
self.target_mask_individual = None
self.ignore_parse_errors = False
self.per_gpu_batch_size = self.global_batch_size // self.world_size


def create_train_fixtures(model="open_lm_11m", fsdp=False):
Expand Down
6 changes: 4 additions & 2 deletions tests/test_dataset_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ def retrieve_dataset(epoch, next_shard, weights, seed, disable_buffer, min_shard
args.train_num_samples = NUM_SAMPLES
args.train_data = train_data_string_per_source
args.workers = 2
args.batch_size = 2
args.global_batch_size = 2
args.seed = seed
args.dataset_resampled = False
args.disable_buffer = disable_buffer
args.vocab_size = _MODEL_CONFIGS[args.model]["vocab_size"]
args.seq_len = _MODEL_CONFIGS[args.model]["seq_len"]
args.world_size = 1
args.rank = 0
args.per_gpu_batch_size = 2
data = get_wds_dataset(args, is_train=True, epoch=epoch, force_num_samples=num_samples_per_source)
dl = data.dataloader

Expand All @@ -58,13 +59,14 @@ def retrieve_dataset_resampled(epoch, next_shard, weights, seed, min_shards_need
args.train_num_samples = NUM_SAMPLES
args.train_data = train_data_string_per_source
args.num_workers = 2
args.batch_size = 2
args.global_batch_size = 2
args.seed = seed
args.dataset_resampled = True
args.vocab_size = _MODEL_CONFIGS[args.model]["vocab_size"]
args.seq_len = _MODEL_CONFIGS[args.model]["seq_len"]
args.world_size = 1
args.rank = 0
args.per_gpu_batch_size = 2
data = get_wds_dataset(args, is_train=True, epoch=epoch)
dl = data.dataloader

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset_no_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def retrieve_dataset_once(
args.train_num_samples = total_seqs
args.train_data = train_data_string_per_source
args.workers = num_workers
args.batch_size = batch_size
args.global_batch_size = args.per_gpu_batch_size = batch_size
args.seed = seed
args.dataset_resampled = False
args.disable_buffer = disable_buffer
Expand Down
4 changes: 2 additions & 2 deletions tests/test_param_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def get_cmdline_config1():
# fmt: off
cmdline = [
"--train-num-samples", str(samples),
"--batch-size", str(batch_size),
"--global-batch-size", str(batch_size),
"--dataset-type", "synthetic",
"--model", "open_lm_test_tiny",
"--epochs", "1",
]
config_dict = {
"train-num-samples": samples,
"batch-size": batch_size,
"global-batch-size": batch_size,
"dataset-type": "synthetic",
"model": "open_lm_test_tiny",
"epochs": 1,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_training_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_train_simple():
# fmt: off
main([
"--train-num-samples", str(num_batches * seq_len),
"--batch-size", str(batch_size),
"--global-batch-size", str(batch_size),
"--dataset-type", "synthetic",
"--model", "open_lm_test_tiny",
"--epochs", "1",
Expand Down
6 changes: 3 additions & 3 deletions tests/test_training_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ def test_token_count(test_case):

download_dl_test_data()
args, model, _, optimizer, scheduler, loss = create_train_fixtures("open_lm_11m")
args.batch_size = batch_size
args.global_batch_size = batch_size
args.per_gpu_batch_size = args.global_batch_size // args.world_size
args.workers = workers
args.train_data = None
args.dataset_manifest = SOURCE_MANIFEST
args.epochs = desired_epochs
args.train_num_samples = desired_sequences_per_epoch

total_samples = desired_sequences_per_epoch * desired_epochs
global_batch_size = args.batch_size * args.world_size
total_steps = total_samples // (global_batch_size)
total_steps = total_samples // (args.global_batch_size)
global_step = 0
next_shard_per_source = [0]
epoch = 0
Expand Down