From da5860f2bacbc2b9d875a4fef886e6aa8e4a0475 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 27 Mar 2020 10:25:26 -0700 Subject: [PATCH] Support Lamb optimizer within BERT. PiperOrigin-RevId: 303356961 --- official/nlp/bert/common_flags.py | 2 ++ official/nlp/bert/run_classifier.py | 3 ++- official/nlp/bert/run_pretraining.py | 3 ++- official/nlp/bert/run_squad_helper.py | 3 ++- official/nlp/optimization.py | 34 ++++++++++++++++++++------- 5 files changed, 34 insertions(+), 11 deletions(-) diff --git a/official/nlp/bert/common_flags.py b/official/nlp/bert/common_flags.py index a44b4090506..3102bb6cc1b 100644 --- a/official/nlp/bert/common_flags.py +++ b/official/nlp/bert/common_flags.py @@ -76,6 +76,8 @@ def define_common_bert_flags(): 'If specified, init_checkpoint flag should not be used.') flags.DEFINE_bool('hub_module_trainable', True, 'True to make keras layers in the hub module trainable.') + flags.DEFINE_string('optimizer_type', 'adamw', + 'The type of optimizer to use for training (adamw|lamb)') flags_core.define_log_steps() diff --git a/official/nlp/bert/run_classifier.py b/official/nlp/bert/run_classifier.py index 35c4388609e..84f41decc95 100644 --- a/official/nlp/bert/run_classifier.py +++ b/official/nlp/bert/run_classifier.py @@ -125,7 +125,8 @@ def _get_classifier_model(): hub_module_url=FLAGS.hub_module_url, hub_module_trainable=FLAGS.hub_module_trainable)) optimizer = optimization.create_optimizer( - initial_lr, steps_per_epoch * epochs, warmup_steps) + initial_lr, steps_per_epoch * epochs, warmup_steps, + FLAGS.optimizer_type) classifier_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), diff --git a/official/nlp/bert/run_pretraining.py b/official/nlp/bert/run_pretraining.py index 822955d6006..5f31f692e19 100644 --- a/official/nlp/bert/run_pretraining.py +++ b/official/nlp/bert/run_pretraining.py @@ -105,7 +105,8 @@ def _get_pretrain_model(): pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq) optimizer = optimization.create_optimizer( - initial_lr, steps_per_epoch * epochs, warmup_steps) + initial_lr, steps_per_epoch * epochs, warmup_steps, + FLAGS.optimizer_type) pretrain_model.optimizer = performance.configure_optimizer( optimizer, use_float16=common_flags.use_float16(), diff --git a/official/nlp/bert/run_squad_helper.py b/official/nlp/bert/run_squad_helper.py index 52e11800b25..ec54b73fda5 100644 --- a/official/nlp/bert/run_squad_helper.py +++ b/official/nlp/bert/run_squad_helper.py @@ -244,7 +244,8 @@ def _get_squad_model(): hub_module_trainable=FLAGS.hub_module_trainable) optimizer = optimization.create_optimizer(FLAGS.learning_rate, steps_per_epoch * epochs, - warmup_steps) + warmup_steps, + FLAGS.optimizer_type) squad_model.optimizer = performance.configure_optimizer( optimizer, diff --git a/official/nlp/optimization.py b/official/nlp/optimization.py index 47f85039dc7..cfbf211d9e4 100644 --- a/official/nlp/optimization.py +++ b/official/nlp/optimization.py @@ -20,7 +20,9 @@ import re +from absl import logging import tensorflow as tf +import tensorflow_addons.optimizers as tfa_optimizers class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): @@ -65,7 +67,8 @@ def get_config(self): } -def create_optimizer(init_lr, num_train_steps, num_warmup_steps): +def create_optimizer(init_lr, num_train_steps, num_warmup_steps, + optimizer_type='adamw'): """Creates an optimizer with learning rate schedule.""" # Implements linear decay of the learning rate. learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( @@ -76,13 +79,28 @@ def create_optimizer(init_lr, num_train_steps, num_warmup_steps): learning_rate_fn = WarmUp(initial_learning_rate=init_lr, decay_schedule_fn=learning_rate_fn, warmup_steps=num_warmup_steps) - optimizer = AdamWeightDecay( - learning_rate=learning_rate_fn, - weight_decay_rate=0.01, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-6, - exclude_from_weight_decay=['layer_norm', 'bias']) + + if optimizer_type == 'adamw': + logging.info('using Adamw optimizer') + optimizer = AdamWeightDecay( + learning_rate=learning_rate_fn, + weight_decay_rate=0.01, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-6, + exclude_from_weight_decay=['layer_norm', 'bias']) + elif optimizer_type == 'lamb': + logging.info('using Lamb optimizer') + optimizer = tfa_optimizers.LAMB( + learning_rate=learning_rate_fn, + weight_decay_rate=0.01, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-6, + exclude_from_weight_decay=['layer_norm', 'bias']) + else: + raise ValueError('Unsupported optimizer type: ', optimizer_type) + return optimizer