Skip to content

Commit

Permalink
Simplify eval and make it usable per-query (#4)
Browse files Browse the repository at this point in the history
* Simplify evaluation by removing unused metrics
* Include changes to run per-query eval
  • Loading branch information
vyaivo authored Oct 10, 2024
1 parent 98e7463 commit 2574c28
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 331 deletions.
2 changes: 1 addition & 1 deletion file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def save_json(data, filename, logger=None):
def load_json(filename, logger=None, sort_by_id = False):
if logger:
logger.info(f'loading from {filename}')
assert filename.endswith("json") or filename.endswith("json.score"), "file provided to load_json does not end with .json extension. Please recheck!"
assert filename.endswith("json") or filename.endswith(".score"), "file provided to load_json does not end with .json extension. Please recheck!"
data = json.load(open(filename))
if sort_by_id:
for d in data:
Expand Down
175 changes: 12 additions & 163 deletions reader/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
from tqdm import tqdm

import string
from evaluate import load

from transformers import logging as t_log
t_log.set_verbosity_error() # supress bert warning

from transformers import (
AutoModelForSeq2SeqLM,
Expand All @@ -34,8 +32,6 @@
- string exact match (str_em), following ALCE
- string hit (str_hit), following ALCE
- rougeLsum, following ALCE
- QA based accuracy, following ALCE
- mauve score, following ALCE
QAMPARI
- precision, following ALCE
Expand All @@ -44,12 +40,10 @@
- f1, following ALCE
- f1 top 5, following ALCE
NQ, BIOASQ
NQ
- substring match, following RAGGED
- f1, following RAGGED
- rougel f1/precision/recall
- bertscore f1/precision/recall, following RAGGED
"""


Expand All @@ -64,7 +58,6 @@



QA_MODEL="gaotianyu1350/roberta-large-squad"
AUTOAIS_MODEL="google/t5_xxl_true_nli_mixture"

global autoais_model, autoais_tokenizer
Expand Down Expand Up @@ -240,91 +233,6 @@ def compute_len(data):
return res / cntr


def compute_qa(data):
"""Compute QA-based accuracy.
Args:
data: requires filed `qa_pairs/short_answers` and `generated_output`
Returns:
QA metrics (QA-EM, QA-F1, QA-Hit)
"""

if 'qa_pairs' not in data[0] or data[0]['qa_pairs'] is None:
logger.warn("Warning: no QA pairs found in data")
return {
'QA-EM': 0,
'QA-F1': 0,
'QA-Hit': 0,
}

# Load model
logger.info("Loading the RoBERTa-large SQuAD model for QA-based accuracy...")
qa_pipeline = pipeline("question-answering", model=QA_MODEL, device=0)
logger.info("Done")

# Get prediction
logger.info("Computing the QA-based accuracy...")
em, f1, bins = [], [], []
for item in tqdm(data):
question = [qa_pair['question'] for qa_pair in item['qa_pairs']]
context = item['generated_output'] if len(item['generated_output']) > 0 else " "
results = qa_pipeline(question=question, context=context, handle_impossible_answer=True)
loc_counter, loc_em, loc_f1 = 0, 0, 0

for idx, res in enumerate(results):
answers = item["qa_pairs"][idx]["short_answers"]
prediction = res["answer"]

loc_em += max([compute_exact(a, prediction) for a in answers])
loc_f1 += max([compute_f1(a, prediction) for a in answers])
loc_counter += 1

em.append(loc_em / loc_counter)
f1.append(loc_f1 / loc_counter)
bins.append(loc_em == loc_counter)

return {
'QA-EM': 100 * np.mean(em),
'QA-F1': 100 * np.mean(f1),
'QA-Hit': 100 * np.mean(bins)
}


def compute_mauve(data):
"""
Compute Mauve: "a measure of the gap between neural text and human text. It is computed using the Kullback–Leibler (KL) divergences between
the two distributions of text in a quantized embedding space of a large language model. MAUVE can identify differences in generation fluency."
https://arxiv.org/abs/2102.01454
Args:
data: requires 'question', 'answer' and 'generated_output' fields
Returns:
mauve score between human data and generated output
"""

logger.info("Computing MAUVE...")
human_data = []
model_data = []
for item in data:
# Remove ending punctuations
# Remove any new lines
# Truncate by 100 words
human_data.append(' '.join((item['question'] + " " + item['answer'].strip()).split()[:100]).rstrip(string.punctuation))
model_data.append(' '.join((item['question'] + " " + item['generated_output'].strip()).split()[:100]).rstrip(string.punctuation))

import mauve
out = mauve.compute_mauve(
p_text=human_data,
q_text=model_data,
device_id=0,
max_text_length=512,
verbose=True,
batch_size=8,
featurize_model_name="gpt2-large"
)
return out.mauve * 100


def _run_nli_autoais(passage, claim):
"""
Run inference for assessing AIS between a premise and hypothesis.
Expand Down Expand Up @@ -526,15 +434,14 @@ def _format_document(doc):
}


def compute_qampari_f1(data, cot=False):
def compute_qampari_f1(data):
"""
Compute qampari-specific f1: splits generation by comma and calculates precision and recall based on this and list of gold entities,
returns average over inputs
Args:
data: requires field `generated_output` and `answers`
- answers: comma separated list of entities
cot: whether answers were generated with cot prompting
"""

prec = []
Expand All @@ -545,13 +452,7 @@ def compute_qampari_f1(data, cot=False):

num_preds = []
for item in data:
if cot:
if ":" in item['generated_output']:
o = ':'.join(item['generated_output'].split(":")[1:]) # try to separate the COT part and the answer list part.
else:
o = ""
else:
o = item['generated_output']
o = item['generated_output']

# remove leading/trailing space, period or comma -> split by comma and normalize
preds = [normalize_answer(x.strip()) for x in o.rstrip().rstrip(".").rstrip(",").split(",")]
Expand Down Expand Up @@ -599,7 +500,7 @@ def compute_qampari_f1(data, cot=False):
"qampari_f1_top5_std": np.std(f1_top5)
}

def compute_ragged_metrics(normalized_data, NO_BERT, MERGE_LIST):
def compute_ragged_metrics(normalized_data):
"""
Original eval for NQ and BIOASQ from RAGGED paper
"""
Expand All @@ -619,19 +520,7 @@ def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
# get max entry by f1 score
return max(scores_for_ground_truths, key=lambda x:x['rougeLsum_f1'])
else:
return max(scores_for_ground_truths, key=lambda x:x["bertscore_f1"])

def _bertscore(prediction, ground_truth):
prediction = convert_textual_numbers_to_numeric(prediction)
ground_truth = [convert_textual_numbers_to_numeric(ans) for ans in ground_truth if ans]

bertscore = load("bertscore")
results = bertscore.compute(predictions=[normalize_answer(prediction)], references=[normalize_answer(ground_truth)], lang="en")
return {
"bertscore_precision" : results["precision"][0],
"bertscore_recall" : results["recall"][0],
"bertscore_f1" : results["f1"][0]
}
raise ValueError("Cannot take max over the given scores")

def _rougel_score(prediction, ground_truth):
# no normalization
Expand All @@ -651,7 +540,7 @@ def _rougel_score(prediction, ground_truth):
}


def kilt_eval(guess_answer, gold_candidate_answers, NO_BERT):
def kilt_eval(guess_answer, gold_candidate_answers):

# returns True if ANY of the gold candidates are in generated answer
substring_match = exact_presence(gold_candidate_answers, guess_answer)
Expand All @@ -667,29 +556,16 @@ def kilt_eval(guess_answer, gold_candidate_answers, NO_BERT):
_rougel_score, guess_answer, gold_candidate_answers
)

local_bertscore = None
if not NO_BERT:
local_bertscore = _metric_max_over_ground_truths(
_bertscore, guess_answer, gold_candidate_answers
)

return substring_match, local_f1, local_rougel, local_bertscore
return substring_match, local_f1, local_rougel

total_count = 0
normalized_substring_match = 0
normalized_f1 = 0
rougel_f1 = 0
rougel_p = 0
rougel_r = 0
if not NO_BERT:
bertscore_f1=0
bertscore_p=0
bertscore_r=0

if NO_BERT:
logger.info("Running kilt evaluation without BERT... ")
else:
logger.info("Running kilt evaluation with BERT... ")
logger.info("Running kilt evaluation...")

for reader_output_info in tqdm(normalized_data):
total_count+=1
Expand All @@ -698,15 +574,10 @@ def kilt_eval(guess_answer, gold_candidate_answers, NO_BERT):
gold_data = reader_output_info["output"]
gold_answer_list = [x for x in gold_data["answer_set"]] # considering only the short answers

# merge gold answer list if bioasq specifies answer_type==list
if MERGE_LIST and "question_type" in gold_data.keys() and gold_data["question_type"] == "list":
gold_answer_list = [" ".join(gold_answer_list)]

substring_match, local_f1, \
local_rougel, local_bertscore = kilt_eval(
local_rougel = kilt_eval(
guess_answer,
gold_answer_list,
NO_BERT
)

normalized_substring_match += substring_match
Expand All @@ -715,23 +586,13 @@ def kilt_eval(guess_answer, gold_candidate_answers, NO_BERT):
rougel_p += local_rougel["rougeLsum_p"]
rougel_r += local_rougel["rougeLsum_r"]

if not NO_BERT:
bertscore_f1 += local_bertscore["bertscore_f1"]
bertscore_p += local_bertscore["bertscore_precision"]
bertscore_r += local_bertscore["bertscore_recall"]

if total_count > 0:
normalized_substring_match /= total_count
normalized_f1 /= total_count
rougel_f1 /= total_count
rougel_p /= total_count
rougel_r /= total_count

if not NO_BERT:
bertscore_f1 /= total_count
bertscore_p /= total_count
bertscore_r /= total_count

method_metrics = {
"ragged_substring_match":round(normalized_substring_match, 4),
"ragged_f1": round(normalized_f1, 4),
Expand All @@ -740,10 +601,6 @@ def kilt_eval(guess_answer, gold_candidate_answers, NO_BERT):
"ragged_rougel_r": round(rougel_r, 4)

}
if not NO_BERT:
method_metrics["bertscore_f1"] = round(bertscore_f1, 4)
method_metrics["bertscore_p"] = round(bertscore_p, 4)
method_metrics["bertscore_r"] = round(bertscore_r, 4)

logger.info(f"total questions - dev: {total_count}/{len(gold_data)}")
logger.info("Reader metrics : ", method_metrics)
Expand Down Expand Up @@ -783,15 +640,13 @@ def main(args):
result['str_em_mean'], result['str_em_std'], \
result['str_hit_mean'], result['str_hit_std'] = compute_str_em(normalized_data)
result['rougeLsum'] = compute_rouge(normalized_data)
result.update(compute_qa(normalized_data)) # QA based accuracy with RoBERTa-large SQuAD
result['mauve'] = compute_mauve(normalized_data)

elif 'qampari' in args.f:
result.update(compute_qampari_f1(normalized_data, cot=args.cot))
result.update(compute_qampari_f1(normalized_data))
qampari = True

elif 'nq' in args.f or 'bioasq' in args.f:
result.update(compute_ragged_metrics(normalized_data, args.no_bert, args.merge_list_answers))
elif 'nq' in args.f:
result.update(compute_ragged_metrics(normalized_data))

if args.citations:
result.update(compute_autoais(
Expand All @@ -811,8 +666,6 @@ def main(args):
parser.add_argument("--f", type=str, required=True, help="Name of reader output file to evaluate in $RESULTS_PATH/reader. Should have field `question`, `generated_output`, (ROUGE) `answer`, \
(accuracy) `qa_pairs`, (AIS) `docs`")

parser.add_argument("--no_bert", action="store_true", help="Add tag to not run bert during nq and bioasq eval (it can be time consuming)") # nq and bioasq
parser.add_argument("--merge_list_answers", action="store_true", help="for bioasq, merge short answers when answer_type==list")

parser.add_argument("--citations", action="store_true", help="Evaluation with citation")
parser.add_argument("--at_most_citations", type=int, default=3, help="At most take this many documents (mostly for precision)")
Expand All @@ -825,9 +678,5 @@ def main(args):
parser.add_argument("--noise_file", type=str, default=None, help="File from which noisy documents were added")



# QAMPARI
parser.add_argument("--cot", action="store_true", help="For QAMPARI, try to find colon and separate the COT and answer listing")

args = parser.parse_args()
main(args)
Loading

0 comments on commit 2574c28

Please sign in to comment.