Skip to content

support get state dict and apply state dict #4145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 205 additions & 18 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def __init__(
f"weights precision: {weights_precision}, "
f"output dtype: {output_dtype}, "
f"chunk size in bulk init: {bulk_init_chunk_size} bytes, backend_type: {backend_type}, "
f"zero_collision_config: {kv_zch_params}"
f"kv_zch_params: {kv_zch_params}"
)
self.register_buffer(
"lxu_cache_state",
Expand Down Expand Up @@ -1983,9 +1983,7 @@ def split_optimizer_states(
dtype=dtype,
row_offset=row_offset,
snapshot_handle=snapshot_handle,
materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]),
sorted_indices=sorted_id_tensor[t],
width_offset=pad4(emb_dim),
)
(
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
Expand Down Expand Up @@ -2086,6 +2084,7 @@ def _may_create_snapshot_for_state_dict(
"""
Create a rocksdb snapshot if needed.
"""
start_time = time.time()
# Force device synchronize for now
torch.cuda.synchronize()
snapshot_handle = None
Expand All @@ -2094,7 +2093,13 @@ def _may_create_snapshot_for_state_dict(
if not no_snapshot:
# Flush L1 and L2 caches
self.flush(force=should_flush)
logging.info(
f"flush latency for weight states: {(time.time() - start_time) * 1000} ms"
)
snapshot_handle = self.ssd_db.create_snapshot()
logging.info(
f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms"
)
elif self.backend_type == BackendType.DRAM:
self.flush(force=should_flush)
return snapshot_handle
Expand All @@ -2105,7 +2110,7 @@ def split_embedding_weights(
no_snapshot: bool = True,
should_flush: bool = False,
) -> Tuple[ # TODO: make this a NamedTuple for readability
List[PartiallyMaterializedTensor],
Union[List[PartiallyMaterializedTensor], List[torch.Tensor]],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
]:
Expand Down Expand Up @@ -2134,6 +2139,17 @@ def split_embedding_weights(
)

dtype = self.weights_precision.as_dtype()
if self.load_state_dict and self.kv_zch_params:
# init for checkpointing loading
assert (
self._cached_kvzch_data is not None
), "weight id and bucket state are not initialized for load checkpointing"
return (
self._cached_kvzch_data.cached_weight_tensor_per_table,
self._cached_kvzch_data.cached_id_tensor_per_table,
self._cached_kvzch_data.cached_bucket_splits,
)
start_time = time.time()
pmt_splits = []
bucket_sorted_id_splits = [] if self.kv_zch_params else None
active_id_cnt_per_bucket_split = [] if self.kv_zch_params else None
Expand All @@ -2142,18 +2158,15 @@ def split_embedding_weights(
for i, (emb_height, emb_dim) in enumerate(self.embedding_specs):
bucket_ascending_id_tensor = None
bucket_t = None
row_offset = table_offset
if self.kv_zch_params:
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
# pyre-ignore
bucket_size = self.kv_zch_params.bucket_sizes[i]

# linearize with table offset
table_input_id_start = (
min(bucket_id_start * bucket_size, emb_height) + table_offset
)
table_input_id_end = (
min(bucket_id_end * bucket_size, emb_height) + table_offset
)
table_input_id_start = table_offset
table_input_id_end = table_offset + emb_height
# 1. get all keys from backend for one table
unordered_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot(
table_input_id_start,
Expand All @@ -2166,15 +2179,38 @@ def split_embedding_weights(
torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor(
unordered_id_tensor,
0, # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
bucket_id_start,
bucket_id_end,
0, # local bucket offset
bucket_id_end - bucket_id_start, # local bucket num
bucket_size,
)
)
# pyre-ignore
# 3. convert local id back to global id
bucket_ascending_id_tensor.add_(bucket_id_start * bucket_size)

if (
bucket_ascending_id_tensor.size(0) == 0
and self.local_weight_counts[i] > 0
):
logging.info(
f"resetting bucket id tensor with {self.local_weight_counts[i]}"
)
bucket_ascending_id_tensor = torch.zeros(
(self.local_weight_counts[i], 1),
device=torch.device("cpu"),
dtype=torch.int64,
)
# self.local_weight_counts[i] = 0 # Reset the count

# pyre-ignore [16] bucket_sorted_id_splits is not None
bucket_sorted_id_splits.append(bucket_ascending_id_tensor)
active_id_cnt_per_bucket_split.append(bucket_t)

# for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
# but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
# first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
# to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
row_offset = table_offset - (bucket_id_start * bucket_size)

tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
shape=[
(
Expand All @@ -2185,33 +2221,184 @@ def split_embedding_weights(
pad4(emb_dim),
],
dtype=dtype,
row_offset=table_offset,
row_offset=row_offset,
snapshot_handle=snapshot_handle,
# set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
# embedding weights.
sorted_indices=(
bucket_ascending_id_tensor if self.kv_zch_params else None
),
)
# TODO add if else support in the future for dram integration.
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
(
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
if self.backend_type == BackendType.SSD
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
)
table_offset += emb_height
pmt_splits.append(
PartiallyMaterializedTensor(
tensor_wrapper,
True if self.kv_zch_params else False,
)
)
logging.info(
f"split_embedding_weights latency: {(time.time() - start_time) * 1000} ms"
)
return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)

@torch.jit.ignore
def apply_state_dict(self) -> None:
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
# Caller should call this function to apply the cached states to backend.
pass
if self.load_state_dict is False:
return
self.load_state_dict = False
assert self.kv_zch_params is not None, "apply_state_dict supports KV ZCH only"
assert (
self._cached_kvzch_data is not None
and self._cached_kvzch_data.cached_optimizer_state_per_table is not None
), "optimizer state is not initialized for load checkpointing"
assert (
self._cached_kvzch_data.cached_weight_tensor_per_table is not None
and self._cached_kvzch_data.cached_id_tensor_per_table is not None
), "weight and id state is not initialized for load checkpointing"

# Compute the number of elements of cache_dtype needed to store the
# optimizer state, round to the nearest 4
# optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
# apply weight and optimizer state per table
table_offset = 0
for i, (emb_height, _) in enumerate(self.embedding_specs):
# pyre-ignore [16]
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[i]
# pyre-ignore [16]
bucket_size = self.kv_zch_params.bucket_sizes[i]
row_offset = table_offset - bucket_id_start * bucket_size

if self.enable_optimizer_offloading:
# pyre-ignore [16]
weight_state = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
# pyre-ignore [16]
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
self.streaming_write_weight_and_id_per_table(
weight_state,
opt_state,
# pyre-ignore [16]
self._cached_kvzch_data.cached_id_tensor_per_table[i],
row_offset,
)
self._cached_kvzch_data.cached_weight_tensor_per_table[i] = None
self._cached_kvzch_data.cached_optimizer_state_per_table[i] = None
else:
weight = self._cached_kvzch_data.cached_weight_tensor_per_table[i]
id = self._cached_kvzch_data.cached_id_tensor_per_table[i]
local_id = id + row_offset
logging.info(
f"applying sd for table {i} without optimizer offloading, local_id is {local_id}"
)
opt_state = self._cached_kvzch_data.cached_optimizer_state_per_table[i]
t_device = self.momentum1_dev.device
self.momentum1_dev.index_put_(
indices=(
local_id.to(t_device).view(-1),
), # expects tuple of tensors
values=opt_state.to(t_device),
)
self.ssd_db.set_cuda(
local_id.view(-1),
weight,
torch.as_tensor(local_id.size(0)),
1,
False,
)
table_offset += emb_height
self.clear_cache()

@torch.jit.ignore
def streaming_write_weight_and_id_per_table(
self,
weight_state: torch.Tensor,
opt_state: torch.Tensor,
id_tensor: torch.Tensor,
row_offset: int,
) -> None:
"""
This function is used to write weight, optimizer and id to the backend using kvt wrapper.
to avoid over use memory, we will write the weight and id to backend in a rolling window manner

Args:
weight_state (torch.tensor): The weight state tensor to be written.
opt_state (torch.tensor): The optimizer state tensor to be written.
id_tensor (torch.tensor): The id tensor to be written.
"""
D_rounded = pad4(weight_state.size(1)) # padded to 4 bytes alignment
dtype = self.weights_precision.as_dtype()
kvt = torch.classes.fbgemm.KVTensorWrapper(
db=self.ssd_db,
shape=[weight_state.size(0), self.cache_row_dim],
dtype=dtype,
row_offset=row_offset,
snapshot_handle=None,
sorted_indices=id_tensor,
)
# TODO: make chunk_size configurable or dynamic
chunk_size = 10000
row = weight_state.size(0)
optimizer_dim = self.optimizer.state_size_dim(dtype)
opt_state_2d = opt_state.view(dtype).view(-1, optimizer_dim)
for i in range(0, row, chunk_size):
length = min(chunk_size, row - i)
chunk_buffer = torch.empty(
length,
self.cache_row_dim,
dtype=dtype,
device="cpu",
)
chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :]
chunk_buffer[:, D_rounded : D_rounded + optimizer_dim] = opt_state_2d[
i : i + length, :
]
kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1))

@torch.jit.ignore
def enable_load_state_dict_mode(self) -> None:
# Enable load state dict mode before loading checkpoint
pass
if self.load_state_dict:
return
self.load_state_dict = True

dtype = self.weights_precision.as_dtype()
self._cached_kvzch_data = KVZCHCachedData([], [], [], [])
for i, (_, emb_dim) in enumerate(self.embedding_specs):
# for checkpointing loading, we need to store the weight and id tensor temporarily in memory
assert (
self.local_weight_counts[i] > 0
), f"local_weight_counts for table {i} is not set"
# pyre-ignore [16]
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
rows = self.local_weight_counts[i]
weight_state = torch.empty(rows, emb_dim, dtype=dtype, device="cpu")
opt_state = torch.empty(rows, dtype=torch.float32, device="cpu")
# pyre-ignore [16]
self._cached_kvzch_data.cached_weight_tensor_per_table.append(weight_state)
# pyre-ignore [16]
self._cached_kvzch_data.cached_optimizer_state_per_table.append(opt_state)
logging.info(
f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}, opt_state shape is {opt_state.shape}"
)
id_tensor = torch.zeros(
(self.local_weight_counts[i], 1), dtype=torch.int64, device="cpu"
)
# pyre-ignore [16]
self._cached_kvzch_data.cached_id_tensor_per_table.append(id_tensor)
# pyre-ignore [16]
self._cached_kvzch_data.cached_bucket_splits.append(
torch.empty(
(bucket_id_end - bucket_id_start, 1),
dtype=torch.int64,
device="cpu",
)
)

@torch.jit.export
def set_learning_rate(self, lr: float) -> None:
Expand Down
Loading
Loading