Skip to content

Commit c27f02f

Browse files
q10facebook-github-bot
authored andcommitted
Enable Partial Rowwise Adam in SSD TBE (#4525)
Summary: Pull Request resolved: #4525 X-link: facebookresearch/FBGEMM#1572 This diff enables Partial Rowwise Adam in SSD TBE, though just for the non KV-ZCH use case. Reviewed By: emlin Differential Revision: D77880362
1 parent fa50579 commit c27f02f

File tree

2 files changed

+196
-1
lines changed

2 files changed

+196
-1
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(
180180
# Set the optimizer
181181
assert optimizer in (
182182
OptimType.EXACT_ROWWISE_ADAGRAD,
183+
OptimType.PARTIAL_ROWWISE_ADAM,
183184
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
184185
self.optimizer = optimizer
185186

@@ -2205,7 +2206,7 @@ def forward(
22052206
self.step += 1
22062207

22072208
# Increment the iteration (value is used for certain optimizers)
2208-
self._increment_iteration()
2209+
iter_int = self._increment_iteration()
22092210

22102211
if self.optimizer == OptimType.EXACT_SGD:
22112212
raise AssertionError(
@@ -2226,6 +2227,24 @@ def forward(
22262227
common_args, self.optimizer_args, momentum1
22272228
)
22282229

2230+
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
2231+
momentum2 = invokers.lookup_args_ssd.Momentum(
2232+
# pyre-ignore[6]
2233+
dev=self.momentum2_dev,
2234+
# pyre-ignore[6]
2235+
host=self.momentum2_host,
2236+
# pyre-ignore[6]
2237+
uvm=self.momentum2_uvm,
2238+
# pyre-ignore[6]
2239+
offsets=self.momentum2_offsets,
2240+
# pyre-ignore[6]
2241+
placements=self.momentum2_placements,
2242+
)
2243+
2244+
return invokers.lookup_partial_rowwise_adam_ssd.invoke(
2245+
common_args, self.optimizer_args, momentum1, momentum2, iter_int
2246+
)
2247+
22292248
@torch.jit.ignore
22302249
def _split_optimizer_states_non_kv_zch(
22312250
self,

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,9 @@ def generate_ssd_tbes(
473473
weighted: bool,
474474
lr: float = 0.01, # from SSDTableBatchedEmbeddingBags
475475
eps: float = 1.0e-8, # from SSDTableBatchedEmbeddingBags
476+
weight_decay: float = 0.0, # used by LARS-SGD, LAMB, ADAM, and Rowwise Adagrad
477+
beta1: float = 0.9, # used by Partial Rowwise Adam
478+
beta2: float = 0.999, # used by Partial Rowwise Adam
476479
ssd_shards: int = 1, # from SSDTableBatchedEmbeddingBags
477480
optimizer: OptimType = OptimType.EXACT_ROWWISE_ADAGRAD,
478481
cache_set_scale: float = 1.0,
@@ -487,6 +490,7 @@ def generate_ssd_tbes(
487490
lazy_bulk_init_enabled: bool = False,
488491
backend_type: BackendType = BackendType.SSD,
489492
enable_raw_embedding_streaming: bool = False,
493+
optimizer_state_dtypes: Dict[str, SparseType] = {}, # noqa: B006
490494
) -> Tuple[SSDTableBatchedEmbeddingBags, List[torch.nn.EmbeddingBag]]:
491495
"""
492496
Generate embedding modules (i,e., SSDTableBatchedEmbeddingBags and
@@ -561,6 +565,9 @@ def generate_ssd_tbes(
561565
ssd_uniform_init_upper=0.1,
562566
learning_rate=lr,
563567
eps=eps,
568+
weight_decay=weight_decay,
569+
beta1=beta1,
570+
beta2=beta2,
564571
ssd_rocksdb_shards=ssd_shards,
565572
optimizer=optimizer,
566573
pooling_mode=pooling_mode,
@@ -796,6 +803,16 @@ def split_optimizer_states_(
796803
bucket_asc_ids_list, no_snapshot=False, should_flush=True
797804
)
798805

806+
def assert_close_(self, test: torch.Tensor, ref: torch.Tensor) -> None:
807+
tolerance = 1.0e-4 if test.dtype == torch.float else 1.0e-2
808+
809+
torch.testing.assert_close(
810+
test.float().cpu(),
811+
ref.float().cpu(),
812+
atol=tolerance,
813+
rtol=tolerance,
814+
)
815+
799816
@given(
800817
**default_st, backend_type=st.sampled_from([BackendType.SSD, BackendType.DRAM])
801818
)
@@ -1014,6 +1031,165 @@ def test_ssd_backward_adagrad(
10141031
rtol=tolerance,
10151032
)
10161033

1034+
@given(
1035+
**default_st,
1036+
backend_type=st.sampled_from([BackendType.SSD, BackendType.DRAM]),
1037+
)
1038+
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
1039+
def test_ssd_backward_partial_rowwise_adam(
1040+
self,
1041+
T: int,
1042+
D: int,
1043+
B: int,
1044+
log_E: int,
1045+
L: int,
1046+
weighted: bool,
1047+
cache_set_scale: float,
1048+
pooling_mode: PoolingMode,
1049+
weights_precision: SparseType,
1050+
output_dtype: SparseType,
1051+
share_table: bool,
1052+
trigger_bounds_check: bool,
1053+
mixed_B: bool,
1054+
backend_type: BackendType,
1055+
) -> None:
1056+
assume(not weighted or pooling_mode == PoolingMode.SUM)
1057+
# VBE is currently not supported for PARTIAL_ROWWISE_ADAM optimizer
1058+
assume(not mixed_B)
1059+
1060+
# Constants
1061+
lr = 0.5
1062+
eps = 0.2
1063+
ssd_shards = 2
1064+
beta1 = 0.9
1065+
beta2 = 0.99
1066+
weight_decay = 0.01
1067+
1068+
# Generate embedding modules and inputs
1069+
(
1070+
emb,
1071+
emb_ref,
1072+
) = self.generate_ssd_tbes(
1073+
T,
1074+
D,
1075+
B,
1076+
log_E,
1077+
L,
1078+
weighted,
1079+
lr=lr,
1080+
eps=eps,
1081+
weight_decay=weight_decay,
1082+
beta1=beta1,
1083+
beta2=beta2,
1084+
ssd_shards=ssd_shards,
1085+
optimizer=OptimType.PARTIAL_ROWWISE_ADAM,
1086+
cache_set_scale=cache_set_scale,
1087+
pooling_mode=pooling_mode,
1088+
weights_precision=weights_precision,
1089+
output_dtype=output_dtype,
1090+
share_table=share_table,
1091+
backend_type=backend_type,
1092+
)
1093+
1094+
Es = [emb.embedding_specs[t][0] for t in range(T)]
1095+
(
1096+
indices_list,
1097+
per_sample_weights_list,
1098+
indices,
1099+
offsets,
1100+
per_sample_weights,
1101+
batch_size_per_feature_per_rank,
1102+
) = self.generate_inputs_(
1103+
B,
1104+
L,
1105+
Es,
1106+
emb.feature_table_map,
1107+
weights_precision=weights_precision,
1108+
trigger_bounds_check=trigger_bounds_check,
1109+
# mixed_B=mixed_B,
1110+
)
1111+
1112+
# Execute forward
1113+
output_ref_list, output = self.execute_ssd_forward_(
1114+
emb,
1115+
emb_ref,
1116+
indices_list,
1117+
per_sample_weights_list,
1118+
indices,
1119+
offsets,
1120+
per_sample_weights,
1121+
B,
1122+
L,
1123+
weighted,
1124+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
1125+
)
1126+
1127+
# Execute backward
1128+
self.execute_ssd_backward_(
1129+
output_ref_list,
1130+
output,
1131+
B,
1132+
D,
1133+
pooling_mode,
1134+
batch_size_per_feature_per_rank,
1135+
)
1136+
1137+
emb.flush()
1138+
1139+
tolerance = 1.0e-2
1140+
1141+
emb_test_weights = emb.debug_split_embedding_weights()
1142+
split_optimizer_states = self.split_optimizer_states_(emb)
1143+
1144+
for _, t in self.get_physical_table_arg_indices_(emb.feature_table_map):
1145+
(m1, m2) = split_optimizer_states[t]
1146+
print(f"{t=} {m1.shape=}, {m2.shape=}")
1147+
1148+
for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map):
1149+
(m1, m2) = split_optimizer_states[t]
1150+
# Some optimizers have non-float momentum values
1151+
# pyre-ignore[16]
1152+
ref_grad = emb_ref[f].weight.grad.cpu().to_dense()
1153+
ref_weights = emb_ref[f].weight.cpu()
1154+
1155+
# Compare momentum2 values: (1 - beta2) * dL^2
1156+
m2_ref = (ref_grad.pow(2).mean(dim=1)) * (1.0 - beta2)
1157+
self.assert_close_(m2, m2_ref)
1158+
1159+
# Compare momentum1 values: (1 - beta1) * dL
1160+
m1_ref = ref_grad * (1.0 - beta1)
1161+
print(f"{m1_ref.shape=}, {m1.shape=}")
1162+
self.assert_close_(m1, m1_ref)
1163+
1164+
# Bias corrections
1165+
iter_ = emb.iter.item()
1166+
v_hat_t = m2_ref / (1 - beta2**iter_)
1167+
v_hat_t = v_hat_t.view(v_hat_t.numel(), 1)
1168+
m_hat_t = m1_ref / (1 - beta1**iter_)
1169+
1170+
# Weight update
1171+
ref_weights_updated = (
1172+
torch.addcdiv(
1173+
ref_weights,
1174+
value=-lr,
1175+
tensor1=m_hat_t,
1176+
tensor2=v_hat_t.sqrt_().add_(eps),
1177+
)
1178+
- lr * weight_decay * ref_weights
1179+
)
1180+
1181+
if weights_precision == SparseType.FP16:
1182+
# Round the reference weight the same way that TBE does
1183+
ref_weights_updated = ref_weights_updated.half().float()
1184+
1185+
# Compare weights
1186+
torch.testing.assert_close(
1187+
emb_test_weights[t].float().cuda(),
1188+
ref_weights_updated.cuda(),
1189+
atol=tolerance,
1190+
rtol=tolerance,
1191+
)
1192+
10171193
@given(
10181194
bulk_init_chunk_size=st.sampled_from([0, 204800]),
10191195
lazy_bulk_init_enabled=st.booleans(),

0 commit comments

Comments
 (0)