Skip to content

Commit

Permalink
Support Lamb optimizer within BERT.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 303356961
  • Loading branch information
tensorflower-gardener committed Mar 27, 2020
1 parent 0265f59 commit da5860f
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 11 deletions.
2 changes: 2 additions & 0 deletions official/nlp/bert/common_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def define_common_bert_flags():
'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.')
flags.DEFINE_string('optimizer_type', 'adamw',
'The type of optimizer to use for training (adamw|lamb)')

flags_core.define_log_steps()

Expand Down
3 changes: 2 additions & 1 deletion official/nlp/bert/run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def _get_classifier_model():
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable))
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
Expand Down
3 changes: 2 additions & 1 deletion official/nlp/bert/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def _get_pretrain_model():
pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq)
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.optimizer_type)
pretrain_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
Expand Down
3 changes: 2 additions & 1 deletion official/nlp/bert/run_squad_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ def _get_squad_model():
hub_module_trainable=FLAGS.hub_module_trainable)
optimizer = optimization.create_optimizer(FLAGS.learning_rate,
steps_per_epoch * epochs,
warmup_steps)
warmup_steps,
FLAGS.optimizer_type)

squad_model.optimizer = performance.configure_optimizer(
optimizer,
Expand Down
34 changes: 26 additions & 8 deletions official/nlp/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

import re

from absl import logging
import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers


class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
Expand Down Expand Up @@ -65,7 +67,8 @@ def get_config(self):
}


def create_optimizer(init_lr, num_train_steps, num_warmup_steps):
def create_optimizer(init_lr, num_train_steps, num_warmup_steps,
optimizer_type='adamw'):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
Expand All @@ -76,13 +79,28 @@ def create_optimizer(init_lr, num_train_steps, num_warmup_steps):
learning_rate_fn = WarmUp(initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=num_warmup_steps)
optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['layer_norm', 'bias'])

if optimizer_type == 'adamw':
logging.info('using Adamw optimizer')
optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['layer_norm', 'bias'])
elif optimizer_type == 'lamb':
logging.info('using Lamb optimizer')
optimizer = tfa_optimizers.LAMB(
learning_rate=learning_rate_fn,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['layer_norm', 'bias'])
else:
raise ValueError('Unsupported optimizer type: ', optimizer_type)

return optimizer


Expand Down

0 comments on commit da5860f

Please sign in to comment.