Skip to content

Commit

Permalink
Fix controller bugs. Add tests for optional args.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 302323163
  • Loading branch information
saberkun authored and tensorflower-gardener committed Mar 22, 2020
1 parent a348a90 commit 13d44a0
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 7 deletions.
21 changes: 14 additions & 7 deletions official/staging/training/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,18 @@ def __init__(
if self.train_fn is not None:
self.train_steps = train_steps
self.steps_per_loop = steps_per_loop
self.summary_dir = summary_dir or checkpoint_manager.directory
if summary_dir:
self.summary_dir = summary_dir
elif checkpoint_manager:
self.summary_dir = checkpoint_manager.directory
else:
self.summary_dir = None

self.summary_interval = summary_interval
summary_writer = tf.summary.create_file_writer(
self.summary_dir) if self.summary_interval else None
if self.summary_dir and self.summary_interval:
summary_writer = tf.summary.create_file_writer(self.summary_dir)
else:
summary_writer = None
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
# maximum customizability.
self.summary_manager = utils.SummaryManager(
Expand All @@ -140,14 +147,14 @@ def __init__(
self.eval_steps = eval_steps
self.eval_interval = eval_interval

# Create and initialize the interval triggers.
# Creates and initializes the interval triggers.
self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
self.global_step.numpy())
self.global_step.numpy()) # pytype: disable=attribute-error

if self.global_step:
tf.summary.experimental.set_step(self.global_step)

# Restore Model if needed.
# Restores the model if needed.
if self.checkpoint_manager is not None:
model_restored = self._restore_model()
if not model_restored and self.checkpoint_manager.checkpoint_interval:
Expand Down Expand Up @@ -192,7 +199,7 @@ def _evaluate_once(self, current_step):
self.eval_summary_manager.flush()

def _maybe_save_checkpoints(self, current_step, force_trigger=False):
if self.checkpoint_manager.checkpoint_interval:
if self.checkpoint_manager and self.checkpoint_manager.checkpoint_interval:
ckpt_path = self.checkpoint_manager.save(
checkpoint_number=current_step, check_interval=not force_trigger)
if ckpt_path is not None:
Expand Down
46 changes: 46 additions & 0 deletions official/staging/training/controller_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,52 @@ def setUp(self):
super(ControllerTest, self).setUp()
self.model_dir = self.get_temp_dir()

def test_no_checkpoint(self):
test_runnable = TestRunnable()
# No checkpoint manager and no strategy.
test_controller = controller.Controller(
train_fn=test_runnable.train,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
eval_steps=2,
eval_interval=5)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10)
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertTrue(
check_eventfile_for_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertTrue(
check_eventfile_for_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# No checkpoint, so global step starts from 0.
test_runnable.global_step.assign(0)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10)

def test_no_checkpoint_and_summaries(self):
test_runnable = TestRunnable()
# No checkpoint + summary directories.
test_controller = controller.Controller(
train_fn=test_runnable.train,
eval_fn=test_runnable.evaluate,
global_step=test_runnable.global_step,
train_steps=10,
steps_per_loop=2,
eval_steps=2,
eval_interval=5)
test_controller.train(evaluate=True)
self.assertEqual(test_runnable.global_step.numpy(), 10)

@combinations.generate(all_strategy_combinations())
def test_train_and_evaluate(self, strategy):
with strategy.scope():
Expand Down

0 comments on commit 13d44a0

Please sign in to comment.