@@ -78,9 +78,9 @@ class IterData:
78
78
79
79
@dataclass
80
80
class KVZCHCachedData :
81
- cached_id_tensor_per_table : List [torch .Tensor ]
82
- cached_weight_tensor_per_table : List [torch .Tensor ]
83
81
cached_optimizer_state_per_table : List [torch .Tensor ]
82
+ cached_weight_tensor_per_table : List [torch .Tensor ]
83
+ cached_id_tensor_per_table : List [torch .Tensor ]
84
84
cached_bucket_splits : List [torch .Tensor ]
85
85
86
86
@@ -175,11 +175,14 @@ def __init__(
175
175
) -> None :
176
176
super (SSDTableBatchedEmbeddingBags , self ).__init__ ()
177
177
178
+ # Set the optimizer
178
179
assert optimizer in (
179
180
OptimType .EXACT_ROWWISE_ADAGRAD ,
180
181
), f"Optimizer { optimizer } is not supported by SSDTableBatchedEmbeddingBags"
181
182
self .optimizer = optimizer
183
+ self .optimizer_dtype : torch .dtype = torch .float32
182
184
185
+ # Set the table weight and output dtypes
183
186
assert weights_precision in (SparseType .FP32 , SparseType .FP16 )
184
187
self .weights_precision = weights_precision
185
188
self .output_dtype : int = output_dtype .as_int ()
@@ -702,7 +705,9 @@ def __init__(
702
705
momentum1_offsets = [0 ] + list (itertools .accumulate (rows ))
703
706
self ._apply_split (
704
707
SplitState (
705
- dev_size = self .total_hash_size ,
708
+ dev_size = (
709
+ self .total_hash_size if not self .enable_optimizer_offloading else 0
710
+ ),
706
711
host_size = 0 ,
707
712
uvm_size = 0 ,
708
713
placements = [EmbeddingLocation .DEVICE for _ in range (T_ )],
@@ -1720,6 +1725,7 @@ def forward(
1720
1725
batch_size_per_feature_per_rank : Optional [List [List [int ]]] = None ,
1721
1726
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
1722
1727
) -> Tensor :
1728
+ self .clear_cache ()
1723
1729
indices , offsets , per_sample_weights , vbe_metadata = self .prepare_inputs (
1724
1730
indices , offsets , per_sample_weights , batch_size_per_feature_per_rank
1725
1731
)
@@ -1881,6 +1887,8 @@ def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor, int, int]]:
1881
1887
def split_optimizer_states (
1882
1888
self ,
1883
1889
sorted_id_tensor : Optional [List [torch .Tensor ]] = None ,
1890
+ no_snapshot : bool = True ,
1891
+ should_flush : bool = False ,
1884
1892
) -> List [torch .Tensor ]:
1885
1893
"""
1886
1894
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD,
@@ -1897,14 +1905,166 @@ def split_optimizer_states(
1897
1905
id consistency between weight and optimizer states.
1898
1906
1899
1907
"""
1900
- raise NotImplementedError (
1901
- "split_optimizer_states is not implemented for SSDTableBatchedEmbeddingBags"
1908
+
1909
+ logging .info (f"split_optimizer_states: { no_snapshot = } , { should_flush = } " )
1910
+ start_time = time .time ()
1911
+ torch .cuda .synchronize ()
1912
+
1913
+ (rows , _ ) = zip (* self .embedding_specs )
1914
+
1915
+ rows_cumsum = [0 ] + list (itertools .accumulate (rows ))
1916
+ if not self .kv_zch_params :
1917
+ logging .info (
1918
+ f"non KV ZCH tables split_optimizer_states query latency: { (time .time () - start_time ) * 1000 } ms"
1919
+ )
1920
+ return [
1921
+ self .momentum1_dev .detach ()[rows_cumsum [t ] : rows_cumsum [t + 1 ]].view (
1922
+ row
1923
+ )
1924
+ for t , row in enumerate (rows )
1925
+ ]
1926
+
1927
+ # With optimizer state offloading, we need to query optimizer states from backend,
1928
+ # so create snapshot for SSD backend first.
1929
+ snapshot_handle = None
1930
+ if self .backend_type == BackendType .SSD :
1931
+ # Create a rocksdb snapshot
1932
+ if not no_snapshot :
1933
+ if should_flush :
1934
+ # Flush L1 and L2 caches
1935
+ self .flush (force = True )
1936
+ logging .info (
1937
+ f"flushed L1 and L2 caches for optimizer state, latency: { time .time () - start_time } ms"
1938
+ )
1939
+ snapshot_handle = self .ssd_db .create_snapshot ()
1940
+ logging .info (f"created snapshot for optimizer state: { snapshot_handle } " )
1941
+ elif self .backend_type == BackendType .DRAM :
1942
+ self .flush (force = True )
1943
+
1944
+ opt_list = []
1945
+ table_offset = 0
1946
+ if self .load_state_dict :
1947
+ # init for checkpointing loading
1948
+ assert (
1949
+ self ._cached_kvzch_data is not None
1950
+ and self ._cached_kvzch_data .cached_optimizer_state_per_table is not None
1951
+ ), "optimizer state is not initialized for load checkpointing"
1952
+ return self ._cached_kvzch_data .cached_optimizer_state_per_table
1953
+
1954
+ dtype = self .weights_precision .as_dtype ()
1955
+ optimizer_dim = self .optimizer .state_size_dim (dtype )
1956
+ pad4_optimizer_dim = pad4 (optimizer_dim )
1957
+ logging .info (
1958
+ f"split_optimizer_states: { optimizer_dim = } { pad4_optimizer_dim = } { self .optimizer_dtype = } "
1959
+ )
1960
+
1961
+ for t , (emb_height , emb_dim ) in enumerate (self .embedding_specs ):
1962
+ # pyre-ignore
1963
+ bucket_id_start , _ = self .kv_zch_params .bucket_offsets [t ]
1964
+ # pyre-ignore
1965
+ bucket_size = self .kv_zch_params .bucket_sizes [t ]
1966
+ row_offset = table_offset
1967
+
1968
+ if not self .enable_optimizer_offloading :
1969
+ if sorted_id_tensor is None or sorted_id_tensor [t ].numel () == 0 :
1970
+ opt_list .append (
1971
+ self .momentum1_dev .detach ()
1972
+ .cpu ()[0 :1 ]
1973
+ .view (- 1 ) # dummy tensor for module initialization
1974
+ )
1975
+ elif all (sorted_id_tensor [t ] == 0 ):
1976
+ # all ids are 0, which means it's the dummy id tensor with correct shape just for loading cp
1977
+ opt_list .append (
1978
+ (
1979
+ self .momentum1_dev .detach ().cpu () # the shape should be correct at this point
1980
+ )
1981
+ )
1982
+ else :
1983
+ # convert global id back to local id, then linearize with table offset
1984
+ local_id_tensor = (
1985
+ sorted_id_tensor [t ]
1986
+ - bucket_id_start * bucket_size
1987
+ + table_offset
1988
+ )
1989
+ opt_list .append (
1990
+ self .momentum1_dev .detach ().cpu ()[local_id_tensor ].view (- 1 ),
1991
+ )
1992
+ else :
1993
+ if sorted_id_tensor is None or sorted_id_tensor [t ].numel () == 0 :
1994
+ opt_list .append (
1995
+ torch .empty (
1996
+ 1 ,
1997
+ dtype = torch .float32 ,
1998
+ device = "cpu" ,
1999
+ ) # dummy tensor for module initialization
2000
+ )
2001
+ else :
2002
+ emb_opt_dim = pad4 (emb_dim ) + pad4_optimizer_dim
2003
+ row_offset = table_offset - (bucket_id_start * bucket_size )
2004
+ # using KVTensorWrapper to query backend to avoid OOM memory, since
2005
+ # backend will return both weight and optimizer in one tensor, read the whole tensor
2006
+ # out could OOM CPU memory.
2007
+ tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2008
+ shape = [emb_height , emb_opt_dim ],
2009
+ dtype = dtype ,
2010
+ row_offset = row_offset ,
2011
+ snapshot_handle = snapshot_handle ,
2012
+ materialized_shape = ([sorted_id_tensor [t ].size (0 ), emb_opt_dim ]),
2013
+ sorted_indices = sorted_id_tensor [t ],
2014
+ )
2015
+ (
2016
+ tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2017
+ if self .backend_type == BackendType .SSD
2018
+ else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
2019
+ )
2020
+ opt_list .append (
2021
+ self .get_offloaded_optimizer_states (
2022
+ tensor_wrapper = tensor_wrapper ,
2023
+ row = sorted_id_tensor [t ].size (
2024
+ 0
2025
+ ), # we only need to copy the size of sorted_id_tensor
2026
+ optimizer_dim = optimizer_dim ,
2027
+ start_dim_pos = pad4 (emb_dim ),
2028
+ )
2029
+ )
2030
+ table_offset += emb_height
2031
+ logging .info (
2032
+ f"KV ZCH tables split_optimizer_states query latency: { (time .time () - start_time ) * 1000 } ms"
1902
2033
)
2034
+ return opt_list
2035
+
2036
+ @torch .jit .export
2037
+ def get_offloaded_optimizer_states (
2038
+ self ,
2039
+ # pyre-ignore [11]: Annotation `KVTensorWrapper` is not defined as a type.
2040
+ tensor_wrapper : torch .classes .fbgemm .KVTensorWrapper ,
2041
+ row : int ,
2042
+ optimizer_dim : int ,
2043
+ start_dim_pos : int ,
2044
+ ) -> torch .Tensor :
2045
+ weight_dtype = self .weights_precision .as_dtype ()
2046
+ opt_state_t = torch .empty (
2047
+ row , optimizer_dim , dtype = weight_dtype , device = "cpu"
2048
+ ) # 1D optimizer for OptimType.EXACT_ROWWISE_ADAGRAD
2049
+
2050
+ # pyre-ignore [16]
2051
+ chunk_size = self .kv_zch_params .streaming_ckpt_chunk_size
2052
+ for i in range (0 , row , chunk_size ):
2053
+ length = min (chunk_size , row - i )
2054
+ opt_state_t .narrow (0 , i , length ).copy_ (
2055
+ tensor_wrapper .narrow (0 , i , length ).narrow (
2056
+ 1 , start_dim_pos , optimizer_dim
2057
+ )
2058
+ )
2059
+ # view optimizer state back to correct dtype
2060
+ return opt_state_t .view (- 1 ).view (self .optimizer_dtype )
1903
2061
1904
2062
@torch .jit .export
1905
2063
def get_optimizer_state (
1906
2064
self ,
1907
2065
sorted_id_tensor : Optional [List [torch .Tensor ]],
2066
+ no_snapshot : bool = True ,
2067
+ should_flush : bool = False ,
1908
2068
) -> List [Dict [str , torch .Tensor ]]:
1909
2069
"""
1910
2070
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD
@@ -1914,6 +2074,8 @@ def get_optimizer_state(
1914
2074
({"momentum1" : states })
1915
2075
for states in self .split_optimizer_states (
1916
2076
sorted_id_tensor = sorted_id_tensor ,
2077
+ no_snapshot = no_snapshot ,
2078
+ should_flush = should_flush ,
1917
2079
)
1918
2080
]
1919
2081
@@ -1963,6 +2125,7 @@ def debug_split_embedding_weights(self) -> List[torch.Tensor]:
1963
2125
return splits
1964
2126
1965
2127
def clear_cache (self ) -> None :
2128
+ # clear KV ZCH cache for checkpointing
1966
2129
self ._cached_kvzch_data = None
1967
2130
1968
2131
@torch .jit .export
0 commit comments