diff --git a/official/nlp/optimization.py b/official/nlp/optimization.py index cfbf211d9e4..8c0686dbe95 100644 --- a/official/nlp/optimization.py +++ b/official/nlp/optimization.py @@ -28,13 +28,12 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): """Applies a warmup schedule on a given learning rate decay schedule.""" - def __init__( - self, - initial_learning_rate, - decay_schedule_fn, - warmup_steps, - power=1.0, - name=None): + def __init__(self, + initial_learning_rate, + decay_schedule_fn, + warmup_steps, + power=1.0, + name=None): super(WarmUp, self).__init__() self.initial_learning_rate = initial_learning_rate self.warmup_steps = warmup_steps @@ -52,10 +51,11 @@ def __call__(self, step): warmup_learning_rate = ( self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power)) - return tf.cond(global_step_float < warmup_steps_float, - lambda: warmup_learning_rate, - lambda: self.decay_schedule_fn(step), - name=name) + return tf.cond( + global_step_float < warmup_steps_float, + lambda: warmup_learning_rate, + lambda: self.decay_schedule_fn(step), + name=name) def get_config(self): return { @@ -67,23 +67,26 @@ def get_config(self): } -def create_optimizer(init_lr, num_train_steps, num_warmup_steps, +def create_optimizer(init_lr, + num_train_steps, + num_warmup_steps, optimizer_type='adamw'): """Creates an optimizer with learning rate schedule.""" # Implements linear decay of the learning rate. - learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( + lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=init_lr, decay_steps=num_train_steps, end_learning_rate=0.0) if num_warmup_steps: - learning_rate_fn = WarmUp(initial_learning_rate=init_lr, - decay_schedule_fn=learning_rate_fn, - warmup_steps=num_warmup_steps) + lr_schedule = WarmUp( + initial_learning_rate=init_lr, + decay_schedule_fn=lr_schedule, + warmup_steps=num_warmup_steps) if optimizer_type == 'adamw': logging.info('using Adamw optimizer') optimizer = AdamWeightDecay( - learning_rate=learning_rate_fn, + learning_rate=lr_schedule, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, @@ -92,7 +95,7 @@ def create_optimizer(init_lr, num_train_steps, num_warmup_steps, elif optimizer_type == 'lamb': logging.info('using Lamb optimizer') optimizer = tfa_optimizers.LAMB( - learning_rate=learning_rate_fn, + learning_rate=lr_schedule, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, @@ -127,8 +130,8 @@ def __init__(self, exclude_from_weight_decay=None, name='AdamWeightDecay', **kwargs): - super(AdamWeightDecay, self).__init__( - learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs) + super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2, + epsilon, amsgrad, name, **kwargs) self.weight_decay_rate = weight_decay_rate self._include_in_weight_decay = include_in_weight_decay self._exclude_from_weight_decay = exclude_from_weight_decay @@ -189,15 +192,15 @@ def _resource_apply_dense(self, grad, var, apply_state=None): lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) decay = self._decay_weights_op(var, lr_t, apply_state) with tf.control_dependencies([decay]): - return super(AdamWeightDecay, self)._resource_apply_dense( - grad, var, **kwargs) + return super(AdamWeightDecay, + self)._resource_apply_dense(grad, var, **kwargs) def _resource_apply_sparse(self, grad, var, indices, apply_state=None): lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) decay = self._decay_weights_op(var, lr_t, apply_state) with tf.control_dependencies([decay]): - return super(AdamWeightDecay, self)._resource_apply_sparse( - grad, var, indices, **kwargs) + return super(AdamWeightDecay, + self)._resource_apply_sparse(grad, var, indices, **kwargs) def get_config(self): config = super(AdamWeightDecay, self).get_config()