50
50
from fbgemm_gpu .split_table_batched_embeddings_ops_training_common import (
51
51
generate_vbe_metadata ,
52
52
)
53
-
54
53
from torch import distributed as dist , nn , Tensor # usort:skip
55
54
from dataclasses import dataclass
56
55
57
- from fbgemm_gpu .tbe .ssd .common import tensor_pad4
58
-
59
56
from torch .autograd .profiler import record_function
60
57
61
58
from ..cache import get_unique_indices_v2
62
-
63
- from .common import ASSOC , pad4
59
+ from .common import ASSOC , pad4 , tensor_pad4
64
60
from .utils .partially_materialized_tensor import PartiallyMaterializedTensor
65
61
66
62
@@ -78,9 +74,9 @@ class IterData:
78
74
79
75
@dataclass
80
76
class KVZCHCachedData :
81
- cached_id_tensor_per_table : List [torch .Tensor ]
82
- cached_weight_tensor_per_table : List [torch .Tensor ]
83
77
cached_optimizer_state_per_table : List [torch .Tensor ]
78
+ cached_weight_tensor_per_table : List [torch .Tensor ]
79
+ cached_id_tensor_per_table : List [torch .Tensor ]
84
80
cached_bucket_splits : List [torch .Tensor ]
85
81
86
82
@@ -175,11 +171,13 @@ def __init__(
175
171
) -> None :
176
172
super (SSDTableBatchedEmbeddingBags , self ).__init__ ()
177
173
174
+ # Set the optimizer
178
175
assert optimizer in (
179
176
OptimType .EXACT_ROWWISE_ADAGRAD ,
180
177
), f"Optimizer { optimizer } is not supported by SSDTableBatchedEmbeddingBags"
181
178
self .optimizer = optimizer
182
179
180
+ # Set the table weight and output dtypes
183
181
assert weights_precision in (SparseType .FP32 , SparseType .FP16 )
184
182
self .weights_precision = weights_precision
185
183
self .output_dtype : int = output_dtype .as_int ()
@@ -702,7 +700,9 @@ def __init__(
702
700
momentum1_offsets = [0 ] + list (itertools .accumulate (rows ))
703
701
self ._apply_split (
704
702
SplitState (
705
- dev_size = self .total_hash_size ,
703
+ dev_size = (
704
+ self .total_hash_size if not self .enable_optimizer_offloading else 0
705
+ ),
706
706
host_size = 0 ,
707
707
uvm_size = 0 ,
708
708
placements = [EmbeddingLocation .DEVICE for _ in range (T_ )],
@@ -1720,6 +1720,7 @@ def forward(
1720
1720
batch_size_per_feature_per_rank : Optional [List [List [int ]]] = None ,
1721
1721
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
1722
1722
) -> Tensor :
1723
+ self .clear_cache ()
1723
1724
indices , offsets , per_sample_weights , vbe_metadata = self .prepare_inputs (
1724
1725
indices , offsets , per_sample_weights , batch_size_per_feature_per_rank
1725
1726
)
@@ -1877,10 +1878,30 @@ def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor, int, int]]:
1877
1878
for t , row in enumerate (rows )
1878
1879
]
1879
1880
1881
+ @torch .jit .ignore
1882
+ def _split_optimizer_states_non_kv_zch (
1883
+ self ,
1884
+ ) -> List [torch .Tensor ]:
1885
+ """
1886
+ Returns a list of optimizer states, split by table. So far, we only support EXACT_ROWWISE_ADAGRAD,
1887
+ so only momentum1 state is returned.
1888
+ """
1889
+ logging .info ("_split_optimizer_states_non_kv_zch" )
1890
+ (rows , _ ) = zip (* self .embedding_specs )
1891
+
1892
+ rows_cumsum = [0 ] + list (itertools .accumulate (rows ))
1893
+
1894
+ return [
1895
+ self .momentum1_dev .detach ()[rows_cumsum [t ] : rows_cumsum [t + 1 ]].view (row )
1896
+ for t , row in enumerate (rows )
1897
+ ]
1898
+
1880
1899
@torch .jit .export
1881
1900
def split_optimizer_states (
1882
1901
self ,
1883
1902
sorted_id_tensor : Optional [List [torch .Tensor ]] = None ,
1903
+ no_snapshot : bool = True ,
1904
+ should_flush : bool = False ,
1884
1905
) -> List [torch .Tensor ]:
1885
1906
"""
1886
1907
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD,
@@ -1897,14 +1918,126 @@ def split_optimizer_states(
1897
1918
id consistency between weight and optimizer states.
1898
1919
1899
1920
"""
1900
- raise NotImplementedError (
1901
- "split_optimizer_states is not implemented for SSDTableBatchedEmbeddingBags"
1921
+
1922
+ if not self .kv_zch_params :
1923
+ return self ._split_optimizer_states_non_kv_zch ()
1924
+
1925
+ if self .load_state_dict :
1926
+ # init for checkpointing loading
1927
+ assert (
1928
+ self ._cached_kvzch_data is not None
1929
+ and self ._cached_kvzch_data .cached_optimizer_state_per_table is not None
1930
+ ), "optimizer state is not initialized for load checkpointing"
1931
+ return self ._cached_kvzch_data .cached_optimizer_state_per_table
1932
+
1933
+ logging .info (
1934
+ f"split_optimizer_states for KV ZCH: { no_snapshot = } , { should_flush = } "
1935
+ )
1936
+ start_time = time .time ()
1937
+ snapshot_handle = self ._may_create_snapshot_for_state_dict (
1938
+ no_snapshot = no_snapshot ,
1939
+ should_flush = should_flush ,
1902
1940
)
1903
1941
1942
+ opt_list = []
1943
+ table_offset = 0
1944
+
1945
+ dtype = self .weights_precision .as_dtype ()
1946
+ optimizer_dim = self .optimizer .state_size_dim (dtype )
1947
+ pad4_optimizer_dim = pad4 (optimizer_dim )
1948
+ logging .info (
1949
+ f"split_optimizer_states: { optimizer_dim = } { pad4_optimizer_dim = } { self .optimizer .dtype ()= } { self .enable_load_state_dict_mode = } "
1950
+ )
1951
+
1952
+ for t , (emb_height , emb_dim ) in enumerate (self .embedding_specs ):
1953
+ # pyre-ignore
1954
+ bucket_id_start , _ = self .kv_zch_params .bucket_offsets [t ]
1955
+ # pyre-ignore
1956
+ bucket_size = self .kv_zch_params .bucket_sizes [t ]
1957
+ row_offset = table_offset
1958
+ if sorted_id_tensor is None or sorted_id_tensor [t ].numel () == 0 :
1959
+ opt_list .append (
1960
+ torch .empty (0 , dtype = self .optimizer .dtype (), device = "cpu" )
1961
+ # empty optimizer state for module initialization
1962
+ )
1963
+ else :
1964
+ if not self .enable_optimizer_offloading :
1965
+ # convert global id back to local id, then linearize with table offset
1966
+ local_id_tensor = (
1967
+ sorted_id_tensor [t ]
1968
+ - bucket_id_start * bucket_size
1969
+ + table_offset
1970
+ )
1971
+ opt_list .append (
1972
+ self .momentum1_dev .detach ().cpu ()[local_id_tensor ].view (- 1 ),
1973
+ )
1974
+ else :
1975
+ emb_opt_dim = pad4 (emb_dim ) + pad4_optimizer_dim
1976
+ row_offset = table_offset - (bucket_id_start * bucket_size )
1977
+ # using KVTensorWrapper to query backend to avoid OOM memory, since
1978
+ # backend will return both weight and optimizer in one tensor, read the whole tensor
1979
+ # out could OOM CPU memory.
1980
+ tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
1981
+ shape = [emb_height , emb_opt_dim ],
1982
+ dtype = dtype ,
1983
+ row_offset = row_offset ,
1984
+ snapshot_handle = snapshot_handle ,
1985
+ materialized_shape = ([sorted_id_tensor [t ].size (0 ), emb_opt_dim ]),
1986
+ sorted_indices = sorted_id_tensor [t ],
1987
+ )
1988
+ (
1989
+ tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
1990
+ if self .backend_type == BackendType .SSD
1991
+ else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
1992
+ )
1993
+ opt_list .append (
1994
+ self .get_offloaded_optimizer_states (
1995
+ tensor_wrapper = tensor_wrapper ,
1996
+ row = sorted_id_tensor [t ].size (
1997
+ 0
1998
+ ), # we only need to copy the size of sorted_id_tensor
1999
+ optimizer_dim = optimizer_dim ,
2000
+ start_dim_pos = pad4 (emb_dim ),
2001
+ )
2002
+ )
2003
+ table_offset += emb_height
2004
+ logging .info (
2005
+ f"KV ZCH tables split_optimizer_states query latency: { (time .time () - start_time ) * 1000 } ms"
2006
+ )
2007
+ return opt_list
2008
+
2009
+ @torch .jit .export
2010
+ def get_offloaded_optimizer_states (
2011
+ self ,
2012
+ # pyre-ignore [2]
2013
+ tensor_wrapper ,
2014
+ row : int ,
2015
+ optimizer_dim : int ,
2016
+ start_dim_pos : int ,
2017
+ ) -> torch .Tensor :
2018
+ weight_dtype = self .weights_precision .as_dtype ()
2019
+ opt_state_t = torch .empty (
2020
+ row , optimizer_dim , dtype = weight_dtype , device = "cpu"
2021
+ ) # 1D optimizer for OptimType.EXACT_ROWWISE_ADAGRAD
2022
+
2023
+ # pyre-ignore [16]
2024
+ chunk_size = self .kv_zch_params .streaming_ckpt_chunk_size
2025
+ for i in range (0 , row , chunk_size ):
2026
+ length = min (chunk_size , row - i )
2027
+ opt_state_t .narrow (0 , i , length ).copy_ (
2028
+ tensor_wrapper .narrow (0 , i , length ).narrow (
2029
+ 1 , start_dim_pos , optimizer_dim
2030
+ )
2031
+ )
2032
+ # view optimizer state back to correct dtype
2033
+ return opt_state_t .view (- 1 ).view (self .optimizer .dtype ())
2034
+
1904
2035
@torch .jit .export
1905
2036
def get_optimizer_state (
1906
2037
self ,
1907
2038
sorted_id_tensor : Optional [List [torch .Tensor ]],
2039
+ no_snapshot : bool = True ,
2040
+ should_flush : bool = False ,
1908
2041
) -> List [Dict [str , torch .Tensor ]]:
1909
2042
"""
1910
2043
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD
@@ -1914,6 +2047,8 @@ def get_optimizer_state(
1914
2047
({"momentum1" : states })
1915
2048
for states in self .split_optimizer_states (
1916
2049
sorted_id_tensor = sorted_id_tensor ,
2050
+ no_snapshot = no_snapshot ,
2051
+ should_flush = should_flush ,
1917
2052
)
1918
2053
]
1919
2054
@@ -1963,8 +2098,32 @@ def debug_split_embedding_weights(self) -> List[torch.Tensor]:
1963
2098
return splits
1964
2099
1965
2100
def clear_cache (self ) -> None :
2101
+ # clear KV ZCH cache for checkpointing
1966
2102
self ._cached_kvzch_data = None
1967
2103
2104
+ @torch .jit .ignore
2105
+ # pyre-ignore [3] - do not definte snapshot class EmbeddingSnapshotHandleWrapper to avoid import dependency in other production code
2106
+ def _may_create_snapshot_for_state_dict (
2107
+ self ,
2108
+ no_snapshot : bool = True ,
2109
+ should_flush : bool = False ,
2110
+ ):
2111
+ """
2112
+ Create a rocksdb snapshot if needed.
2113
+ """
2114
+ # Force device synchronize for now
2115
+ torch .cuda .synchronize ()
2116
+ snapshot_handle = None
2117
+ if self .backend_type == BackendType .SSD :
2118
+ # Create a rocksdb snapshot
2119
+ if not no_snapshot :
2120
+ # Flush L1 and L2 caches
2121
+ self .flush (force = should_flush )
2122
+ snapshot_handle = self .ssd_db .create_snapshot ()
2123
+ elif self .backend_type == BackendType .DRAM :
2124
+ self .flush (force = should_flush )
2125
+ return snapshot_handle
2126
+
1968
2127
@torch .jit .export
1969
2128
def split_embedding_weights (
1970
2129
self ,
@@ -1994,18 +2153,10 @@ def split_embedding_weights(
1994
2153
3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
1995
2154
where for the i th element, we have i + bucket_id_start = global bucket id
1996
2155
"""
1997
- # Force device synchronize for now
1998
- torch .cuda .synchronize ()
1999
- snapshot_handle = None
2000
- if self .backend_type == BackendType .SSD :
2001
- # Create a rocksdb snapshot
2002
- if not no_snapshot :
2003
- if should_flush :
2004
- # Flush L1 and L2 caches
2005
- self .flush (force = True )
2006
- snapshot_handle = self .ssd_db .create_snapshot ()
2007
- elif self .backend_type == BackendType .DRAM :
2008
- self .flush (force = True )
2156
+ snapshot_handle = self ._may_create_snapshot_for_state_dict (
2157
+ no_snapshot = no_snapshot ,
2158
+ should_flush = should_flush ,
2159
+ )
2009
2160
2010
2161
dtype = self .weights_precision .as_dtype ()
2011
2162
pmt_splits = []
0 commit comments