Skip to content

Commit

Permalink
lr_schedule is a more reasonable name. It is a callable, not a function.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 303840099
  • Loading branch information
saberkun authored and tensorflower-gardener committed Mar 30, 2020
1 parent 39f05f0 commit 1aaab0b
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions official/nlp/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applies a warmup schedule on a given learning rate decay schedule."""

def __init__(
self,
initial_learning_rate,
decay_schedule_fn,
warmup_steps,
power=1.0,
name=None):
def __init__(self,
initial_learning_rate,
decay_schedule_fn,
warmup_steps,
power=1.0,
name=None):
super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
Expand All @@ -52,10 +51,11 @@ def __call__(self, step):
warmup_learning_rate = (
self.initial_learning_rate *
tf.math.pow(warmup_percent_done, self.power))
return tf.cond(global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step),
name=name)
return tf.cond(
global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step),
name=name)

def get_config(self):
return {
Expand All @@ -67,23 +67,26 @@ def get_config(self):
}


def create_optimizer(init_lr, num_train_steps, num_warmup_steps,
def create_optimizer(init_lr,
num_train_steps,
num_warmup_steps,
optimizer_type='adamw'):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr,
decay_steps=num_train_steps,
end_learning_rate=0.0)
if num_warmup_steps:
learning_rate_fn = WarmUp(initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=num_warmup_steps)
lr_schedule = WarmUp(
initial_learning_rate=init_lr,
decay_schedule_fn=lr_schedule,
warmup_steps=num_warmup_steps)

if optimizer_type == 'adamw':
logging.info('using Adamw optimizer')
optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn,
learning_rate=lr_schedule,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
Expand All @@ -92,7 +95,7 @@ def create_optimizer(init_lr, num_train_steps, num_warmup_steps,
elif optimizer_type == 'lamb':
logging.info('using Lamb optimizer')
optimizer = tfa_optimizers.LAMB(
learning_rate=learning_rate_fn,
learning_rate=lr_schedule,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
Expand Down Expand Up @@ -127,8 +130,8 @@ def __init__(self,
exclude_from_weight_decay=None,
name='AdamWeightDecay',
**kwargs):
super(AdamWeightDecay, self).__init__(
learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2,
epsilon, amsgrad, name, **kwargs)
self.weight_decay_rate = weight_decay_rate
self._include_in_weight_decay = include_in_weight_decay
self._exclude_from_weight_decay = exclude_from_weight_decay
Expand Down Expand Up @@ -189,15 +192,15 @@ def _resource_apply_dense(self, grad, var, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_dense(
grad, var, **kwargs)
return super(AdamWeightDecay,
self)._resource_apply_dense(grad, var, **kwargs)

def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_sparse(
grad, var, indices, **kwargs)
return super(AdamWeightDecay,
self)._resource_apply_sparse(grad, var, indices, **kwargs)

def get_config(self):
config = super(AdamWeightDecay, self).get_config()
Expand Down

0 comments on commit 1aaab0b

Please sign in to comment.