diff --git a/official/nlp/transformer/misc.py b/official/nlp/transformer/misc.py index 1baefe24dd6..3eb430f06f5 100644 --- a/official/nlp/transformer/misc.py +++ b/official/nlp/transformer/misc.py @@ -239,7 +239,10 @@ def get_callbacks(steps_per_epoch): """Returns common callbacks.""" callbacks = [] if FLAGS.enable_time_history: - time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps) + time_callback = keras_utils.TimeHistory( + FLAGS.batch_size, + FLAGS.log_steps, + FLAGS.model_dir if FLAGS.enable_tensorboard else None) callbacks.append(time_callback) if FLAGS.enable_tensorboard: diff --git a/official/nlp/transformer/transformer_main.py b/official/nlp/transformer/transformer_main.py index 126262fc6a7..f72cbe73762 100644 --- a/official/nlp/transformer/transformer_main.py +++ b/official/nlp/transformer/transformer_main.py @@ -246,6 +246,11 @@ def train(self): callbacks = self._create_callbacks(flags_obj.model_dir, 0, params) + # Only TimeHistory callback is supported for CTL + if params["use_ctl"]: + callbacks = [cb for cb in callbacks + if isinstance(cb, keras_utils.TimeHistory)] + # TODO(b/139418525): Refactor the custom training loop logic. @tf.function def train_steps(iterator, steps): @@ -299,8 +304,13 @@ def _step_fn(inputs): if not self.use_tpu: raise NotImplementedError( "Custom training loop on GPUs is not implemented.") + # Runs training steps. with summary_writer.as_default(): + for cb in callbacks: + cb.on_epoch_begin(current_iteration) + cb.on_batch_begin(0) + train_steps( train_ds_iterator, tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32)) @@ -309,10 +319,18 @@ def _step_fn(inputs): logging.info("Train Step: %d/%d / loss = %s", current_step, flags_obj.train_steps, train_loss) + for cb in callbacks: + cb.on_batch_end(train_steps_per_eval - 1) + cb.on_epoch_end(current_iteration) + if params["enable_tensorboard"]: for metric_obj in train_metrics: tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(), current_step) + summary_writer.flush() + + for cb in callbacks: + cb.on_train_end() if flags_obj.enable_checkpointing: # avoid check-pointing when running for benchmarking.