Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 302977474
  • Loading branch information
saberkun authored and tensorflower-gardener committed Mar 25, 2020
1 parent 7e6167a commit 3848672
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
5 changes: 5 additions & 0 deletions official/nlp/bert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,18 @@ python run_classifier.py \
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
--train_batch_size=32 \
--eval_batch_size=32 \
--steps_per_loop=1000 \
--learning_rate=2e-5 \
--num_train_epochs=3 \
--model_dir=${MODEL_DIR} \
--distribution_strategy=tpu \
--tpu=grpc://${TPU_IP_ADDRESS}:8470
```

Note that, we specify `steps_per_loop=1000` for TPU, because running a loop of
training steps inside a `tf.function` can significantly increase TPU utilization
and callbacks will not be called inside the loop.

### SQuAD 1.1

The Stanford Question Answering Dataset (SQuAD) is a popular question answering
Expand Down
2 changes: 1 addition & 1 deletion official/nlp/bert/common_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def define_common_bert_flags():
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer(
'steps_per_loop', 200,
'steps_per_loop', 1,
'Number of steps per graph-mode loop. Only training step '
'happens inside the loop. Callbacks will not be called '
'inside.')
Expand Down
8 changes: 7 additions & 1 deletion official/nlp/bert/run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def metric_fn():
init_checkpoint,
epochs,
steps_per_epoch,
steps_per_loop,
eval_steps,
custom_callbacks=custom_callbacks)

Expand Down Expand Up @@ -189,6 +190,7 @@ def run_keras_compile_fit(model_dir,
init_checkpoint,
epochs,
steps_per_epoch,
steps_per_loop,
eval_steps,
custom_callbacks=None):
"""Runs BERT classifier model using Keras compile/fit API."""
Expand All @@ -203,7 +205,11 @@ def run_keras_compile_fit(model_dir,
checkpoint = tf.train.Checkpoint(model=sub_model)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched()

bert_model.compile(optimizer=optimizer, loss=loss_fn, metrics=[metric_fn()])
bert_model.compile(
optimizer=optimizer,
loss=loss_fn,
metrics=[metric_fn()],
experimental_steps_per_execution=steps_per_loop)

summary_dir = os.path.join(model_dir, 'summaries')
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
Expand Down

0 comments on commit 3848672

Please sign in to comment.