Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 302921283
  • Loading branch information
will-cromar authored and tensorflower-gardener committed Mar 25, 2020
1 parent 9fc4fd0 commit 7c83a9d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
5 changes: 4 additions & 1 deletion official/nlp/transformer/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,10 @@ def get_callbacks(steps_per_epoch):
"""Returns common callbacks."""
callbacks = []
if FLAGS.enable_time_history:
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
time_callback = keras_utils.TimeHistory(
FLAGS.batch_size,
FLAGS.log_steps,
FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks.append(time_callback)

if FLAGS.enable_tensorboard:
Expand Down
18 changes: 18 additions & 0 deletions official/nlp/transformer/transformer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,11 @@ def train(self):

callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)

# Only TimeHistory callback is supported for CTL
if params["use_ctl"]:
callbacks = [cb for cb in callbacks
if isinstance(cb, keras_utils.TimeHistory)]

# TODO(b/139418525): Refactor the custom training loop logic.
@tf.function
def train_steps(iterator, steps):
Expand Down Expand Up @@ -299,8 +304,13 @@ def _step_fn(inputs):
if not self.use_tpu:
raise NotImplementedError(
"Custom training loop on GPUs is not implemented.")

# Runs training steps.
with summary_writer.as_default():
for cb in callbacks:
cb.on_epoch_begin(current_iteration)
cb.on_batch_begin(0)

train_steps(
train_ds_iterator,
tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
Expand All @@ -309,10 +319,18 @@ def _step_fn(inputs):
logging.info("Train Step: %d/%d / loss = %s", current_step,
flags_obj.train_steps, train_loss)

for cb in callbacks:
cb.on_batch_end(train_steps_per_eval - 1)
cb.on_epoch_end(current_iteration)

if params["enable_tensorboard"]:
for metric_obj in train_metrics:
tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(),
current_step)
summary_writer.flush()

for cb in callbacks:
cb.on_train_end()

if flags_obj.enable_checkpointing:
# avoid check-pointing when running for benchmarking.
Expand Down

0 comments on commit 7c83a9d

Please sign in to comment.