diff --git a/official/nlp/optimization.py b/official/nlp/optimization.py index a40968a50ad..47f85039dc7 100644 --- a/official/nlp/optimization.py +++ b/official/nlp/optimization.py @@ -140,19 +140,19 @@ def _decay_weights_op(self, var, learning_rate, apply_state): def apply_gradients(self, grads_and_vars, name=None, - all_reduce_sum_gradients=True): + experimental_aggregate_gradients=True): grads, tvars = list(zip(*grads_and_vars)) - if all_reduce_sum_gradients: - # when all_reduce_sum_gradients = False, apply_gradients() no longer - # implicitly allreduce gradients, users manually allreduce gradient and - # passed the allreduced grads_and_vars. For now, the clip_by_global_norm - # will be moved to before the explicit allreduce to keep the math - # the same as TF 1 and pre TF 2.2 implementation. + if experimental_aggregate_gradients: + # when experimental_aggregate_gradients = False, apply_gradients() no + # longer implicitly allreduce gradients, users manually allreduce gradient + # and passed the allreduced grads_and_vars. For now, the + # clip_by_global_norm will be moved to before the explicit allreduce to + # keep the math the same as TF 1 and pre TF 2.2 implementation. (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) return super(AdamWeightDecay, self).apply_gradients( zip(grads, tvars), name=name, - all_reduce_sum_gradients=all_reduce_sum_gradients) + experimental_aggregate_gradients=experimental_aggregate_gradients) def _get_lr(self, var_device, var_dtype, apply_state): """Retrieves the learning rate with the given state.""" diff --git a/official/staging/training/grad_utils.py b/official/staging/training/grad_utils.py index d195d2b556c..efda2e7616e 100644 --- a/official/staging/training/grad_utils.py +++ b/official/staging/training/grad_utils.py @@ -54,7 +54,7 @@ def _filter_and_allreduce_gradients(grads_and_vars, This utils function is used when users intent to explicitly allreduce gradients and customize gradients operations before and after allreduce. The allreduced gradients are then passed to optimizer.apply_gradients( - all_reduce_sum_gradients=False). + experimental_aggregate_gradients=False). Arguments: grads_and_vars: gradients and variables pairs. @@ -139,4 +139,5 @@ def minimize_using_explicit_allreduce(tape, grads_and_vars = zip(allreduced_grads, filtered_training_vars) if post_allreduce_callbacks: grads_and_vars = _run_callbacks(post_allreduce_callbacks, grads_and_vars) - optimizer.apply_gradients(grads_and_vars, all_reduce_sum_gradients=False) + optimizer.apply_gradients( + grads_and_vars, experimental_aggregate_gradients=False)