Skip to content

Commit 74925fa

Browse files
emlinfacebook-github-bot
authored andcommitted
support get state dict and apply state dict (#4145)
Summary: X-link: pytorch/torchrec#2976 Pull Request resolved: #4145 X-link: facebookresearch/FBGEMM#1226 # Functions **Saving State Dict** When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded. **Checkpoint Loading Mode** We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading. # Current Solution The current solution involves caching all data in Python tensors, following these steps: - Set self.local_weight_counts based on checkpoint bucket tensor size. - Enable load state dict mode to initialize local cache tensors. - Call state_dict to get empty tensors for the checkpoint loader. - Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader. - Call apply_state_dict to write all cached tensors to the backend. **Apply State Dict Flow** During the apply_state_dict step, we perform the following operations: - If optimizer offloading is enabled: - Loop through chunks of weight and optimizer. - Concatenate weight and optimizer together. - Write to backend using KVTensorWrapper interface. - If optimizer offloading is disabled: - Set optimizer to device tensor based on ID. - Write ID weight to backend for each table. # Limitations The current solution has two limitations: - Memory overhead: - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables. - Performance regression: - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression. # Future Improvements After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead. Reviewed By: bobbyliujb Differential Revision: D74790154
1 parent 16ce772 commit 74925fa

File tree

2 files changed

+417
-17
lines changed

2 files changed

+417
-17
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 205 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def __init__(
313313
f"weights precision: {weights_precision}, "
314314
f"output dtype: {output_dtype}, "
315315
f"chunk size in bulk init: {bulk_init_chunk_size} bytes, backend_type: {backend_type}, "
316-
f"zero_collision_config: {kv_zch_params}"
316+
f"kv_zch_params: {kv_zch_params}"
317317
)
318318
self.register_buffer(
319319
"lxu_cache_state",
@@ -2050,7 +2050,6 @@ def split_optimizer_states(
20502050
dtype=dtype,
20512051
row_offset=row_offset,
20522052
snapshot_handle=snapshot_handle,
2053-
materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]),
20542053
sorted_indices=sorted_id_tensor[t],
20552054
)
20562055
(
@@ -2179,6 +2178,7 @@ def _may_create_snapshot_for_state_dict(
21792178
"""
21802179
Create a rocksdb snapshot if needed.
21812180
"""
2181+
start_time = time.time()
21822182
# Force device synchronize for now
21832183
torch.cuda.synchronize()
21842184
snapshot_handle = None
@@ -2187,7 +2187,13 @@ def _may_create_snapshot_for_state_dict(
21872187
if not no_snapshot:
21882188
# Flush L1 and L2 caches
21892189
self.flush(force=should_flush)
2190+
logging.info(
2191+
f"flush latency for weight states: {(time.time() - start_time) * 1000} ms"
2192+
)
21902193
snapshot_handle = self.ssd_db.create_snapshot()
2194+
logging.info(
2195+
f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms"
2196+
)
21912197
elif self.backend_type == BackendType.DRAM:
21922198
self.flush(force=should_flush)
21932199
return snapshot_handle
@@ -2198,7 +2204,7 @@ def split_embedding_weights(
21982204
no_snapshot: bool = True,
21992205
should_flush: bool = False,
22002206
) -> Tuple[ # TODO: make this a NamedTuple for readability
2201-
List[PartiallyMaterializedTensor],
2207+
Union[List[PartiallyMaterializedTensor], List[torch.Tensor]],
22022208
Optional[List[torch.Tensor]],
22032209
Optional[List[torch.Tensor]],
22042210
]:
@@ -2227,6 +2233,17 @@ def split_embedding_weights(
22272233
)
22282234

22292235
dtype = self.weights_precision.as_dtype()
2236+
if self.load_state_dict and self.kv_zch_params:
2237+
# init for checkpointing loading
2238+
assert (
2239+
self._cached_kvzch_data is not None
2240+
), "weight id and bucket state are not initialized for load checkpointing"
2241+
return (
2242+
self._cached_kvzch_data.cached_weight_tensor_per_table,
2243+
self._cached_kvzch_data.cached_id_tensor_per_table,
2244+
self._cached_kvzch_data.cached_bucket_splits,
2245+
)
2246+
start_time = time.time()
22302247
pmt_splits = []
22312248
bucket_sorted_id_splits = [] if self.kv_zch_params else None
22322249
active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
@@ -2235,18 +2252,15 @@ def split_embedding_weights(
22352252
for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
22362253
bucket_ascending_id_tensor = None
22372254
bucket_t = None
2255+
row_offset = table_offset
22382256
if self.kv_zch_params:
22392257
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
22402258
# pyre-ignore
22412259
bucket_size = self.kv_zch_params.bucket_sizes[i]
22422260

22432261
# linearize with table offset
2244-
table_input_id_start = (
2245-
min(bucket_id_start * bucket_size, emb_height) + table_offset
2246-
)
2247-
table_input_id_end = (
2248-
min(bucket_id_end * bucket_size, emb_height) + table_offset
2249-
)
2262+
table_input_id_start = table_offset
2263+
table_input_id_end = table_offset + emb_height
22502264
# 1. get all keys from backend for one table
22512265
unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot(
22522266
table_input_id_start,
@@ -2259,15 +2273,38 @@ def split_embedding_weights(
22592273
torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor(
22602274
unordered_id_tensor,
22612275
0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
2262-
bucket_id_start,
2263-
bucket_id_end,
2276+
0, # local bucket offset
2277+
bucket_id_end - bucket_id_start, # local bucket num
22642278
bucket_size,
22652279
)
22662280
)
2267-
# pyre-ignore
2281+
# 3. convert local id back to global id
2282+
bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)
2283+
2284+
if (
2285+
bucket_ascending_id_tensor.size(0) == 0
2286+
and self.local_weight_counts[i] > 0
2287+
):
2288+
logging.info(
2289+
f"resetting bucket id tensor with {self.local_weight_counts[i]}"
2290+
)
2291+
bucket_ascending_id_tensor = torch.zeros(
2292+
(self.local_weight_counts[i], 1),
2293+
device=torch.device("cpu"),
2294+
dtype=torch.int64,
2295+
)
2296+
# self.local_weight_counts[i] = 0 # Reset the count
2297+
2298+
# pyre-ignore [16] bucket_sorted_id_splits is not None
22682299
bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
22692300
active_id_cnt_per_bucket_split.append(bucket_t)
22702301

2302+
# for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
2303+
# but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
2304+
# first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
2305+
# to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
2306+
row_offset = table_offset - (bucket_id_start * bucket_size)
2307+
22712308
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
22722309
shape=[
22732310
(
@@ -2278,33 +2315,184 @@ def split_embedding_weights(
22782315
pad4(emb_dim),
22792316
],
22802317
dtype=dtype,
2281-
row_offset=table_offset,
2318+
row_offset=row_offset,
22822319
snapshot_handle=snapshot_handle,
2320+
# set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
2321+
# embedding weights.
22832322
sorted_indices=(
22842323
bucket_ascending_id_tensor if self.kv_zch_params else None
22852324
),
22862325
)
2287-
# TODO add if else support in the future for dram integration.
2288-
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2326+
(
2327+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2328+
if self.backend_type == BackendType.SSD
2329+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2330+
)
22892331
table_offset += emb_height
22902332
pmt_splits.append(
22912333
PartiallyMaterializedTensor(
22922334
tensor_wrapper,
22932335
True if self.kv_zch_params else False,
22942336
)
22952337
)
2338+
logging.info(
2339+
f"split_embedding_weights latency: {(time.time() - start_time) * 1000} ms"
2340+
)
22962341
return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)
22972342

22982343
@torch.jit.ignore
22992344
def apply_state_dict(self) -> None:
23002345
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
23012346
# Caller should call this function to apply the cached states to backend.
2302-
pass
2347+
if self.load_state_dict is False:
2348+
return
2349+
self.load_state_dict = False
2350+
assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only"
2351+
assert (
2352+
self._cached_kvzch_data is not None
2353+
and self._cached_kvzch_data.cached_optimizer_state_per_table is not None
2354+
), "optimizer state is not initialized for load checkpointing"
2355+
assert (
2356+
self._cached_kvzch_data.cached_weight_tensor_per_table is not None
2357+
and self._cached_kvzch_data.cached_id_tensor_per_table is not None
2358+
), "weight and id state is not initialized for load checkpointing"
2359+
2360+
# Compute the number of elements of cache_dtype needed to store the
2361+
# optimizer state, round to the nearest 4
2362+
# optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
2363+
# apply weight and optimizer state per table
2364+
table_offset = 0
2365+
for i, (emb_height, _) in enumerate(self.embedding_specs):
2366+
# pyre-ignore [16]
2367+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[i]
2368+
# pyre-ignore [16]
2369+
bucket_size = self.kv_zch_params.bucket_sizes[i]
2370+
row_offset = table_offset - bucket_id_start * bucket_size
2371+
2372+
if self.enable_optimizer_offloading:
2373+
# pyre-ignore [16]
2374+
weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
2375+
# pyre-ignore [16]
2376+
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
2377+
self.streaming_write_weight_and_id_per_table(
2378+
weight_state,
2379+
opt_state,
2380+
# pyre-ignore [16]
2381+
self._cached_kvzch_data.cached_id_tensor_per_table[i],
2382+
row_offset,
2383+
)
2384+
self._cached_kvzch_data.cached_weight_tensor_per_table[i] = None
2385+
self._cached_kvzch_data.cached_optimizer_state_per_table[i] = None
2386+
else:
2387+
weight = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
2388+
id = self._cached_kvzch_data.cached_id_tensor_per_table[i]
2389+
local_id = id + row_offset
2390+
logging.info(
2391+
f"applying sd for table {i} without optimizer offloading, local_id is {local_id}"
2392+
)
2393+
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
2394+
t_device = self.momentum1_dev.device
2395+
self.momentum1_dev.index_put_(
2396+
indices=(
2397+
local_id.to(t_device).view(-1),
2398+
), # expects tuple of tensors
2399+
values=opt_state.to(t_device),
2400+
)
2401+
self.ssd_db.set_cuda(
2402+
local_id.view(-1),
2403+
weight,
2404+
torch.as_tensor(local_id.size(0)),
2405+
1,
2406+
False,
2407+
)
2408+
table_offset += emb_height
2409+
self.clear_cache()
2410+
2411+
@torch.jit.ignore
2412+
def streaming_write_weight_and_id_per_table(
2413+
self,
2414+
weight_state: torch.Tensor,
2415+
opt_state: torch.Tensor,
2416+
id_tensor: torch.Tensor,
2417+
row_offset: int,
2418+
) -> None:
2419+
"""
2420+
This function is used to write weight, optimizer and id to the backend using kvt wrapper.
2421+
to avoid over use memory, we will write the weight and id to backend in a rolling window manner
2422+
2423+
Args:
2424+
weight_state (torch.tensor): The weight state tensor to be written.
2425+
opt_state (torch.tensor): The optimizer state tensor to be written.
2426+
id_tensor (torch.tensor): The id tensor to be written.
2427+
"""
2428+
D_rounded = pad4(weight_state.size(1)) # padded to 4 bytes alignment
2429+
dtype = self.weights_precision.as_dtype()
2430+
kvt = torch.classes.fbgemm.KVTensorWrapper(
2431+
db=self.ssd_db,
2432+
shape=[weight_state.size(0), self.cache_row_dim],
2433+
dtype=dtype,
2434+
row_offset=row_offset,
2435+
snapshot_handle=None,
2436+
sorted_indices=id_tensor,
2437+
)
2438+
# TODO: make chunk_size configurable or dynamic
2439+
chunk_size = 10000
2440+
row = weight_state.size(0)
2441+
optimizer_dim = self.optimizer.state_size_dim(dtype)
2442+
opt_state_2d = opt_state.view(dtype).view(-1, optimizer_dim)
2443+
for i in range(0, row, chunk_size):
2444+
length = min(chunk_size, row - i)
2445+
chunk_buffer = torch.empty(
2446+
length,
2447+
self.cache_row_dim,
2448+
dtype=dtype,
2449+
device="cpu",
2450+
)
2451+
chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :]
2452+
chunk_buffer[:, D_rounded : D_rounded + optimizer_dim] = opt_state_2d[
2453+
i : i + length, :
2454+
]
2455+
kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1))
23032456

23042457
@torch.jit.ignore
23052458
def enable_load_state_dict_mode(self) -> None:
23062459
# Enable load state dict mode before loading checkpoint
2307-
pass
2460+
if self.load_state_dict:
2461+
return
2462+
self.load_state_dict = True
2463+
2464+
dtype = self.weights_precision.as_dtype()
2465+
self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
2466+
for i, (_, emb_dim) in enumerate(self.embedding_specs):
2467+
# for checkpointing loading, we need to store the weight and id tensor temporarily in memory
2468+
assert (
2469+
self.local_weight_counts[i] > 0
2470+
), f"local_weight_counts for table {i} is not set"
2471+
# pyre-ignore [16]
2472+
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
2473+
rows = self.local_weight_counts[i]
2474+
weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu")
2475+
opt_state = torch.empty(rows, dtype=torch.float32, device="cpu")
2476+
# pyre-ignore [16]
2477+
self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state)
2478+
# pyre-ignore [16]
2479+
self._cached_kvzch_data.cached_optimizer_state_per_table.append(opt_state)
2480+
logging.info(
2481+
f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}, opt_state shape is {opt_state.shape}"
2482+
)
2483+
id_tensor = torch.zeros(
2484+
(self.local_weight_counts[i], 1), dtype=torch.int64, device="cpu"
2485+
)
2486+
# pyre-ignore [16]
2487+
self._cached_kvzch_data.cached_id_tensor_per_table.append(id_tensor)
2488+
# pyre-ignore [16]
2489+
self._cached_kvzch_data.cached_bucket_splits.append(
2490+
torch.empty(
2491+
(bucket_id_end - bucket_id_start, 1),
2492+
dtype=torch.int64,
2493+
device="cpu",
2494+
)
2495+
)
23082496

23092497
@torch.jit.export
23102498
def set_learning_rate(self, lr: float) -> None:

0 commit comments

Comments
 (0)