Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

global batch size and val fix #149

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,13 @@ def main(args):
# fully initialize distributed device environment
device = init_distributed_device(args)

if args.global_batch_size is not None:
if args.global_batch_size % args.world_size != 0:
raise ValueError(f"World size: {args.world_size} must divide global batch size: {args.global_batch_size}")

# set batch_size accordingly
args.batch_size = args.global_batch_size // args.world_size

if args.hf_model is not None and args.hf_seq_len is None:
raise ValueError("If passing --hf-model, must also pass --hf-seq-len to be used for training/fine-tuning.")

Expand Down
6 changes: 5 additions & 1 deletion open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def parse_args(args):
help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
)
parser.add_argument("--workers", type=int, default=1, help="Number of dataloader workers per GPU.")
parser.add_argument("--batch-size", type=int, default=64, help="Batch size per GPU.")
parser.add_argument("--batch-size", type=int, default=None, help="Batch size per GPU.")
parser.add_argument("--global-batch-size", type=int, default=None, help="Batch size over all GPUs.")
parser.add_argument("--epochs", type=int, default=32, help="Number of epochs to train for.")
parser.add_argument(
"--epochs-cooldown",
Expand Down Expand Up @@ -574,6 +575,9 @@ def parse_args(args):
assert args.train_data is None, "--train-data must not be specified if --dataset-type='synthetic'"
assert args.dataset_manifest is None, "--dataset-manifest must not be specified if --dataset-type='synthetic'"

if args.batch_size is not None and args.global_batch_size is not None:
assert False, "specify --batch-size or --global-batch-size but not both"

if args.val_data is not None and args.val_batch_size is None:
# if not set explicitly make sure that the val batch size is set to the micro batch size

Expand Down
11 changes: 6 additions & 5 deletions open_lm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self):

def reset(self):
self.points = []
self.points_tensor = None

def update(self, val):
self.points.append(val)
Expand All @@ -53,13 +54,13 @@ def compute_bootstrap_ci(self, num_samples=10_000, interval=95):
lower = None
upper = None

points_tensor = torch.cat(self.points)
num_points = self.points.shape[0]
self.points_tensor = torch.cat(self.points)
num_points = self.points_tensor.shape[0]

estimates = []
for _ in range(num_samples):
i = np.random.choice(num_points, size=num_points)
estimate = torch.sum(points_tensor[i]) / num_points
estimate = torch.sum(self.points_tensor[i]) / num_points
estimates.append(estimate.item())

half = (100 - interval) / 2
Expand Down Expand Up @@ -384,8 +385,8 @@ def evaluate(model, data, start_epoch, args, writer):

lower_seq, upper_seq = losses_seq_ci_m.compute_bootstrap_ci()
lower_tok, upper_tok = losses_tok_ci_m.compute_bootstrap_ci()
num_seqs = losses_seq_ci_m.points.shape[0]
num_toks = losses_tok_ci_m.points.shape[0]
num_seqs = losses_seq_ci_m.points_tensor.shape[0]
num_toks = losses_tok_ci_m.points_tensor.shape[0]

# Save eval loss / etc.
log_data = {
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _read_reqs(relpath):

setuptools.setup(
name="open_lm",
version="0.0.21",
version="0.0.22",
author=[
"Suchin Gururangan*",
"Mitchell Wortsman*",
Expand Down