From 13d44a051751808ec11eb075cc3b3adab87a49f3 Mon Sep 17 00:00:00 2001 From: Hongkun Yu Date: Sun, 22 Mar 2020 14:13:04 -0700 Subject: [PATCH] Fix controller bugs. Add tests for optional args. PiperOrigin-RevId: 302323163 --- official/staging/training/controller.py | 21 ++++++--- official/staging/training/controller_test.py | 46 ++++++++++++++++++++ 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/official/staging/training/controller.py b/official/staging/training/controller.py index 939aa09d262..a07be66329a 100644 --- a/official/staging/training/controller.py +++ b/official/staging/training/controller.py @@ -117,11 +117,18 @@ def __init__( if self.train_fn is not None: self.train_steps = train_steps self.steps_per_loop = steps_per_loop - self.summary_dir = summary_dir or checkpoint_manager.directory + if summary_dir: + self.summary_dir = summary_dir + elif checkpoint_manager: + self.summary_dir = checkpoint_manager.directory + else: + self.summary_dir = None self.summary_interval = summary_interval - summary_writer = tf.summary.create_file_writer( - self.summary_dir) if self.summary_interval else None + if self.summary_dir and self.summary_interval: + summary_writer = tf.summary.create_file_writer(self.summary_dir) + else: + summary_writer = None # TODO(rxsang): Consider pass SummaryManager directly into Controller for # maximum customizability. self.summary_manager = utils.SummaryManager( @@ -140,14 +147,14 @@ def __init__( self.eval_steps = eval_steps self.eval_interval = eval_interval - # Create and initialize the interval triggers. + # Creates and initializes the interval triggers. self.eval_trigger = utils.IntervalTrigger(self.eval_interval, - self.global_step.numpy()) + self.global_step.numpy()) # pytype: disable=attribute-error if self.global_step: tf.summary.experimental.set_step(self.global_step) - # Restore Model if needed. + # Restores the model if needed. if self.checkpoint_manager is not None: model_restored = self._restore_model() if not model_restored and self.checkpoint_manager.checkpoint_interval: @@ -192,7 +199,7 @@ def _evaluate_once(self, current_step): self.eval_summary_manager.flush() def _maybe_save_checkpoints(self, current_step, force_trigger=False): - if self.checkpoint_manager.checkpoint_interval: + if self.checkpoint_manager and self.checkpoint_manager.checkpoint_interval: ckpt_path = self.checkpoint_manager.save( checkpoint_number=current_step, check_interval=not force_trigger) if ckpt_path is not None: diff --git a/official/staging/training/controller_test.py b/official/staging/training/controller_test.py index d7a7282fc99..eeaa191c04d 100644 --- a/official/staging/training/controller_test.py +++ b/official/staging/training/controller_test.py @@ -143,6 +143,52 @@ def setUp(self): super(ControllerTest, self).setUp() self.model_dir = self.get_temp_dir() + def test_no_checkpoint(self): + test_runnable = TestRunnable() + # No checkpoint manager and no strategy. + test_controller = controller.Controller( + train_fn=test_runnable.train, + eval_fn=test_runnable.evaluate, + global_step=test_runnable.global_step, + train_steps=10, + steps_per_loop=2, + summary_dir=os.path.join(self.model_dir, "summaries/train"), + summary_interval=2, + eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), + eval_steps=2, + eval_interval=5) + test_controller.train(evaluate=True) + self.assertEqual(test_runnable.global_step.numpy(), 10) + # Loss and accuracy values should be written into summaries. + self.assertNotEmpty( + tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) + self.assertTrue( + check_eventfile_for_keyword( + "loss", os.path.join(self.model_dir, "summaries/train"))) + self.assertNotEmpty( + tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval"))) + self.assertTrue( + check_eventfile_for_keyword( + "eval_loss", os.path.join(self.model_dir, "summaries/eval"))) + # No checkpoint, so global step starts from 0. + test_runnable.global_step.assign(0) + test_controller.train(evaluate=True) + self.assertEqual(test_runnable.global_step.numpy(), 10) + + def test_no_checkpoint_and_summaries(self): + test_runnable = TestRunnable() + # No checkpoint + summary directories. + test_controller = controller.Controller( + train_fn=test_runnable.train, + eval_fn=test_runnable.evaluate, + global_step=test_runnable.global_step, + train_steps=10, + steps_per_loop=2, + eval_steps=2, + eval_interval=5) + test_controller.train(evaluate=True) + self.assertEqual(test_runnable.global_step.numpy(), 10) + @combinations.generate(all_strategy_combinations()) def test_train_and_evaluate(self, strategy): with strategy.scope():