Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 302552996
  • Loading branch information
tensorflower-gardener committed Mar 24, 2020
1 parent 13d44a0 commit 1ec383c
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 57 deletions.
24 changes: 5 additions & 19 deletions official/benchmark/bert_squad_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@

# pylint: disable=g-bad-import-order
from absl import flags
from absl import logging
from absl.testing import flagsaver
import tensorflow as tf
# pylint: enable=g-bad-import-order

from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.benchmark import squad_evaluate_v1_1
from official.nlp.bert import run_squad
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
Expand Down Expand Up @@ -70,18 +70,6 @@ def _read_input_meta_data_from_file(self):
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
return json.loads(reader.read().decode('utf-8'))

def _read_predictions_dataset_from_file(self):
"""Reads the predictions dataset from a file."""
with tf.io.gfile.GFile(SQUAD_PREDICT_FILE, 'r') as reader:
dataset_json = json.load(reader)
return dataset_json['data']

def _read_predictions_from_file(self):
"""Reads the predictions from a file."""
predictions_file = os.path.join(FLAGS.model_dir, 'predictions.json')
with tf.io.gfile.GFile(predictions_file, 'r') as reader:
return json.load(reader)

def _get_distribution_strategy(self, ds_type='mirrored'):
"""Gets the distribution strategy.
Expand Down Expand Up @@ -135,12 +123,10 @@ def _evaluate_squad(self, ds_type='mirrored'):
input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(ds_type)

run_squad.predict_squad(strategy=strategy, input_meta_data=input_meta_data)

dataset = self._read_predictions_dataset_from_file()
predictions = self._read_predictions_from_file()

eval_metrics = squad_evaluate_v1_1.evaluate(dataset, predictions)
if input_meta_data.get('version_2_with_negative', False):
logging.error('In memory evaluation result for SQuAD v2 is not accurate')
eval_metrics = run_squad.eval_squad(strategy=strategy,
input_meta_data=input_meta_data)
# Use F1 score as reported evaluation metric.
self.eval_metrics = eval_metrics['f1']

Expand Down
43 changes: 38 additions & 5 deletions official/nlp/bert/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@

import json

import os
import tempfile
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf

from official.nlp.bert import configs as bert_configs
Expand Down Expand Up @@ -52,12 +55,22 @@ def train_squad(strategy,


def predict_squad(strategy, input_meta_data):
"""Makes predictions for a squad dataset."""
"""Makes predictions for the squad dataset."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
run_squad_helper.predict_squad(strategy, input_meta_data, tokenizer,
bert_config, squad_lib_wp)
run_squad_helper.predict_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)


def eval_squad(strategy, input_meta_data):
"""Evaluate on the squad dataset."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
eval_metrics = run_squad_helper.eval_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
return eval_metrics


def export_squad(model_export_path, input_meta_data):
Expand Down Expand Up @@ -93,7 +106,8 @@ def main(_):
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'):

if 'train' in FLAGS.mode:
if FLAGS.log_steps:
custom_callbacks = [keras_utils.TimeHistory(
batch_size=FLAGS.train_batch_size,
Expand All @@ -109,8 +123,27 @@ def main(_):
custom_callbacks=custom_callbacks,
run_eagerly=FLAGS.run_eagerly,
)
if FLAGS.mode in ('predict', 'train_and_predict'):
if 'predict' in FLAGS.mode:
predict_squad(strategy, input_meta_data)
if 'eval' in FLAGS.mode:
if input_meta_data.get('version_2_with_negative', False):
logging.error('SQuAD v2 eval is not supported. '
'Falling back to predict mode.')
predict_squad(strategy, input_meta_data)
else:
eval_metrics = eval_squad(strategy, input_meta_data)
f1_score = eval_metrics['f1']
logging.info('SQuAD eval F1-score: %f', f1_score)
if (not strategy) or strategy.extended.should_save_summary:
summary_dir = os.path.join(FLAGS.model_dir, 'summaries')
else:
summary_dir = tempfile.mkdtemp()
summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, 'eval'))
with summary_writer.as_default():
# TODO(lehou): write to the correct step number.
tf.summary.scalar('F1-score', f1_score, step=0)
summary_writer.flush()


if __name__ == '__main__':
Expand Down
81 changes: 63 additions & 18 deletions official/nlp/bert/run_squad_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import print_function

import collections
import json
import os
from absl import flags
from absl import logging
Expand All @@ -30,18 +31,23 @@
from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils
from official.nlp.bert import squad_evaluate_v1_1
from official.nlp.data import squad_lib_sp
from official.utils.misc import keras_utils


def define_common_squad_flags():
"""Defines common flags used by SQuAD tasks."""
flags.DEFINE_enum(
'mode', 'train_and_predict',
['train_and_predict', 'train', 'predict', 'export_only'],
'One of {"train_and_predict", "train", "predict", "export_only"}. '
'`train_and_predict`: both train and predict to a json file. '
'mode', 'train_and_eval',
['train_and_eval', 'train_and_predict',
'train', 'eval', 'predict', 'export_only'],
'One of {"train_and_eval", "train_and_predict", '
'"train", "eval", "predict", "export_only"}. '
'`train_and_eval`: train & predict to json files & compute eval metrics. '
'`train_and_predict`: train & predict to json files. '
'`train`: only trains the model. '
'`eval`: predict answers from squad json file & compute eval metrics. '
'`predict`: predict answers from the squad json file. '
'`export_only`: will take the latest checkpoint inside '
'model_dir and export a `SavedModel`.')
Expand Down Expand Up @@ -271,7 +277,8 @@ def clip_by_global_norm_callback(grads_and_vars):
post_allreduce_callbacks=[clip_by_global_norm_callback])


def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
def prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib):
"""Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
Expand Down Expand Up @@ -322,23 +329,61 @@ def _append_feature(feature, is_padding):
all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
eval_writer.filename, num_steps)

all_predictions, all_nbest_json, scores_diff_json = (
squad_lib.postprocess_output(
eval_examples,
eval_features,
all_results,
FLAGS.n_best_size,
FLAGS.max_answer_length,
FLAGS.do_lower_case,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=FLAGS.null_score_diff_threshold,
verbose=FLAGS.verbose_logging))

return all_predictions, all_nbest_json, scores_diff_json


def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib, version_2_with_negative):
"""Save output to json files."""
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
logging.info('Writing predictions to: %s', (output_prediction_file))
logging.info('Writing nbest to: %s', (output_nbest_file))

squad_lib.write_predictions(
eval_examples,
eval_features,
all_results,
FLAGS.n_best_size,
FLAGS.max_answer_length,
FLAGS.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
version_2_with_negative=version_2_with_negative,
null_score_diff_threshold=FLAGS.null_score_diff_threshold,
verbose=FLAGS.verbose_logging)
squad_lib.write_to_json_files(all_predictions, output_prediction_file)
squad_lib.write_to_json_files(all_nbest_json, output_nbest_file)
if version_2_with_negative:
squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)


def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
"""Get prediction results and evaluate them to hard drive."""
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False))


def eval_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
"""Get prediction results and evaluate them against ground truth."""
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False))

if input_meta_data.get('version_2_with_negative', False):
# TODO(lehou): support in memory evaluation for SQuAD v2.
logging.error('SQuAD v2 eval is not supported. Skipping eval')
return None
else:
with tf.io.gfile.GFile(FLAGS.predict_file, 'r') as reader:
dataset_json = json.load(reader)
pred_dataset = dataset_json['data']
eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
return eval_metrics


def export_squad(model_export_path, input_meta_data, bert_config):
Expand Down
108 changes: 108 additions & 0 deletions official/nlp/bert/squad_evaluate_v1_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluation of SQuAD predictions (version 1.1).
The functions are copied from
https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.
The SQuAD dataset is described in this paper:
SQuAD: 100,000+ Questions for Machine Comprehension of Text
Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import re
import string

# pylint: disable=g-bad-import-order
from absl import logging
# pylint: enable=g-bad-import-order


def _normalize_answer(s):
"""Lowers text and remove punctuation, articles and extra whitespace."""

def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)

def white_space_fix(text):
return " ".join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))


def _f1_score(prediction, ground_truth):
"""Computes F1 score by comparing prediction to ground truth."""
prediction_tokens = _normalize_answer(prediction).split()
ground_truth_tokens = _normalize_answer(ground_truth).split()
prediction_counter = collections.Counter(prediction_tokens)
ground_truth_counter = collections.Counter(ground_truth_tokens)
common = prediction_counter & ground_truth_counter
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1


def _exact_match_score(prediction, ground_truth):
"""Checks if predicted answer exactly matches ground truth answer."""
return _normalize_answer(prediction) == _normalize_answer(ground_truth)


def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Computes the max over all metric scores."""
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)


def evaluate(dataset, predictions):
"""Evaluates predictions for a dataset."""
f1 = exact_match = total = 0
for article in dataset:
for paragraph in article["paragraphs"]:
for qa in paragraph["qas"]:
total += 1
if qa["id"] not in predictions:
message = "Unanswered question " + qa["id"] + " will receive score 0."
logging.error(message)
continue
ground_truths = [entry["text"] for entry in qa["answers"]]
prediction = predictions[qa["id"]]
exact_match += _metric_max_over_ground_truths(_exact_match_score,
prediction, ground_truths)
f1 += _metric_max_over_ground_truths(_f1_score, prediction,
ground_truths)

exact_match = exact_match / total
f1 = f1 / total

return {"exact_match": exact_match, "f1": f1}
Loading

0 comments on commit 1ec383c

Please sign in to comment.