diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py index d86df572c1..8b7d155bbf 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py @@ -150,6 +150,54 @@ def byte_offsets_along_row( else: return {} + def empty_states( + self, + embedding_specs: List[Tuple[int, int]], # Tuple of (rows, dims) + optimizer_state_dtypes: Dict[str, "SparseType"] = {}, # noqa: B006 + local_row_counts: Optional[List[int]] = None, + ) -> List[List[torch.Tensor]]: + """ + Creates sets of empty tensors per table to hold optimizer states based + on the specified optimizer type, state dtypes, embedding specs, and + (optionally) local row counts. + """ + if local_row_counts is None: + # If local_row_counts is not specified, then we assume that the + # local row count for each table is the same as the global row count + (local_row_counts, _) = zip(*embedding_specs) + else: + # Else, check that the local row count for each table is set + assert len(local_row_counts) == len(embedding_specs) + for i, r in enumerate(local_row_counts): + assert r > 0, f"local_row_counts for table {i} is not set" + + opt_states_set: List[List[torch.Tensor]] = [] + + for i, (_, D) in enumerate(embedding_specs): + # Get the local row count for this table + r = local_row_counts[i] + + # Set up the table of state names to state sizes, ordered by their + # memory layout + state_size_table = self.state_size_table(D) + ordered_state_sizes = [(k, state_size_table[k]) for k in self.state_names()] + + # Create the optimizer states for this table + opt_states = [ + torch.empty( + # If the state size is 1, then fix tensor to 1D to be + # consistent with training.py code + (r, d) if d > 1 else r, + dtype=self._extract_dtype(optimizer_state_dtypes, state_name), + device="cpu", + ) + for state_name, d in ordered_state_sizes + ] + + opt_states_set.append(opt_states) + + return opt_states_set + def ssd_state_splits( self, embedding_specs: List[Tuple[int, int]], # Tuple of (rows, dims) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index efd82088e8..f909c458db 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -76,7 +76,7 @@ class IterData: @dataclass class KVZCHCachedData: - cached_optimizer_state_per_table: List[torch.Tensor] + cached_optimizer_states_per_table: List[List[torch.Tensor]] cached_weight_tensor_per_table: List[torch.Tensor] cached_id_tensor_per_table: List[torch.Tensor] cached_bucket_splits: List[torch.Tensor] @@ -2393,18 +2393,10 @@ def split_optimizer_states( # init for checkpointing loading assert ( self._cached_kvzch_data is not None - and self._cached_kvzch_data.cached_optimizer_state_per_table + and self._cached_kvzch_data.cached_optimizer_states_per_table ), "optimizer state is not initialized for load checkpointing" - # NOTE: This is a temporary hack to have split_optimizer_states return a - # List[List[Tensor]] instead of List[Tensor] to match the behavior of - # _split_optimizer_states_non_kv_zch. This should be removed after - # proper support for multiple optimizers is added for the - # enable_optimizer_offloading=True case. - return [ - [opt] - for opt in self._cached_kvzch_data.cached_optimizer_state_per_table - ] + return self._cached_kvzch_data.cached_optimizer_states_per_table logging.info( f"split_optimizer_states for KV ZCH: {no_snapshot=}, {should_flush=}" @@ -2559,25 +2551,15 @@ def get_optimizer_state( should_flush: bool = False, ) -> List[Dict[str, torch.Tensor]]: """ - Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD - so only momentum1 state is returned. + Returns a list of dictionaries of optimizer states split by table. """ - states_list = self.split_optimizer_states( + states_list: List[List[Tensor]] = self.split_optimizer_states( sorted_id_tensor=sorted_id_tensor, no_snapshot=no_snapshot, should_flush=should_flush, ) - - if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD: - keys = ["momentum1"] - elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM: - keys = ["momentum1", "momentum2"] - else: - raise NotImplementedError( - f"Getting optimizer states is not supported for {self.optimizer}" - ) - - return [dict(zip(keys, states)) for states in states_list] + state_names = self.optimizer.state_names() + return [dict(zip(state_names, states)) for states in states_list] @torch.jit.export def debug_split_embedding_weights(self) -> List[torch.Tensor]: @@ -2829,6 +2811,94 @@ def split_embedding_weights( return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split) + @torch.jit.ignore + def _apply_state_dict_w_offloading(self) -> None: + # Row count per table + (rows, _) = zip(*self.embedding_specs) + # Cumulative row counts per table for rowwise states + row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows)) + + for t, _ in enumerate(self.embedding_specs): + # pyre-ignore [16] + bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t] + # pyre-ignore [16] + bucket_size = self.kv_zch_params.bucket_sizes[t] + row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size + + # pyre-ignore [16] + weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[t] + # pyre-ignore [16] + opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t] + + self.streaming_write_weight_and_id_per_table( + weight_state, + opt_states, + # pyre-ignore [16] + self._cached_kvzch_data.cached_id_tensor_per_table[t], + row_offset, + ) + self._cached_kvzch_data.cached_weight_tensor_per_table[t] = None + self._cached_kvzch_data.cached_optimizer_states_per_table[t] = None + + @torch.jit.ignore + def _apply_state_dict_no_offloading(self) -> None: + # Row count per table + (rows, _) = zip(*self.embedding_specs) + # Cumulative row counts per table for rowwise states + row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows)) + + def copy_optimizer_state_(dst: Tensor, src: Tensor, indices: Tensor) -> None: + device = dst.device + dst.index_put_( + indices=( + # indices is expected to be a tuple of Tensors, not Tensor + indices.to(device).view(-1), + ), + values=src.to(device), + ) + + for t, _ in enumerate(rows): + # pyre-ignore [16] + bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t] + # pyre-ignore [16] + bucket_size = self.kv_zch_params.bucket_sizes[t] + row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size + + # pyre-ignore [16] + weights = self._cached_kvzch_data.cached_weight_tensor_per_table[t] + # pyre-ignore [16] + ids = self._cached_kvzch_data.cached_id_tensor_per_table[t] + local_ids = ids + row_offset + + logging.info( + f"applying sd for table {t} without optimizer offloading, local_ids is {local_ids}" + ) + # pyre-ignore [16] + opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t] + + # Set up the plan for copying optimizer states over + if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD: + mapping = [(opt_states[0], self.momentum1_dev)] + elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM: + mapping = [ + (opt_states[0], self.momentum1_dev), + (opt_states[1], self.momentum2_dev), + ] + else: + mapping = [] + + # Execute the plan and copy the optimizer states over + # pyre-ignore [6] + [copy_optimizer_state_(dst, src, local_ids) for (src, dst) in mapping] + + self.ssd_db.set_cuda( + local_ids.view(-1), + weights, + torch.as_tensor(local_ids.size(0)), + 1, + False, + ) + @torch.jit.ignore def apply_state_dict(self) -> None: if self.backend_return_whole_row: @@ -2844,7 +2914,7 @@ def apply_state_dict(self) -> None: assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only" assert ( self._cached_kvzch_data is not None - and self._cached_kvzch_data.cached_optimizer_state_per_table is not None + and self._cached_kvzch_data.cached_optimizer_states_per_table is not None ), "optimizer state is not initialized for load checkpointing" assert ( self._cached_kvzch_data.cached_weight_tensor_per_table is not None @@ -2855,51 +2925,11 @@ def apply_state_dict(self) -> None: # optimizer state, round to the nearest 4 # optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype) # apply weight and optimizer state per table - table_offset = 0 - for i, (emb_height, _) in enumerate(self.embedding_specs): - # pyre-ignore [16] - bucket_id_start, _ = self.kv_zch_params.bucket_offsets[i] - # pyre-ignore [16] - bucket_size = self.kv_zch_params.bucket_sizes[i] - row_offset = table_offset - bucket_id_start * bucket_size + if self.enable_optimizer_offloading: + self._apply_state_dict_w_offloading() + else: + self._apply_state_dict_no_offloading() - if self.enable_optimizer_offloading: - # pyre-ignore [16] - weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[i] - # pyre-ignore [16] - opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i] - self.streaming_write_weight_and_id_per_table( - weight_state, - [opt_state], - # pyre-ignore [16] - self._cached_kvzch_data.cached_id_tensor_per_table[i], - row_offset, - ) - self._cached_kvzch_data.cached_weight_tensor_per_table[i] = None - self._cached_kvzch_data.cached_optimizer_state_per_table[i] = None - else: - weight = self._cached_kvzch_data.cached_weight_tensor_per_table[i] - id = self._cached_kvzch_data.cached_id_tensor_per_table[i] - local_id = id + row_offset - logging.info( - f"applying sd for table {i} without optimizer offloading, local_id is {local_id}" - ) - opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i] - t_device = self.momentum1_dev.device - self.momentum1_dev.index_put_( - indices=( - local_id.to(t_device).view(-1), - ), # expects tuple of tensors - values=opt_state.to(t_device), - ) - self.ssd_db.set_cuda( - local_id.view(-1), - weight, - torch.as_tensor(local_id.size(0)), - 1, - False, - ) - table_offset += emb_height self.clear_cache() @torch.jit.ignore @@ -3001,22 +3031,33 @@ def enable_load_state_dict_mode(self) -> None: dtype = self.weights_precision.as_dtype() self._cached_kvzch_data = KVZCHCachedData([], [], [], []) - for i, (_, emb_dim) in enumerate(self.embedding_specs): - # for checkpointing loading, we need to store the weight and id tensor temporarily in memory + + for i, _ in enumerate(self.embedding_specs): + # For checkpointing loading, we need to store the weight and id + # tensor temporarily in memory. First check that the local_weight_counts + # are properly set before even initializing the optimizer states assert ( self.local_weight_counts[i] > 0 ), f"local_weight_counts for table {i} is not set" + + # pyre-ignore [16] + self._cached_kvzch_data.cached_optimizer_states_per_table = ( + self.optimizer.empty_states( + self.embedding_specs, + self.optimizer_state_dtypes, + self.local_weight_counts, + ) + ) + + for i, (_, emb_dim) in enumerate(self.embedding_specs): # pyre-ignore [16] bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i] rows = self.local_weight_counts[i] weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu") - opt_state = torch.empty(rows, dtype=torch.float32, device="cpu") # pyre-ignore [16] self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state) - # pyre-ignore [16] - self._cached_kvzch_data.cached_optimizer_state_per_table.append(opt_state) logging.info( - f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}, opt_state shape is {opt_state.shape}" + f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}" ) id_tensor = torch.zeros((rows, 1), dtype=torch.int64, device="cpu") # pyre-ignore [16] diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index 07117d74aa..f2a892d359 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -2477,7 +2477,7 @@ def test_apply_kv_state_dict( emb2.local_weight_counts = [ids.numel() for ids in bucket_asc_ids_list] emb2.enable_load_state_dict_mode() self.assertIsNotNone(emb2._cached_kvzch_data) - for i in range(len(emb.embedding_specs)): + for i, _ in enumerate(emb.embedding_specs): # pyre-ignore [16] emb2._cached_kvzch_data.cached_weight_tensor_per_table[i].copy_( # pyre-fixme[16]: Undefined attribute: Item `torch._tensor.Tensor` of `typing.Uni... @@ -2487,7 +2487,7 @@ def test_apply_kv_state_dict( # EXACT_ROWWISE_ADAGRAD optimizer. The test in general should # be upgraded in the future to support multiple optimizers # pyre-ignore [16] - emb2._cached_kvzch_data.cached_optimizer_state_per_table[i].copy_( + emb2._cached_kvzch_data.cached_optimizer_states_per_table[i][0].copy_( split_optimizer_states[i][0] ) # pyre-ignore [16]