From 3f2b034134a9e81dc590c65bb2b3809e4dc77809 Mon Sep 17 00:00:00 2001 From: Minhua Chen Date: Tue, 20 May 2025 00:21:21 -0700 Subject: [PATCH] Make iter persistent for AdagradW (#4147) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1228 Make iter persistent for AdagradW optimizer state saving. This is to avoid potential loss of the iter information when training is restarted. Reviewed By: q10 Differential Revision: D74717848 --- .../split_table_batched_embeddings_ops_training.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 4c06501815..26463db3ec 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1122,6 +1122,7 @@ def __init__( # noqa C901 self._max_counter_update_freq: int = -1 # Extract parameters from CounterBasedRegularizationDefinition or CowClipDefinition # which are passed as entries for OptimizerArgs + self._used_rowwise_adagrad_with_adagradw: bool = False if self._used_rowwise_adagrad_with_counter: if self.weight_decay_mode == WeightDecayMode.COUNTER: self._max_counter_update_freq = ( @@ -1131,6 +1132,9 @@ def __init__( # noqa C901 counter_based_regularization.counter_weight_decay_mode ) counter_halflife = counter_based_regularization.counter_halflife + self._used_rowwise_adagrad_with_adagradw = ( + opt_arg_weight_decay_mode == CounterWeightDecayMode.ADAGRADW + ) else: opt_arg_weight_decay_mode = ( cowclip_regularization.counter_weight_decay_mode @@ -1359,6 +1363,7 @@ def __init__( # noqa C901 OptimType.EMAINPLACE_ROWWISE_ADAGRAD, ) or self._used_rowwise_adagrad_with_global_weight_decay + or self._used_rowwise_adagrad_with_adagradw ): self.register_buffer( "iter", @@ -2766,10 +2771,7 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]: "row_counter": states[2], "iter": self.iter, } - if self.optimizer_args.regularization_mode - == WeightDecayMode.COUNTER.value - and self.optimizer_args.weight_decay_mode - == CounterWeightDecayMode.ADAGRADW.value + if self._used_rowwise_adagrad_with_adagradw else { "sum": states[0], "prev_iter": states[1],