Skip to content

Commit

Permalink
Replace batch_size with global_batch_size. (#150)
Browse files Browse the repository at this point in the history
* Replace batch_size with global_batch_size.

* Formatting.

* Version bump.

* Small test fix.

* Update val batch size.

* Bugfix.

* Change handling of accum_freq.

* Version bump + no 0 bsz.

* No 0 bsz in training as well.

* Formatting.

---------

Co-authored-by: George Smyrnis <[email protected]>
  • Loading branch information
GeorgiosSmyrnis and GeorgiosSmyrnis authored Dec 18, 2023
1 parent 0503a56 commit 41ca9a0
Show file tree
Hide file tree
Showing 11 changed files with 46 additions and 28 deletions.
4 changes: 2 additions & 2 deletions open_lm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
)

map_dict_handler = {"handler": log_and_continue} if args.ignore_parse_errors else {}
batch_size = args.batch_size if is_train else args.val_batch_size
batch_size = args.per_gpu_batch_size if is_train else args.per_gpu_val_batch_size

if data_key == "json" or data_key == "json.gz":
pipeline.extend(
Expand Down Expand Up @@ -545,7 +545,7 @@ def get_synthetic_dataset(args, is_train, epoch, tokenizer, data_key, floor):

dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
batch_size=args.per_gpu_batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
Expand Down
21 changes: 18 additions & 3 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,21 @@ def main(args):

# fully initialize distributed device environment
device = init_distributed_device(args)

assert (
args.global_batch_size % args.world_size == 0
), f"Global batch size ({args.global_batch_size}) is not divisible by number of GPUs ({args.world_size}), and thus cannot be respected."

args.per_gpu_batch_size = max(args.global_batch_size // args.world_size, 1)
if args.val_data is not None:
args.per_gpu_val_batch_size = max(args.global_val_batch_size // args.world_size, 1)

if args.hf_model is not None and args.hf_seq_len is None:
raise ValueError("If passing --hf-model, must also pass --hf-seq-len to be used for training/fine-tuning.")

if args.hf_model is not None and args.fsdp and args.hf_fsdp_block is None:
raise ValueError("If passing --hf-model and --fsdp, must also pass --hf-fspd-block.")

if args.fsdp and not args.distributed:
raise ValueError(f"--fsdp can only be specified in distributed mode.")

Expand All @@ -301,7 +316,7 @@ def main(args):
date_str,
f"model_{model_name_safe}",
f"lr_{args.lr}",
f"b_{args.batch_size}",
f"b_{args.per_gpu_batch_size}", # Per gpu to respect old naming convention
]
)

Expand Down Expand Up @@ -607,7 +622,7 @@ def main(args):
scheduler = None
if requires_training:
if args.dataset_manifest is not None:
total_steps = (args.train_num_samples * args.epochs) // (args.batch_size * args.world_size)
total_steps = (args.train_num_samples * args.epochs) // args.global_batch_size
else:
total_steps = (data["train"].dataloader.num_batches) * args.epochs

Expand Down Expand Up @@ -720,7 +735,7 @@ def main(args):

done_training = global_step >= total_steps
steps_done_epoch = global_step - prev_step
samples_seen = samples_seen + steps_done_epoch * args.batch_size * args.world_size
samples_seen = samples_seen + steps_done_epoch * args.global_batch_size

if not success:
logging.info("Training exiting due to NaN value")
Expand Down
13 changes: 6 additions & 7 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def parse_args(args):
help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
)
parser.add_argument("--workers", type=int, default=1, help="Number of dataloader workers per GPU.")
parser.add_argument("--batch-size", type=int, default=64, help="Batch size per GPU.")
parser.add_argument("--global-batch-size", type=int, default=64, help="Global batch size.")
parser.add_argument("--epochs", type=int, default=32, help="Number of epochs to train for.")
parser.add_argument(
"--epochs-cooldown",
Expand Down Expand Up @@ -298,7 +298,7 @@ def parse_args(args):
help="How often to run evaluation with val-data (in epochs). Last epoch validated if val-data provided.",
)
parser.add_argument(
"--val-batch-size",
"--global-val-batch-size",
type=int,
default=None,
help="Batch size to be used with val-data.",
Expand Down Expand Up @@ -380,7 +380,7 @@ def parse_args(args):
"--accum-freq",
type=int,
default=1,
help="Update the model every --acum-freq steps.",
help="Update the model every --accum-freq steps.",
)
# arguments for distributed training
parser.add_argument(
Expand Down Expand Up @@ -581,9 +581,8 @@ def parse_args(args):
assert args.train_data is None, "--train-data must not be specified if --dataset-type='synthetic'"
assert args.dataset_manifest is None, "--dataset-manifest must not be specified if --dataset-type='synthetic'"

if args.val_data is not None and args.val_batch_size is None:
# if not set explicitly make sure that the val batch size is set to the micro batch size

args.val_batch_size = args.batch_size // args.accum_freq
if args.val_data is not None and args.global_val_batch_size is None:
# Make sure that val batch size is set to micro batch size
args.global_val_batch_size = args.global_batch_size // args.accum_freq

return args
6 changes: 3 additions & 3 deletions open_lm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
else:
# split up batch into accum_freq chunks -- if you have --batch-size 8 and --accum-freq 4
# then you only process 2 items at a time. batch-size must be divisible by accume-freq.
assert args.batch_size % args.accum_freq == 0, "Batch size must be divisible by accum_freq"
per_batch = args.batch_size // args.accum_freq
assert args.per_gpu_batch_size % args.accum_freq == 0, "Per-GPU batch size must be divisible by accum_freq"
per_batch = args.per_gpu_batch_size // args.accum_freq

inputs, targets = sample_chunk(texts, args)

Expand Down Expand Up @@ -288,7 +288,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
"samples_per_second": samples_per_second,
"samples_per_second_per_gpu": samples_per_second_per_gpu,
"lr": optimizer.param_groups[0]["lr"],
"tokens": (step + 1) * args.batch_size * args.seq_len * args.world_size,
"tokens": (step + 1) * args.global_batch_size * args.seq_len,
}

if args.log_logit_mean:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _read_reqs(relpath):

setuptools.setup(
name="open_lm",
version="0.0.22",
version="0.0.23",
author=[
"Suchin Gururangan*",
"Mitchell Wortsman*",
Expand Down
8 changes: 5 additions & 3 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, model, **kwargs):
"--wd", "0.033",
"--lr", "3e-3",
"--warmup", "2",
"--batch-size", "8",
"--global-batch-size", "8",
"--accum", "1",
"--name", "test_model_name",
"--logs", "./tests/assets/",
Expand All @@ -43,6 +43,7 @@ def __init__(self, model, **kwargs):
self.vocab_size = 50432
self.seq_len = 300
self.wandb = False
self.per_gpu_batch_size = self.global_batch_size // self.world_size
self.distributed = False

for k, v in kwargs.items():
Expand All @@ -64,7 +65,7 @@ def __init__(self):
self.disable_buffer = True
self.seq_len = 300
self.vocab_size = 50432
self.batch_size = 64
self.global_batch_size = 64
self.world_size = 1
self.rank = 0
self.workers = 2
Expand All @@ -73,14 +74,15 @@ def __init__(self):
self.target_mask_left = None
self.target_mask_individual = None
self.ignore_parse_errors = False
self.per_gpu_batch_size = self.global_batch_size // self.world_size


def create_train_fixtures(model="open_lm_11m", fsdp=False, **kwargs):
# Setup data, optimizer, and other basic settings
args = MockTrainArgs(model, **kwargs)

# only want to look at one batch
args.train_num_samples = args.batch_size
args.train_num_samples = args.global_batch_size

# increase learning rate and remove warmup for maximize change to model weights
args.lr = 1e-3
Expand Down
6 changes: 4 additions & 2 deletions tests/test_dataset_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ def retrieve_dataset(epoch, next_shard, weights, seed, disable_buffer, min_shard
args.train_num_samples = NUM_SAMPLES
args.train_data = train_data_string_per_source
args.workers = 2
args.batch_size = 2
args.global_batch_size = 2
args.seed = seed
args.dataset_resampled = False
args.disable_buffer = disable_buffer
args.vocab_size = _MODEL_CONFIGS[args.model]["vocab_size"]
args.seq_len = _MODEL_CONFIGS[args.model]["seq_len"]
args.world_size = 1
args.rank = 0
args.per_gpu_batch_size = 2
data = get_wds_dataset(args, is_train=True, epoch=epoch, force_num_samples=num_samples_per_source)
dl = data.dataloader

Expand All @@ -58,13 +59,14 @@ def retrieve_dataset_resampled(epoch, next_shard, weights, seed, min_shards_need
args.train_num_samples = NUM_SAMPLES
args.train_data = train_data_string_per_source
args.num_workers = 2
args.batch_size = 2
args.global_batch_size = 2
args.seed = seed
args.dataset_resampled = True
args.vocab_size = _MODEL_CONFIGS[args.model]["vocab_size"]
args.seq_len = _MODEL_CONFIGS[args.model]["seq_len"]
args.world_size = 1
args.rank = 0
args.per_gpu_batch_size = 2
data = get_wds_dataset(args, is_train=True, epoch=epoch)
dl = data.dataloader

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset_no_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def retrieve_dataset_once(
args.train_num_samples = total_seqs
args.train_data = train_data_string_per_source
args.workers = num_workers
args.batch_size = batch_size
args.global_batch_size = args.per_gpu_batch_size = batch_size
args.seed = seed
args.dataset_resampled = False
args.disable_buffer = disable_buffer
Expand Down
4 changes: 2 additions & 2 deletions tests/test_param_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def get_cmdline_config1():
# fmt: off
cmdline = [
"--train-num-samples", str(samples),
"--batch-size", str(batch_size),
"--global-batch-size", str(batch_size),
"--dataset-type", "synthetic",
"--model", "open_lm_test_tiny",
"--epochs", "1",
]
config_dict = {
"train-num-samples": samples,
"batch-size": batch_size,
"global-batch-size": batch_size,
"dataset-type": "synthetic",
"model": "open_lm_test_tiny",
"epochs": 1,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_training_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_train_simple():
# fmt: off
main([
"--train-num-samples", str(num_batches * seq_len),
"--batch-size", str(batch_size),
"--global-batch-size", str(batch_size),
"--dataset-type", "synthetic",
"--model", "open_lm_test_tiny",
"--epochs", "1",
Expand Down
6 changes: 3 additions & 3 deletions tests/test_training_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ def test_token_count(test_case):

download_dl_test_data()
args, model, _, optimizer, scheduler, loss = create_train_fixtures("open_lm_11m")
args.batch_size = batch_size
args.global_batch_size = batch_size
args.per_gpu_batch_size = args.global_batch_size // args.world_size
args.workers = workers
args.train_data = None
args.dataset_manifest = SOURCE_MANIFEST
args.epochs = desired_epochs
args.train_num_samples = desired_sequences_per_epoch

total_samples = desired_sequences_per_epoch * desired_epochs
global_batch_size = args.batch_size * args.world_size
total_steps = total_samples // (global_batch_size)
total_steps = total_samples // (args.global_batch_size)
global_step = 0
next_shard_per_source = [0]
epoch = 0
Expand Down

0 comments on commit 41ca9a0

Please sign in to comment.