From 28ba9b56543deadf5620500108e3cba9147e7242 Mon Sep 17 00:00:00 2001 From: sagadre Date: Mon, 11 Dec 2023 21:31:58 +0000 Subject: [PATCH] global batch size and val fix --- open_lm/main.py | 7 +++++++ open_lm/params.py | 6 +++++- open_lm/train.py | 11 ++++++----- setup.py | 2 +- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/open_lm/main.py b/open_lm/main.py index b29f7ed5..b2f53536 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -243,6 +243,13 @@ def main(args): # fully initialize distributed device environment device = init_distributed_device(args) + if args.global_batch_size is not None: + if args.global_batch_size % args.world_size != 0: + raise ValueError(f"World size: {args.world_size} must divide global batch size: {args.global_batch_size}") + + # set batch_size accordingly + args.batch_size = args.global_batch_size // args.world_size + if args.hf_model is not None and args.hf_seq_len is None: raise ValueError("If passing --hf-model, must also pass --hf-seq-len to be used for training/fine-tuning.") diff --git a/open_lm/params.py b/open_lm/params.py index 424e4e76..91b74ce7 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -214,7 +214,8 @@ def parse_args(args): help="Optional identifier for the experiment when storing logs. Otherwise use current time.", ) parser.add_argument("--workers", type=int, default=1, help="Number of dataloader workers per GPU.") - parser.add_argument("--batch-size", type=int, default=64, help="Batch size per GPU.") + parser.add_argument("--batch-size", type=int, default=None, help="Batch size per GPU.") + parser.add_argument("--global-batch-size", type=int, default=None, help="Batch size over all GPUs.") parser.add_argument("--epochs", type=int, default=32, help="Number of epochs to train for.") parser.add_argument( "--epochs-cooldown", @@ -574,6 +575,9 @@ def parse_args(args): assert args.train_data is None, "--train-data must not be specified if --dataset-type='synthetic'" assert args.dataset_manifest is None, "--dataset-manifest must not be specified if --dataset-type='synthetic'" + if args.batch_size is not None and args.global_batch_size is not None: + assert False, "specify --batch-size or --global-batch-size but not both" + if args.val_data is not None and args.val_batch_size is None: # if not set explicitly make sure that the val batch size is set to the micro batch size diff --git a/open_lm/train.py b/open_lm/train.py index b671eb16..6bb6edb1 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -45,6 +45,7 @@ def __init__(self): def reset(self): self.points = [] + self.points_tensor = None def update(self, val): self.points.append(val) @@ -53,13 +54,13 @@ def compute_bootstrap_ci(self, num_samples=10_000, interval=95): lower = None upper = None - points_tensor = torch.cat(self.points) - num_points = self.points.shape[0] + self.points_tensor = torch.cat(self.points) + num_points = self.points_tensor.shape[0] estimates = [] for _ in range(num_samples): i = np.random.choice(num_points, size=num_points) - estimate = torch.sum(points_tensor[i]) / num_points + estimate = torch.sum(self.points_tensor[i]) / num_points estimates.append(estimate.item()) half = (100 - interval) / 2 @@ -384,8 +385,8 @@ def evaluate(model, data, start_epoch, args, writer): lower_seq, upper_seq = losses_seq_ci_m.compute_bootstrap_ci() lower_tok, upper_tok = losses_tok_ci_m.compute_bootstrap_ci() - num_seqs = losses_seq_ci_m.points.shape[0] - num_toks = losses_tok_ci_m.points.shape[0] + num_seqs = losses_seq_ci_m.points_tensor.shape[0] + num_toks = losses_tok_ci_m.points_tensor.shape[0] # Save eval loss / etc. log_data = { diff --git a/setup.py b/setup.py index c0409630..bfe1e72a 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ def _read_reqs(relpath): setuptools.setup( name="open_lm", - version="0.0.21", + version="0.0.22", author=[ "Suchin Gururangan*", "Mitchell Wortsman*",