Skip to content

Commit daf6b8b

Browse files
emlinfacebook-github-bot
authored andcommitted
support get state dict and apply state dict (#4145)
Summary: X-link: pytorch/torchrec#2976 Pull Request resolved: #4145 X-link: facebookresearch/FBGEMM#1226 # Functions **Saving State Dict** When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded. **Checkpoint Loading Mode** We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading. # Current Solution The current solution involves caching all data in Python tensors, following these steps: - Set self.local_weight_counts based on checkpoint bucket tensor size. - Enable load state dict mode to initialize local cache tensors. - Call state_dict to get empty tensors for the checkpoint loader. - Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader. - Call apply_state_dict to write all cached tensors to the backend. **Apply State Dict Flow** During the apply_state_dict step, we perform the following operations: - If optimizer offloading is enabled: - Loop through chunks of weight and optimizer. - Concatenate weight and optimizer together. - Write to backend using KVTensorWrapper interface. - If optimizer offloading is disabled: - Set optimizer to device tensor based on ID. - Write ID weight to backend for each table. # Limitations The current solution has two limitations: - Memory overhead: - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables. - Performance regression: - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression. # Future Improvements After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead. Reviewed By: bobbyliujb Differential Revision: D74790154 fbshipit-source-id: b60e939653b4788df3974ded8404a1574adff804
1 parent 96cfdbd commit daf6b8b

File tree

2 files changed

+417
-18
lines changed

2 files changed

+417
-18
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 205 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def __init__(
313313
f"weights precision: {weights_precision}, "
314314
f"output dtype: {output_dtype}, "
315315
f"chunk size in bulk init: {bulk_init_chunk_size} bytes, backend_type: {backend_type}, "
316-
f"zero_collision_config: {kv_zch_params}"
316+
f"kv_zch_params: {kv_zch_params}"
317317
)
318318
self.register_buffer(
319319
"lxu_cache_state",
@@ -1983,9 +1983,7 @@ def split_optimizer_states(
19831983
dtype=dtype,
19841984
row_offset=row_offset,
19851985
snapshot_handle=snapshot_handle,
1986-
materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]),
19871986
sorted_indices=sorted_id_tensor[t],
1988-
width_offset=pad4(emb_dim),
19891987
)
19901988
(
19911989
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
@@ -2086,6 +2084,7 @@ def _may_create_snapshot_for_state_dict(
20862084
"""
20872085
Create a rocksdb snapshot if needed.
20882086
"""
2087+
start_time = time.time()
20892088
# Force device synchronize for now
20902089
torch.cuda.synchronize()
20912090
snapshot_handle = None
@@ -2094,7 +2093,13 @@ def _may_create_snapshot_for_state_dict(
20942093
if not no_snapshot:
20952094
# Flush L1 and L2 caches
20962095
self.flush(force=should_flush)
2096+
logging.info(
2097+
f"flush latency for weight states: {(time.time() - start_time) * 1000} ms"
2098+
)
20972099
snapshot_handle = self.ssd_db.create_snapshot()
2100+
logging.info(
2101+
f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms"
2102+
)
20982103
elif self.backend_type == BackendType.DRAM:
20992104
self.flush(force=should_flush)
21002105
return snapshot_handle
@@ -2105,7 +2110,7 @@ def split_embedding_weights(
21052110
no_snapshot: bool = True,
21062111
should_flush: bool = False,
21072112
) -> Tuple[ # TODO: make this a NamedTuple for readability
2108-
List[PartiallyMaterializedTensor],
2113+
Union[List[PartiallyMaterializedTensor], List[torch.Tensor]],
21092114
Optional[List[torch.Tensor]],
21102115
Optional[List[torch.Tensor]],
21112116
]:
@@ -2134,6 +2139,17 @@ def split_embedding_weights(
21342139
)
21352140

21362141
dtype = self.weights_precision.as_dtype()
2142+
if self.load_state_dict and self.kv_zch_params:
2143+
# init for checkpointing loading
2144+
assert (
2145+
self._cached_kvzch_data is not None
2146+
), "weight id and bucket state are not initialized for load checkpointing"
2147+
return (
2148+
self._cached_kvzch_data.cached_weight_tensor_per_table,
2149+
self._cached_kvzch_data.cached_id_tensor_per_table,
2150+
self._cached_kvzch_data.cached_bucket_splits,
2151+
)
2152+
start_time = time.time()
21372153
pmt_splits = []
21382154
bucket_sorted_id_splits = [] if self.kv_zch_params else None
21392155
active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
@@ -2142,18 +2158,15 @@ def split_embedding_weights(
21422158
for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
21432159
bucket_ascending_id_tensor = None
21442160
bucket_t = None
2161+
row_offset = table_offset
21452162
if self.kv_zch_params:
21462163
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
21472164
# pyre-ignore
21482165
bucket_size = self.kv_zch_params.bucket_sizes[i]
21492166

21502167
# linearize with table offset
2151-
table_input_id_start = (
2152-
min(bucket_id_start * bucket_size, emb_height) + table_offset
2153-
)
2154-
table_input_id_end = (
2155-
min(bucket_id_end * bucket_size, emb_height) + table_offset
2156-
)
2168+
table_input_id_start = table_offset
2169+
table_input_id_end = table_offset + emb_height
21572170
# 1. get all keys from backend for one table
21582171
unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot(
21592172
table_input_id_start,
@@ -2166,15 +2179,38 @@ def split_embedding_weights(
21662179
torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor(
21672180
unordered_id_tensor,
21682181
0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
2169-
bucket_id_start,
2170-
bucket_id_end,
2182+
0, # local bucket offset
2183+
bucket_id_end - bucket_id_start, # local bucket num
21712184
bucket_size,
21722185
)
21732186
)
2174-
# pyre-ignore
2187+
# 3. convert local id back to global id
2188+
bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)
2189+
2190+
if (
2191+
bucket_ascending_id_tensor.size(0) == 0
2192+
and self.local_weight_counts[i] > 0
2193+
):
2194+
logging.info(
2195+
f"resetting bucket id tensor with {self.local_weight_counts[i]}"
2196+
)
2197+
bucket_ascending_id_tensor = torch.zeros(
2198+
(self.local_weight_counts[i], 1),
2199+
device=torch.device("cpu"),
2200+
dtype=torch.int64,
2201+
)
2202+
# self.local_weight_counts[i] = 0 # Reset the count
2203+
2204+
# pyre-ignore [16] bucket_sorted_id_splits is not None
21752205
bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
21762206
active_id_cnt_per_bucket_split.append(bucket_t)
21772207

2208+
# for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
2209+
# but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
2210+
# first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
2211+
# to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
2212+
row_offset = table_offset - (bucket_id_start * bucket_size)
2213+
21782214
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
21792215
shape=[
21802216
(
@@ -2185,33 +2221,184 @@ def split_embedding_weights(
21852221
pad4(emb_dim),
21862222
],
21872223
dtype=dtype,
2188-
row_offset=table_offset,
2224+
row_offset=row_offset,
21892225
snapshot_handle=snapshot_handle,
2226+
# set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
2227+
# embedding weights.
21902228
sorted_indices=(
21912229
bucket_ascending_id_tensor if self.kv_zch_params else None
21922230
),
21932231
)
2194-
# TODO add if else support in the future for dram integration.
2195-
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2232+
(
2233+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2234+
if self.backend_type == BackendType.SSD
2235+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2236+
)
21962237
table_offset += emb_height
21972238
pmt_splits.append(
21982239
PartiallyMaterializedTensor(
21992240
tensor_wrapper,
22002241
True if self.kv_zch_params else False,
22012242
)
22022243
)
2244+
logging.info(
2245+
f"split_embedding_weights latency: {(time.time() - start_time) * 1000} ms"
2246+
)
22032247
return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)
22042248

22052249
@torch.jit.ignore
22062250
def apply_state_dict(self) -> None:
22072251
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
22082252
# Caller should call this function to apply the cached states to backend.
2209-
pass
2253+
if self.load_state_dict is False:
2254+
return
2255+
self.load_state_dict = False
2256+
assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only"
2257+
assert (
2258+
self._cached_kvzch_data is not None
2259+
and self._cached_kvzch_data.cached_optimizer_state_per_table is not None
2260+
), "optimizer state is not initialized for load checkpointing"
2261+
assert (
2262+
self._cached_kvzch_data.cached_weight_tensor_per_table is not None
2263+
and self._cached_kvzch_data.cached_id_tensor_per_table is not None
2264+
), "weight and id state is not initialized for load checkpointing"
2265+
2266+
# Compute the number of elements of cache_dtype needed to store the
2267+
# optimizer state, round to the nearest 4
2268+
# optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
2269+
# apply weight and optimizer state per table
2270+
table_offset = 0
2271+
for i, (emb_height, _) in enumerate(self.embedding_specs):
2272+
# pyre-ignore [16]
2273+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[i]
2274+
# pyre-ignore [16]
2275+
bucket_size = self.kv_zch_params.bucket_sizes[i]
2276+
row_offset = table_offset - bucket_id_start * bucket_size
2277+
2278+
if self.enable_optimizer_offloading:
2279+
# pyre-ignore [16]
2280+
weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
2281+
# pyre-ignore [16]
2282+
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
2283+
self.streaming_write_weight_and_id_per_table(
2284+
weight_state,
2285+
opt_state,
2286+
# pyre-ignore [16]
2287+
self._cached_kvzch_data.cached_id_tensor_per_table[i],
2288+
row_offset,
2289+
)
2290+
self._cached_kvzch_data.cached_weight_tensor_per_table[i] = None
2291+
self._cached_kvzch_data.cached_optimizer_state_per_table[i] = None
2292+
else:
2293+
weight = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
2294+
id = self._cached_kvzch_data.cached_id_tensor_per_table[i]
2295+
local_id = id + row_offset
2296+
logging.info(
2297+
f"applying sd for table {i} without optimizer offloading, local_id is {local_id}"
2298+
)
2299+
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
2300+
t_device = self.momentum1_dev.device
2301+
self.momentum1_dev.index_put_(
2302+
indices=(
2303+
local_id.to(t_device).view(-1),
2304+
), # expects tuple of tensors
2305+
values=opt_state.to(t_device),
2306+
)
2307+
self.ssd_db.set_cuda(
2308+
local_id.view(-1),
2309+
weight,
2310+
torch.as_tensor(local_id.size(0)),
2311+
1,
2312+
False,
2313+
)
2314+
table_offset += emb_height
2315+
self.clear_cache()
2316+
2317+
@torch.jit.ignore
2318+
def streaming_write_weight_and_id_per_table(
2319+
self,
2320+
weight_state: torch.Tensor,
2321+
opt_state: torch.Tensor,
2322+
id_tensor: torch.Tensor,
2323+
row_offset: int,
2324+
) -> None:
2325+
"""
2326+
This function is used to write weight, optimizer and id to the backend using kvt wrapper.
2327+
to avoid over use memory, we will write the weight and id to backend in a rolling window manner
2328+
2329+
Args:
2330+
weight_state (torch.tensor): The weight state tensor to be written.
2331+
opt_state (torch.tensor): The optimizer state tensor to be written.
2332+
id_tensor (torch.tensor): The id tensor to be written.
2333+
"""
2334+
D_rounded = pad4(weight_state.size(1)) # padded to 4 bytes alignment
2335+
dtype = self.weights_precision.as_dtype()
2336+
kvt = torch.classes.fbgemm.KVTensorWrapper(
2337+
db=self.ssd_db,
2338+
shape=[weight_state.size(0), self.cache_row_dim],
2339+
dtype=dtype,
2340+
row_offset=row_offset,
2341+
snapshot_handle=None,
2342+
sorted_indices=id_tensor,
2343+
)
2344+
# TODO: make chunk_size configurable or dynamic
2345+
chunk_size = 10000
2346+
row = weight_state.size(0)
2347+
optimizer_dim = self.optimizer.state_size_dim(dtype)
2348+
opt_state_2d = opt_state.view(dtype).view(-1, optimizer_dim)
2349+
for i in range(0, row, chunk_size):
2350+
length = min(chunk_size, row - i)
2351+
chunk_buffer = torch.empty(
2352+
length,
2353+
self.cache_row_dim,
2354+
dtype=dtype,
2355+
device="cpu",
2356+
)
2357+
chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :]
2358+
chunk_buffer[:, D_rounded : D_rounded + optimizer_dim] = opt_state_2d[
2359+
i : i + length, :
2360+
]
2361+
kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1))
22102362

22112363
@torch.jit.ignore
22122364
def enable_load_state_dict_mode(self) -> None:
22132365
# Enable load state dict mode before loading checkpoint
2214-
pass
2366+
if self.load_state_dict:
2367+
return
2368+
self.load_state_dict = True
2369+
2370+
dtype = self.weights_precision.as_dtype()
2371+
self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
2372+
for i, (_, emb_dim) in enumerate(self.embedding_specs):
2373+
# for checkpointing loading, we need to store the weight and id tensor temporarily in memory
2374+
assert (
2375+
self.local_weight_counts[i] > 0
2376+
), f"local_weight_counts for table {i} is not set"
2377+
# pyre-ignore [16]
2378+
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
2379+
rows = self.local_weight_counts[i]
2380+
weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu")
2381+
opt_state = torch.empty(rows, dtype=torch.float32, device="cpu")
2382+
# pyre-ignore [16]
2383+
self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state)
2384+
# pyre-ignore [16]
2385+
self._cached_kvzch_data.cached_optimizer_state_per_table.append(opt_state)
2386+
logging.info(
2387+
f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}, opt_state shape is {opt_state.shape}"
2388+
)
2389+
id_tensor = torch.zeros(
2390+
(self.local_weight_counts[i], 1), dtype=torch.int64, device="cpu"
2391+
)
2392+
# pyre-ignore [16]
2393+
self._cached_kvzch_data.cached_id_tensor_per_table.append(id_tensor)
2394+
# pyre-ignore [16]
2395+
self._cached_kvzch_data.cached_bucket_splits.append(
2396+
torch.empty(
2397+
(bucket_id_end - bucket_id_start, 1),
2398+
dtype=torch.int64,
2399+
device="cpu",
2400+
)
2401+
)
22152402

22162403
@torch.jit.export
22172404
def set_learning_rate(self, lr: float) -> None:

0 commit comments

Comments
 (0)