Skip to content

Commit b71f498

Browse files
emlinfacebook-github-bot
authored andcommitted
implement optimizer state with opt offloading (#4141)
Summary: X-link: facebookresearch/FBGEMM#1224 implement split_optimizer_states for optimizer state dict integration Reviewed By: bobbyliujb Differential Revision: D74790121
1 parent cbc7dbb commit b71f498

File tree

3 files changed

+422
-11
lines changed

3 files changed

+422
-11
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class KVZCHParams(NamedTuple):
5858
bucket_sizes: List[int] = []
5959
# enable optimizer offloading or not
6060
enable_optimizer_offloading: bool = True
61+
# streaming load/save checkpoint chunk size
62+
streaming_ckpt_chunk_size: int = 1000000
6163

6264
def validate(self) -> None:
6365
assert len(self.bucket_offsets) == len(self.bucket_sizes), (

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 168 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ class IterData:
7878

7979
@dataclass
8080
class KVZCHCachedData:
81-
cached_id_tensor_per_table: List[torch.Tensor]
82-
cached_weight_tensor_per_table: List[torch.Tensor]
8381
cached_optimizer_state_per_table: List[torch.Tensor]
82+
cached_weight_tensor_per_table: List[torch.Tensor]
83+
cached_id_tensor_per_table: List[torch.Tensor]
8484
cached_bucket_splits: List[torch.Tensor]
8585

8686

@@ -175,11 +175,14 @@ def __init__(
175175
) -> None:
176176
super(SSDTableBatchedEmbeddingBags, self).__init__()
177177

178+
# Set the optimizer
178179
assert optimizer in (
179180
OptimType.EXACT_ROWWISE_ADAGRAD,
180181
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
181182
self.optimizer = optimizer
183+
self.optimizer_dtype: torch.dtype = torch.float32
182184

185+
# Set the table weight and output dtypes
183186
assert weights_precision in (SparseType.FP32, SparseType.FP16)
184187
self.weights_precision = weights_precision
185188
self.output_dtype: int = output_dtype.as_int()
@@ -702,7 +705,9 @@ def __init__(
702705
momentum1_offsets = [0] + list(itertools.accumulate(rows))
703706
self._apply_split(
704707
SplitState(
705-
dev_size=self.total_hash_size,
708+
dev_size=(
709+
self.total_hash_size if not self.enable_optimizer_offloading else 0
710+
),
706711
host_size=0,
707712
uvm_size=0,
708713
placements=[EmbeddingLocation.DEVICE for _ in range(T_)],
@@ -1720,6 +1725,7 @@ def forward(
17201725
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
17211726
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
17221727
) -> Tensor:
1728+
self.clear_cache()
17231729
indices, offsets, per_sample_weights, vbe_metadata = self.prepare_inputs(
17241730
indices, offsets, per_sample_weights, batch_size_per_feature_per_rank
17251731
)
@@ -1881,6 +1887,8 @@ def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor, int, int]]:
18811887
def split_optimizer_states(
18821888
self,
18831889
sorted_id_tensor: Optional[List[torch.Tensor]] = None,
1890+
no_snapshot: bool = True,
1891+
should_flush: bool = False,
18841892
) -> List[torch.Tensor]:
18851893
"""
18861894
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD,
@@ -1897,14 +1905,166 @@ def split_optimizer_states(
18971905
id consistency between weight and optimizer states.
18981906
18991907
"""
1900-
raise NotImplementedError(
1901-
"split_optimizer_states is not implemented for SSDTableBatchedEmbeddingBags"
1908+
1909+
logging.info(f"split_optimizer_states: {no_snapshot=}, {should_flush=}")
1910+
start_time = time.time()
1911+
torch.cuda.synchronize()
1912+
1913+
(rows, _) = zip(*self.embedding_specs)
1914+
1915+
rows_cumsum = [0] + list(itertools.accumulate(rows))
1916+
if not self.kv_zch_params:
1917+
logging.info(
1918+
f"non KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms"
1919+
)
1920+
return [
1921+
self.momentum1_dev.detach()[rows_cumsum[t] : rows_cumsum[t + 1]].view(
1922+
row
1923+
)
1924+
for t, row in enumerate(rows)
1925+
]
1926+
1927+
# With optimizer state offloading, we need to query optimizer states from backend,
1928+
# so create snapshot for SSD backend first.
1929+
snapshot_handle = None
1930+
if self.backend_type == BackendType.SSD:
1931+
# Create a rocksdb snapshot
1932+
if not no_snapshot:
1933+
if should_flush:
1934+
# Flush L1 and L2 caches
1935+
self.flush(force=True)
1936+
logging.info(
1937+
f"flushed L1 and L2 caches for optimizer state, latency: {time.time() - start_time} ms"
1938+
)
1939+
snapshot_handle = self.ssd_db.create_snapshot()
1940+
logging.info(f"created snapshot for optimizer state: {snapshot_handle}")
1941+
elif self.backend_type == BackendType.DRAM:
1942+
self.flush(force=True)
1943+
1944+
opt_list = []
1945+
table_offset = 0
1946+
if self.load_state_dict:
1947+
# init for checkpointing loading
1948+
assert (
1949+
self._cached_kvzch_data is not None
1950+
and self._cached_kvzch_data.cached_optimizer_state_per_table is not None
1951+
), "optimizer state is not initialized for load checkpointing"
1952+
return self._cached_kvzch_data.cached_optimizer_state_per_table
1953+
1954+
dtype = self.weights_precision.as_dtype()
1955+
optimizer_dim = self.optimizer.state_size_dim(dtype)
1956+
pad4_optimizer_dim = pad4(optimizer_dim)
1957+
logging.info(
1958+
f"split_optimizer_states: {optimizer_dim=} {pad4_optimizer_dim=} {self.optimizer_dtype=}"
1959+
)
1960+
1961+
for t, (emb_height, emb_dim) in enumerate(self.embedding_specs):
1962+
# pyre-ignore
1963+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
1964+
# pyre-ignore
1965+
bucket_size = self.kv_zch_params.bucket_sizes[t]
1966+
row_offset = table_offset
1967+
1968+
if not self.enable_optimizer_offloading:
1969+
if sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0:
1970+
opt_list.append(
1971+
self.momentum1_dev.detach()
1972+
.cpu()[0:1]
1973+
.view(-1) # dummy tensor for module initialization
1974+
)
1975+
elif all(sorted_id_tensor[t] == 0):
1976+
# all ids are 0, which means it's the dummy id tensor with correct shape just for loading cp
1977+
opt_list.append(
1978+
(
1979+
self.momentum1_dev.detach().cpu() # the shape should be correct at this point
1980+
)
1981+
)
1982+
else:
1983+
# convert global id back to local id, then linearize with table offset
1984+
local_id_tensor = (
1985+
sorted_id_tensor[t]
1986+
- bucket_id_start * bucket_size
1987+
+ table_offset
1988+
)
1989+
opt_list.append(
1990+
self.momentum1_dev.detach().cpu()[local_id_tensor].view(-1),
1991+
)
1992+
else:
1993+
if sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0:
1994+
opt_list.append(
1995+
torch.empty(
1996+
1,
1997+
dtype=torch.float32,
1998+
device="cpu",
1999+
) # dummy tensor for module initialization
2000+
)
2001+
else:
2002+
emb_opt_dim = pad4(emb_dim) + pad4_optimizer_dim
2003+
row_offset = table_offset - (bucket_id_start * bucket_size)
2004+
# using KVTensorWrapper to query backend to avoid OOM memory, since
2005+
# backend will return both weight and optimizer in one tensor, read the whole tensor
2006+
# out could OOM CPU memory.
2007+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2008+
shape=[emb_height, emb_opt_dim],
2009+
dtype=dtype,
2010+
row_offset=row_offset,
2011+
snapshot_handle=snapshot_handle,
2012+
materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]),
2013+
sorted_indices=sorted_id_tensor[t],
2014+
)
2015+
(
2016+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2017+
if self.backend_type == BackendType.SSD
2018+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2019+
)
2020+
opt_list.append(
2021+
self.get_offloaded_optimizer_states(
2022+
tensor_wrapper=tensor_wrapper,
2023+
row=sorted_id_tensor[t].size(
2024+
0
2025+
), # we only need to copy the size of sorted_id_tensor
2026+
optimizer_dim=optimizer_dim,
2027+
start_dim_pos=pad4(emb_dim),
2028+
)
2029+
)
2030+
table_offset += emb_height
2031+
logging.info(
2032+
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms"
19022033
)
2034+
return opt_list
2035+
2036+
@torch.jit.export
2037+
def get_offloaded_optimizer_states(
2038+
self,
2039+
# pyre-ignore [11]: Annotation `KVTensorWrapper` is not defined as a type.
2040+
tensor_wrapper: torch.classes.fbgemm.KVTensorWrapper,
2041+
row: int,
2042+
optimizer_dim: int,
2043+
start_dim_pos: int,
2044+
) -> torch.Tensor:
2045+
weight_dtype = self.weights_precision.as_dtype()
2046+
opt_state_t = torch.empty(
2047+
row, optimizer_dim, dtype=weight_dtype, device="cpu"
2048+
) # 1D optimizer for OptimType.EXACT_ROWWISE_ADAGRAD
2049+
2050+
# pyre-ignore [16]
2051+
chunk_size = self.kv_zch_params.streaming_ckpt_chunk_size
2052+
for i in range(0, row, chunk_size):
2053+
length = min(chunk_size, row - i)
2054+
opt_state_t.narrow(0, i, length).copy_(
2055+
tensor_wrapper.narrow(0, i, length).narrow(
2056+
1, start_dim_pos, optimizer_dim
2057+
)
2058+
)
2059+
# view optimizer state back to correct dtype
2060+
return opt_state_t.view(-1).view(self.optimizer_dtype)
19032061

19042062
@torch.jit.export
19052063
def get_optimizer_state(
19062064
self,
19072065
sorted_id_tensor: Optional[List[torch.Tensor]],
2066+
no_snapshot: bool = True,
2067+
should_flush: bool = False,
19082068
) -> List[Dict[str, torch.Tensor]]:
19092069
"""
19102070
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD
@@ -1914,6 +2074,8 @@ def get_optimizer_state(
19142074
({"momentum1": states})
19152075
for states in self.split_optimizer_states(
19162076
sorted_id_tensor=sorted_id_tensor,
2077+
no_snapshot=no_snapshot,
2078+
should_flush=should_flush,
19172079
)
19182080
]
19192081

@@ -1963,6 +2125,7 @@ def debug_split_embedding_weights(self) -> List[torch.Tensor]:
19632125
return splits
19642126

19652127
def clear_cache(self) -> None:
2128+
# clear KV ZCH cache for checkpointing
19662129
self._cached_kvzch_data = None
19672130

19682131
@torch.jit.export

0 commit comments

Comments
 (0)