Skip to content

Commit 91c5a79

Browse files
emlinfacebook-github-bot
authored andcommitted
implement optimizer state with opt offloading (#4141)
Summary: Pull Request resolved: #4141 X-link: facebookresearch/FBGEMM#1224 implement split_optimizer_states for optimizer state dict integration Reviewed By: duduyi2013, bobbyliujb Differential Revision: D74790121 fbshipit-source-id: e7ba97309572d9d488919e14d00506ff7d2d821c
1 parent e0e3e8c commit 91c5a79

File tree

5 files changed

+437
-29
lines changed

5 files changed

+437
-29
lines changed

fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ def state_size_dim(self, dtype: torch.dtype) -> int:
5858
"""
5959
return int(math.ceil(self.state_size() / dtype.itemsize))
6060

61+
def dtype(self) -> torch.dtype:
62+
"""
63+
Returns the dtype of the optimizer state
64+
"""
65+
return {
66+
EmbOptimType.EXACT_ROWWISE_ADAGRAD: torch.float32,
67+
}.get(self, torch.float32)
68+
6169

6270
# Base class for quantization configuration (in case other numeric types have
6371
# configs)

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ class KVZCHParams(NamedTuple):
5757
# the value indicates corresponding input space for each bucket id, e.g. 2^50 / total_num_buckets
5858
bucket_sizes: List[int] = []
5959
# enable optimizer offloading or not
60-
enable_optimizer_offloading: bool = True
60+
enable_optimizer_offloading: bool = False
61+
# streaming load/save checkpoint chunk size
62+
streaming_ckpt_chunk_size: int = 1000000
6163

6264
def validate(self) -> None:
6365
assert len(self.bucket_offsets) == len(self.bucket_sizes), (

fbgemm_gpu/fbgemm_gpu/tbe/ssd/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# pyre-ignore-all-errors[56]
1010

1111
import torch
12+
1213
from fbgemm_gpu.utils.loader import load_torch_module
1314

1415
try:

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 173 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,13 @@
5050
from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
5151
generate_vbe_metadata,
5252
)
53-
5453
from torch import distributed as dist, nn, Tensor # usort:skip
5554
from dataclasses import dataclass
5655

57-
from fbgemm_gpu.tbe.ssd.common import tensor_pad4
58-
5956
from torch.autograd.profiler import record_function
6057

6158
from ..cache import get_unique_indices_v2
62-
63-
from .common import ASSOC, pad4
59+
from .common import ASSOC, pad4, tensor_pad4
6460
from .utils.partially_materialized_tensor import PartiallyMaterializedTensor
6561

6662

@@ -78,9 +74,9 @@ class IterData:
7874

7975
@dataclass
8076
class KVZCHCachedData:
81-
cached_id_tensor_per_table: List[torch.Tensor]
82-
cached_weight_tensor_per_table: List[torch.Tensor]
8377
cached_optimizer_state_per_table: List[torch.Tensor]
78+
cached_weight_tensor_per_table: List[torch.Tensor]
79+
cached_id_tensor_per_table: List[torch.Tensor]
8480
cached_bucket_splits: List[torch.Tensor]
8581

8682

@@ -175,11 +171,13 @@ def __init__(
175171
) -> None:
176172
super(SSDTableBatchedEmbeddingBags, self).__init__()
177173

174+
# Set the optimizer
178175
assert optimizer in (
179176
OptimType.EXACT_ROWWISE_ADAGRAD,
180177
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
181178
self.optimizer = optimizer
182179

180+
# Set the table weight and output dtypes
183181
assert weights_precision in (SparseType.FP32, SparseType.FP16)
184182
self.weights_precision = weights_precision
185183
self.output_dtype: int = output_dtype.as_int()
@@ -702,7 +700,9 @@ def __init__(
702700
momentum1_offsets = [0] + list(itertools.accumulate(rows))
703701
self._apply_split(
704702
SplitState(
705-
dev_size=self.total_hash_size,
703+
dev_size=(
704+
self.total_hash_size if not self.enable_optimizer_offloading else 0
705+
),
706706
host_size=0,
707707
uvm_size=0,
708708
placements=[EmbeddingLocation.DEVICE for _ in range(T_)],
@@ -1720,6 +1720,7 @@ def forward(
17201720
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
17211721
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
17221722
) -> Tensor:
1723+
self.clear_cache()
17231724
indices, offsets, per_sample_weights, vbe_metadata = self.prepare_inputs(
17241725
indices, offsets, per_sample_weights, batch_size_per_feature_per_rank
17251726
)
@@ -1877,10 +1878,30 @@ def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor, int, int]]:
18771878
for t, row in enumerate(rows)
18781879
]
18791880

1881+
@torch.jit.ignore
1882+
def _split_optimizer_states_non_kv_zch(
1883+
self,
1884+
) -> List[torch.Tensor]:
1885+
"""
1886+
Returns a list of optimizer states, split by table. So far, we only support EXACT_ROWWISE_ADAGRAD,
1887+
so only momentum1 state is returned.
1888+
"""
1889+
logging.info("_split_optimizer_states_non_kv_zch")
1890+
(rows, _) = zip(*self.embedding_specs)
1891+
1892+
rows_cumsum = [0] + list(itertools.accumulate(rows))
1893+
1894+
return [
1895+
self.momentum1_dev.detach()[rows_cumsum[t] : rows_cumsum[t + 1]].view(row)
1896+
for t, row in enumerate(rows)
1897+
]
1898+
18801899
@torch.jit.export
18811900
def split_optimizer_states(
18821901
self,
18831902
sorted_id_tensor: Optional[List[torch.Tensor]] = None,
1903+
no_snapshot: bool = True,
1904+
should_flush: bool = False,
18841905
) -> List[torch.Tensor]:
18851906
"""
18861907
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD,
@@ -1897,14 +1918,126 @@ def split_optimizer_states(
18971918
id consistency between weight and optimizer states.
18981919
18991920
"""
1900-
raise NotImplementedError(
1901-
"split_optimizer_states is not implemented for SSDTableBatchedEmbeddingBags"
1921+
1922+
if not self.kv_zch_params:
1923+
return self._split_optimizer_states_non_kv_zch()
1924+
1925+
if self.load_state_dict:
1926+
# init for checkpointing loading
1927+
assert (
1928+
self._cached_kvzch_data is not None
1929+
and self._cached_kvzch_data.cached_optimizer_state_per_table is not None
1930+
), "optimizer state is not initialized for load checkpointing"
1931+
return self._cached_kvzch_data.cached_optimizer_state_per_table
1932+
1933+
logging.info(
1934+
f"split_optimizer_states for KV ZCH: {no_snapshot=}, {should_flush=}"
1935+
)
1936+
start_time = time.time()
1937+
snapshot_handle = self._may_create_snapshot_for_state_dict(
1938+
no_snapshot=no_snapshot,
1939+
should_flush=should_flush,
19021940
)
19031941

1942+
opt_list = []
1943+
table_offset = 0
1944+
1945+
dtype = self.weights_precision.as_dtype()
1946+
optimizer_dim = self.optimizer.state_size_dim(dtype)
1947+
pad4_optimizer_dim = pad4(optimizer_dim)
1948+
logging.info(
1949+
f"split_optimizer_states: {optimizer_dim=} {pad4_optimizer_dim=} {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}"
1950+
)
1951+
1952+
for t, (emb_height, emb_dim) in enumerate(self.embedding_specs):
1953+
# pyre-ignore
1954+
bucket_id_start, _ = self.kv_zch_params.bucket_offsets[t]
1955+
# pyre-ignore
1956+
bucket_size = self.kv_zch_params.bucket_sizes[t]
1957+
row_offset = table_offset
1958+
if sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0:
1959+
opt_list.append(
1960+
torch.empty(0, dtype=self.optimizer.dtype(), device="cpu")
1961+
# empty optimizer state for module initialization
1962+
)
1963+
else:
1964+
if not self.enable_optimizer_offloading:
1965+
# convert global id back to local id, then linearize with table offset
1966+
local_id_tensor = (
1967+
sorted_id_tensor[t]
1968+
- bucket_id_start * bucket_size
1969+
+ table_offset
1970+
)
1971+
opt_list.append(
1972+
self.momentum1_dev.detach().cpu()[local_id_tensor].view(-1),
1973+
)
1974+
else:
1975+
emb_opt_dim = pad4(emb_dim) + pad4_optimizer_dim
1976+
row_offset = table_offset - (bucket_id_start * bucket_size)
1977+
# using KVTensorWrapper to query backend to avoid OOM memory, since
1978+
# backend will return both weight and optimizer in one tensor, read the whole tensor
1979+
# out could OOM CPU memory.
1980+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
1981+
shape=[emb_height, emb_opt_dim],
1982+
dtype=dtype,
1983+
row_offset=row_offset,
1984+
snapshot_handle=snapshot_handle,
1985+
materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]),
1986+
sorted_indices=sorted_id_tensor[t],
1987+
)
1988+
(
1989+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
1990+
if self.backend_type == BackendType.SSD
1991+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
1992+
)
1993+
opt_list.append(
1994+
self.get_offloaded_optimizer_states(
1995+
tensor_wrapper=tensor_wrapper,
1996+
row=sorted_id_tensor[t].size(
1997+
0
1998+
), # we only need to copy the size of sorted_id_tensor
1999+
optimizer_dim=optimizer_dim,
2000+
start_dim_pos=pad4(emb_dim),
2001+
)
2002+
)
2003+
table_offset += emb_height
2004+
logging.info(
2005+
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms"
2006+
)
2007+
return opt_list
2008+
2009+
@torch.jit.export
2010+
def get_offloaded_optimizer_states(
2011+
self,
2012+
# pyre-ignore [2]
2013+
tensor_wrapper,
2014+
row: int,
2015+
optimizer_dim: int,
2016+
start_dim_pos: int,
2017+
) -> torch.Tensor:
2018+
weight_dtype = self.weights_precision.as_dtype()
2019+
opt_state_t = torch.empty(
2020+
row, optimizer_dim, dtype=weight_dtype, device="cpu"
2021+
) # 1D optimizer for OptimType.EXACT_ROWWISE_ADAGRAD
2022+
2023+
# pyre-ignore [16]
2024+
chunk_size = self.kv_zch_params.streaming_ckpt_chunk_size
2025+
for i in range(0, row, chunk_size):
2026+
length = min(chunk_size, row - i)
2027+
opt_state_t.narrow(0, i, length).copy_(
2028+
tensor_wrapper.narrow(0, i, length).narrow(
2029+
1, start_dim_pos, optimizer_dim
2030+
)
2031+
)
2032+
# view optimizer state back to correct dtype
2033+
return opt_state_t.view(-1).view(self.optimizer.dtype())
2034+
19042035
@torch.jit.export
19052036
def get_optimizer_state(
19062037
self,
19072038
sorted_id_tensor: Optional[List[torch.Tensor]],
2039+
no_snapshot: bool = True,
2040+
should_flush: bool = False,
19082041
) -> List[Dict[str, torch.Tensor]]:
19092042
"""
19102043
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD
@@ -1914,6 +2047,8 @@ def get_optimizer_state(
19142047
({"momentum1": states})
19152048
for states in self.split_optimizer_states(
19162049
sorted_id_tensor=sorted_id_tensor,
2050+
no_snapshot=no_snapshot,
2051+
should_flush=should_flush,
19172052
)
19182053
]
19192054

@@ -1963,8 +2098,32 @@ def debug_split_embedding_weights(self) -> List[torch.Tensor]:
19632098
return splits
19642099

19652100
def clear_cache(self) -> None:
2101+
# clear KV ZCH cache for checkpointing
19662102
self._cached_kvzch_data = None
19672103

2104+
@torch.jit.ignore
2105+
# pyre-ignore [3] - do not definte snapshot class EmbeddingSnapshotHandleWrapper to avoid import dependency in other production code
2106+
def _may_create_snapshot_for_state_dict(
2107+
self,
2108+
no_snapshot: bool = True,
2109+
should_flush: bool = False,
2110+
):
2111+
"""
2112+
Create a rocksdb snapshot if needed.
2113+
"""
2114+
# Force device synchronize for now
2115+
torch.cuda.synchronize()
2116+
snapshot_handle = None
2117+
if self.backend_type == BackendType.SSD:
2118+
# Create a rocksdb snapshot
2119+
if not no_snapshot:
2120+
# Flush L1 and L2 caches
2121+
self.flush(force=should_flush)
2122+
snapshot_handle = self.ssd_db.create_snapshot()
2123+
elif self.backend_type == BackendType.DRAM:
2124+
self.flush(force=should_flush)
2125+
return snapshot_handle
2126+
19682127
@torch.jit.export
19692128
def split_embedding_weights(
19702129
self,
@@ -1994,18 +2153,10 @@ def split_embedding_weights(
19942153
3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
19952154
where for the i th element, we have i + bucket_id_start = global bucket id
19962155
"""
1997-
# Force device synchronize for now
1998-
torch.cuda.synchronize()
1999-
snapshot_handle = None
2000-
if self.backend_type == BackendType.SSD:
2001-
# Create a rocksdb snapshot
2002-
if not no_snapshot:
2003-
if should_flush:
2004-
# Flush L1 and L2 caches
2005-
self.flush(force=True)
2006-
snapshot_handle = self.ssd_db.create_snapshot()
2007-
elif self.backend_type == BackendType.DRAM:
2008-
self.flush(force=True)
2156+
snapshot_handle = self._may_create_snapshot_for_state_dict(
2157+
no_snapshot=no_snapshot,
2158+
should_flush=should_flush,
2159+
)
20092160

20102161
dtype = self.weights_precision.as_dtype()
20112162
pmt_splits = []

0 commit comments

Comments
 (0)