From ecfd2cbad696fee23ac4d2951d363f5fec4fc07b Mon Sep 17 00:00:00 2001 From: Emma Lin Date: Wed, 21 May 2025 21:18:10 -0700 Subject: [PATCH] support get state dict and apply state dict (#4145) Summary: X-link: https://github.com/pytorch/torchrec/pull/2976 X-link: https://github.com/facebookresearch/FBGEMM/pull/1226 # Functions **Saving State Dict** When saving state dict, we convert IDs from local to global. This allows us to avoid shifting IDs based on sharding decisions when tables are resharded. **Checkpoint Loading Mode** We have enabled load state dict mode for checkpoint loading, which allows us to cache all ID, weight, and bucket tensors in memory before applying them to the backend. This approach ensures that all data is loaded correctly, even though the checkpoint client does not support ordered tensor loading. # Current Solution The current solution involves caching all data in Python tensors, following these steps: - Set self.local_weight_counts based on checkpoint bucket tensor size. - Enable load state dict mode to initialize local cache tensors. - Call state_dict to get empty tensors for the checkpoint loader. - Write checkpoint data to cached tensors from persisted checkpoint by the checkpoint loader. - Call apply_state_dict to write all cached tensors to the backend. **Apply State Dict Flow** During the apply_state_dict step, we perform the following operations: - If optimizer offloading is enabled: - Loop through chunks of weight and optimizer. - Concatenate weight and optimizer together. - Write to backend using KVTensorWrapper interface. - If optimizer offloading is disabled: - Set optimizer to device tensor based on ID. - Write ID weight to backend for each table. # Limitations The current solution has two limitations: - Memory overhead: - When writing data to the backend, the Python tensor's memory cannot be released until the whole tensor data is duplicated in the backend. This can lead to high memory usage, especially when dealing with single large tables. - Performance regression: - With optimizer offloading, we need to concatenate weight and optimizer together before writing to the backend. To avoid triple one large tensor's memory, we loop through smaller chunks during writing, which can cause performance regression. # Future Improvements After the first version e2e is ready, we plan to support unordered loading from the backend to improve performance and reduce memory overhead. Reviewed By: bobbyliujb Differential Revision: D74790154 --- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 223 ++++++++++++++++-- .../tbe/ssd/ssd_split_tbe_training_test.py | 212 +++++++++++++++++ 2 files changed, 417 insertions(+), 18 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index b1c64bc1b8..771c768415 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -313,7 +313,7 @@ def __init__( f"weights precision: {weights_precision}, " f"output dtype: {output_dtype}, " f"chunk size in bulk init: {bulk_init_chunk_size} bytes, backend_type: {backend_type}, " - f"zero_collision_config: {kv_zch_params}" + f"kv_zch_params: {kv_zch_params}" ) self.register_buffer( "lxu_cache_state", @@ -1983,9 +1983,7 @@ def split_optimizer_states( dtype=dtype, row_offset=row_offset, snapshot_handle=snapshot_handle, - materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]), sorted_indices=sorted_id_tensor[t], - width_offset=pad4(emb_dim), ) ( tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db) @@ -2086,6 +2084,7 @@ def _may_create_snapshot_for_state_dict( """ Create a rocksdb snapshot if needed. """ + start_time = time.time() # Force device synchronize for now torch.cuda.synchronize() snapshot_handle = None @@ -2094,7 +2093,13 @@ def _may_create_snapshot_for_state_dict( if not no_snapshot: # Flush L1 and L2 caches self.flush(force=should_flush) + logging.info( + f"flush latency for weight states: {(time.time() - start_time) * 1000} ms" + ) snapshot_handle = self.ssd_db.create_snapshot() + logging.info( + f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms" + ) elif self.backend_type == BackendType.DRAM: self.flush(force=should_flush) return snapshot_handle @@ -2105,7 +2110,7 @@ def split_embedding_weights( no_snapshot: bool = True, should_flush: bool = False, ) -> Tuple[ # TODO: make this a NamedTuple for readability - List[PartiallyMaterializedTensor], + Union[List[PartiallyMaterializedTensor], List[torch.Tensor]], Optional[List[torch.Tensor]], Optional[List[torch.Tensor]], ]: @@ -2134,6 +2139,17 @@ def split_embedding_weights( ) dtype = self.weights_precision.as_dtype() + if self.load_state_dict and self.kv_zch_params: + # init for checkpointing loading + assert ( + self._cached_kvzch_data is not None + ), "weight id and bucket state are not initialized for load checkpointing" + return ( + self._cached_kvzch_data.cached_weight_tensor_per_table, + self._cached_kvzch_data.cached_id_tensor_per_table, + self._cached_kvzch_data.cached_bucket_splits, + ) + start_time = time.time() pmt_splits = [] bucket_sorted_id_splits = [] if self.kv_zch_params else None active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None @@ -2142,18 +2158,15 @@ def split_embedding_weights( for i, (emb_height, emb_dim) in enumerate(self.embedding_specs): bucket_ascending_id_tensor = None bucket_t = None + row_offset = table_offset if self.kv_zch_params: bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i] # pyre-ignore bucket_size = self.kv_zch_params.bucket_sizes[i] # linearize with table offset - table_input_id_start = ( - min(bucket_id_start * bucket_size, emb_height) + table_offset - ) - table_input_id_end = ( - min(bucket_id_end * bucket_size, emb_height) + table_offset - ) + table_input_id_start = table_offset + table_input_id_end = table_offset + emb_height # 1. get all keys from backend for one table unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot( table_input_id_start, @@ -2166,15 +2179,38 @@ def split_embedding_weights( torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor( unordered_id_tensor, 0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing - bucket_id_start, - bucket_id_end, + 0, # local bucket offset + bucket_id_end - bucket_id_start, # local bucket num bucket_size, ) ) - # pyre-ignore + # 3. convert local id back to global id + bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size) + + if ( + bucket_ascending_id_tensor.size(0) == 0 + and self.local_weight_counts[i] > 0 + ): + logging.info( + f"resetting bucket id tensor with {self.local_weight_counts[i]}" + ) + bucket_ascending_id_tensor = torch.zeros( + (self.local_weight_counts[i], 1), + device=torch.device("cpu"), + dtype=torch.int64, + ) + # self.local_weight_counts[i] = 0 # Reset the count + + # pyre-ignore [16] bucket_sorted_id_splits is not None bucket_sorted_id_splits.append(bucket_ascending_id_tensor) active_id_cnt_per_bucket_split.append(bucket_t) + # for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing + # but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id + # first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset + # to achieve this, the row_offset will be set to (table_offset - local_shard_offset) + row_offset = table_offset - (bucket_id_start * bucket_size) + tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( shape=[ ( @@ -2185,14 +2221,19 @@ def split_embedding_weights( pad4(emb_dim), ], dtype=dtype, - row_offset=table_offset, + row_offset=row_offset, snapshot_handle=snapshot_handle, + # set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return + # embedding weights. sorted_indices=( bucket_ascending_id_tensor if self.kv_zch_params else None ), ) - # TODO add if else support in the future for dram integration. - tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db) + ( + tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db) + if self.backend_type == BackendType.SSD + else tensor_wrapper.set_dram_db_wrapper(self.ssd_db) + ) table_offset += emb_height pmt_splits.append( PartiallyMaterializedTensor( @@ -2200,18 +2241,164 @@ def split_embedding_weights( True if self.kv_zch_params else False, ) ) + logging.info( + f"split_embedding_weights latency: {(time.time() - start_time) * 1000} ms" + ) return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split) @torch.jit.ignore def apply_state_dict(self) -> None: # After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint. # Caller should call this function to apply the cached states to backend. - pass + if self.load_state_dict is False: + return + self.load_state_dict = False + assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only" + assert ( + self._cached_kvzch_data is not None + and self._cached_kvzch_data.cached_optimizer_state_per_table is not None + ), "optimizer state is not initialized for load checkpointing" + assert ( + self._cached_kvzch_data.cached_weight_tensor_per_table is not None + and self._cached_kvzch_data.cached_id_tensor_per_table is not None + ), "weight and id state is not initialized for load checkpointing" + + # Compute the number of elements of cache_dtype needed to store the + # optimizer state, round to the nearest 4 + # optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype) + # apply weight and optimizer state per table + table_offset = 0 + for i, (emb_height, _) in enumerate(self.embedding_specs): + # pyre-ignore [16] + bucket_id_start, _ = self.kv_zch_params.bucket_offsets[i] + # pyre-ignore [16] + bucket_size = self.kv_zch_params.bucket_sizes[i] + row_offset = table_offset - bucket_id_start * bucket_size + + if self.enable_optimizer_offloading: + # pyre-ignore [16] + weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[i] + # pyre-ignore [16] + opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i] + self.streaming_write_weight_and_id_per_table( + weight_state, + opt_state, + # pyre-ignore [16] + self._cached_kvzch_data.cached_id_tensor_per_table[i], + row_offset, + ) + self._cached_kvzch_data.cached_weight_tensor_per_table[i] = None + self._cached_kvzch_data.cached_optimizer_state_per_table[i] = None + else: + weight = self._cached_kvzch_data.cached_weight_tensor_per_table[i] + id = self._cached_kvzch_data.cached_id_tensor_per_table[i] + local_id = id + row_offset + logging.info( + f"applying sd for table {i} without optimizer offloading, local_id is {local_id}" + ) + opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i] + t_device = self.momentum1_dev.device + self.momentum1_dev.index_put_( + indices=( + local_id.to(t_device).view(-1), + ), # expects tuple of tensors + values=opt_state.to(t_device), + ) + self.ssd_db.set_cuda( + local_id.view(-1), + weight, + torch.as_tensor(local_id.size(0)), + 1, + False, + ) + table_offset += emb_height + self.clear_cache() + + @torch.jit.ignore + def streaming_write_weight_and_id_per_table( + self, + weight_state: torch.Tensor, + opt_state: torch.Tensor, + id_tensor: torch.Tensor, + row_offset: int, + ) -> None: + """ + This function is used to write weight, optimizer and id to the backend using kvt wrapper. + to avoid over use memory, we will write the weight and id to backend in a rolling window manner + + Args: + weight_state (torch.tensor): The weight state tensor to be written. + opt_state (torch.tensor): The optimizer state tensor to be written. + id_tensor (torch.tensor): The id tensor to be written. + """ + D_rounded = pad4(weight_state.size(1)) # padded to 4 bytes alignment + dtype = self.weights_precision.as_dtype() + kvt = torch.classes.fbgemm.KVTensorWrapper( + db=self.ssd_db, + shape=[weight_state.size(0), self.cache_row_dim], + dtype=dtype, + row_offset=row_offset, + snapshot_handle=None, + sorted_indices=id_tensor, + ) + # TODO: make chunk_size configurable or dynamic + chunk_size = 10000 + row = weight_state.size(0) + optimizer_dim = self.optimizer.state_size_dim(dtype) + opt_state_2d = opt_state.view(dtype).view(-1, optimizer_dim) + for i in range(0, row, chunk_size): + length = min(chunk_size, row - i) + chunk_buffer = torch.empty( + length, + self.cache_row_dim, + dtype=dtype, + device="cpu", + ) + chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :] + chunk_buffer[:, D_rounded : D_rounded + optimizer_dim] = opt_state_2d[ + i : i + length, : + ] + kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1)) @torch.jit.ignore def enable_load_state_dict_mode(self) -> None: # Enable load state dict mode before loading checkpoint - pass + if self.load_state_dict: + return + self.load_state_dict = True + + dtype = self.weights_precision.as_dtype() + self._cached_kvzch_data = KVZCHCachedData([], [], [], []) + for i, (_, emb_dim) in enumerate(self.embedding_specs): + # for checkpointing loading, we need to store the weight and id tensor temporarily in memory + assert ( + self.local_weight_counts[i] > 0 + ), f"local_weight_counts for table {i} is not set" + # pyre-ignore [16] + bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i] + rows = self.local_weight_counts[i] + weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu") + opt_state = torch.empty(rows, dtype=torch.float32, device="cpu") + # pyre-ignore [16] + self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state) + # pyre-ignore [16] + self._cached_kvzch_data.cached_optimizer_state_per_table.append(opt_state) + logging.info( + f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}, opt_state shape is {opt_state.shape}" + ) + id_tensor = torch.zeros( + (self.local_weight_counts[i], 1), dtype=torch.int64, device="cpu" + ) + # pyre-ignore [16] + self._cached_kvzch_data.cached_id_tensor_per_table.append(id_tensor) + # pyre-ignore [16] + self._cached_kvzch_data.cached_bucket_splits.append( + torch.empty( + (bucket_id_end - bucket_id_start, 1), + dtype=torch.int64, + device="cpu", + ) + ) @torch.jit.export def set_learning_rate(self, lr: float) -> None: diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index c3522acaaf..476906f0e5 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -8,6 +8,7 @@ # pyre-ignore-all-errors[3,6,56] import math +import tempfile import unittest from enum import Enum @@ -2007,3 +2008,214 @@ def test_kv_opt_state_w_offloading( atol=tolerance, rtol=tolerance, ) + + @given( + **default_st, + num_buckets=st.integers(min_value=10, max_value=15), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_apply_kv_state_dict( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, + weights_precision: SparseType, + output_dtype: SparseType, + share_table: bool, + trigger_bounds_check: bool, + mixed_B: bool, + num_buckets: int, + ) -> None: + # Constants + lr = 0.5 + eps = 0.2 + ssd_shards = 2 + + trigger_bounds_check = False # don't stimulate boundary check cases + assume(not weighted or pooling_mode == PoolingMode.SUM) + assume(not mixed_B or pooling_mode != PoolingMode.NONE) + + # TODO: check split_optimizer_states when optimizer offloading is ready + # Generate embedding modules and inputs + ( + emb, + emb_ref, + Es, + _, + bucket_offsets, + bucket_sizes, + ) = self.generate_kvzch_tbes( + T, + D, + B, + log_E, + L, + weighted, + lr=lr, + eps=eps, + ssd_shards=ssd_shards, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, + weights_precision=weights_precision, + output_dtype=output_dtype, + share_table=share_table, + num_buckets=num_buckets, + enable_optimizer_offloading=False, + ) + + # Generate inputs + ( + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank, + ) = self.generate_inputs_( + B, + L, + Es, + emb.feature_table_map, + weights_precision=weights_precision, + trigger_bounds_check=trigger_bounds_check, + mixed_B=mixed_B, + bucket_offsets=bucket_offsets, + bucket_sizes=bucket_sizes, + is_kv_tbes=True, + ) + + # Execute forward + output_ref_list, output = self.execute_ssd_forward_( + emb, + emb_ref, + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + B, + L, + weighted, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + + # Generate output gradient + output_grad_list = [torch.randn_like(out) for out in output_ref_list] + + # Execute torch EmbeddingBag backward + [out.backward(grad) for (out, grad) in zip(output_ref_list, output_grad_list)] + if batch_size_per_feature_per_rank is not None: + grad_test = self.concat_ref_tensors_vbe( + output_grad_list, batch_size_per_feature_per_rank + ) + else: + grad_test = self.concat_ref_tensors( + output_grad_list, + pooling_mode != PoolingMode.NONE, # do_pooling + B, + D * 4, + ) + + # Execute TBE SSD backward + output.backward(grad_test) + + tolerance = ( + 1.0e-4 + if weights_precision == SparseType.FP32 and output_dtype == SparseType.FP32 + else 1.0e-2 + ) + + emb.flush() + + # Compare emb state dict with expected values from nn.EmbeddingBag + emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list = ( + emb.split_embedding_weights(no_snapshot=False, should_flush=True) + ) + split_optimizer_states = emb.split_optimizer_states(bucket_asc_ids_list) + + # create an empty emb with same parameters + # Construct feature_table_map + + cache_sets = max(int(max(T * B * L, 1) * cache_set_scale), 1) + emb2 = SSDTableBatchedEmbeddingBags( + embedding_specs=emb.embedding_specs, + feature_table_map=emb.feature_table_map, + ssd_storage_directory=tempfile.mkdtemp(), + cache_sets=cache_sets, + ssd_uniform_init_lower=-0.1, + ssd_uniform_init_upper=0.1, + learning_rate=lr, + eps=eps, + ssd_rocksdb_shards=ssd_shards, + optimizer=OptimType.EXACT_ROWWISE_ADAGRAD, + pooling_mode=pooling_mode, + weights_precision=weights_precision, + output_dtype=output_dtype, + stochastic_rounding=True, + prefetch_pipeline=False, + bounds_check_mode=BoundsCheckMode.WARNING, + l2_cache_size=8, + backend_type=BackendType.SSD, + kv_zch_params=emb.kv_zch_params, + ).cuda() + + emb2.local_weight_counts = [ids.numel() for ids in bucket_asc_ids_list] + emb2.enable_load_state_dict_mode() + self.assertIsNotNone(emb2._cached_kvzch_data) + for i in range(len(emb.embedding_specs)): + # pyre-ignore [16] + emb2._cached_kvzch_data.cached_weight_tensor_per_table[i].copy_( + emb_state_dict_list[i].full_tensor() + ) + # pyre-ignore [16] + emb2._cached_kvzch_data.cached_optimizer_state_per_table[i].copy_( + split_optimizer_states[i] + ) + # pyre-ignore [16] + emb2._cached_kvzch_data.cached_id_tensor_per_table[i].copy_( + bucket_asc_ids_list[i] + ) + # pyre-ignore [16] + emb2._cached_kvzch_data.cached_bucket_splits[i].copy_( + num_active_id_per_bucket_list[i] + ) + + emb2.apply_state_dict() + + emb2.flush(True) + # Compare emb state dict with expected values from nn.EmbeddingBag + emb_state_dict_list2, bucket_asc_ids_list2, num_active_id_per_bucket_list2 = ( + emb2.split_embedding_weights(no_snapshot=False, should_flush=True) + ) + split_optimizer_states2 = emb2.split_optimizer_states(bucket_asc_ids_list2) + + for t in range(len(emb.embedding_specs)): + torch.testing.assert_close( + split_optimizer_states[t], + split_optimizer_states2[t], + atol=tolerance, + rtol=tolerance, + ) + torch.testing.assert_close( + emb_state_dict_list[t].full_tensor(), + emb_state_dict_list2[t].full_tensor(), + atol=tolerance, + rtol=tolerance, + ) + torch.testing.assert_close( + bucket_asc_ids_list[t], + bucket_asc_ids_list2[t], + atol=tolerance, + rtol=tolerance, + ) + torch.testing.assert_close( + num_active_id_per_bucket_list[t], + num_active_id_per_bucket_list2[t], + atol=tolerance, + rtol=tolerance, + )