Skip to content

Commit 9cb160a

Browse files
minhua-chenfacebook-github-bot
authored andcommitted
Make iter persistent for AdagradW (#4147)
Summary: Pull Request resolved: #4147 X-link: facebookresearch/FBGEMM#1228 Make iter persistent for AdagradW optimizer state saving. This is to avoid potential loss of the iter information when training is restarted. Reviewed By: q10 Differential Revision: D74717848
1 parent 157e88b commit 9cb160a

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,6 +1122,7 @@ def __init__( # noqa C901
11221122
self._max_counter_update_freq: int = -1
11231123
# Extract parameters from CounterBasedRegularizationDefinition or CowClipDefinition
11241124
# which are passed as entries for OptimizerArgs
1125+
self._used_rowwise_adagrad_with_adagradw: bool = False
11251126
if self._used_rowwise_adagrad_with_counter:
11261127
if self.weight_decay_mode == WeightDecayMode.COUNTER:
11271128
self._max_counter_update_freq = (
@@ -1131,6 +1132,9 @@ def __init__( # noqa C901
11311132
counter_based_regularization.counter_weight_decay_mode
11321133
)
11331134
counter_halflife = counter_based_regularization.counter_halflife
1135+
self._used_rowwise_adagrad_with_adagradw = (
1136+
opt_arg_weight_decay_mode == CounterWeightDecayMode.ADAGRADW
1137+
)
11341138
else:
11351139
opt_arg_weight_decay_mode = (
11361140
cowclip_regularization.counter_weight_decay_mode
@@ -1359,6 +1363,7 @@ def __init__( # noqa C901
13591363
OptimType.EMAINPLACE_ROWWISE_ADAGRAD,
13601364
)
13611365
or self._used_rowwise_adagrad_with_global_weight_decay
1366+
or self._used_rowwise_adagrad_with_adagradw
13621367
):
13631368
self.register_buffer(
13641369
"iter",
@@ -2766,10 +2771,7 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]:
27662771
"row_counter": states[2],
27672772
"iter": self.iter,
27682773
}
2769-
if self.optimizer_args.regularization_mode
2770-
== WeightDecayMode.COUNTER.value
2771-
and self.optimizer_args.weight_decay_mode
2772-
== CounterWeightDecayMode.ADAGRADW.value
2774+
if self._used_rowwise_adagrad_with_adagradw
27732775
else {
27742776
"sum": states[0],
27752777
"prev_iter": states[1],

0 commit comments

Comments
 (0)