From 73e7b3709057d62408d700dcaad4ebc6317c56b9 Mon Sep 17 00:00:00 2001 From: Samir Yitzhak Gadre Date: Mon, 11 Dec 2023 13:45:15 -0500 Subject: [PATCH] revamped val (bootstrap ci to support seq and tok intervals, many perplexity evals) (#142) * modified val bootstrap ci to support seq and tok intervals * support eval on many val sets independently * version bump * minor * embed all eval metrics * reduce num calls to cat * version bump * nits --- open_lm/data.py | 20 ++++++--- open_lm/main.py | 44 +++++++------------- open_lm/params.py | 33 ++++++--------- open_lm/train.py | 103 ++++++++++++++++++++++++++++++++-------------- setup.py | 2 +- 5 files changed, 117 insertions(+), 85 deletions(-) diff --git a/open_lm/data.py b/open_lm/data.py index 725703c6..1a67731d 100644 --- a/open_lm/data.py +++ b/open_lm/data.py @@ -12,6 +12,7 @@ from multiprocessing import Value from functools import partial from itertools import islice +import copy import numpy as np import pandas as pd @@ -557,7 +558,7 @@ def get_synthetic_dataset(args, is_train, epoch, tokenizer, data_key, floor): return DataInfo(dataloader, sampler) -def get_dataset_fn(data_path, dataset_type): +def get_dataset_fn(dataset_type): if dataset_type == "synthetic": return get_synthetic_dataset else: @@ -571,13 +572,22 @@ def get_data(args, epoch=0, tokenizer=None, skip_train=False, floor=True): data["train"] = None else: if args.train_data or args.dataset_type == "synthetic": - data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( + # train data is treated as a shard list where all data is combined and tained on + data["train"] = get_dataset_fn(args.dataset_type)( args, is_train=True, epoch=epoch, tokenizer=tokenizer, data_key=args.data_key, floor=floor ) if args.val_data: - data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( - args, is_train=False, tokenizer=tokenizer, data_key=args.val_data_key - ) + # val data is treated as independent val sets to be evaluated + data["val_list"] = [] + for i, val_data in enumerate(args.val_data): + args_copy = copy.deepcopy(args) + args_copy.val_data = [val_data] + data_val = { + "val": get_dataset_fn(args.dataset_type)( + args_copy, is_train=False, tokenizer=tokenizer, data_key=args.val_data_key[i] + ) + } + data["val_list"].append(data_val) return data diff --git a/open_lm/main.py b/open_lm/main.py index 898c9c53..b29f7ed5 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -50,7 +50,7 @@ from open_lm.logger import setup_logging from open_lm.params import parse_args from open_lm.scheduler import cosine_lr -from open_lm.train import train_one_epoch, evaluate +from open_lm.train import train_one_epoch, evaluate_loop from open_lm.file_utils import ( pt_load, check_exists, @@ -171,22 +171,11 @@ def save_checkpoint( optim_state = FSDP.optim_state_dict(model, optimizer) if args.save_logs: - loss_dict = { - "evaluation_loss": -1, - "evaluation_loss_lower_95": -1, - "evaluation_loss_upper_95": -1, - } - - if evaluation_metrics is not None: - loss_dict["evaluation_loss"] = evaluation_metrics["loss"] - loss_dict["evaluation_loss_lower_95"] = evaluation_metrics["loss_lower_95"] - loss_dict["evaluation_loss_upper_95"] = evaluation_metrics["loss_upper_95"] - checkpoint_dict_model = { "epoch": completed_epoch, "name": args.name, "state_dict": cpu_state if args.fsdp else model.state_dict(), - **loss_dict, + "evaluation_metrics": evaluation_metrics, } if next_shard_per_source is not None: checkpoint_dict_model["next_shard_per_source"] = next_shard_per_source @@ -201,7 +190,7 @@ def save_checkpoint( "epoch": completed_epoch, "name": args.name, "optimizer": optim_state if args.fsdp else optimizer.state_dict(), - **loss_dict, + "evaluation_metrics": evaluation_metrics, } if scaler is not None: @@ -211,7 +200,7 @@ def save_checkpoint( "epoch": completed_epoch, "name": args.name, "is_final_checkpoint": is_final_checkpoint, - **loss_dict, + "evaluation_metrics": evaluation_metrics, } prefixes = { @@ -569,8 +558,7 @@ def main(args): # initialize datasets # use tokenizer=None because the data is already pre-tokenized. - if args.val_data is not None: - args.val_data = [args.val_data] + data = get_data( args, epoch=start_epoch, @@ -627,8 +615,7 @@ def main(args): if args.wandb and is_master(args): assert wandb is not None, "Please install wandb." logging.debug("Starting wandb.") - if args.val_data is not None: - args.val_sz = data["val"].dataloader.num_samples + wandb.init( project=args.wandb_project_name, name=args.name, @@ -645,15 +632,11 @@ def main(args): if not requires_training: checkpoint_root = os.path.dirname(args.resume) - metrics = evaluate(model, data, start_epoch, args, writer) - metrics["checkpoint_path"] = args.resume - metrics["val_data"] = args.val_data - metrics["model"] = args.hf_model if args.hf_model else args.model + metrics = evaluate_loop(model, data["val_list"], start_epoch, args, writer) if is_master(args): with fsspec.open(os.path.join(checkpoint_root, "results.jsonl"), "a") as f: - f.write(json.dumps(metrics)) - f.write("\n") + f.write(f"{json.dumps(metrics)}\n") return @@ -730,11 +713,16 @@ def main(args): break epoch = epoch + 1 - evaluation_metrics = None - if "val" in data and (epoch % args.val_frequency == 0 or done_training): + evaluation_metrics = [] + if "val_list" in data and (epoch % args.val_frequency == 0 or done_training): # validate based on frequency and always validate the last checkpoint try: - evaluation_metrics = evaluate(model, data, epoch, args, writer) + evaluation_metrics = evaluate_loop(model, data["val_list"], epoch, args, writer) + + if is_master(args): + with fsspec.open(os.path.join(args.checkpoint_path, "results.jsonl"), "a") as f: + f.write(f"{json.dumps(evaluation_metrics)}\n") + except Exception as e: if is_master(args): logging.error(e) diff --git a/open_lm/params.py b/open_lm/params.py index 78cde04f..424e4e76 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -6,14 +6,6 @@ import yaml -def get_default_params(model_name): - model_name = model_name.lower() - if "vit" in model_name: - return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} - else: - return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} - - class ParseKwargs(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): kw = {} @@ -153,8 +145,12 @@ def parse_args(args): parser.add_argument( "--val-data", type=str, + nargs="+", default=None, - help="Path to file(s) with validation data", + help=( + "Path to file(s) with validation data. Note: each space seperated entry will be processed seperately and writen as seperate entries " + "in a results.jsonl file." + ), ) parser.add_argument( "--data-key", @@ -227,10 +223,10 @@ def parse_args(args): help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards.", ) parser.add_argument("--optimizer", default="adamw", help="Optimizer.") - parser.add_argument("--lr", type=float, default=None, help="Learning rate.") - parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") - parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") - parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") + parser.add_argument("--lr", type=float, default=5.0e-4, help="Learning rate.") + parser.add_argument("--beta1", type=float, default=0.9, help="Adam beta 1.") + parser.add_argument("--beta2", type=float, default=0.95, help="Adam beta 2.") + parser.add_argument("--eps", type=float, default=1.0e-8, help="Adam epsilon.") parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") parser.add_argument("--warmup", type=int, default=10000, help="Number of steps to warmup for.") parser.add_argument( @@ -309,8 +305,9 @@ def parse_args(args): parser.add_argument( "--val-data-key", type=str, - default="txt", - help="what is the extension for val-data.", + nargs="+", + default=None, + help="what is the extension fore each val-data source.", ) parser.add_argument( "--resume", @@ -573,12 +570,6 @@ def parse_args(args): else: args = parser.parse_args(args) - # If some params are not passed, we use the default values based on model name. - default_params = get_default_params(args.model) - for name, val in default_params.items(): - if getattr(args, name) is None: - setattr(args, name, val) - if args.dataset_type == "synthetic": assert args.train_data is None, "--train-data must not be specified if --dataset-type='synthetic'" assert args.dataset_manifest is None, "--dataset-manifest must not be specified if --dataset-type='synthetic'" diff --git a/open_lm/train.py b/open_lm/train.py index 2a0d178e..b671eb16 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -1,18 +1,14 @@ -import ast import itertools -import json import logging import math -import os import time from contextlib import nullcontext +import copy import numpy as np import torch import torch.distributed as dist -import torch.nn.functional as F from torch.distributed.distributed_c10d import ReduceOp -from torch.nn.parallel.distributed import DistributedDataParallel from torch.distributed.fsdp import FullyShardedDataParallel as FSDP try: @@ -48,29 +44,23 @@ def __init__(self): self.reset() def reset(self): - self.weights = [] - self.sums = [] + self.points = [] - def update(self, val, n=1): - self.weights.append(n) - self.sums.append(val * n) + def update(self, val): + self.points.append(val) - def compute_bootstrap_ci(self, num_samples=2000, interval=95): + def compute_bootstrap_ci(self, num_samples=10_000, interval=95): lower = None upper = None + points_tensor = torch.cat(self.points) + num_points = self.points.shape[0] + estimates = [] for _ in range(num_samples): - acc = 0 - denom = 0 - i = np.random.choice(len(self.sums), size=len(self.sums)).tolist() - for ii in i: - acc += self.sums[ii] - denom += self.weights[ii] - - acc /= denom - - estimates.append(acc) + i = np.random.choice(num_points, size=num_points) + estimate = torch.sum(points_tensor[i]) / num_points + estimates.append(estimate.item()) half = (100 - interval) / 2 @@ -341,16 +331,23 @@ def evaluate(model, data, start_epoch, args, writer): data["val"].set_epoch(start_epoch) # set epoch in process safe manner via sampler or shared_epoch dataloader = data["val"].dataloader + # NOTE: max_num_batches = 0 corresponds to exhausting iterator + max_num_batches = dataloader.num_batches + losses_m = AverageMeter() batch_time_m = AverageMeter() data_time_m = AverageMeter() sps_m = AverageMeter() spspg_m = AverageMeter() - losses_ci_m = ConfidenceIntervalMeter() + losses_seq_ci_m = ConfidenceIntervalMeter() + losses_tok_ci_m = ConfidenceIntervalMeter() end = time.time() - loss = torch.nn.CrossEntropyLoss() + loss = torch.nn.CrossEntropyLoss(reduction="none") for i, batch in enumerate(dataloader): + if i == max_num_batches and max_num_batches != 0: + break + (texts,) = batch texts = torch.LongTensor(texts).to(device) @@ -359,15 +356,37 @@ def evaluate(model, data, start_epoch, args, writer): with autocast(): inputs, targets = sample_chunk(texts, args) - out, _ = model(inputs) - total_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1)) + out, _ = model(inputs) # [bs, seq_len, vocab_size] + + bs, seq_len = targets.shape + + targets = targets.reshape(-1) + total_loss = loss(out.reshape(-1, args.vocab_size), targets) # [bs * seq_len] + + # cross entropy ignores -100 values in loss computation + mask = targets != -100 + + # reshape and average for sequence losses + sum_loss_per_seq = torch.sum(total_loss.reshape(bs, seq_len), -1) + num_toks_per_seq = torch.sum(mask.reshape(bs, seq_len), -1).float() + losses_seq_ci_m.update(sum_loss_per_seq / num_toks_per_seq) + + # individual token losses + losses_tok_ci_m.update(total_loss[mask]) + + # compute average loss for the mini-batch + total_loss = total_loss[mask].mean() losses_m.update(total_loss.item(), n=inputs.shape[0]) - losses_ci_m.update(total_loss.item(), n=inputs.shape[0]) + batch_time_m.update(time.time() - end) sps_m.update(inputs.numel() * args.world_size / batch_time_m.val) spspg_m.update(inputs.numel() / batch_time_m.val) - lower, upper = losses_ci_m.compute_bootstrap_ci() + lower_seq, upper_seq = losses_seq_ci_m.compute_bootstrap_ci() + lower_tok, upper_tok = losses_tok_ci_m.compute_bootstrap_ci() + num_seqs = losses_seq_ci_m.points.shape[0] + num_toks = losses_tok_ci_m.points.shape[0] + # Save eval loss / etc. log_data = { "loss": losses_m.avg, @@ -375,11 +394,15 @@ def evaluate(model, data, start_epoch, args, writer): "batch_time": batch_time_m.avg, "samples_per_second": sps_m.avg, "samples_per_second_per_gpu": spspg_m.avg, - "loss_lower_95": lower, - "loss_upper_95": upper, + "loss_sequences_lower_95": lower_seq, + "loss_sequences_upper_95": upper_seq, + "loss_tokens_lower_95": lower_tok, + "loss_tokens_upper_95": upper_tok, + "sequences": num_seqs, + "tokens": num_toks, } if args.train_num_samples is not None: - log_data["tokens"] = start_epoch * args.train_num_samples * args.seq_len + log_data["train_tokens"] = start_epoch * args.train_num_samples * args.seq_len for name, val in log_data.items(): name = "valid/" + name @@ -389,6 +412,26 @@ def evaluate(model, data, start_epoch, args, writer): assert wandb is not None, "Please install wandb." wandb.log({name: val, "epoch": start_epoch, "tokens": log_data["tokens"]}) if is_master(args): + print(f"evaluation on: {args.val_data}") print(f"evaluation loss: {losses_m.avg}") print(f"evaluation perplexity: {math.exp(losses_m.avg)}") + print(f"num seqs: {num_seqs}") + print(f"num tokens: {num_toks}") + + log_data["checkpoint_path"] = args.resume + log_data["val_data"] = args.val_data + log_data["model"] = args.hf_model if args.hf_model else args.model + return log_data + + +def evaluate_loop(model, data_list, start_epoch, args, writer): + log_data_list = [] + for i, data in enumerate(data_list): + args_copy = copy.deepcopy(args) + args_copy.val_data = [args.val_data[i]] + args_copy.val_data_key = args.val_data_key[i] + + log_data_list.append(evaluate(model, data, start_epoch, args_copy, writer)) + + return log_data_list diff --git a/setup.py b/setup.py index 7519f6a3..c0409630 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ def _read_reqs(relpath): setuptools.setup( name="open_lm", - version="0.0.20", + version="0.0.21", author=[ "Suchin Gururangan*", "Mitchell Wortsman*",