diff --git a/official/benchmark/bert_squad_benchmark.py b/official/benchmark/bert_squad_benchmark.py index 23bb8123dbc..c950d4703c4 100644 --- a/official/benchmark/bert_squad_benchmark.py +++ b/official/benchmark/bert_squad_benchmark.py @@ -24,12 +24,12 @@ # pylint: disable=g-bad-import-order from absl import flags +from absl import logging from absl.testing import flagsaver import tensorflow as tf # pylint: enable=g-bad-import-order from official.benchmark import bert_benchmark_utils as benchmark_utils -from official.benchmark import squad_evaluate_v1_1 from official.nlp.bert import run_squad from official.utils.misc import distribution_utils from official.utils.misc import keras_utils @@ -70,18 +70,6 @@ def _read_input_meta_data_from_file(self): with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: return json.loads(reader.read().decode('utf-8')) - def _read_predictions_dataset_from_file(self): - """Reads the predictions dataset from a file.""" - with tf.io.gfile.GFile(SQUAD_PREDICT_FILE, 'r') as reader: - dataset_json = json.load(reader) - return dataset_json['data'] - - def _read_predictions_from_file(self): - """Reads the predictions from a file.""" - predictions_file = os.path.join(FLAGS.model_dir, 'predictions.json') - with tf.io.gfile.GFile(predictions_file, 'r') as reader: - return json.load(reader) - def _get_distribution_strategy(self, ds_type='mirrored'): """Gets the distribution strategy. @@ -135,12 +123,10 @@ def _evaluate_squad(self, ds_type='mirrored'): input_meta_data = self._read_input_meta_data_from_file() strategy = self._get_distribution_strategy(ds_type) - run_squad.predict_squad(strategy=strategy, input_meta_data=input_meta_data) - - dataset = self._read_predictions_dataset_from_file() - predictions = self._read_predictions_from_file() - - eval_metrics = squad_evaluate_v1_1.evaluate(dataset, predictions) + if input_meta_data.get('version_2_with_negative', False): + logging.error('In memory evaluation result for SQuAD v2 is not accurate') + eval_metrics = run_squad.eval_squad(strategy=strategy, + input_meta_data=input_meta_data) # Use F1 score as reported evaluation metric. self.eval_metrics = eval_metrics['f1'] diff --git a/official/nlp/bert/run_squad.py b/official/nlp/bert/run_squad.py index f39b73162d1..d88aa85f371 100644 --- a/official/nlp/bert/run_squad.py +++ b/official/nlp/bert/run_squad.py @@ -20,8 +20,11 @@ import json +import os +import tempfile from absl import app from absl import flags +from absl import logging import tensorflow as tf from official.nlp.bert import configs as bert_configs @@ -52,12 +55,22 @@ def train_squad(strategy, def predict_squad(strategy, input_meta_data): - """Makes predictions for a squad dataset.""" + """Makes predictions for the squad dataset.""" bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) tokenizer = tokenization.FullTokenizer( vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) - run_squad_helper.predict_squad(strategy, input_meta_data, tokenizer, - bert_config, squad_lib_wp) + run_squad_helper.predict_squad( + strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp) + + +def eval_squad(strategy, input_meta_data): + """Evaluate on the squad dataset.""" + bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file) + tokenizer = tokenization.FullTokenizer( + vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) + eval_metrics = run_squad_helper.eval_squad( + strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp) + return eval_metrics def export_squad(model_export_path, input_meta_data): @@ -93,7 +106,8 @@ def main(_): num_gpus=FLAGS.num_gpus, all_reduce_alg=FLAGS.all_reduce_alg, tpu_address=FLAGS.tpu) - if FLAGS.mode in ('train', 'train_and_predict'): + + if 'train' in FLAGS.mode: if FLAGS.log_steps: custom_callbacks = [keras_utils.TimeHistory( batch_size=FLAGS.train_batch_size, @@ -109,8 +123,27 @@ def main(_): custom_callbacks=custom_callbacks, run_eagerly=FLAGS.run_eagerly, ) - if FLAGS.mode in ('predict', 'train_and_predict'): + if 'predict' in FLAGS.mode: predict_squad(strategy, input_meta_data) + if 'eval' in FLAGS.mode: + if input_meta_data.get('version_2_with_negative', False): + logging.error('SQuAD v2 eval is not supported. ' + 'Falling back to predict mode.') + predict_squad(strategy, input_meta_data) + else: + eval_metrics = eval_squad(strategy, input_meta_data) + f1_score = eval_metrics['f1'] + logging.info('SQuAD eval F1-score: %f', f1_score) + if (not strategy) or strategy.extended.should_save_summary: + summary_dir = os.path.join(FLAGS.model_dir, 'summaries') + else: + summary_dir = tempfile.mkdtemp() + summary_writer = tf.summary.create_file_writer( + os.path.join(summary_dir, 'eval')) + with summary_writer.as_default(): + # TODO(lehou): write to the correct step number. + tf.summary.scalar('F1-score', f1_score, step=0) + summary_writer.flush() if __name__ == '__main__': diff --git a/official/nlp/bert/run_squad_helper.py b/official/nlp/bert/run_squad_helper.py index 4fc62d44ead..51f3731ef46 100644 --- a/official/nlp/bert/run_squad_helper.py +++ b/official/nlp/bert/run_squad_helper.py @@ -18,6 +18,7 @@ from __future__ import print_function import collections +import json import os from absl import flags from absl import logging @@ -30,6 +31,7 @@ from official.nlp.bert import common_flags from official.nlp.bert import input_pipeline from official.nlp.bert import model_saving_utils +from official.nlp.bert import squad_evaluate_v1_1 from official.nlp.data import squad_lib_sp from official.utils.misc import keras_utils @@ -37,11 +39,15 @@ def define_common_squad_flags(): """Defines common flags used by SQuAD tasks.""" flags.DEFINE_enum( - 'mode', 'train_and_predict', - ['train_and_predict', 'train', 'predict', 'export_only'], - 'One of {"train_and_predict", "train", "predict", "export_only"}. ' - '`train_and_predict`: both train and predict to a json file. ' + 'mode', 'train_and_eval', + ['train_and_eval', 'train_and_predict', + 'train', 'eval', 'predict', 'export_only'], + 'One of {"train_and_eval", "train_and_predict", ' + '"train", "eval", "predict", "export_only"}. ' + '`train_and_eval`: train & predict to json files & compute eval metrics. ' + '`train_and_predict`: train & predict to json files. ' '`train`: only trains the model. ' + '`eval`: predict answers from squad json file & compute eval metrics. ' '`predict`: predict answers from the squad json file. ' '`export_only`: will take the latest checkpoint inside ' 'model_dir and export a `SavedModel`.') @@ -271,7 +277,8 @@ def clip_by_global_norm_callback(grads_and_vars): post_allreduce_callbacks=[clip_by_global_norm_callback]) -def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib): +def prediction_output_squad( + strategy, input_meta_data, tokenizer, bert_config, squad_lib): """Makes predictions for a squad dataset.""" doc_stride = input_meta_data['doc_stride'] max_query_length = input_meta_data['max_query_length'] @@ -322,23 +329,61 @@ def _append_feature(feature, is_padding): all_results = predict_squad_customized(strategy, input_meta_data, bert_config, eval_writer.filename, num_steps) + all_predictions, all_nbest_json, scores_diff_json = ( + squad_lib.postprocess_output( + eval_examples, + eval_features, + all_results, + FLAGS.n_best_size, + FLAGS.max_answer_length, + FLAGS.do_lower_case, + version_2_with_negative=version_2_with_negative, + null_score_diff_threshold=FLAGS.null_score_diff_threshold, + verbose=FLAGS.verbose_logging)) + + return all_predictions, all_nbest_json, scores_diff_json + + +def dump_to_files(all_predictions, all_nbest_json, scores_diff_json, + squad_lib, version_2_with_negative): + """Save output to json files.""" output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json') output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json') output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json') + logging.info('Writing predictions to: %s', (output_prediction_file)) + logging.info('Writing nbest to: %s', (output_nbest_file)) - squad_lib.write_predictions( - eval_examples, - eval_features, - all_results, - FLAGS.n_best_size, - FLAGS.max_answer_length, - FLAGS.do_lower_case, - output_prediction_file, - output_nbest_file, - output_null_log_odds_file, - version_2_with_negative=version_2_with_negative, - null_score_diff_threshold=FLAGS.null_score_diff_threshold, - verbose=FLAGS.verbose_logging) + squad_lib.write_to_json_files(all_predictions, output_prediction_file) + squad_lib.write_to_json_files(all_nbest_json, output_nbest_file) + if version_2_with_negative: + squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file) + + +def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib): + """Get prediction results and evaluate them to hard drive.""" + all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad( + strategy, input_meta_data, tokenizer, bert_config, squad_lib) + dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, + input_meta_data.get('version_2_with_negative', False)) + + +def eval_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib): + """Get prediction results and evaluate them against ground truth.""" + all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad( + strategy, input_meta_data, tokenizer, bert_config, squad_lib) + dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, + input_meta_data.get('version_2_with_negative', False)) + + if input_meta_data.get('version_2_with_negative', False): + # TODO(lehou): support in memory evaluation for SQuAD v2. + logging.error('SQuAD v2 eval is not supported. Skipping eval') + return None + else: + with tf.io.gfile.GFile(FLAGS.predict_file, 'r') as reader: + dataset_json = json.load(reader) + pred_dataset = dataset_json['data'] + eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions) + return eval_metrics def export_squad(model_export_path, input_meta_data, bert_config): diff --git a/official/nlp/bert/squad_evaluate_v1_1.py b/official/nlp/bert/squad_evaluate_v1_1.py new file mode 100644 index 00000000000..2495d285aea --- /dev/null +++ b/official/nlp/bert/squad_evaluate_v1_1.py @@ -0,0 +1,108 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Evaluation of SQuAD predictions (version 1.1). + +The functions are copied from +https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/. + +The SQuAD dataset is described in this paper: +SQuAD: 100,000+ Questions for Machine Comprehension of Text +Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang +https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re +import string + +# pylint: disable=g-bad-import-order +from absl import logging +# pylint: enable=g-bad-import-order + + +def _normalize_answer(s): + """Lowers text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def _f1_score(prediction, ground_truth): + """Computes F1 score by comparing prediction to ground truth.""" + prediction_tokens = _normalize_answer(prediction).split() + ground_truth_tokens = _normalize_answer(ground_truth).split() + prediction_counter = collections.Counter(prediction_tokens) + ground_truth_counter = collections.Counter(ground_truth_tokens) + common = prediction_counter & ground_truth_counter + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def _exact_match_score(prediction, ground_truth): + """Checks if predicted answer exactly matches ground truth answer.""" + return _normalize_answer(prediction) == _normalize_answer(ground_truth) + + +def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + """Computes the max over all metric scores.""" + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def evaluate(dataset, predictions): + """Evaluates predictions for a dataset.""" + f1 = exact_match = total = 0 + for article in dataset: + for paragraph in article["paragraphs"]: + for qa in paragraph["qas"]: + total += 1 + if qa["id"] not in predictions: + message = "Unanswered question " + qa["id"] + " will receive score 0." + logging.error(message) + continue + ground_truths = [entry["text"] for entry in qa["answers"]] + prediction = predictions[qa["id"]] + exact_match += _metric_max_over_ground_truths(_exact_match_score, + prediction, ground_truths) + f1 += _metric_max_over_ground_truths(_f1_score, prediction, + ground_truths) + + exact_match = exact_match / total + f1 = f1 / total + + return {"exact_match": exact_match, "f1": f1} diff --git a/official/nlp/data/squad_lib.py b/official/nlp/data/squad_lib.py index 605087b0751..54f75a660b7 100644 --- a/official/nlp/data/squad_lib.py +++ b/official/nlp/data/squad_lib.py @@ -506,6 +506,34 @@ def write_predictions(all_examples, logging.info("Writing predictions to: %s", (output_prediction_file)) logging.info("Writing nbest to: %s", (output_nbest_file)) + all_predictions, all_nbest_json, scores_diff_json = ( + postprocess_output(all_examples=all_examples, + all_features=all_features, + all_results=all_results, + n_best_size=n_best_size, + max_answer_length=max_answer_length, + do_lower_case=do_lower_case, + version_2_with_negative=version_2_with_negative, + null_score_diff_threshold=null_score_diff_threshold, + verbose=verbose)) + + write_to_json_files(all_predictions, output_prediction_file) + write_to_json_files(all_nbest_json, output_nbest_file) + if version_2_with_negative: + write_to_json_files(scores_diff_json, output_null_log_odds_file) + + +def postprocess_output(all_examples, + all_features, + all_results, + n_best_size, + max_answer_length, + do_lower_case, + version_2_with_negative=False, + null_score_diff_threshold=0.0, + verbose=False): + """Postprocess model output, to form predicton results.""" + example_index_to_features = collections.defaultdict(list) for feature in all_features: example_index_to_features[feature.example_index].append(feature) @@ -676,15 +704,12 @@ def write_predictions(all_examples, all_nbest_json[example.qas_id] = nbest_json - with tf.io.gfile.GFile(output_prediction_file, "w") as writer: - writer.write(json.dumps(all_predictions, indent=4) + "\n") + return all_predictions, all_nbest_json, scores_diff_json - with tf.io.gfile.GFile(output_nbest_file, "w") as writer: - writer.write(json.dumps(all_nbest_json, indent=4) + "\n") - if version_2_with_negative: - with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer: - writer.write(json.dumps(scores_diff_json, indent=4) + "\n") +def write_to_json_files(json_records, json_file): + with tf.io.gfile.GFile(json_file, "w") as writer: + writer.write(json.dumps(json_records, indent=4) + "\n") def get_final_text(pred_text, orig_text, do_lower_case, verbose=False): diff --git a/official/nlp/data/squad_lib_sp.py b/official/nlp/data/squad_lib_sp.py index f64d85f0aa2..6caa6e1876c 100644 --- a/official/nlp/data/squad_lib_sp.py +++ b/official/nlp/data/squad_lib_sp.py @@ -575,10 +575,39 @@ def write_predictions(all_examples, null_score_diff_threshold=0.0, verbose=False): """Write final predictions to the json file and log-odds of null if needed.""" - del do_lower_case, verbose logging.info("Writing predictions to: %s", (output_prediction_file)) logging.info("Writing nbest to: %s", (output_nbest_file)) + all_predictions, all_nbest_json, scores_diff_json = ( + postprocess_output(all_examples=all_examples, + all_features=all_features, + all_results=all_results, + n_best_size=n_best_size, + max_answer_length=max_answer_length, + do_lower_case=do_lower_case, + version_2_with_negative=version_2_with_negative, + null_score_diff_threshold=null_score_diff_threshold, + verbose=verbose)) + + write_to_json_files(all_predictions, output_prediction_file) + write_to_json_files(all_nbest_json, output_nbest_file) + if version_2_with_negative: + write_to_json_files(scores_diff_json, output_null_log_odds_file) + + +def postprocess_output(all_examples, + all_features, + all_results, + n_best_size, + max_answer_length, + do_lower_case, + version_2_with_negative=False, + null_score_diff_threshold=0.0, + verbose=False): + """Postprocess model output, to form predicton results.""" + + del do_lower_case, verbose + example_index_to_features = collections.defaultdict(list) for feature in all_features: example_index_to_features[feature.example_index].append(feature) @@ -740,15 +769,12 @@ def write_predictions(all_examples, all_nbest_json[example.qas_id] = nbest_json - with tf.io.gfile.GFile(output_prediction_file, "w") as writer: - writer.write(json.dumps(all_predictions, indent=4) + "\n") + return all_predictions, all_nbest_json, scores_diff_json - with tf.io.gfile.GFile(output_nbest_file, "w") as writer: - writer.write(json.dumps(all_nbest_json, indent=4) + "\n") - if version_2_with_negative: - with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer: - writer.write(json.dumps(scores_diff_json, indent=4) + "\n") +def write_to_json_files(json_records, json_file): + with tf.io.gfile.GFile(json_file, "w") as writer: + writer.write(json.dumps(json_records, indent=4) + "\n") def _get_best_indexes(logits, n_best_size):