Skip to content

Commit a92923e

Browse files
emlinfacebook-github-bot
authored andcommitted
support get state dict and apply state dict (pytorch#4145)
Summary: X-link: pytorch/torchrec#2976 X-link: facebookresearch/FBGEMM#1226 # Functions **Saving State Dict** When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded. **Checkpoint Loading Mode** We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading. # Current Solution The current solution involves caching all data in Python tensors, following these steps: - Set self.local_weight_counts based on checkpoint bucket tensor size. - Enable load state dict mode to initialize local cache tensors. - Call state_dict to get empty tensors for the checkpoint loader. - Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader. - Call apply_state_dict to write all cached tensors to the backend. **Apply State Dict Flow** During the apply_state_dict step, we perform the following operations: - If optimizer offloading is enabled: - Loop through chunks of weight and optimizer. - Concatenate weight and optimizer together. - Write to backend using KVTensorWrapper interface. - If optimizer offloading is disabled: - Set optimizer to device tensor based on ID. - Write ID weight to backend for each table. # Limitations The current solution has two limitations: - Memory overhead: - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables. - Performance regression: - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression. # Future Improvements After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead. Reviewed By: bobbyliujb Differential Revision: D74790154
1 parent e2fe03d commit a92923e

File tree

2 files changed

+417
-17
lines changed

2 files changed

+417
-17
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 205 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def __init__(
317317
f"weights precision: {weights_precision}, "
318318
f"output dtype: {output_dtype}, "
319319
f"chunk size in bulk init: {bulk_init_chunk_size} bytes, backend_type: {backend_type}, "
320-
f"zero_collision_config: {kv_zch_params}"
320+
f"kv_zch_params: {kv_zch_params}"
321321
)
322322
self.register_buffer(
323323
"lxu_cache_state",
@@ -1986,7 +1986,6 @@ def split_optimizer_states(
19861986
dtype=dtype,
19871987
row_offset=row_offset,
19881988
snapshot_handle=snapshot_handle,
1989-
materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]),
19901989
sorted_indices=sorted_id_tensor[t],
19911990
)
19921991
(
@@ -2115,6 +2114,7 @@ def _may_create_snapshot_for_state_dict(
21152114
"""
21162115
Create a rocksdb snapshot if needed.
21172116
"""
2117+
start_time = time.time()
21182118
# Force device synchronize for now
21192119
torch.cuda.synchronize()
21202120
snapshot_handle = None
@@ -2124,7 +2124,13 @@ def _may_create_snapshot_for_state_dict(
21242124
if should_flush:
21252125
# Flush L1 and L2 caches
21262126
self.flush(force=True)
2127+
logging.info(
2128+
f"flush latency for weight states: {(time.time() - start_time) * 1000} ms"
2129+
)
21272130
snapshot_handle = self.ssd_db.create_snapshot()
2131+
logging.info(
2132+
f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms"
2133+
)
21282134
elif self.backend_type == BackendType.DRAM:
21292135
if should_flush:
21302136
self.flush(force=True)
@@ -2136,7 +2142,7 @@ def split_embedding_weights(
21362142
no_snapshot: bool = True,
21372143
should_flush: bool = False,
21382144
) -> Tuple[ # TODO: make this a NamedTuple for readability
2139-
List[PartiallyMaterializedTensor],
2145+
List[PartiallyMaterializedTensor] | List[torch.Tensor],
21402146
Optional[List[torch.Tensor]],
21412147
Optional[List[torch.Tensor]],
21422148
]:
@@ -2165,6 +2171,17 @@ def split_embedding_weights(
21652171
)
21662172

21672173
dtype = self.weights_precision.as_dtype()
2174+
if self.load_state_dict and self.kv_zch_params:
2175+
# init for checkpointing loading
2176+
assert (
2177+
self._cached_kvzch_data is not None
2178+
), "weight id and bucket state are not initialized for load checkpointing"
2179+
return (
2180+
self._cached_kvzch_data.cached_weight_tensor_per_table,
2181+
self._cached_kvzch_data.cached_id_tensor_per_table,
2182+
self._cached_kvzch_data.cached_bucket_splits,
2183+
)
2184+
start_time = time.time()
21682185
pmt_splits = []
21692186
bucket_sorted_id_splits = [] if self.kv_zch_params else None
21702187
active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
@@ -2173,18 +2190,15 @@ def split_embedding_weights(
21732190
for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
21742191
bucket_ascending_id_tensor = None
21752192
bucket_t = None
2193+
row_offset = table_offset
21762194
if self.kv_zch_params:
21772195
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
21782196
# pyre-ignore
21792197
bucket_size = self.kv_zch_params.bucket_sizes[i]
21802198

21812199
# linearize with table offset
2182-
table_input_id_start = (
2183-
min(bucket_id_start * bucket_size, emb_height) + table_offset
2184-
)
2185-
table_input_id_end = (
2186-
min(bucket_id_end * bucket_size, emb_height) + table_offset
2187-
)
2200+
table_input_id_start = table_offset
2201+
table_input_id_end = table_offset + emb_height
21882202
# 1. get all keys from backend for one table
21892203
unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot(
21902204
table_input_id_start,
@@ -2197,15 +2211,38 @@ def split_embedding_weights(
21972211
torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor(
21982212
unordered_id_tensor,
21992213
0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
2200-
bucket_id_start,
2201-
bucket_id_end,
2214+
0, # local bucket offset
2215+
bucket_id_end - bucket_id_start, # local bucket num
22022216
bucket_size,
22032217
)
22042218
)
2205-
# pyre-ignore
2219+
# 3. convert local id back to global id
2220+
bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)
2221+
2222+
if (
2223+
bucket_ascending_id_tensor.size(0) == 0
2224+
and self.local_weight_counts[i] > 0
2225+
):
2226+
logging.info(
2227+
f"resetting bucket id tensor with {self.local_weight_counts[i]}"
2228+
)
2229+
bucket_ascending_id_tensor = torch.zeros(
2230+
(self.local_weight_counts[i], 1),
2231+
device=torch.device("cpu"),
2232+
dtype=torch.int64,
2233+
)
2234+
# self.local_weight_counts[i] = 0 # Reset the count
2235+
2236+
# pyre-ignore [16] bucket_sorted_id_splits is not None
22062237
bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
22072238
active_id_cnt_per_bucket_split.append(bucket_t)
22082239

2240+
# for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
2241+
# but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
2242+
# first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
2243+
# to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
2244+
row_offset = table_offset - (bucket_id_start * bucket_size)
2245+
22092246
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
22102247
shape=[
22112248
(
@@ -2216,33 +2253,184 @@ def split_embedding_weights(
22162253
emb_dim,
22172254
],
22182255
dtype=dtype,
2219-
row_offset=table_offset,
2256+
row_offset=row_offset,
22202257
snapshot_handle=snapshot_handle,
2258+
# set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
2259+
# embedding weights.
22212260
sorted_indices=(
22222261
bucket_ascending_id_tensor if self.kv_zch_params else None
22232262
),
22242263
)
2225-
# TODO add if else support in the future for dram integration.
2226-
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2264+
(
2265+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2266+
if self.backend_type == BackendType.SSD
2267+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2268+
)
22272269
table_offset += emb_height
22282270
pmt_splits.append(
22292271
PartiallyMaterializedTensor(
22302272
tensor_wrapper,
22312273
True if self.kv_zch_params else False,
22322274
)
22332275
)
2276+
logging.info(
2277+
f"split_embedding_weights latency: {(time.time() - start_time) * 1000} ms"
2278+
)
22342279
return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)
22352280

22362281
@torch.jit.ignore
22372282
def apply_state_dict(self) -> None:
22382283
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
22392284
# Caller should call this function to apply the cached states to backend.
2240-
pass
2285+
if self.load_state_dict is False:
2286+
return
2287+
self.load_state_dict = False
2288+
assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only"
2289+
assert (
2290+
self._cached_kvzch_data is not None
2291+
and self._cached_kvzch_data.cached_optimizer_state_per_table is not None
2292+
), "optimizer state is not initialized for load checkpointing"
2293+
assert (
2294+
self._cached_kvzch_data.cached_weight_tensor_per_table is not None
2295+
and self._cached_kvzch_data.cached_id_tensor_per_table is not None
2296+
), "weight and id state is not initialized for load checkpointing"
2297+
2298+
# Compute the number of elements of cache_dtype needed to store the
2299+
# optimizer state, round to the nearest 4
2300+
# optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
2301+
# apply weight and optimizer state per table
2302+
table_offset = 0
2303+
for i, (emb_height, _) in enumerate(self.embedding_specs):
2304+
# pyre-ignore [16]
2305+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[i]
2306+
# pyre-ignore [16]
2307+
bucket_size = self.kv_zch_params.bucket_sizes[i]
2308+
row_offset = table_offset - bucket_id_start * bucket_size
2309+
2310+
if self.enable_optimizer_offloading:
2311+
# pyre-ignore [16]
2312+
weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
2313+
# pyre-ignore [16]
2314+
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
2315+
self.streaming_write_weight_and_id_per_table(
2316+
weight_state,
2317+
opt_state,
2318+
# pyre-ignore [16]
2319+
self._cached_kvzch_data.cached_id_tensor_per_table[i],
2320+
row_offset,
2321+
)
2322+
self._cached_kvzch_data.cached_weight_tensor_per_table[i] = None
2323+
self._cached_kvzch_data.cached_optimizer_state_per_table[i] = None
2324+
else:
2325+
weight = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
2326+
id = self._cached_kvzch_data.cached_id_tensor_per_table[i]
2327+
local_id = id + row_offset
2328+
logging.info(
2329+
f"applying sd for table {i} without optimizer offloading, local_id is {local_id}"
2330+
)
2331+
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
2332+
t_device = self.momentum1_dev.device
2333+
self.momentum1_dev.index_put_(
2334+
indices=(
2335+
local_id.to(t_device).view(-1),
2336+
), # expects tuple of tensors
2337+
values=opt_state.to(t_device),
2338+
)
2339+
self.ssd_db.set_cuda(
2340+
local_id.view(-1),
2341+
weight,
2342+
torch.as_tensor(local_id.size(0)),
2343+
1,
2344+
False,
2345+
)
2346+
table_offset += emb_height
2347+
self.clear_cache()
2348+
2349+
@torch.jit.ignore
2350+
def streaming_write_weight_and_id_per_table(
2351+
self,
2352+
weight_state: torch.Tensor,
2353+
opt_state: torch.Tensor,
2354+
id_tensor: torch.Tensor,
2355+
row_offset: int,
2356+
) -> None:
2357+
"""
2358+
This function is used to write weight, optimizer and id to the backend using kvt wrapper.
2359+
to avoid over use memory, we will write the weight and id to backend in a rolling window manner
2360+
2361+
Args:
2362+
weight_state (torch.tensor): The weight state tensor to be written.
2363+
opt_state (torch.tensor): The optimizer state tensor to be written.
2364+
id_tensor (torch.tensor): The id tensor to be written.
2365+
"""
2366+
D_rounded = pad4(weight_state.size(1)) # padded to 4 bytes alignment
2367+
dtype = self.weights_precision.as_dtype()
2368+
kvt = torch.classes.fbgemm.KVTensorWrapper(
2369+
db=self.ssd_db,
2370+
shape=[weight_state.size(0), self.cache_row_dim],
2371+
dtype=dtype,
2372+
row_offset=row_offset,
2373+
snapshot_handle=None,
2374+
sorted_indices=id_tensor,
2375+
)
2376+
# TODO: make chunk_size configurable or dynamic
2377+
chunk_size = 10000
2378+
row = weight_state.size(0)
2379+
optimizer_dim = self.optimizer.state_size_dim(dtype)
2380+
opt_state_2d = opt_state.view(dtype).view(-1, optimizer_dim)
2381+
for i in range(0, row, chunk_size):
2382+
length = min(chunk_size, row - i)
2383+
chunk_buffer = torch.empty(
2384+
length,
2385+
self.cache_row_dim,
2386+
dtype=dtype,
2387+
device="cpu",
2388+
)
2389+
chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :]
2390+
chunk_buffer[:, D_rounded : D_rounded + optimizer_dim] = opt_state_2d[
2391+
i : i + length, :
2392+
]
2393+
kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1))
22412394

22422395
@torch.jit.ignore
22432396
def enable_load_state_dict_mode(self) -> None:
22442397
# Enable load state dict mode before loading checkpoint
2245-
pass
2398+
if self.load_state_dict:
2399+
return
2400+
self.load_state_dict = True
2401+
2402+
dtype = self.weights_precision.as_dtype()
2403+
self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
2404+
for i, (_, emb_dim) in enumerate(self.embedding_specs):
2405+
# for checkpointing loading, we need to store the weight and id tensor temporarily in memory
2406+
assert (
2407+
self.local_weight_counts[i] > 0
2408+
), f"local_weight_counts for table {i} is not set"
2409+
# pyre-ignore [16]
2410+
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
2411+
rows = self.local_weight_counts[i]
2412+
weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu")
2413+
opt_state = torch.empty(rows, dtype=torch.float32, device="cpu")
2414+
# pyre-ignore [16]
2415+
self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state)
2416+
# pyre-ignore [16]
2417+
self._cached_kvzch_data.cached_optimizer_state_per_table.append(opt_state)
2418+
logging.info(
2419+
f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}, opt_state shape is {opt_state.shape}"
2420+
)
2421+
id_tensor = torch.zeros(
2422+
(self.local_weight_counts[i], 1), dtype=torch.int64, device="cpu"
2423+
)
2424+
# pyre-ignore [16]
2425+
self._cached_kvzch_data.cached_id_tensor_per_table.append(id_tensor)
2426+
# pyre-ignore [16]
2427+
self._cached_kvzch_data.cached_bucket_splits.append(
2428+
torch.empty(
2429+
(bucket_id_end - bucket_id_start, 1),
2430+
dtype=torch.int64,
2431+
device="cpu",
2432+
)
2433+
)
22462434

22472435
@torch.jit.export
22482436
def set_learning_rate(self, lr: float) -> None:

0 commit comments

Comments
 (0)