@@ -473,6 +473,9 @@ def generate_ssd_tbes(
473
473
weighted : bool ,
474
474
lr : float = 0.01 , # from SSDTableBatchedEmbeddingBags
475
475
eps : float = 1.0e-8 , # from SSDTableBatchedEmbeddingBags
476
+ weight_decay : float = 0.0 , # used by LARS-SGD, LAMB, ADAM, and Rowwise Adagrad
477
+ beta1 : float = 0.9 , # used by Partial Rowwise Adam
478
+ beta2 : float = 0.999 , # used by Partial Rowwise Adam
476
479
ssd_shards : int = 1 , # from SSDTableBatchedEmbeddingBags
477
480
optimizer : OptimType = OptimType .EXACT_ROWWISE_ADAGRAD ,
478
481
cache_set_scale : float = 1.0 ,
@@ -487,6 +490,7 @@ def generate_ssd_tbes(
487
490
lazy_bulk_init_enabled : bool = False ,
488
491
backend_type : BackendType = BackendType .SSD ,
489
492
enable_raw_embedding_streaming : bool = False ,
493
+ optimizer_state_dtypes : Dict [str , SparseType ] = {}, # noqa: B006
490
494
) -> Tuple [SSDTableBatchedEmbeddingBags , List [torch .nn .EmbeddingBag ]]:
491
495
"""
492
496
Generate embedding modules (i,e., SSDTableBatchedEmbeddingBags and
@@ -561,6 +565,9 @@ def generate_ssd_tbes(
561
565
ssd_uniform_init_upper = 0.1 ,
562
566
learning_rate = lr ,
563
567
eps = eps ,
568
+ weight_decay = weight_decay ,
569
+ beta1 = beta1 ,
570
+ beta2 = beta2 ,
564
571
ssd_rocksdb_shards = ssd_shards ,
565
572
optimizer = optimizer ,
566
573
pooling_mode = pooling_mode ,
@@ -796,6 +803,16 @@ def split_optimizer_states_(
796
803
bucket_asc_ids_list , no_snapshot = False , should_flush = True
797
804
)
798
805
806
+ def assert_close_ (self , test : torch .Tensor , ref : torch .Tensor ) -> None :
807
+ tolerance = 1.0e-4 if test .dtype == torch .float else 1.0e-2
808
+
809
+ torch .testing .assert_close (
810
+ test .float ().cpu (),
811
+ ref .float ().cpu (),
812
+ atol = tolerance ,
813
+ rtol = tolerance ,
814
+ )
815
+
799
816
@given (
800
817
** default_st , backend_type = st .sampled_from ([BackendType .SSD , BackendType .DRAM ])
801
818
)
@@ -1014,6 +1031,165 @@ def test_ssd_backward_adagrad(
1014
1031
rtol = tolerance ,
1015
1032
)
1016
1033
1034
+ @given (
1035
+ ** default_st ,
1036
+ backend_type = st .sampled_from ([BackendType .SSD , BackendType .DRAM ]),
1037
+ )
1038
+ @settings (verbosity = Verbosity .verbose , max_examples = MAX_EXAMPLES , deadline = None )
1039
+ def test_ssd_backward_partial_rowwise_adam (
1040
+ self ,
1041
+ T : int ,
1042
+ D : int ,
1043
+ B : int ,
1044
+ log_E : int ,
1045
+ L : int ,
1046
+ weighted : bool ,
1047
+ cache_set_scale : float ,
1048
+ pooling_mode : PoolingMode ,
1049
+ weights_precision : SparseType ,
1050
+ output_dtype : SparseType ,
1051
+ share_table : bool ,
1052
+ trigger_bounds_check : bool ,
1053
+ mixed_B : bool ,
1054
+ backend_type : BackendType ,
1055
+ ) -> None :
1056
+ assume (not weighted or pooling_mode == PoolingMode .SUM )
1057
+ # VBE is currently not supported for PARTIAL_ROWWISE_ADAM optimizer
1058
+ assume (not mixed_B )
1059
+
1060
+ # Constants
1061
+ lr = 0.5
1062
+ eps = 0.2
1063
+ ssd_shards = 2
1064
+ beta1 = 0.9
1065
+ beta2 = 0.99
1066
+ weight_decay = 0.01
1067
+
1068
+ # Generate embedding modules and inputs
1069
+ (
1070
+ emb ,
1071
+ emb_ref ,
1072
+ ) = self .generate_ssd_tbes (
1073
+ T ,
1074
+ D ,
1075
+ B ,
1076
+ log_E ,
1077
+ L ,
1078
+ weighted ,
1079
+ lr = lr ,
1080
+ eps = eps ,
1081
+ weight_decay = weight_decay ,
1082
+ beta1 = beta1 ,
1083
+ beta2 = beta2 ,
1084
+ ssd_shards = ssd_shards ,
1085
+ optimizer = OptimType .PARTIAL_ROWWISE_ADAM ,
1086
+ cache_set_scale = cache_set_scale ,
1087
+ pooling_mode = pooling_mode ,
1088
+ weights_precision = weights_precision ,
1089
+ output_dtype = output_dtype ,
1090
+ share_table = share_table ,
1091
+ backend_type = backend_type ,
1092
+ )
1093
+
1094
+ Es = [emb .embedding_specs [t ][0 ] for t in range (T )]
1095
+ (
1096
+ indices_list ,
1097
+ per_sample_weights_list ,
1098
+ indices ,
1099
+ offsets ,
1100
+ per_sample_weights ,
1101
+ batch_size_per_feature_per_rank ,
1102
+ ) = self .generate_inputs_ (
1103
+ B ,
1104
+ L ,
1105
+ Es ,
1106
+ emb .feature_table_map ,
1107
+ weights_precision = weights_precision ,
1108
+ trigger_bounds_check = trigger_bounds_check ,
1109
+ # mixed_B=mixed_B,
1110
+ )
1111
+
1112
+ # Execute forward
1113
+ output_ref_list , output = self .execute_ssd_forward_ (
1114
+ emb ,
1115
+ emb_ref ,
1116
+ indices_list ,
1117
+ per_sample_weights_list ,
1118
+ indices ,
1119
+ offsets ,
1120
+ per_sample_weights ,
1121
+ B ,
1122
+ L ,
1123
+ weighted ,
1124
+ batch_size_per_feature_per_rank = batch_size_per_feature_per_rank ,
1125
+ )
1126
+
1127
+ # Execute backward
1128
+ self .execute_ssd_backward_ (
1129
+ output_ref_list ,
1130
+ output ,
1131
+ B ,
1132
+ D ,
1133
+ pooling_mode ,
1134
+ batch_size_per_feature_per_rank ,
1135
+ )
1136
+
1137
+ emb .flush ()
1138
+
1139
+ tolerance = 1.0e-2
1140
+
1141
+ emb_test_weights = emb .debug_split_embedding_weights ()
1142
+ split_optimizer_states = self .split_optimizer_states_ (emb )
1143
+
1144
+ for _ , t in self .get_physical_table_arg_indices_ (emb .feature_table_map ):
1145
+ (m1 , m2 ) = split_optimizer_states [t ]
1146
+ print (f"{ t = } { m1 .shape = } , { m2 .shape = } " )
1147
+
1148
+ for f , t in self .get_physical_table_arg_indices_ (emb .feature_table_map ):
1149
+ (m1 , m2 ) = split_optimizer_states [t ]
1150
+ # Some optimizers have non-float momentum values
1151
+ # pyre-ignore[16]
1152
+ ref_grad = emb_ref [f ].weight .grad .cpu ().to_dense ()
1153
+ ref_weights = emb_ref [f ].weight .cpu ()
1154
+
1155
+ # Compare momentum2 values: (1 - beta2) * dL^2
1156
+ m2_ref = (ref_grad .pow (2 ).mean (dim = 1 )) * (1.0 - beta2 )
1157
+ self .assert_close_ (m2 , m2_ref )
1158
+
1159
+ # Compare momentum1 values: (1 - beta1) * dL
1160
+ m1_ref = ref_grad * (1.0 - beta1 )
1161
+ print (f"{ m1_ref .shape = } , { m1 .shape = } " )
1162
+ self .assert_close_ (m1 , m1_ref )
1163
+
1164
+ # Bias corrections
1165
+ iter_ = emb .iter .item ()
1166
+ v_hat_t = m2_ref / (1 - beta2 ** iter_ )
1167
+ v_hat_t = v_hat_t .view (v_hat_t .numel (), 1 )
1168
+ m_hat_t = m1_ref / (1 - beta1 ** iter_ )
1169
+
1170
+ # Weight update
1171
+ ref_weights_updated = (
1172
+ torch .addcdiv (
1173
+ ref_weights ,
1174
+ value = - lr ,
1175
+ tensor1 = m_hat_t ,
1176
+ tensor2 = v_hat_t .sqrt_ ().add_ (eps ),
1177
+ )
1178
+ - lr * weight_decay * ref_weights
1179
+ )
1180
+
1181
+ if weights_precision == SparseType .FP16 :
1182
+ # Round the reference weight the same way that TBE does
1183
+ ref_weights_updated = ref_weights_updated .half ().float ()
1184
+
1185
+ # Compare weights
1186
+ torch .testing .assert_close (
1187
+ emb_test_weights [t ].float ().cuda (),
1188
+ ref_weights_updated .cuda (),
1189
+ atol = tolerance ,
1190
+ rtol = tolerance ,
1191
+ )
1192
+
1017
1193
@given (
1018
1194
bulk_init_chunk_size = st .sampled_from ([0 , 204800 ]),
1019
1195
lazy_bulk_init_enabled = st .booleans (),
0 commit comments