diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index b1c64bc1b8..771c768415 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -313,7 +313,7 @@ def __init__( f"weights precision: {weights_precision}, " f"output dtype: {output_dtype}, " f"chunk size in bulk init: {bulk_init_chunk_size} bytes, backend_type: {backend_type}, " - f"zero_collision_config: {kv_zch_params}" + f"kv_zch_params: {kv_zch_params}" ) self.register_buffer( "lxu_cache_state", @@ -1983,9 +1983,7 @@ def split_optimizer_states( dtype=dtype, row_offset=row_offset, snapshot_handle=snapshot_handle, - materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]), sorted_indices=sorted_id_tensor[t], - width_offset=pad4(emb_dim), ) ( tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db) @@ -2086,6 +2084,7 @@ def _may_create_snapshot_for_state_dict( """ Create a rocksdb snapshot if needed. """ + start_time = time.time() # Force device synchronize for now torch.cuda.synchronize() snapshot_handle = None @@ -2094,7 +2093,13 @@ def _may_create_snapshot_for_state_dict( if not no_snapshot: # Flush L1 and L2 caches self.flush(force=should_flush) + logging.info( + f"flush latency for weight states: {(time.time() - start_time) * 1000} ms" + ) snapshot_handle = self.ssd_db.create_snapshot() + logging.info( + f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms" + ) elif self.backend_type == BackendType.DRAM: self.flush(force=should_flush) return snapshot_handle @@ -2105,7 +2110,7 @@ def split_embedding_weights( no_snapshot: bool = True, should_flush: bool = False, ) -> Tuple[ # TODO: make this a NamedTuple for readability - List[PartiallyMaterializedTensor], + Union[List[PartiallyMaterializedTensor], List[torch.Tensor]], Optional[List[torch.Tensor]], Optional[List[torch.Tensor]], ]: @@ -2134,6 +2139,17 @@ def split_embedding_weights( ) dtype = self.weights_precision.as_dtype() + if self.load_state_dict and self.kv_zch_params: + # init for checkpointing loading + assert ( + self._cached_kvzch_data is not None + ), "weight id and bucket state are not initialized for load checkpointing" + return ( + self._cached_kvzch_data.cached_weight_tensor_per_table, + self._cached_kvzch_data.cached_id_tensor_per_table, + self._cached_kvzch_data.cached_bucket_splits, + ) + start_time = time.time() pmt_splits = [] bucket_sorted_id_splits = [] if self.kv_zch_params else None active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None @@ -2142,18 +2158,15 @@ def split_embedding_weights( for i, (emb_height, emb_dim) in enumerate(self.embedding_specs): bucket_ascending_id_tensor = None bucket_t = None + row_offset = table_offset if self.kv_zch_params: bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i] # pyre-ignore bucket_size = self.kv_zch_params.bucket_sizes[i] # linearize with table offset - table_input_id_start = ( - min(bucket_id_start * bucket_size, emb_height) + table_offset - ) - table_input_id_end = ( - min(bucket_id_end * bucket_size, emb_height) + table_offset - ) + table_input_id_start = table_offset + table_input_id_end = table_offset + emb_height # 1. get all keys from backend for one table unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot( table_input_id_start, @@ -2166,15 +2179,38 @@ def split_embedding_weights( torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor( unordered_id_tensor, 0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing - bucket_id_start, - bucket_id_end, + 0, # local bucket offset + bucket_id_end - bucket_id_start, # local bucket num bucket_size, ) ) - # pyre-ignore + # 3. convert local id back to global id + bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size) + + if ( + bucket_ascending_id_tensor.size(0) == 0 + and self.local_weight_counts[i] > 0 + ): + logging.info( + f"resetting bucket id tensor with {self.local_weight_counts[i]}" + ) + bucket_ascending_id_tensor = torch.zeros( + (self.local_weight_counts[i], 1), + device=torch.device("cpu"), + dtype=torch.int64, + ) + # self.local_weight_counts[i] = 0 # Reset the count + + # pyre-ignore [16] bucket_sorted_id_splits is not None bucket_sorted_id_splits.append(bucket_ascending_id_tensor) active_id_cnt_per_bucket_split.append(bucket_t) + # for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing + # but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id + # first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset + # to achieve this, the row_offset will be set to (table_offset - local_shard_offset) + row_offset = table_offset - (bucket_id_start * bucket_size) + tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( shape=[ ( @@ -2185,14 +2221,19 @@ def split_embedding_weights( pad4(emb_dim), ], dtype=dtype, - row_offset=table_offset, + row_offset=row_offset, snapshot_handle=snapshot_handle, + # set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return + # embedding weights. sorted_indices=( bucket_ascending_id_tensor if self.kv_zch_params else None ), ) - # TODO add if else support in the future for dram integration. - tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db) + ( + tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db) + if self.backend_type == BackendType.SSD + else tensor_wrapper.set_dram_db_wrapper(self.ssd_db) + ) table_offset += emb_height pmt_splits.append( PartiallyMaterializedTensor( @@ -2200,18 +2241,164 @@ def split_embedding_weights( True if self.kv_zch_params else False, ) ) + logging.info( + f"split_embedding_weights latency: {(time.time() - start_time) * 1000} ms" + ) return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split) @torch.jit.ignore def apply_state_dict(self) -> None: # After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint. # Caller should call this function to apply the cached states to backend. - pass + if self.load_state_dict is False: + return + self.load_state_dict = False + assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only" + assert ( + self._cached_kvzch_data is not None + and self._cached_kvzch_data.cached_optimizer_state_per_table is not None + ), "optimizer state is not initialized for load checkpointing" + assert ( + self._cached_kvzch_data.cached_weight_tensor_per_table is not None + and self._cached_kvzch_data.cached_id_tensor_per_table is not None + ), "weight and id state is not initialized for load checkpointing" + + # Compute the number of elements of cache_dtype needed to store the + # optimizer state, round to the nearest 4 + # optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype) + # apply weight and optimizer state per table + table_offset = 0 + for i, (emb_height, _) in enumerate(self.embedding_specs): + # pyre-ignore [16] + bucket_id_start, _ = self.kv_zch_params.bucket_offsets[i] + # pyre-ignore [16] + bucket_size = self.kv_zch_params.bucket_sizes[i] + row_offset = table_offset - bucket_id_start * bucket_size + + if self.enable_optimizer_offloading: + # pyre-ignore [16] + weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[i] + # pyre-ignore [16] + opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i] + self.streaming_write_weight_and_id_per_table( + weight_state, + opt_state, + # pyre-ignore [16] + self._cached_kvzch_data.cached_id_tensor_per_table[i], + row_offset, + ) + self._cached_kvzch_data.cached_weight_tensor_per_table[i] = None + self._cached_kvzch_data.cached_optimizer_state_per_table[i] = None + else: + weight = self._cached_kvzch_data.cached_weight_tensor_per_table[i] + id = self._cached_kvzch_data.cached_id_tensor_per_table[i] + local_id = id + row_offset + logging.info( + f"applying sd for table {i} without optimizer offloading, local_id is {local_id}" + ) + opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i] + t_device = self.momentum1_dev.device + self.momentum1_dev.index_put_( + indices=( + local_id.to(t_device).view(-1), + ), # expects tuple of tensors + values=opt_state.to(t_device), + ) + self.ssd_db.set_cuda( + local_id.view(-1), + weight, + torch.as_tensor(local_id.size(0)), + 1, + False, + ) + table_offset += emb_height + self.clear_cache() + + @torch.jit.ignore + def streaming_write_weight_and_id_per_table( + self, + weight_state: torch.Tensor, + opt_state: torch.Tensor, + id_tensor: torch.Tensor, + row_offset: int, + ) -> None: + """ + This function is used to write weight, optimizer and id to the backend using kvt wrapper. + to avoid over use memory, we will write the weight and id to backend in a rolling window manner + + Args: + weight_state (torch.tensor): The weight state tensor to be written. + opt_state (torch.tensor): The optimizer state tensor to be written. + id_tensor (torch.tensor): The id tensor to be written. + """ + D_rounded = pad4(weight_state.size(1)) # padded to 4 bytes alignment + dtype = self.weights_precision.as_dtype() + kvt = torch.classes.fbgemm.KVTensorWrapper( + db=self.ssd_db, + shape=[weight_state.size(0), self.cache_row_dim], + dtype=dtype, + row_offset=row_offset, + snapshot_handle=None, + sorted_indices=id_tensor, + ) + # TODO: make chunk_size configurable or dynamic + chunk_size = 10000 + row = weight_state.size(0) + optimizer_dim = self.optimizer.state_size_dim(dtype) + opt_state_2d = opt_state.view(dtype).view(-1, optimizer_dim) + for i in range(0, row, chunk_size): + length = min(chunk_size, row - i) + chunk_buffer = torch.empty( + length, + self.cache_row_dim, + dtype=dtype, + device="cpu", + ) + chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :] + chunk_buffer[:, D_rounded : D_rounded + optimizer_dim] = opt_state_2d[ + i : i + length, : + ] + kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1)) @torch.jit.ignore def enable_load_state_dict_mode(self) -> None: # Enable load state dict mode before loading checkpoint - pass + if self.load_state_dict: + return + self.load_state_dict = True + + dtype = self.weights_precision.as_dtype() + self._cached_kvzch_data = KVZCHCachedData([], [], [], []) + for i, (_, emb_dim) in enumerate(self.embedding_specs): + # for checkpointing loading, we need to store the weight and id tensor temporarily in memory + assert ( + self.local_weight_counts[i] > 0 + ), f"local_weight_counts for table {i} is not set" + # pyre-ignore [16] + bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i] + rows = self.local_weight_counts[i] + weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu") + opt_state = torch.empty(rows, dtype=torch.float32, device="cpu") + # pyre-ignore [16] + self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state) + # pyre-ignore [16] + self._cached_kvzch_data.cached_optimizer_state_per_table.append(opt_state) + logging.info( + f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}, opt_state shape is {opt_state.shape}" + ) + id_tensor = torch.zeros( + (self.local_weight_counts[i], 1), dtype=torch.int64, device="cpu" + ) + # pyre-ignore [16] + self._cached_kvzch_data.cached_id_tensor_per_table.append(id_tensor) + # pyre-ignore [16] + self._cached_kvzch_data.cached_bucket_splits.append( + torch.empty( + (bucket_id_end - bucket_id_start, 1), + dtype=torch.int64, + device="cpu", + ) + ) @torch.jit.export def set_learning_rate(self, lr: float) -> None: diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index c3522acaaf..476906f0e5 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -8,6 +8,7 @@ # pyre-ignore-all-errors[3,6,56] import math +import tempfile import unittest from enum import Enum @@ -2007,3 +2008,214 @@ def test_kv_opt_state_w_offloading( atol=tolerance, rtol=tolerance, ) + + @given( + **default_st, + num_buckets=st.integers(min_value=10, max_value=15), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_apply_kv_state_dict( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, + weights_precision: SparseType, + output_dtype: SparseType, + share_table: bool, + trigger_bounds_check: bool, + mixed_B: bool, + num_buckets: int, + ) -> None: + # Constants + lr = 0.5 + eps = 0.2 + ssd_shards = 2 + + trigger_bounds_check = False # don't stimulate boundary check cases + assume(not weighted or pooling_mode == PoolingMode.SUM) + assume(not mixed_B or pooling_mode != PoolingMode.NONE) + + # TODO: check split_optimizer_states when optimizer offloading is ready + # Generate embedding modules and inputs + ( + emb, + emb_ref, + Es, + _, + bucket_offsets, + bucket_sizes, + ) = self.generate_kvzch_tbes( + T, + D, + B, + log_E, + L, + weighted, + lr=lr, + eps=eps, + ssd_shards=ssd_shards, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, + weights_precision=weights_precision, + output_dtype=output_dtype, + share_table=share_table, + num_buckets=num_buckets, + enable_optimizer_offloading=False, + ) + + # Generate inputs + ( + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank, + ) = self.generate_inputs_( + B, + L, + Es, + emb.feature_table_map, + weights_precision=weights_precision, + trigger_bounds_check=trigger_bounds_check, + mixed_B=mixed_B, + bucket_offsets=bucket_offsets, + bucket_sizes=bucket_sizes, + is_kv_tbes=True, + ) + + # Execute forward + output_ref_list, output = self.execute_ssd_forward_( + emb, + emb_ref, + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + B, + L, + weighted, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + + # Generate output gradient + output_grad_list = [torch.randn_like(out) for out in output_ref_list] + + # Execute torch EmbeddingBag backward + [out.backward(grad) for (out, grad) in zip(output_ref_list, output_grad_list)] + if batch_size_per_feature_per_rank is not None: + grad_test = self.concat_ref_tensors_vbe( + output_grad_list, batch_size_per_feature_per_rank + ) + else: + grad_test = self.concat_ref_tensors( + output_grad_list, + pooling_mode != PoolingMode.NONE, # do_pooling + B, + D * 4, + ) + + # Execute TBE SSD backward + output.backward(grad_test) + + tolerance = ( + 1.0e-4 + if weights_precision == SparseType.FP32 and output_dtype == SparseType.FP32 + else 1.0e-2 + ) + + emb.flush() + + # Compare emb state dict with expected values from nn.EmbeddingBag + emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list = ( + emb.split_embedding_weights(no_snapshot=False, should_flush=True) + ) + split_optimizer_states = emb.split_optimizer_states(bucket_asc_ids_list) + + # create an empty emb with same parameters + # Construct feature_table_map + + cache_sets = max(int(max(T * B * L, 1) * cache_set_scale), 1) + emb2 = SSDTableBatchedEmbeddingBags( + embedding_specs=emb.embedding_specs, + feature_table_map=emb.feature_table_map, + ssd_storage_directory=tempfile.mkdtemp(), + cache_sets=cache_sets, + ssd_uniform_init_lower=-0.1, + ssd_uniform_init_upper=0.1, + learning_rate=lr, + eps=eps, + ssd_rocksdb_shards=ssd_shards, + optimizer=OptimType.EXACT_ROWWISE_ADAGRAD, + pooling_mode=pooling_mode, + weights_precision=weights_precision, + output_dtype=output_dtype, + stochastic_rounding=True, + prefetch_pipeline=False, + bounds_check_mode=BoundsCheckMode.WARNING, + l2_cache_size=8, + backend_type=BackendType.SSD, + kv_zch_params=emb.kv_zch_params, + ).cuda() + + emb2.local_weight_counts = [ids.numel() for ids in bucket_asc_ids_list] + emb2.enable_load_state_dict_mode() + self.assertIsNotNone(emb2._cached_kvzch_data) + for i in range(len(emb.embedding_specs)): + # pyre-ignore [16] + emb2._cached_kvzch_data.cached_weight_tensor_per_table[i].copy_( + emb_state_dict_list[i].full_tensor() + ) + # pyre-ignore [16] + emb2._cached_kvzch_data.cached_optimizer_state_per_table[i].copy_( + split_optimizer_states[i] + ) + # pyre-ignore [16] + emb2._cached_kvzch_data.cached_id_tensor_per_table[i].copy_( + bucket_asc_ids_list[i] + ) + # pyre-ignore [16] + emb2._cached_kvzch_data.cached_bucket_splits[i].copy_( + num_active_id_per_bucket_list[i] + ) + + emb2.apply_state_dict() + + emb2.flush(True) + # Compare emb state dict with expected values from nn.EmbeddingBag + emb_state_dict_list2, bucket_asc_ids_list2, num_active_id_per_bucket_list2 = ( + emb2.split_embedding_weights(no_snapshot=False, should_flush=True) + ) + split_optimizer_states2 = emb2.split_optimizer_states(bucket_asc_ids_list2) + + for t in range(len(emb.embedding_specs)): + torch.testing.assert_close( + split_optimizer_states[t], + split_optimizer_states2[t], + atol=tolerance, + rtol=tolerance, + ) + torch.testing.assert_close( + emb_state_dict_list[t].full_tensor(), + emb_state_dict_list2[t].full_tensor(), + atol=tolerance, + rtol=tolerance, + ) + torch.testing.assert_close( + bucket_asc_ids_list[t], + bucket_asc_ids_list2[t], + atol=tolerance, + rtol=tolerance, + ) + torch.testing.assert_close( + num_active_id_per_bucket_list[t], + num_active_id_per_bucket_list2[t], + atol=tolerance, + rtol=tolerance, + )