Skip to content

Handle fetch optimizer states for the KV ZCH load state dict case #4512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,54 @@ def byte_offsets_along_row(
else:
return {}

def empty_states(
self,
embedding_specs: List[Tuple[int, int]], # Tuple of (rows, dims)
optimizer_state_dtypes: Dict[str, "SparseType"] = {}, # noqa: B006
local_row_counts: Optional[List[int]] = None,
) -> List[List[torch.Tensor]]:
"""
Creates sets of empty tensors per table to hold optimizer states based
on the specified optimizer type, state dtypes, embedding specs, and
(optionally) local row counts.
"""
if local_row_counts is None:
# If local_row_counts is not specified, then we assume that the
# local row count for each table is the same as the global row count
(local_row_counts, _) = zip(*embedding_specs)
else:
# Else, check that the local row count for each table is set
assert len(local_row_counts) == len(embedding_specs)
for i, r in enumerate(local_row_counts):
assert r > 0, f"local_row_counts for table {i} is not set"

opt_states_set: List[List[torch.Tensor]] = []

for i, (_, D) in enumerate(embedding_specs):
# Get the local row count for this table
r = local_row_counts[i]

# Set up the table of state names to state sizes, ordered by their
# memory layout
state_size_table = self.state_size_table(D)
ordered_state_sizes = [(k, state_size_table[k]) for k in self.state_names()]

# Create the optimizer states for this table
opt_states = [
torch.empty(
# If the state size is 1, then fix tensor to 1D to be
# consistent with training.py code
(r, d) if d > 1 else r,
dtype=self._extract_dtype(optimizer_state_dtypes, state_name),
device="cpu",
)
for state_name, d in ordered_state_sizes
]

opt_states_set.append(opt_states)

return opt_states_set

def ssd_state_splits(
self,
embedding_specs: List[Tuple[int, int]], # Tuple of (rows, dims)
Expand Down
193 changes: 117 additions & 76 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class IterData:

@dataclass
class KVZCHCachedData:
cached_optimizer_state_per_table: List[torch.Tensor]
cached_optimizer_states_per_table: List[List[torch.Tensor]]
cached_weight_tensor_per_table: List[torch.Tensor]
cached_id_tensor_per_table: List[torch.Tensor]
cached_bucket_splits: List[torch.Tensor]
Expand Down Expand Up @@ -2393,18 +2393,10 @@ def split_optimizer_states(
# init for checkpointing loading
assert (
self._cached_kvzch_data is not None
and self._cached_kvzch_data.cached_optimizer_state_per_table
and self._cached_kvzch_data.cached_optimizer_states_per_table
), "optimizer state is not initialized for load checkpointing"

# NOTE: This is a temporary hack to have split_optimizer_states return a
# List[List[Tensor]] instead of List[Tensor] to match the behavior of
# _split_optimizer_states_non_kv_zch. This should be removed after
# proper support for multiple optimizers is added for the
# enable_optimizer_offloading=True case.
return [
[opt]
for opt in self._cached_kvzch_data.cached_optimizer_state_per_table
]
return self._cached_kvzch_data.cached_optimizer_states_per_table

logging.info(
f"split_optimizer_states for KV ZCH: {no_snapshot=}, {should_flush=}"
Expand Down Expand Up @@ -2559,25 +2551,15 @@ def get_optimizer_state(
should_flush: bool = False,
) -> List[Dict[str, torch.Tensor]]:
"""
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD
so only momentum1 state is returned.
Returns a list of dictionaries of optimizer states split by table.
"""
states_list = self.split_optimizer_states(
states_list: List[List[Tensor]] = self.split_optimizer_states(
sorted_id_tensor=sorted_id_tensor,
no_snapshot=no_snapshot,
should_flush=should_flush,
)

if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
keys = ["momentum1"]
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
keys = ["momentum1", "momentum2"]
else:
raise NotImplementedError(
f"Getting optimizer states is not supported for {self.optimizer}"
)

return [dict(zip(keys, states)) for states in states_list]
state_names = self.optimizer.state_names()
return [dict(zip(state_names, states)) for states in states_list]

@torch.jit.export
def debug_split_embedding_weights(self) -> List[torch.Tensor]:
Expand Down Expand Up @@ -2829,6 +2811,94 @@ def split_embedding_weights(

return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)

@torch.jit.ignore
def _apply_state_dict_w_offloading(self) -> None:
# Row count per table
(rows, _) = zip(*self.embedding_specs)
# Cumulative row counts per table for rowwise states
row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows))

for t, _ in enumerate(self.embedding_specs):
# pyre-ignore [16]
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
# pyre-ignore [16]
bucket_size = self.kv_zch_params.bucket_sizes[t]
row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size

# pyre-ignore [16]
weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[t]
# pyre-ignore [16]
opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t]

self.streaming_write_weight_and_id_per_table(
weight_state,
opt_states,
# pyre-ignore [16]
self._cached_kvzch_data.cached_id_tensor_per_table[t],
row_offset,
)
self._cached_kvzch_data.cached_weight_tensor_per_table[t] = None
self._cached_kvzch_data.cached_optimizer_states_per_table[t] = None

@torch.jit.ignore
def _apply_state_dict_no_offloading(self) -> None:
# Row count per table
(rows, _) = zip(*self.embedding_specs)
# Cumulative row counts per table for rowwise states
row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows))

def copy_optimizer_state_(dst: Tensor, src: Tensor, indices: Tensor) -> None:
device = dst.device
dst.index_put_(
indices=(
# indices is expected to be a tuple of Tensors, not Tensor
indices.to(device).view(-1),
),
values=src.to(device),
)

for t, _ in enumerate(rows):
# pyre-ignore [16]
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
# pyre-ignore [16]
bucket_size = self.kv_zch_params.bucket_sizes[t]
row_offset = row_count_cumsum[t] - bucket_id_start * bucket_size

# pyre-ignore [16]
weights = self._cached_kvzch_data.cached_weight_tensor_per_table[t]
# pyre-ignore [16]
ids = self._cached_kvzch_data.cached_id_tensor_per_table[t]
local_ids = ids + row_offset

logging.info(
f"applying sd for table {t} without optimizer offloading, local_ids is {local_ids}"
)
# pyre-ignore [16]
opt_states = self._cached_kvzch_data.cached_optimizer_states_per_table[t]

# Set up the plan for copying optimizer states over
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
mapping = [(opt_states[0], self.momentum1_dev)]
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
mapping = [
(opt_states[0], self.momentum1_dev),
(opt_states[1], self.momentum2_dev),
]
else:
mapping = []

# Execute the plan and copy the optimizer states over
# pyre-ignore [6]
[copy_optimizer_state_(dst, src, local_ids) for (src, dst) in mapping]

self.ssd_db.set_cuda(
local_ids.view(-1),
weights,
torch.as_tensor(local_ids.size(0)),
1,
False,
)

@torch.jit.ignore
def apply_state_dict(self) -> None:
if self.backend_return_whole_row:
Expand All @@ -2844,7 +2914,7 @@ def apply_state_dict(self) -> None:
assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only"
assert (
self._cached_kvzch_data is not None
and self._cached_kvzch_data.cached_optimizer_state_per_table is not None
and self._cached_kvzch_data.cached_optimizer_states_per_table is not None
), "optimizer state is not initialized for load checkpointing"
assert (
self._cached_kvzch_data.cached_weight_tensor_per_table is not None
Expand All @@ -2855,51 +2925,11 @@ def apply_state_dict(self) -> None:
# optimizer state, round to the nearest 4
# optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
# apply weight and optimizer state per table
table_offset = 0
for i, (emb_height, _) in enumerate(self.embedding_specs):
# pyre-ignore [16]
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[i]
# pyre-ignore [16]
bucket_size = self.kv_zch_params.bucket_sizes[i]
row_offset = table_offset - bucket_id_start * bucket_size
if self.enable_optimizer_offloading:
self._apply_state_dict_w_offloading()
else:
self._apply_state_dict_no_offloading()

if self.enable_optimizer_offloading:
# pyre-ignore [16]
weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
# pyre-ignore [16]
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
self.streaming_write_weight_and_id_per_table(
weight_state,
[opt_state],
# pyre-ignore [16]
self._cached_kvzch_data.cached_id_tensor_per_table[i],
row_offset,
)
self._cached_kvzch_data.cached_weight_tensor_per_table[i] = None
self._cached_kvzch_data.cached_optimizer_state_per_table[i] = None
else:
weight = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
id = self._cached_kvzch_data.cached_id_tensor_per_table[i]
local_id = id + row_offset
logging.info(
f"applying sd for table {i} without optimizer offloading, local_id is {local_id}"
)
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
t_device = self.momentum1_dev.device
self.momentum1_dev.index_put_(
indices=(
local_id.to(t_device).view(-1),
), # expects tuple of tensors
values=opt_state.to(t_device),
)
self.ssd_db.set_cuda(
local_id.view(-1),
weight,
torch.as_tensor(local_id.size(0)),
1,
False,
)
table_offset += emb_height
self.clear_cache()

@torch.jit.ignore
Expand Down Expand Up @@ -3001,22 +3031,33 @@ def enable_load_state_dict_mode(self) -> None:

dtype = self.weights_precision.as_dtype()
self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
for i, (_, emb_dim) in enumerate(self.embedding_specs):
# for checkpointing loading, we need to store the weight and id tensor temporarily in memory

for i, _ in enumerate(self.embedding_specs):
# For checkpointing loading, we need to store the weight and id
# tensor temporarily in memory. First check that the local_weight_counts
# are properly set before even initializing the optimizer states
assert (
self.local_weight_counts[i] > 0
), f"local_weight_counts for table {i} is not set"

# pyre-ignore [16]
self._cached_kvzch_data.cached_optimizer_states_per_table = (
self.optimizer.empty_states(
self.embedding_specs,
self.optimizer_state_dtypes,
self.local_weight_counts,
)
)

for i, (_, emb_dim) in enumerate(self.embedding_specs):
# pyre-ignore [16]
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
rows = self.local_weight_counts[i]
weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu")
opt_state = torch.empty(rows, dtype=torch.float32, device="cpu")
# pyre-ignore [16]
self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state)
# pyre-ignore [16]
self._cached_kvzch_data.cached_optimizer_state_per_table.append(opt_state)
logging.info(
f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}, opt_state shape is {opt_state.shape}"
f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}"
)
id_tensor = torch.zeros((rows, 1), dtype=torch.int64, device="cpu")
# pyre-ignore [16]
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2477,7 +2477,7 @@ def test_apply_kv_state_dict(
emb2.local_weight_counts = [ids.numel() for ids in bucket_asc_ids_list]
emb2.enable_load_state_dict_mode()
self.assertIsNotNone(emb2._cached_kvzch_data)
for i in range(len(emb.embedding_specs)):
for i, _ in enumerate(emb.embedding_specs):
# pyre-ignore [16]
emb2._cached_kvzch_data.cached_weight_tensor_per_table[i].copy_(
# pyre-fixme[16]: Undefined attribute: Item `torch._tensor.Tensor` of `typing.Uni...
Expand All @@ -2487,7 +2487,7 @@ def test_apply_kv_state_dict(
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
# be upgraded in the future to support multiple optimizers
# pyre-ignore [16]
emb2._cached_kvzch_data.cached_optimizer_state_per_table[i].copy_(
emb2._cached_kvzch_data.cached_optimizer_states_per_table[i][0].copy_(
split_optimizer_states[i][0]
)
# pyre-ignore [16]
Expand Down
Loading