Skip to content

Commit 16ce772

Browse files
linefbfacebook-github-bot
authored andcommitted
implement optimizer state with opt offloading (#4141)
Summary: Pull Request resolved: #4141 X-link: facebookresearch/FBGEMM#1224 implement split_optimizer_states for optimizer state dict integration Differential Revision: D74790121 Reviewed By: duduyi2013, bobbyliujb
1 parent d68987c commit 16ce772

2 files changed

Lines changed: 320 additions & 0 deletions

File tree

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2007,6 +2007,99 @@ def split_optimizer_states(
20072007
)
20082008
return opt_list
20092009

2010+
opt_list = []
2011+
table_offset = 0
2012+
2013+
dtype = self.weights_precision.as_dtype()
2014+
optimizer_dim = self.optimizer.state_size_dim(dtype)
2015+
pad4_optimizer_dim = pad4(optimizer_dim)
2016+
logging.info(
2017+
f"split_optimizer_states: {optimizer_dim=} {pad4_optimizer_dim=} {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}"
2018+
)
2019+
2020+
for t, (emb_height, emb_dim) in enumerate(self.embedding_specs):
2021+
# pyre-ignore
2022+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
2023+
# pyre-ignore
2024+
bucket_size = self.kv_zch_params.bucket_sizes[t]
2025+
row_offset = table_offset
2026+
if sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0:
2027+
opt_list.append(
2028+
torch.empty(0, dtype=self.optimizer.dtype(), device="cpu")
2029+
# empty optimizer state for module initialization
2030+
)
2031+
else:
2032+
if not self.enable_optimizer_offloading:
2033+
# convert global id back to local id, then linearize with table offset
2034+
local_id_tensor = (
2035+
sorted_id_tensor[t]
2036+
- bucket_id_start * bucket_size
2037+
+ table_offset
2038+
)
2039+
opt_list.append(
2040+
self.momentum1_dev.detach().cpu()[local_id_tensor].view(-1),
2041+
)
2042+
else:
2043+
emb_opt_dim = pad4(emb_dim) + pad4_optimizer_dim
2044+
row_offset = table_offset - (bucket_id_start * bucket_size)
2045+
# using KVTensorWrapper to query backend to avoid OOM memory, since
2046+
# backend will return both weight and optimizer in one tensor, read the whole tensor
2047+
# out could OOM CPU memory.
2048+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2049+
shape=[emb_height, emb_opt_dim],
2050+
dtype=dtype,
2051+
row_offset=row_offset,
2052+
snapshot_handle=snapshot_handle,
2053+
materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]),
2054+
sorted_indices=sorted_id_tensor[t],
2055+
)
2056+
(
2057+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2058+
if self.backend_type == BackendType.SSD
2059+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2060+
)
2061+
opt_list.append(
2062+
self.get_offloaded_optimizer_states(
2063+
tensor_wrapper=tensor_wrapper,
2064+
row=sorted_id_tensor[t].size(
2065+
0
2066+
), # we only need to copy the size of sorted_id_tensor
2067+
optimizer_dim=optimizer_dim,
2068+
start_dim_pos=pad4(emb_dim),
2069+
)
2070+
)
2071+
table_offset += emb_height
2072+
logging.info(
2073+
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms"
2074+
)
2075+
return opt_list
2076+
2077+
@torch.jit.export
2078+
def get_offloaded_optimizer_states(
2079+
self,
2080+
# pyre-ignore [2]
2081+
tensor_wrapper,
2082+
row: int,
2083+
optimizer_dim: int,
2084+
start_dim_pos: int,
2085+
) -> torch.Tensor:
2086+
weight_dtype = self.weights_precision.as_dtype()
2087+
opt_state_t = torch.empty(
2088+
row, optimizer_dim, dtype=weight_dtype, device="cpu"
2089+
) # 1D optimizer for OptimType.EXACT_ROWWISE_ADAGRAD
2090+
2091+
# pyre-ignore [16]
2092+
chunk_size = self.kv_zch_params.streaming_ckpt_chunk_size
2093+
for i in range(0, row, chunk_size):
2094+
length = min(chunk_size, row - i)
2095+
opt_state_t.narrow(0, i, length).copy_(
2096+
tensor_wrapper.narrow(0, i, length).narrow(
2097+
1, start_dim_pos, optimizer_dim
2098+
)
2099+
)
2100+
# view optimizer state back to correct dtype
2101+
return opt_state_t.view(-1).view(self.optimizer.dtype())
2102+
20102103
@torch.jit.export
20112104
def get_optimizer_state(
20122105
self,

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2007,3 +2007,230 @@ def test_kv_opt_state_w_offloading(
20072007
atol=tolerance,
20082008
rtol=tolerance,
20092009
)
2010+
2011+
@given(
2012+
**default_st,
2013+
num_buckets=st.integers(min_value=10, max_value=15),
2014+
)
2015+
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
2016+
def test_kv_opt_state_w_offloading(
2017+
self,
2018+
T: int,
2019+
D: int,
2020+
B: int,
2021+
log_E: int,
2022+
L: int,
2023+
weighted: bool,
2024+
cache_set_scale: float,
2025+
pooling_mode: PoolingMode,
2026+
weights_precision: SparseType,
2027+
output_dtype: SparseType,
2028+
share_table: bool,
2029+
trigger_bounds_check: bool,
2030+
mixed_B: bool,
2031+
num_buckets: int,
2032+
) -> None:
2033+
# Constants
2034+
lr = 0.5
2035+
eps = 0.2
2036+
ssd_shards = 2
2037+
2038+
trigger_bounds_check = False # don't stimulate boundary check cases
2039+
assume(not weighted or pooling_mode == PoolingMode.SUM)
2040+
assume(not mixed_B or pooling_mode != PoolingMode.NONE)
2041+
2042+
# TODO: check split_optimizer_states when optimizer offloading is ready
2043+
# Generate embedding modules and inputs
2044+
(
2045+
emb,
2046+
emb_ref,
2047+
Es,
2048+
_,
2049+
bucket_offsets,
2050+
bucket_sizes,
2051+
) = self.generate_kvzch_tbes(
2052+
T,
2053+
D,
2054+
B,
2055+
log_E,
2056+
L,
2057+
weighted,
2058+
lr=lr,
2059+
eps=eps,
2060+
ssd_shards=ssd_shards,
2061+
cache_set_scale=cache_set_scale,
2062+
pooling_mode=pooling_mode,
2063+
weights_precision=weights_precision,
2064+
output_dtype=output_dtype,
2065+
share_table=share_table,
2066+
num_buckets=num_buckets,
2067+
enable_optimizer_offloading=False,
2068+
)
2069+
2070+
# Generate inputs
2071+
(
2072+
indices_list,
2073+
per_sample_weights_list,
2074+
indices,
2075+
offsets,
2076+
per_sample_weights,
2077+
batch_size_per_feature_per_rank,
2078+
) = self.generate_inputs_(
2079+
B,
2080+
L,
2081+
Es,
2082+
emb.feature_table_map,
2083+
weights_precision=weights_precision,
2084+
trigger_bounds_check=trigger_bounds_check,
2085+
mixed_B=mixed_B,
2086+
bucket_offsets=bucket_offsets,
2087+
bucket_sizes=bucket_sizes,
2088+
is_kv_tbes=True,
2089+
)
2090+
2091+
# Execute forward
2092+
output_ref_list, output = self.execute_ssd_forward_(
2093+
emb,
2094+
emb_ref,
2095+
indices_list,
2096+
per_sample_weights_list,
2097+
indices,
2098+
offsets,
2099+
per_sample_weights,
2100+
B,
2101+
L,
2102+
weighted,
2103+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
2104+
)
2105+
2106+
# Generate output gradient
2107+
output_grad_list = [torch.randn_like(out) for out in output_ref_list]
2108+
2109+
# Execute torch EmbeddingBag backward
2110+
[out.backward(grad) for (out, grad) in zip(output_ref_list, output_grad_list)]
2111+
if batch_size_per_feature_per_rank is not None:
2112+
grad_test = self.concat_ref_tensors_vbe(
2113+
output_grad_list, batch_size_per_feature_per_rank
2114+
)
2115+
else:
2116+
grad_test = self.concat_ref_tensors(
2117+
output_grad_list,
2118+
pooling_mode != PoolingMode.NONE, # do_pooling
2119+
B,
2120+
D * 4,
2121+
)
2122+
2123+
# Execute TBE SSD backward
2124+
output.backward(grad_test)
2125+
2126+
tolerance = (
2127+
1.0e-4
2128+
if weights_precision == SparseType.FP32 and output_dtype == SparseType.FP32
2129+
else 1.0e-2
2130+
)
2131+
2132+
emb.flush()
2133+
2134+
# Compare emb state dict with expected values from nn.EmbeddingBag
2135+
emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list = (
2136+
emb.split_embedding_weights(no_snapshot=False, should_flush=True)
2137+
)
2138+
split_optimizer_states = emb.split_optimizer_states(bucket_asc_ids_list)
2139+
table_input_id_range = []
2140+
for t, row in enumerate(Es):
2141+
bucket_id_start = bucket_offsets[t][0]
2142+
bucket_id_end = bucket_offsets[t][1]
2143+
bucket_size = bucket_sizes[t]
2144+
table_input_id_range.append(
2145+
(
2146+
min(bucket_id_start * bucket_size, row),
2147+
min(bucket_id_end * bucket_size, row),
2148+
)
2149+
)
2150+
# since we use ref_emb in dense format, the rows start from id 0
2151+
self.assertEqual(table_input_id_range[-1][0], 0)
2152+
2153+
# Compare optimizer states
2154+
for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map):
2155+
# pyre-fixme[16]: Optional type has no attribute `float`.
2156+
ref_emb = emb_ref[f].weight.grad.float().to_dense().pow(2).cpu()
2157+
ref_optimizer_state = ref_emb.mean(dim=1)[
2158+
table_input_id_range[t][0] : min(
2159+
table_input_id_range[t][1], emb_ref[f].weight.size(0)
2160+
)
2161+
]
2162+
ref_kv_opt = ref_optimizer_state[bucket_asc_ids_list[t]].view(-1)
2163+
torch.testing.assert_close(
2164+
split_optimizer_states[t].float(),
2165+
ref_kv_opt,
2166+
atol=tolerance,
2167+
rtol=tolerance,
2168+
)
2169+
2170+
for feature_index, table_index in self.get_physical_table_arg_indices_(
2171+
emb.feature_table_map
2172+
):
2173+
"""
2174+
validate bucket_asc_ids_list and num_active_id_per_bucket_list
2175+
"""
2176+
bucket_asc_id = bucket_asc_ids_list[table_index]
2177+
num_active_id_per_bucket = num_active_id_per_bucket_list[table_index]
2178+
2179+
bucket_id_start = bucket_offsets[table_index][0]
2180+
bucket_id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
2181+
num_active_id_per_bucket.view(-1)
2182+
)
2183+
for bucket_idx, id_count in enumerate(num_active_id_per_bucket):
2184+
bucket_id = bucket_idx + bucket_id_start
2185+
active_id_cnt = 0
2186+
for idx in range(
2187+
bucket_id_offsets[bucket_idx],
2188+
bucket_id_offsets[bucket_idx + 1],
2189+
):
2190+
# for chunk-based hashing
2191+
self.assertEqual(
2192+
bucket_id, bucket_asc_id[idx] // bucket_sizes[table_index]
2193+
)
2194+
active_id_cnt += 1
2195+
self.assertEqual(active_id_cnt, id_count)
2196+
2197+
"""
2198+
validate embeddings
2199+
"""
2200+
num_ids = len(bucket_asc_ids_list[table_index])
2201+
emb_r_w = emb_ref[feature_index].weight[
2202+
bucket_asc_ids_list[table_index].view(-1)
2203+
]
2204+
emb_r_w_g = (
2205+
emb_ref[feature_index]
2206+
.weight.grad.float()
2207+
.to_dense()[bucket_asc_ids_list[table_index].view(-1)]
2208+
)
2209+
self.assertLess(table_index, len(emb_state_dict_list))
2210+
assert len(split_optimizer_states[table_index]) == num_ids
2211+
opt = split_optimizer_states[table_index]
2212+
new_ref_weight = torch.addcdiv(
2213+
emb_r_w.float(),
2214+
value=-lr,
2215+
tensor1=emb_r_w_g,
2216+
tensor2=opt.float()
2217+
.sqrt_()
2218+
.add_(eps)
2219+
.view(
2220+
num_ids,
2221+
1,
2222+
)
2223+
.cuda(),
2224+
).cpu()
2225+
2226+
emb_w = (
2227+
emb_state_dict_list[table_index]
2228+
.narrow(0, 0, bucket_asc_ids_list[table_index].size(0))
2229+
.float()
2230+
)
2231+
torch.testing.assert_close(
2232+
emb_w,
2233+
new_ref_weight,
2234+
atol=tolerance,
2235+
rtol=tolerance,
2236+
)

0 commit comments

Comments
 (0)