Skip to content

Commit

Permalink
Merge branch 'perplexity'
Browse files Browse the repository at this point in the history
# Conflicts:
#	lm_eval/models/gpt2.py
#	lm_eval/models/gpt3.py
#	tests/test_misc.py
#	tests/test_models.py
  • Loading branch information
leogao2 committed May 11, 2021
2 parents eb8456d + 5452fdd commit c77b60c
Show file tree
Hide file tree
Showing 13 changed files with 684 additions and 20 deletions.
108 changes: 105 additions & 3 deletions lm_eval/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import abc
import random
import numpy as np
import re

from lm_eval.metrics import mean
from lm_eval.metrics import mean, perplexity, weighted_mean


class LM(abc.ABC):
Expand All @@ -27,9 +28,51 @@ def loglikelihood(self, requests):
:return: list
A list of pairs (logprob, isgreedy)
logprob: float
The log probability of `contination`
The log probability of `continuation`
isgreedy:
Whether `contination` would be generated by greedy sampling from `context`
Whether `continuation` would be generated by greedy sampling from `context`
"""
pass

@abc.abstractmethod
def loglikelihood_rolling(self, requests):
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
the max context length.
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: EOT
Max context length: 4
Resulting input/prediction pairs:
INPUT: EOT 0 1 2
PRED: 0 1 2 3
INPUT: 3 4 5 6
PRED: 4 5 6 7
INPUT: 5 6 7 8
PRED: 8 9
Observe that:
1. Each token is predicted exactly once
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list
A list of strings
string: str
String for which we are computing per-toke loglikelihood
:return: list
A list of pairs (logprob, isgreedy)
logprob: float
The log probability of `continuation`
isgreedy:
Whether `continuation` would be generated by greedy sampling from `context`
"""
pass

Expand Down Expand Up @@ -247,9 +290,68 @@ def aggregation(self):
}


class PerplexityTask(Task, abc.ABC):

def has_training_docs(self):
return False

def fewshot_description(self):
return ""

def fewshot_examples(self, k, rnd):
assert k == 0
return []

def fewshot_context(self, doc, num_fewshot, provide_description, rnd):
assert num_fewshot == 0
assert not provide_description
return ""

def higher_is_better(self):
return {
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}

def doc_to_text(self, doc):
return doc

def doc_to_target(self, doc):
raise NotImplementedError()

def construct_requests(self, doc, ctx):
assert not ctx
req = rf.loglikelihood_rolling(doc)
return req

def process_results(self, doc, results):
loglikelihood, = results
return {
"word_perplexity": loglikelihood / self.count_words(self.doc_to_text(doc)),
"byte_perplexity": loglikelihood / self.count_bytes(self.doc_to_text(doc)),
"bits_per_byte": (-loglikelihood, self.count_bytes(self.doc_to_text(doc)))
}

def aggregation(self):
return {
"word_perplexity": perplexity,
"byte_perplexity": perplexity,
"bits_per_byte": weighted_mean
}

def count_bytes(self, s):
return len(s.encode("utf-8"))

def count_words(self, s):
""" Downstream tasks with custom word boundaries should override this! """
return len(re.split(r"\s+", s))


req_ret_lens = {
'loglikelihood': 2,
'greedy_until': None,
'loglikelihood_rolling': None,
}

import os
Expand Down
1 change: 1 addition & 0 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
# only in index. We could implement some kind of caching, but that would be more of a bandaid
# solution. we could also implement some kind of autogrouping here; they should end up next to each other.

print("Running", reqtype, "requests")
resps = getattr(lm, reqtype)([req.args for req in reqs])

resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]
Expand Down
5 changes: 5 additions & 0 deletions lm_eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ def perplexity(items):
return math.exp(-mean(items))


def weighted_mean(items):
a, b = zip(*items)
return sum(a) / sum(b)


def bleu(items):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
Expand Down
8 changes: 8 additions & 0 deletions lm_eval/models/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,11 @@ def greedy_until(self, requests):
assert ctx.strip() != ''

return res

def loglikelihood_rolling(self, requests):
res = []

for _ in requests:
res.append(-random.random())

return res
81 changes: 69 additions & 12 deletions lm_eval/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from lm_eval.base import LM
from lm_eval import utils
from tqdm import tqdm
import numpy as np


class GPT2LM(LM):
MAX_GEN_TOKS = 256
VOCAB_SIZE = 50257
EOT_TOKEN_ID = 50256

def __init__(self, device='cuda', pretrained='gpt2', batch_size=1):
super().__init__()
Expand Down Expand Up @@ -51,7 +54,7 @@ def loglikelihood(self, requests):
for context, continuation in requests:
if context == "":
# end of text as context
context_enc = [50256]
context_enc = [self.EOT_TOKEN_ID]
else:
context_enc = self.tokenizer.encode(context)

Expand All @@ -61,7 +64,36 @@ def loglikelihood(self, requests):

return self._loglikelihood_tokens(new_reqs)

def _loglikelihood_tokens(self, requests):
def loglikelihood_rolling(self, requests):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization

loglikelihoods = []
with torch.no_grad():
for string, in tqdm(requests):
encoded = self.tokenizer.encode_plus(string)["input_ids"]

rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
token_list=encoded,
prefix_token=self.EOT_TOKEN_ID,
max_seq_len=self.max_length,
context_len=1,
)))

rolling_token_windows = [(None,) + x for x in rolling_token_windows]

# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)

# discard is_greedy
string_nll = [x[0] for x in string_nll]

string_nll = sum(string_nll)
loglikelihoods.append(string_nll)

return loglikelihoods

def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
with torch.no_grad():
Expand All @@ -78,18 +110,38 @@ def _collate(x):

# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered()), self.batch_size):
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
inps = []
contlens = []
inplens = []
ctxlens = []

padding_length = None

# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying

for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length

# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.VOCAB_SIZE] slice
# cont_toks 4 5 6 7 8 9

# when too long to fit in context, truncate from the left
inp = torch.tensor((context_enc + continuation_enc)[-self.max_length:], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1]
, dtype=torch.long).to(self.device)
inplen, = inp.shape

cont = continuation_enc

# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen

Expand All @@ -100,19 +152,24 @@ def _collate(x):
], dim=0)

inps.append(inp.unsqueeze(0))
contlens.append(cont)
inplens.append(inplen)
ctxlens.append(ctxlen)

multi_logits = F.log_softmax(self.gpt2(torch.cat(inps, dim=0))[0][:, :, :50257], dim=-1) # [batch, seq, vocab]
multi_logits = F.log_softmax(self.gpt2(torch.cat(inps, dim=0))[0][:, :, :50257], dim=-1).cpu() # [batch, seq, vocab]

for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
contlen = len(cont_toks)

for (cache_key, _, _), logits, ctxlen, inp, inplen in zip(chunk, multi_logits, ctxlens, inps, inplens):
logits = logits[ctxlen - 1:inplen - 1].unsqueeze(0) # [1, seq, vocab]
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]

greedy_tokens = logits.argmax(dim=-1)
cont_toks = inp[:, ctxlen:inplen] # [1, seq]

# cont_toks :: [1, seq]
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)

max_equal = (greedy_tokens == cont_toks).all()

last_token_slice = logits[:, -1, :].squeeze(0).tolist()
#last_token_slice = logits[:, -1, :].squeeze(0).tolist()

logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]

Expand Down
51 changes: 49 additions & 2 deletions lm_eval/models/gpt3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import numpy as np
import transformers
from lm_eval.base import LM
from lm_eval import utils
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(self, engine, truncate=False):
self.tokenizer.pad_token = "<|endoftext|>"
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0]

# Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
Expand All @@ -83,6 +85,30 @@ def loglikelihood(self, requests):

return self._loglikelihood_tokens(new_reqs)

def loglikelihood_rolling(self, requests):
# TODO: switch implementation to use _loglikelihood_tokens rather than having it do its own thing

loglikelihoods = []
for string, in tqdm(requests):
encoded = self.tokenizer.encode_plus(string)["input_ids"]
rolling_token_windows = utils.get_rolling_token_windows(
token_list=encoded,
prefix_token=self.end_of_text_token_id,
max_seq_len=self.MAX_LENGTH,
context_len=1,
)
string_loglikelihoods = []
for input_tokens, pred_tokens in rolling_token_windows:
block_output = self.get_token_logprobs(
input_tokens=input_tokens,
pred_tokens=pred_tokens,
)
string_loglikelihoods.append(block_output["logprobs"])
string_loglikelihoods = np.concatenate(string_loglikelihoods).sum()
loglikelihoods.append(string_loglikelihoods)

return loglikelihoods

def _loglikelihood_tokens(self, requests):
import openai
res = []
Expand All @@ -95,7 +121,7 @@ def _collate(x):
return (-len(toks), tuple(toks))

reord = utils.Reorderer(requests, _collate)

for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
inps = []
ctxlens = []
Expand All @@ -122,9 +148,30 @@ def _collate(x):
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)

return reord.get_original(res)

def get_token_logprobs(self, input_tokens, pred_tokens):
pred_start = len(input_tokens) - len(pred_tokens) + 1
# We're going to stitch together the input_tokens and pred_tokens
# In the longest case, this gets us to length = max_seq_len+1 (which the API works with)
assert input_tokens[pred_start:] == pred_tokens[:-1]
token_ids = input_tokens + [pred_tokens[-1]]
response = oa_completion(
engine=self.engine,
prompt=token_ids,
max_tokens=0,
temperature=0.0,
logprobs=0,
echo=True,
)
logprobs = np.array(response["choices"][0]["logprobs"]["token_logprobs"][pred_start:])
positions = np.arange(pred_start-1, pred_start-1 + len(token_ids[pred_start:]))
return {
"logprobs": logprobs,
"positions": positions,
}

def greedy_until(self, requests):
if not requests: return []
import openai
Expand Down
Loading

0 comments on commit c77b60c

Please sign in to comment.