-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
revamped val (bootstrap ci to support seq and tok intervals, many perplexity evals) #142
Changes from 9 commits
26e3d39
0b81088
676c1b6
0903ce9
dedd7a3
58330f9
35a2523
59dafb6
946cf0c
1e704e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,14 +6,6 @@ | |
import yaml | ||
|
||
|
||
def get_default_params(model_name): | ||
model_name = model_name.lower() | ||
if "vit" in model_name: | ||
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} | ||
else: | ||
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} | ||
|
||
|
||
class ParseKwargs(argparse.Action): | ||
def __call__(self, parser, namespace, values, option_string=None): | ||
kw = {} | ||
|
@@ -153,8 +145,12 @@ def parse_args(args): | |
parser.add_argument( | ||
"--val-data", | ||
type=str, | ||
nargs="+", | ||
default=None, | ||
help="Path to file(s) with validation data", | ||
help=( | ||
"Path to file(s) with validation data. Note: each space seperated entry will be processed seperately and writen as seperate entries " | ||
"in a results.jsonl file." | ||
), | ||
) | ||
parser.add_argument( | ||
"--data-key", | ||
|
@@ -227,10 +223,10 @@ def parse_args(args): | |
help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards.", | ||
) | ||
parser.add_argument("--optimizer", default="adamw", help="Optimizer.") | ||
parser.add_argument("--lr", type=float, default=None, help="Learning rate.") | ||
parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") | ||
parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") | ||
parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") | ||
parser.add_argument("--lr", type=float, default=5.0e-4, help="Learning rate.") | ||
parser.add_argument("--beta1", type=float, default=0.9, help="Adam beta 1.") | ||
parser.add_argument("--beta2", type=float, default=0.95, help="Adam beta 2.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're changing the default beta2 here compared to before right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah but beta2 of 0.95 has been pretty standard for most open_lm runs so far |
||
parser.add_argument("--eps", type=float, default=1.0e-8, help="Adam epsilon.") | ||
parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") | ||
parser.add_argument("--warmup", type=int, default=10000, help="Number of steps to warmup for.") | ||
parser.add_argument( | ||
|
@@ -309,8 +305,9 @@ def parse_args(args): | |
parser.add_argument( | ||
"--val-data-key", | ||
type=str, | ||
default="txt", | ||
help="what is the extension for val-data.", | ||
nargs="+", | ||
default=None, | ||
help="what is the extension fore each val-data source.", | ||
) | ||
parser.add_argument( | ||
"--resume", | ||
|
@@ -573,12 +570,6 @@ def parse_args(args): | |
else: | ||
args = parser.parse_args(args) | ||
|
||
# If some params are not passed, we use the default values based on model name. | ||
default_params = get_default_params(args.model) | ||
for name, val in default_params.items(): | ||
if getattr(args, name) is None: | ||
setattr(args, name, val) | ||
|
||
if args.dataset_type == "synthetic": | ||
assert args.train_data is None, "--train-data must not be specified if --dataset-type='synthetic'" | ||
assert args.dataset_manifest is None, "--dataset-manifest must not be specified if --dataset-type='synthetic'" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,14 @@ | ||
import ast | ||
import itertools | ||
import json | ||
import logging | ||
import math | ||
import os | ||
import time | ||
from contextlib import nullcontext | ||
import copy | ||
|
||
import numpy as np | ||
import torch | ||
import torch.distributed as dist | ||
import torch.nn.functional as F | ||
from torch.distributed.distributed_c10d import ReduceOp | ||
from torch.nn.parallel.distributed import DistributedDataParallel | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
|
||
try: | ||
|
@@ -48,29 +44,23 @@ def __init__(self): | |
self.reset() | ||
|
||
def reset(self): | ||
self.weights = [] | ||
self.sums = [] | ||
self.points = [] | ||
|
||
def update(self, val, n=1): | ||
self.weights.append(n) | ||
self.sums.append(val * n) | ||
def update(self, val): | ||
self.points.append(val) | ||
|
||
def compute_bootstrap_ci(self, num_samples=2000, interval=95): | ||
def compute_bootstrap_ci(self, num_samples=10_000, interval=95): | ||
lower = None | ||
upper = None | ||
|
||
points_tensor = torch.cat(self.points) | ||
num_points = self.points.shape[0] | ||
|
||
estimates = [] | ||
for _ in range(num_samples): | ||
acc = 0 | ||
denom = 0 | ||
i = np.random.choice(len(self.sums), size=len(self.sums)).tolist() | ||
for ii in i: | ||
acc += self.sums[ii] | ||
denom += self.weights[ii] | ||
|
||
acc /= denom | ||
|
||
estimates.append(acc) | ||
i = np.random.choice(num_points, size=num_points) | ||
estimate = torch.sum(points_tensor[i]) / num_points | ||
estimates.append(estimate.item()) | ||
|
||
half = (100 - interval) / 2 | ||
|
||
|
@@ -341,16 +331,23 @@ def evaluate(model, data, start_epoch, args, writer): | |
data["val"].set_epoch(start_epoch) # set epoch in process safe manner via sampler or shared_epoch | ||
dataloader = data["val"].dataloader | ||
|
||
# NOTE: max_num_batches = 0 corresponds to exhausting iterator | ||
max_num_batches = dataloader.num_batches | ||
|
||
losses_m = AverageMeter() | ||
batch_time_m = AverageMeter() | ||
data_time_m = AverageMeter() | ||
sps_m = AverageMeter() | ||
spspg_m = AverageMeter() | ||
losses_ci_m = ConfidenceIntervalMeter() | ||
losses_seq_ci_m = ConfidenceIntervalMeter() | ||
losses_tok_ci_m = ConfidenceIntervalMeter() | ||
|
||
end = time.time() | ||
loss = torch.nn.CrossEntropyLoss() | ||
loss = torch.nn.CrossEntropyLoss(reduction="none") | ||
for i, batch in enumerate(dataloader): | ||
if i == max_num_batches and max_num_batches != 0: | ||
break | ||
|
||
(texts,) = batch | ||
texts = torch.LongTensor(texts).to(device) | ||
|
||
|
@@ -359,27 +356,53 @@ def evaluate(model, data, start_epoch, args, writer): | |
with autocast(): | ||
inputs, targets = sample_chunk(texts, args) | ||
|
||
out, _ = model(inputs) | ||
total_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1)) | ||
out, _ = model(inputs) # [bs, seq_len, vocab_size] | ||
|
||
bs, seq_len = targets.shape | ||
|
||
targets = targets.reshape(-1) | ||
total_loss = loss(out.reshape(-1, args.vocab_size), targets) # [bs * seq_len] | ||
|
||
# cross entropy ignores -100 values in loss computation | ||
mask = targets != -100 | ||
|
||
# reshape and average for sequence losses | ||
sum_loss_per_seq = torch.sum(total_loss.reshape(bs, seq_len), -1) | ||
num_toks_per_seq = torch.sum(mask.reshape(bs, seq_len), -1).float() | ||
losses_seq_ci_m.update(sum_loss_per_seq / num_toks_per_seq) | ||
|
||
# individual token losses | ||
losses_tok_ci_m.update(total_loss[mask]) | ||
|
||
# compute average loss for the mini-batch | ||
total_loss = total_loss[mask].mean() | ||
losses_m.update(total_loss.item(), n=inputs.shape[0]) | ||
losses_ci_m.update(total_loss.item(), n=inputs.shape[0]) | ||
|
||
batch_time_m.update(time.time() - end) | ||
sps_m.update(inputs.numel() * args.world_size / batch_time_m.val) | ||
spspg_m.update(inputs.numel() / batch_time_m.val) | ||
|
||
lower, upper = losses_ci_m.compute_bootstrap_ci() | ||
lower_seq, upper_seq = losses_seq_ci_m.compute_bootstrap_ci() | ||
lower_tok, upper_tok = losses_tok_ci_m.compute_bootstrap_ci() | ||
num_seqs = losses_seq_ci_m.points.shape[0] | ||
num_toks = losses_tok_ci_m.points.shape[0] | ||
|
||
# Save eval loss / etc. | ||
log_data = { | ||
"loss": losses_m.avg, | ||
"data_time": data_time_m.avg, | ||
"batch_time": batch_time_m.avg, | ||
"samples_per_second": sps_m.avg, | ||
"samples_per_second_per_gpu": spspg_m.avg, | ||
"loss_lower_95": lower, | ||
"loss_upper_95": upper, | ||
"loss_sequences_lower_95": lower_seq, | ||
"loss_sequences_upper_95": upper_seq, | ||
"loss_tokens_lower_95": lower_tok, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also save loss tokens mean? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. given that we never see partial sequences i think this will just be loss |
||
"loss_tokens_upper_95": upper_tok, | ||
"sequences": num_seqs, | ||
"tokens": num_toks, | ||
} | ||
if args.train_num_samples is not None: | ||
log_data["tokens"] = start_epoch * args.train_num_samples * args.seq_len | ||
log_data["train_tokens"] = start_epoch * args.train_num_samples * args.seq_len | ||
|
||
for name, val in log_data.items(): | ||
name = "valid/" + name | ||
|
@@ -389,6 +412,26 @@ def evaluate(model, data, start_epoch, args, writer): | |
assert wandb is not None, "Please install wandb." | ||
wandb.log({name: val, "epoch": start_epoch, "tokens": log_data["tokens"]}) | ||
if is_master(args): | ||
print(f"evaluation on: {args.val_data}") | ||
print(f"evaluation loss: {losses_m.avg}") | ||
print(f"evaluation perplexity: {math.exp(losses_m.avg)}") | ||
print(f"num seqs: {num_seqs}") | ||
print(f"num tokens: {num_toks}") | ||
|
||
log_data["checkpoint_path"] = args.resume | ||
log_data["val_data"] = args.val_data | ||
log_data["model"] = args.hf_model if args.hf_model else args.model | ||
|
||
return log_data | ||
|
||
|
||
def evaluate_loop(model, data_list, start_epoch, args, writer): | ||
log_data_list = [] | ||
for i, data in enumerate(data_list): | ||
args_copy = copy.deepcopy(args) | ||
args_copy.val_data = [args.val_data[i]] | ||
args_copy.val_data_key = args.val_data_key[i] | ||
|
||
log_data_list.append(evaluate(model, data, start_epoch, args_copy, writer)) | ||
|
||
return log_data_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it might be good to do this as 1 write, since we are doing this over fsspec, there could be a network error and we are left with a corrupt jsonl or something (without the newline)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as in do the dumps + \n at once so its idempotent and keeps jsonl consistent. In general append mode is scary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg