Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

revamped val (bootstrap ci to support seq and tok intervals, many perplexity evals) #142

Merged
merged 10 commits into from
Dec 11, 2023
20 changes: 15 additions & 5 deletions open_lm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from multiprocessing import Value
from functools import partial
from itertools import islice
import copy

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -557,7 +558,7 @@ def get_synthetic_dataset(args, is_train, epoch, tokenizer, data_key, floor):
return DataInfo(dataloader, sampler)


def get_dataset_fn(data_path, dataset_type):
def get_dataset_fn(dataset_type):
if dataset_type == "synthetic":
return get_synthetic_dataset
else:
Expand All @@ -571,13 +572,22 @@ def get_data(args, epoch=0, tokenizer=None, skip_train=False, floor=True):
data["train"] = None
else:
if args.train_data or args.dataset_type == "synthetic":
data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
# train data is treated as a shard list where all data is combined and tained on
data["train"] = get_dataset_fn(args.dataset_type)(
args, is_train=True, epoch=epoch, tokenizer=tokenizer, data_key=args.data_key, floor=floor
)

if args.val_data:
data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
args, is_train=False, tokenizer=tokenizer, data_key=args.val_data_key
)
# val data is treated as independent val sets to be evaluated
data["val_list"] = []
for i, val_data in enumerate(args.val_data):
args_copy = copy.deepcopy(args)
args_copy.val_data = [val_data]
data_val = {
"val": get_dataset_fn(args.dataset_type)(
args_copy, is_train=False, tokenizer=tokenizer, data_key=args.val_data_key[i]
)
}
data["val_list"].append(data_val)

return data
44 changes: 16 additions & 28 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from open_lm.logger import setup_logging
from open_lm.params import parse_args
from open_lm.scheduler import cosine_lr
from open_lm.train import train_one_epoch, evaluate
from open_lm.train import train_one_epoch, evaluate_loop
from open_lm.file_utils import (
pt_load,
check_exists,
Expand Down Expand Up @@ -171,22 +171,11 @@ def save_checkpoint(
optim_state = FSDP.optim_state_dict(model, optimizer)

if args.save_logs:
loss_dict = {
"evaluation_loss": -1,
"evaluation_loss_lower_95": -1,
"evaluation_loss_upper_95": -1,
}

if evaluation_metrics is not None:
loss_dict["evaluation_loss"] = evaluation_metrics["loss"]
loss_dict["evaluation_loss_lower_95"] = evaluation_metrics["loss_lower_95"]
loss_dict["evaluation_loss_upper_95"] = evaluation_metrics["loss_upper_95"]

checkpoint_dict_model = {
"epoch": completed_epoch,
"name": args.name,
"state_dict": cpu_state if args.fsdp else model.state_dict(),
**loss_dict,
"evaluation_metrics": evaluation_metrics,
}
if next_shard_per_source is not None:
checkpoint_dict_model["next_shard_per_source"] = next_shard_per_source
Expand All @@ -201,7 +190,7 @@ def save_checkpoint(
"epoch": completed_epoch,
"name": args.name,
"optimizer": optim_state if args.fsdp else optimizer.state_dict(),
**loss_dict,
"evaluation_metrics": evaluation_metrics,
}

if scaler is not None:
Expand All @@ -211,7 +200,7 @@ def save_checkpoint(
"epoch": completed_epoch,
"name": args.name,
"is_final_checkpoint": is_final_checkpoint,
**loss_dict,
"evaluation_metrics": evaluation_metrics,
}

prefixes = {
Expand Down Expand Up @@ -569,8 +558,7 @@ def main(args):

# initialize datasets
# use tokenizer=None because the data is already pre-tokenized.
if args.val_data is not None:
args.val_data = [args.val_data]

data = get_data(
args,
epoch=start_epoch,
Expand Down Expand Up @@ -627,8 +615,7 @@ def main(args):
if args.wandb and is_master(args):
assert wandb is not None, "Please install wandb."
logging.debug("Starting wandb.")
if args.val_data is not None:
args.val_sz = data["val"].dataloader.num_samples

wandb.init(
project=args.wandb_project_name,
name=args.name,
Expand All @@ -645,15 +632,11 @@ def main(args):
if not requires_training:
checkpoint_root = os.path.dirname(args.resume)

metrics = evaluate(model, data, start_epoch, args, writer)
metrics["checkpoint_path"] = args.resume
metrics["val_data"] = args.val_data
metrics["model"] = args.hf_model if args.hf_model else args.model
metrics = evaluate_loop(model, data["val_list"], start_epoch, args, writer)

if is_master(args):
with fsspec.open(os.path.join(checkpoint_root, "results.jsonl"), "a") as f:
f.write(json.dumps(metrics))
f.write("\n")
f.write(f"{json.dumps(metrics)}\n")

return

Expand Down Expand Up @@ -730,11 +713,16 @@ def main(args):
break

epoch = epoch + 1
evaluation_metrics = None
if "val" in data and (epoch % args.val_frequency == 0 or done_training):
evaluation_metrics = []
if "val_list" in data and (epoch % args.val_frequency == 0 or done_training):
# validate based on frequency and always validate the last checkpoint
try:
evaluation_metrics = evaluate(model, data, epoch, args, writer)
evaluation_metrics = evaluate_loop(model, data["val_list"], epoch, args, writer)

if is_master(args):
with fsspec.open(os.path.join(args.checkpoint_path, "results.jsonl"), "a") as f:
f.write(f"{json.dumps(evaluation_metrics)}\n")

except Exception as e:
if is_master(args):
logging.error(e)
Expand Down
33 changes: 12 additions & 21 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,6 @@
import yaml


def get_default_params(model_name):
model_name = model_name.lower()
if "vit" in model_name:
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
else:
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}


class ParseKwargs(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
kw = {}
Expand Down Expand Up @@ -153,8 +145,12 @@ def parse_args(args):
parser.add_argument(
"--val-data",
type=str,
nargs="+",
default=None,
help="Path to file(s) with validation data",
help=(
"Path to file(s) with validation data. Note: each space seperated entry will be processed seperately and writen as seperate entries "
"in a results.jsonl file."
),
)
parser.add_argument(
"--data-key",
Expand Down Expand Up @@ -227,10 +223,10 @@ def parse_args(args):
help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards.",
)
parser.add_argument("--optimizer", default="adamw", help="Optimizer.")
parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
parser.add_argument("--lr", type=float, default=5.0e-4, help="Learning rate.")
parser.add_argument("--beta1", type=float, default=0.9, help="Adam beta 1.")
parser.add_argument("--beta2", type=float, default=0.95, help="Adam beta 2.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're changing the default beta2 here compared to before right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but beta2 of 0.95 has been pretty standard for most open_lm runs so far

parser.add_argument("--eps", type=float, default=1.0e-8, help="Adam epsilon.")
parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
parser.add_argument("--warmup", type=int, default=10000, help="Number of steps to warmup for.")
parser.add_argument(
Expand Down Expand Up @@ -309,8 +305,9 @@ def parse_args(args):
parser.add_argument(
"--val-data-key",
type=str,
default="txt",
help="what is the extension for val-data.",
nargs="+",
default=None,
help="what is the extension fore each val-data source.",
)
parser.add_argument(
"--resume",
Expand Down Expand Up @@ -573,12 +570,6 @@ def parse_args(args):
else:
args = parser.parse_args(args)

# If some params are not passed, we use the default values based on model name.
default_params = get_default_params(args.model)
for name, val in default_params.items():
if getattr(args, name) is None:
setattr(args, name, val)

if args.dataset_type == "synthetic":
assert args.train_data is None, "--train-data must not be specified if --dataset-type='synthetic'"
assert args.dataset_manifest is None, "--dataset-manifest must not be specified if --dataset-type='synthetic'"
Expand Down
103 changes: 73 additions & 30 deletions open_lm/train.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import ast
import itertools
import json
import logging
import math
import os
import time
from contextlib import nullcontext
import copy

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.distributed_c10d import ReduceOp
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

try:
Expand Down Expand Up @@ -48,29 +44,23 @@ def __init__(self):
self.reset()

def reset(self):
self.weights = []
self.sums = []
self.points = []

def update(self, val, n=1):
self.weights.append(n)
self.sums.append(val * n)
def update(self, val):
self.points.append(val)

def compute_bootstrap_ci(self, num_samples=2000, interval=95):
def compute_bootstrap_ci(self, num_samples=10_000, interval=95):
lower = None
upper = None

points_tensor = torch.cat(self.points)
num_points = self.points.shape[0]

estimates = []
for _ in range(num_samples):
acc = 0
denom = 0
i = np.random.choice(len(self.sums), size=len(self.sums)).tolist()
for ii in i:
acc += self.sums[ii]
denom += self.weights[ii]

acc /= denom

estimates.append(acc)
i = np.random.choice(num_points, size=num_points)
estimate = torch.sum(points_tensor[i]) / num_points
estimates.append(estimate.item())

half = (100 - interval) / 2

Expand Down Expand Up @@ -341,16 +331,23 @@ def evaluate(model, data, start_epoch, args, writer):
data["val"].set_epoch(start_epoch) # set epoch in process safe manner via sampler or shared_epoch
dataloader = data["val"].dataloader

# NOTE: max_num_batches = 0 corresponds to exhausting iterator
max_num_batches = dataloader.num_batches

losses_m = AverageMeter()
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
sps_m = AverageMeter()
spspg_m = AverageMeter()
losses_ci_m = ConfidenceIntervalMeter()
losses_seq_ci_m = ConfidenceIntervalMeter()
losses_tok_ci_m = ConfidenceIntervalMeter()

end = time.time()
loss = torch.nn.CrossEntropyLoss()
loss = torch.nn.CrossEntropyLoss(reduction="none")
for i, batch in enumerate(dataloader):
if i == max_num_batches and max_num_batches != 0:
break

(texts,) = batch
texts = torch.LongTensor(texts).to(device)

Expand All @@ -359,27 +356,53 @@ def evaluate(model, data, start_epoch, args, writer):
with autocast():
inputs, targets = sample_chunk(texts, args)

out, _ = model(inputs)
total_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1))
out, _ = model(inputs) # [bs, seq_len, vocab_size]

bs, seq_len = targets.shape

targets = targets.reshape(-1)
total_loss = loss(out.reshape(-1, args.vocab_size), targets) # [bs * seq_len]

# cross entropy ignores -100 values in loss computation
mask = targets != -100

# reshape and average for sequence losses
sum_loss_per_seq = torch.sum(total_loss.reshape(bs, seq_len), -1)
num_toks_per_seq = torch.sum(mask.reshape(bs, seq_len), -1).float()
losses_seq_ci_m.update(sum_loss_per_seq / num_toks_per_seq)

# individual token losses
losses_tok_ci_m.update(total_loss[mask])

# compute average loss for the mini-batch
total_loss = total_loss[mask].mean()
losses_m.update(total_loss.item(), n=inputs.shape[0])
losses_ci_m.update(total_loss.item(), n=inputs.shape[0])

batch_time_m.update(time.time() - end)
sps_m.update(inputs.numel() * args.world_size / batch_time_m.val)
spspg_m.update(inputs.numel() / batch_time_m.val)

lower, upper = losses_ci_m.compute_bootstrap_ci()
lower_seq, upper_seq = losses_seq_ci_m.compute_bootstrap_ci()
lower_tok, upper_tok = losses_tok_ci_m.compute_bootstrap_ci()
num_seqs = losses_seq_ci_m.points.shape[0]
num_toks = losses_tok_ci_m.points.shape[0]

# Save eval loss / etc.
log_data = {
"loss": losses_m.avg,
"data_time": data_time_m.avg,
"batch_time": batch_time_m.avg,
"samples_per_second": sps_m.avg,
"samples_per_second_per_gpu": spspg_m.avg,
"loss_lower_95": lower,
"loss_upper_95": upper,
"loss_sequences_lower_95": lower_seq,
"loss_sequences_upper_95": upper_seq,
"loss_tokens_lower_95": lower_tok,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also save loss tokens mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that we never see partial sequences i think this will just be loss

"loss_tokens_upper_95": upper_tok,
"sequences": num_seqs,
"tokens": num_toks,
}
if args.train_num_samples is not None:
log_data["tokens"] = start_epoch * args.train_num_samples * args.seq_len
log_data["train_tokens"] = start_epoch * args.train_num_samples * args.seq_len

for name, val in log_data.items():
name = "valid/" + name
Expand All @@ -389,6 +412,26 @@ def evaluate(model, data, start_epoch, args, writer):
assert wandb is not None, "Please install wandb."
wandb.log({name: val, "epoch": start_epoch, "tokens": log_data["tokens"]})
if is_master(args):
print(f"evaluation on: {args.val_data}")
print(f"evaluation loss: {losses_m.avg}")
print(f"evaluation perplexity: {math.exp(losses_m.avg)}")
print(f"num seqs: {num_seqs}")
print(f"num tokens: {num_toks}")

log_data["checkpoint_path"] = args.resume
log_data["val_data"] = args.val_data
log_data["model"] = args.hf_model if args.hf_model else args.model

return log_data


def evaluate_loop(model, data_list, start_epoch, args, writer):
log_data_list = []
for i, data in enumerate(data_list):
args_copy = copy.deepcopy(args)
args_copy.val_data = [args.val_data[i]]
args_copy.val_data_key = args.val_data_key[i]

log_data_list.append(evaluate(model, data, start_epoch, args_copy, writer))

return log_data_list
Loading