Skip to content

Commit 1b04531

Browse files
q10facebook-github-bot
authored andcommitted
Handle fetch optimizer states for the KV ZCH load state dict case
Summary: X-link: facebookresearch/FBGEMM#1561 This diff updates `KVZCHCachedData` to hold multiple optimizer states per table in cached_optimizer_states_per_table, and updates apply_state_dict to handle writing out multiple optimizer states per table row to the cache. This is needed for enabling other optimizers to work with SSD TBE, such as Partial Rowwise Adam. There are 4 cases to handle when attempting to fetch the split optimizer states: 1. The no-KV ZCH case 1. The KV ZCH case, but where `self.load_state_dict` is `True` (i.e. fall back to `self._cached_kvzch_data`) 1. The KV ZCH case, where `self.load_state_dict` is `False`, and `self.enable_optimizer_offloading` is false 1. The KV ZCH case, where `self.load_state_dict` is `False`, and `self.enable_optimizer_offloading` is `True` The diff completes the handling of returning optimizer states for the KV ZCH case, but where `self.load_state_dict` is true (case 2). Reviewed By: emlin Differential Revision: D77771359
1 parent 8594beb commit 1b04531

File tree

3 files changed

+167
-78
lines changed

3 files changed

+167
-78
lines changed

fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,54 @@ def byte_offsets_along_row(
150150
else:
151151
return {}
152152

153+
def empty_states(
154+
self,
155+
embedding_specs: List[Tuple[int, int]], # Tuple of (rows, dims)
156+
optimizer_state_dtypes: Dict[str, "SparseType"] = {}, # noqa: B006
157+
local_row_counts: Optional[List[int]] = None,
158+
) -> List[List[torch.Tensor]]:
159+
"""
160+
Creates sets of empty tensors per table to hold optimizer states based
161+
on the specified optimizer type, state dtypes, embedding specs, and
162+
(optionally) local row counts.
163+
"""
164+
if local_row_counts is None:
165+
# If local_row_counts is not specified, then we assume that the
166+
# local row count for each table is the same as the global row count
167+
(local_row_counts, _) = zip(*embedding_specs)
168+
else:
169+
# Else, check that the local row count for each table is set
170+
assert len(local_row_counts) == len(embedding_specs)
171+
for i, r in enumerate(local_row_counts):
172+
assert r > 0, f"local_row_counts for table {i} is not set"
173+
174+
opt_states_set: List[List[torch.Tensor]] = []
175+
176+
for i, (_, D) in enumerate(embedding_specs):
177+
# Get the local row count for this table
178+
r = local_row_counts[i]
179+
180+
# Set up the table of state names to state sizes, ordered by their
181+
# memory layout
182+
state_size_table = self.state_size_table(D)
183+
ordered_state_sizes = [(k, state_size_table[k]) for k in self.state_names()]
184+
185+
# Create the optimizer states for this table
186+
opt_states = [
187+
torch.empty(
188+
# If the state size is 1, then fix tensor to 1D to be
189+
# consistent with training.py code
190+
(r, d) if d > 1 else r,
191+
dtype=self._extract_dtype(optimizer_state_dtypes, state_name),
192+
device="cpu",
193+
)
194+
for state_name, d in ordered_state_sizes
195+
]
196+
197+
opt_states_set.append(opt_states)
198+
199+
return opt_states_set
200+
153201
def ssd_state_splits(
154202
self,
155203
embedding_specs: List[Tuple[int, int]], # Tuple of (rows, dims)

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 117 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class IterData:
7676

7777
@dataclass
7878
class KVZCHCachedData:
79-
cached_optimizer_state_per_table: List[torch.Tensor]
79+
cached_optimizer_states_per_table: List[List[torch.Tensor]]
8080
cached_weight_tensor_per_table: List[torch.Tensor]
8181
cached_id_tensor_per_table: List[torch.Tensor]
8282
cached_bucket_splits: List[torch.Tensor]
@@ -2393,18 +2393,10 @@ def split_optimizer_states(
23932393
# init for checkpointing loading
23942394
assert (
23952395
self._cached_kvzch_data is not None
2396-
and self._cached_kvzch_data.cached_optimizer_state_per_table
2396+
and self._cached_kvzch_data.cached_optimizer_states_per_table
23972397
), "optimizer state is not initialized for load checkpointing"
23982398

2399-
# NOTE: This is a temporary hack to have split_optimizer_states return a
2400-
# List[List[Tensor]] instead of List[Tensor] to match the behavior of
2401-
# _split_optimizer_states_non_kv_zch. This should be removed after
2402-
# proper support for multiple optimizers is added for the
2403-
# enable_optimizer_offloading=True case.
2404-
return [
2405-
[opt]
2406-
for opt in self._cached_kvzch_data.cached_optimizer_state_per_table
2407-
]
2399+
return self._cached_kvzch_data.cached_optimizer_states_per_table
24082400

24092401
logging.info(
24102402
f"split_optimizer_states for KV ZCH: {no_snapshot=}, {should_flush=}"
@@ -2559,25 +2551,15 @@ def get_optimizer_state(
25592551
should_flush: bool = False,
25602552
) -> List[Dict[str, torch.Tensor]]:
25612553
"""
2562-
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD
2563-
so only momentum1 state is returned.
2554+
Returns a list of dictionaries of optimizer states split by table.
25642555
"""
2565-
states_list = self.split_optimizer_states(
2556+
states_list: List[List[Tensor]] = self.split_optimizer_states(
25662557
sorted_id_tensor=sorted_id_tensor,
25672558
no_snapshot=no_snapshot,
25682559
should_flush=should_flush,
25692560
)
2570-
2571-
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
2572-
keys = ["momentum1"]
2573-
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
2574-
keys = ["momentum1", "momentum2"]
2575-
else:
2576-
raise NotImplementedError(
2577-
f"Getting optimizer states is not supported for {self.optimizer}"
2578-
)
2579-
2580-
return [dict(zip(keys, states)) for states in states_list]
2561+
state_names = self.optimizer.state_names()
2562+
return [dict(zip(state_names, states)) for states in states_list]
25812563

25822564
@torch.jit.export
25832565
def debug_split_embedding_weights(self) -> List[torch.Tensor]:
@@ -2829,6 +2811,94 @@ def split_embedding_weights(
28292811

28302812
return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)
28312813

2814+
@torch.jit.ignore
2815+
def _apply_state_dict_w_offloading(self) -> None:
2816+
# Row count per table
2817+
(rows, _) = zip(*self.embedding_specs)
2818+
# Cumulative row counts per table for rowwise states
2819+
row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows))
2820+
2821+
for t, _ in enumerate(self.embedding_specs):
2822+
# pyre-ignore [16]
2823+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
2824+
# pyre-ignore [16]
2825+
bucket_size = self.kv_zch_params.bucket_sizes[t]
2826+
row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size
2827+
2828+
# pyre-ignore [16]
2829+
weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[t]
2830+
# pyre-ignore [16]
2831+
opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t]
2832+
2833+
self.streaming_write_weight_and_id_per_table(
2834+
weight_state,
2835+
opt_states,
2836+
# pyre-ignore [16]
2837+
self._cached_kvzch_data.cached_id_tensor_per_table[t],
2838+
row_offset,
2839+
)
2840+
self._cached_kvzch_data.cached_weight_tensor_per_table[t] = None
2841+
self._cached_kvzch_data.cached_optimizer_states_per_table[t] = None
2842+
2843+
@torch.jit.ignore
2844+
def _apply_state_dict_no_offloading(self) -> None:
2845+
# Row count per table
2846+
(rows, _) = zip(*self.embedding_specs)
2847+
# Cumulative row counts per table for rowwise states
2848+
row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows))
2849+
2850+
def copy_optimizer_state_(dst: Tensor, src: Tensor, indices: Tensor) -> None:
2851+
device = dst.device
2852+
dst.index_put_(
2853+
indices=(
2854+
# indices is expected to be a tuple of Tensors, not Tensor
2855+
indices.to(device).view(-1),
2856+
),
2857+
values=src.to(device),
2858+
)
2859+
2860+
for t, _ in enumerate(rows):
2861+
# pyre-ignore [16]
2862+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
2863+
# pyre-ignore [16]
2864+
bucket_size = self.kv_zch_params.bucket_sizes[t]
2865+
row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size
2866+
2867+
# pyre-ignore [16]
2868+
weights = self._cached_kvzch_data.cached_weight_tensor_per_table[t]
2869+
# pyre-ignore [16]
2870+
ids = self._cached_kvzch_data.cached_id_tensor_per_table[t]
2871+
local_ids = ids + row_offset
2872+
2873+
logging.info(
2874+
f"applying sd for table {t} without optimizer offloading, local_ids is {local_ids}"
2875+
)
2876+
# pyre-ignore [16]
2877+
opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t]
2878+
2879+
# Set up the plan for copying optimizer states over
2880+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
2881+
mapping = [(opt_states[0], self.momentum1_dev)]
2882+
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
2883+
mapping = [
2884+
(opt_states[0], self.momentum1_dev),
2885+
(opt_states[1], self.momentum2_dev),
2886+
]
2887+
else:
2888+
mapping = []
2889+
2890+
# Execute the plan and copy the optimizer states over
2891+
# pyre-ignore [6]
2892+
[copy_optimizer_state_(dst, src, local_ids) for (src, dst) in mapping]
2893+
2894+
self.ssd_db.set_cuda(
2895+
local_ids.view(-1),
2896+
weights,
2897+
torch.as_tensor(local_ids.size(0)),
2898+
1,
2899+
False,
2900+
)
2901+
28322902
@torch.jit.ignore
28332903
def apply_state_dict(self) -> None:
28342904
if self.backend_return_whole_row:
@@ -2844,7 +2914,7 @@ def apply_state_dict(self) -> None:
28442914
assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only"
28452915
assert (
28462916
self._cached_kvzch_data is not None
2847-
and self._cached_kvzch_data.cached_optimizer_state_per_table is not None
2917+
and self._cached_kvzch_data.cached_optimizer_states_per_table is not None
28482918
), "optimizer state is not initialized for load checkpointing"
28492919
assert (
28502920
self._cached_kvzch_data.cached_weight_tensor_per_table is not None
@@ -2855,51 +2925,11 @@ def apply_state_dict(self) -> None:
28552925
# optimizer state, round to the nearest 4
28562926
# optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
28572927
# apply weight and optimizer state per table
2858-
table_offset = 0
2859-
for i, (emb_height, _) in enumerate(self.embedding_specs):
2860-
# pyre-ignore [16]
2861-
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[i]
2862-
# pyre-ignore [16]
2863-
bucket_size = self.kv_zch_params.bucket_sizes[i]
2864-
row_offset = table_offset - bucket_id_start * bucket_size
2928+
if self.enable_optimizer_offloading:
2929+
self._apply_state_dict_w_offloading()
2930+
else:
2931+
self._apply_state_dict_no_offloading()
28652932

2866-
if self.enable_optimizer_offloading:
2867-
# pyre-ignore [16]
2868-
weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
2869-
# pyre-ignore [16]
2870-
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
2871-
self.streaming_write_weight_and_id_per_table(
2872-
weight_state,
2873-
[opt_state],
2874-
# pyre-ignore [16]
2875-
self._cached_kvzch_data.cached_id_tensor_per_table[i],
2876-
row_offset,
2877-
)
2878-
self._cached_kvzch_data.cached_weight_tensor_per_table[i] = None
2879-
self._cached_kvzch_data.cached_optimizer_state_per_table[i] = None
2880-
else:
2881-
weight = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
2882-
id = self._cached_kvzch_data.cached_id_tensor_per_table[i]
2883-
local_id = id + row_offset
2884-
logging.info(
2885-
f"applying sd for table {i} without optimizer offloading, local_id is {local_id}"
2886-
)
2887-
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
2888-
t_device = self.momentum1_dev.device
2889-
self.momentum1_dev.index_put_(
2890-
indices=(
2891-
local_id.to(t_device).view(-1),
2892-
), # expects tuple of tensors
2893-
values=opt_state.to(t_device),
2894-
)
2895-
self.ssd_db.set_cuda(
2896-
local_id.view(-1),
2897-
weight,
2898-
torch.as_tensor(local_id.size(0)),
2899-
1,
2900-
False,
2901-
)
2902-
table_offset += emb_height
29032933
self.clear_cache()
29042934

29052935
@torch.jit.ignore
@@ -3001,22 +3031,33 @@ def enable_load_state_dict_mode(self) -> None:
30013031

30023032
dtype = self.weights_precision.as_dtype()
30033033
self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
3004-
for i, (_, emb_dim) in enumerate(self.embedding_specs):
3005-
# for checkpointing loading, we need to store the weight and id tensor temporarily in memory
3034+
3035+
for i, _ in enumerate(self.embedding_specs):
3036+
# For checkpointing loading, we need to store the weight and id
3037+
# tensor temporarily in memory. First check that the local_weight_counts
3038+
# are properly set before even initializing the optimizer states
30063039
assert (
30073040
self.local_weight_counts[i] > 0
30083041
), f"local_weight_counts for table {i} is not set"
3042+
3043+
# pyre-ignore [16]
3044+
self._cached_kvzch_data.cached_optimizer_states_per_table = (
3045+
self.optimizer.empty_states(
3046+
self.embedding_specs,
3047+
self.optimizer_state_dtypes,
3048+
self.local_weight_counts,
3049+
)
3050+
)
3051+
3052+
for i, (_, emb_dim) in enumerate(self.embedding_specs):
30093053
# pyre-ignore [16]
30103054
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
30113055
rows = self.local_weight_counts[i]
30123056
weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu")
3013-
opt_state = torch.empty(rows, dtype=torch.float32, device="cpu")
30143057
# pyre-ignore [16]
30153058
self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state)
3016-
# pyre-ignore [16]
3017-
self._cached_kvzch_data.cached_optimizer_state_per_table.append(opt_state)
30183059
logging.info(
3019-
f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}, opt_state shape is {opt_state.shape}"
3060+
f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}"
30203061
)
30213062
id_tensor = torch.zeros((rows, 1), dtype=torch.int64, device="cpu")
30223063
# pyre-ignore [16]

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2477,7 +2477,7 @@ def test_apply_kv_state_dict(
24772477
emb2.local_weight_counts = [ids.numel() for ids in bucket_asc_ids_list]
24782478
emb2.enable_load_state_dict_mode()
24792479
self.assertIsNotNone(emb2._cached_kvzch_data)
2480-
for i in range(len(emb.embedding_specs)):
2480+
for i, _ in enumerate(emb.embedding_specs):
24812481
# pyre-ignore [16]
24822482
emb2._cached_kvzch_data.cached_weight_tensor_per_table[i].copy_(
24832483
# pyre-fixme[16]: Undefined attribute: Item `torch._tensor.Tensor` of `typing.Uni...
@@ -2487,7 +2487,7 @@ def test_apply_kv_state_dict(
24872487
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
24882488
# be upgraded in the future to support multiple optimizers
24892489
# pyre-ignore [16]
2490-
emb2._cached_kvzch_data.cached_optimizer_state_per_table[i].copy_(
2490+
emb2._cached_kvzch_data.cached_optimizer_states_per_table[i][0].copy_(
24912491
split_optimizer_states[i][0]
24922492
)
24932493
# pyre-ignore [16]

0 commit comments

Comments
 (0)