From 384867252aff91a8630d2bf31deddd557bec4d86 Mon Sep 17 00:00:00 2001 From: Hongkun Yu Date: Wed, 25 Mar 2020 14:58:22 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 302977474 --- official/nlp/bert/README.md | 5 +++++ official/nlp/bert/common_flags.py | 2 +- official/nlp/bert/run_classifier.py | 8 +++++++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/official/nlp/bert/README.md b/official/nlp/bert/README.md index 98f397f35a7..c3c829776fe 100644 --- a/official/nlp/bert/README.md +++ b/official/nlp/bert/README.md @@ -269,6 +269,7 @@ python run_classifier.py \ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \ --train_batch_size=32 \ --eval_batch_size=32 \ + --steps_per_loop=1000 \ --learning_rate=2e-5 \ --num_train_epochs=3 \ --model_dir=${MODEL_DIR} \ @@ -276,6 +277,10 @@ python run_classifier.py \ --tpu=grpc://${TPU_IP_ADDRESS}:8470 ``` +Note that, we specify `steps_per_loop=1000` for TPU, because running a loop of +training steps inside a `tf.function` can significantly increase TPU utilization +and callbacks will not be called inside the loop. + ### SQuAD 1.1 The Stanford Question Answering Dataset (SQuAD) is a popular question answering diff --git a/official/nlp/bert/common_flags.py b/official/nlp/bert/common_flags.py index 4a3be35ccd4..a44b4090506 100644 --- a/official/nlp/bert/common_flags.py +++ b/official/nlp/bert/common_flags.py @@ -57,7 +57,7 @@ def define_common_bert_flags(): flags.DEFINE_integer('num_train_epochs', 3, 'Total number of training epochs to perform.') flags.DEFINE_integer( - 'steps_per_loop', 200, + 'steps_per_loop', 1, 'Number of steps per graph-mode loop. Only training step ' 'happens inside the loop. Callbacks will not be called ' 'inside.') diff --git a/official/nlp/bert/run_classifier.py b/official/nlp/bert/run_classifier.py index 2734d028805..d45a78820b6 100644 --- a/official/nlp/bert/run_classifier.py +++ b/official/nlp/bert/run_classifier.py @@ -156,6 +156,7 @@ def metric_fn(): init_checkpoint, epochs, steps_per_epoch, + steps_per_loop, eval_steps, custom_callbacks=custom_callbacks) @@ -189,6 +190,7 @@ def run_keras_compile_fit(model_dir, init_checkpoint, epochs, steps_per_epoch, + steps_per_loop, eval_steps, custom_callbacks=None): """Runs BERT classifier model using Keras compile/fit API.""" @@ -203,7 +205,11 @@ def run_keras_compile_fit(model_dir, checkpoint = tf.train.Checkpoint(model=sub_model) checkpoint.restore(init_checkpoint).assert_existing_objects_matched() - bert_model.compile(optimizer=optimizer, loss=loss_fn, metrics=[metric_fn()]) + bert_model.compile( + optimizer=optimizer, + loss=loss_fn, + metrics=[metric_fn()], + experimental_steps_per_execution=steps_per_loop) summary_dir = os.path.join(model_dir, 'summaries') summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)