Skip to content

Commit

Permalink
revamped val (bootstrap ci to support seq and tok intervals, many per…
Browse files Browse the repository at this point in the history
…plexity evals) (#142)

* modified  val bootstrap ci to support seq and tok intervals

* support eval on many val sets independently

* version bump

* minor

* embed all eval metrics

* reduce num calls to cat

* version bump

* nits
  • Loading branch information
sagadre authored Dec 11, 2023
1 parent 158124d commit 73e7b37
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 85 deletions.
20 changes: 15 additions & 5 deletions open_lm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from multiprocessing import Value
from functools import partial
from itertools import islice
import copy

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -557,7 +558,7 @@ def get_synthetic_dataset(args, is_train, epoch, tokenizer, data_key, floor):
return DataInfo(dataloader, sampler)


def get_dataset_fn(data_path, dataset_type):
def get_dataset_fn(dataset_type):
if dataset_type == "synthetic":
return get_synthetic_dataset
else:
Expand All @@ -571,13 +572,22 @@ def get_data(args, epoch=0, tokenizer=None, skip_train=False, floor=True):
data["train"] = None
else:
if args.train_data or args.dataset_type == "synthetic":
data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
# train data is treated as a shard list where all data is combined and tained on
data["train"] = get_dataset_fn(args.dataset_type)(
args, is_train=True, epoch=epoch, tokenizer=tokenizer, data_key=args.data_key, floor=floor
)

if args.val_data:
data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
args, is_train=False, tokenizer=tokenizer, data_key=args.val_data_key
)
# val data is treated as independent val sets to be evaluated
data["val_list"] = []
for i, val_data in enumerate(args.val_data):
args_copy = copy.deepcopy(args)
args_copy.val_data = [val_data]
data_val = {
"val": get_dataset_fn(args.dataset_type)(
args_copy, is_train=False, tokenizer=tokenizer, data_key=args.val_data_key[i]
)
}
data["val_list"].append(data_val)

return data
44 changes: 16 additions & 28 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from open_lm.logger import setup_logging
from open_lm.params import parse_args
from open_lm.scheduler import cosine_lr
from open_lm.train import train_one_epoch, evaluate
from open_lm.train import train_one_epoch, evaluate_loop
from open_lm.file_utils import (
pt_load,
check_exists,
Expand Down Expand Up @@ -171,22 +171,11 @@ def save_checkpoint(
optim_state = FSDP.optim_state_dict(model, optimizer)

if args.save_logs:
loss_dict = {
"evaluation_loss": -1,
"evaluation_loss_lower_95": -1,
"evaluation_loss_upper_95": -1,
}

if evaluation_metrics is not None:
loss_dict["evaluation_loss"] = evaluation_metrics["loss"]
loss_dict["evaluation_loss_lower_95"] = evaluation_metrics["loss_lower_95"]
loss_dict["evaluation_loss_upper_95"] = evaluation_metrics["loss_upper_95"]

checkpoint_dict_model = {
"epoch": completed_epoch,
"name": args.name,
"state_dict": cpu_state if args.fsdp else model.state_dict(),
**loss_dict,
"evaluation_metrics": evaluation_metrics,
}
if next_shard_per_source is not None:
checkpoint_dict_model["next_shard_per_source"] = next_shard_per_source
Expand All @@ -201,7 +190,7 @@ def save_checkpoint(
"epoch": completed_epoch,
"name": args.name,
"optimizer": optim_state if args.fsdp else optimizer.state_dict(),
**loss_dict,
"evaluation_metrics": evaluation_metrics,
}

if scaler is not None:
Expand All @@ -211,7 +200,7 @@ def save_checkpoint(
"epoch": completed_epoch,
"name": args.name,
"is_final_checkpoint": is_final_checkpoint,
**loss_dict,
"evaluation_metrics": evaluation_metrics,
}

prefixes = {
Expand Down Expand Up @@ -569,8 +558,7 @@ def main(args):

# initialize datasets
# use tokenizer=None because the data is already pre-tokenized.
if args.val_data is not None:
args.val_data = [args.val_data]

data = get_data(
args,
epoch=start_epoch,
Expand Down Expand Up @@ -627,8 +615,7 @@ def main(args):
if args.wandb and is_master(args):
assert wandb is not None, "Please install wandb."
logging.debug("Starting wandb.")
if args.val_data is not None:
args.val_sz = data["val"].dataloader.num_samples

wandb.init(
project=args.wandb_project_name,
name=args.name,
Expand All @@ -645,15 +632,11 @@ def main(args):
if not requires_training:
checkpoint_root = os.path.dirname(args.resume)

metrics = evaluate(model, data, start_epoch, args, writer)
metrics["checkpoint_path"] = args.resume
metrics["val_data"] = args.val_data
metrics["model"] = args.hf_model if args.hf_model else args.model
metrics = evaluate_loop(model, data["val_list"], start_epoch, args, writer)

if is_master(args):
with fsspec.open(os.path.join(checkpoint_root, "results.jsonl"), "a") as f:
f.write(json.dumps(metrics))
f.write("\n")
f.write(f"{json.dumps(metrics)}\n")

return

Expand Down Expand Up @@ -730,11 +713,16 @@ def main(args):
break

epoch = epoch + 1
evaluation_metrics = None
if "val" in data and (epoch % args.val_frequency == 0 or done_training):
evaluation_metrics = []
if "val_list" in data and (epoch % args.val_frequency == 0 or done_training):
# validate based on frequency and always validate the last checkpoint
try:
evaluation_metrics = evaluate(model, data, epoch, args, writer)
evaluation_metrics = evaluate_loop(model, data["val_list"], epoch, args, writer)

if is_master(args):
with fsspec.open(os.path.join(args.checkpoint_path, "results.jsonl"), "a") as f:
f.write(f"{json.dumps(evaluation_metrics)}\n")

except Exception as e:
if is_master(args):
logging.error(e)
Expand Down
33 changes: 12 additions & 21 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,6 @@
import yaml


def get_default_params(model_name):
model_name = model_name.lower()
if "vit" in model_name:
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
else:
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}


class ParseKwargs(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
kw = {}
Expand Down Expand Up @@ -153,8 +145,12 @@ def parse_args(args):
parser.add_argument(
"--val-data",
type=str,
nargs="+",
default=None,
help="Path to file(s) with validation data",
help=(
"Path to file(s) with validation data. Note: each space seperated entry will be processed seperately and writen as seperate entries "
"in a results.jsonl file."
),
)
parser.add_argument(
"--data-key",
Expand Down Expand Up @@ -227,10 +223,10 @@ def parse_args(args):
help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards.",
)
parser.add_argument("--optimizer", default="adamw", help="Optimizer.")
parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
parser.add_argument("--lr", type=float, default=5.0e-4, help="Learning rate.")
parser.add_argument("--beta1", type=float, default=0.9, help="Adam beta 1.")
parser.add_argument("--beta2", type=float, default=0.95, help="Adam beta 2.")
parser.add_argument("--eps", type=float, default=1.0e-8, help="Adam epsilon.")
parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
parser.add_argument("--warmup", type=int, default=10000, help="Number of steps to warmup for.")
parser.add_argument(
Expand Down Expand Up @@ -309,8 +305,9 @@ def parse_args(args):
parser.add_argument(
"--val-data-key",
type=str,
default="txt",
help="what is the extension for val-data.",
nargs="+",
default=None,
help="what is the extension fore each val-data source.",
)
parser.add_argument(
"--resume",
Expand Down Expand Up @@ -573,12 +570,6 @@ def parse_args(args):
else:
args = parser.parse_args(args)

# If some params are not passed, we use the default values based on model name.
default_params = get_default_params(args.model)
for name, val in default_params.items():
if getattr(args, name) is None:
setattr(args, name, val)

if args.dataset_type == "synthetic":
assert args.train_data is None, "--train-data must not be specified if --dataset-type='synthetic'"
assert args.dataset_manifest is None, "--dataset-manifest must not be specified if --dataset-type='synthetic'"
Expand Down
103 changes: 73 additions & 30 deletions open_lm/train.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import ast
import itertools
import json
import logging
import math
import os
import time
from contextlib import nullcontext
import copy

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.distributed_c10d import ReduceOp
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

try:
Expand Down Expand Up @@ -48,29 +44,23 @@ def __init__(self):
self.reset()

def reset(self):
self.weights = []
self.sums = []
self.points = []

def update(self, val, n=1):
self.weights.append(n)
self.sums.append(val * n)
def update(self, val):
self.points.append(val)

def compute_bootstrap_ci(self, num_samples=2000, interval=95):
def compute_bootstrap_ci(self, num_samples=10_000, interval=95):
lower = None
upper = None

points_tensor = torch.cat(self.points)
num_points = self.points.shape[0]

estimates = []
for _ in range(num_samples):
acc = 0
denom = 0
i = np.random.choice(len(self.sums), size=len(self.sums)).tolist()
for ii in i:
acc += self.sums[ii]
denom += self.weights[ii]

acc /= denom

estimates.append(acc)
i = np.random.choice(num_points, size=num_points)
estimate = torch.sum(points_tensor[i]) / num_points
estimates.append(estimate.item())

half = (100 - interval) / 2

Expand Down Expand Up @@ -341,16 +331,23 @@ def evaluate(model, data, start_epoch, args, writer):
data["val"].set_epoch(start_epoch) # set epoch in process safe manner via sampler or shared_epoch
dataloader = data["val"].dataloader

# NOTE: max_num_batches = 0 corresponds to exhausting iterator
max_num_batches = dataloader.num_batches

losses_m = AverageMeter()
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
sps_m = AverageMeter()
spspg_m = AverageMeter()
losses_ci_m = ConfidenceIntervalMeter()
losses_seq_ci_m = ConfidenceIntervalMeter()
losses_tok_ci_m = ConfidenceIntervalMeter()

end = time.time()
loss = torch.nn.CrossEntropyLoss()
loss = torch.nn.CrossEntropyLoss(reduction="none")
for i, batch in enumerate(dataloader):
if i == max_num_batches and max_num_batches != 0:
break

(texts,) = batch
texts = torch.LongTensor(texts).to(device)

Expand All @@ -359,27 +356,53 @@ def evaluate(model, data, start_epoch, args, writer):
with autocast():
inputs, targets = sample_chunk(texts, args)

out, _ = model(inputs)
total_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1))
out, _ = model(inputs) # [bs, seq_len, vocab_size]

bs, seq_len = targets.shape

targets = targets.reshape(-1)
total_loss = loss(out.reshape(-1, args.vocab_size), targets) # [bs * seq_len]

# cross entropy ignores -100 values in loss computation
mask = targets != -100

# reshape and average for sequence losses
sum_loss_per_seq = torch.sum(total_loss.reshape(bs, seq_len), -1)
num_toks_per_seq = torch.sum(mask.reshape(bs, seq_len), -1).float()
losses_seq_ci_m.update(sum_loss_per_seq / num_toks_per_seq)

# individual token losses
losses_tok_ci_m.update(total_loss[mask])

# compute average loss for the mini-batch
total_loss = total_loss[mask].mean()
losses_m.update(total_loss.item(), n=inputs.shape[0])
losses_ci_m.update(total_loss.item(), n=inputs.shape[0])

batch_time_m.update(time.time() - end)
sps_m.update(inputs.numel() * args.world_size / batch_time_m.val)
spspg_m.update(inputs.numel() / batch_time_m.val)

lower, upper = losses_ci_m.compute_bootstrap_ci()
lower_seq, upper_seq = losses_seq_ci_m.compute_bootstrap_ci()
lower_tok, upper_tok = losses_tok_ci_m.compute_bootstrap_ci()
num_seqs = losses_seq_ci_m.points.shape[0]
num_toks = losses_tok_ci_m.points.shape[0]

# Save eval loss / etc.
log_data = {
"loss": losses_m.avg,
"data_time": data_time_m.avg,
"batch_time": batch_time_m.avg,
"samples_per_second": sps_m.avg,
"samples_per_second_per_gpu": spspg_m.avg,
"loss_lower_95": lower,
"loss_upper_95": upper,
"loss_sequences_lower_95": lower_seq,
"loss_sequences_upper_95": upper_seq,
"loss_tokens_lower_95": lower_tok,
"loss_tokens_upper_95": upper_tok,
"sequences": num_seqs,
"tokens": num_toks,
}
if args.train_num_samples is not None:
log_data["tokens"] = start_epoch * args.train_num_samples * args.seq_len
log_data["train_tokens"] = start_epoch * args.train_num_samples * args.seq_len

for name, val in log_data.items():
name = "valid/" + name
Expand All @@ -389,6 +412,26 @@ def evaluate(model, data, start_epoch, args, writer):
assert wandb is not None, "Please install wandb."
wandb.log({name: val, "epoch": start_epoch, "tokens": log_data["tokens"]})
if is_master(args):
print(f"evaluation on: {args.val_data}")
print(f"evaluation loss: {losses_m.avg}")
print(f"evaluation perplexity: {math.exp(losses_m.avg)}")
print(f"num seqs: {num_seqs}")
print(f"num tokens: {num_toks}")

log_data["checkpoint_path"] = args.resume
log_data["val_data"] = args.val_data
log_data["model"] = args.hf_model if args.hf_model else args.model

return log_data


def evaluate_loop(model, data_list, start_epoch, args, writer):
log_data_list = []
for i, data in enumerate(data_list):
args_copy = copy.deepcopy(args)
args_copy.val_data = [args.val_data[i]]
args_copy.val_data_key = args.val_data_key[i]

log_data_list.append(evaluate(model, data, start_epoch, args_copy, writer))

return log_data_list
Loading

0 comments on commit 73e7b37

Please sign in to comment.