@@ -318,7 +318,7 @@ def __init__(
318318 f"weights precision: { weights_precision } , "
319319 f"output dtype: { output_dtype } , "
320320 f"chunk size in bulk init: { bulk_init_chunk_size } bytes, backend_type: { backend_type } , "
321- f"zero_collision_config : { kv_zch_params } "
321+ f"kv_zch_params : { kv_zch_params } "
322322 )
323323 self .register_buffer (
324324 "lxu_cache_state" ,
@@ -2009,7 +2009,6 @@ def split_optimizer_states(
20092009 dtype = dtype ,
20102010 row_offset = row_offset ,
20112011 snapshot_handle = snapshot_handle ,
2012- materialized_shape = ([sorted_id_tensor [t ].size (0 ), emb_opt_dim ]),
20132012 sorted_indices = sorted_id_tensor [t ],
20142013 )
20152014 (
@@ -2134,7 +2133,7 @@ def split_embedding_weights(
21342133 no_snapshot : bool = True ,
21352134 should_flush : bool = False ,
21362135 ) -> Tuple [ # TODO: make this a NamedTuple for readability
2137- List [PartiallyMaterializedTensor ],
2136+ List [PartiallyMaterializedTensor ] | List [ torch . Tensor ] ,
21382137 Optional [List [torch .Tensor ]],
21392138 Optional [List [torch .Tensor ]],
21402139 ]:
@@ -2157,6 +2156,7 @@ def split_embedding_weights(
21572156 3rd arg: active id count per bucket id, tensor size is [bucket_id_end - bucket_id_start]
21582157 where for the i th element, we have i + bucket_id_start = global bucket id
21592158 """
2159+ start_time = time .time ()
21602160 # Force device synchronize for now
21612161 torch .cuda .synchronize ()
21622162 snapshot_handle = None
@@ -2166,11 +2166,28 @@ def split_embedding_weights(
21662166 if should_flush :
21672167 # Flush L1 and L2 caches
21682168 self .flush (force = True )
2169+ logging .info (
2170+ f"flush latency for weight states: { (time .time () - start_time ) * 1000 } ms"
2171+ )
21692172 snapshot_handle = self .ssd_db .create_snapshot ()
2173+ logging .info (
2174+ f"created snapshot for weight states: { snapshot_handle } , latency: { (time .time () - start_time ) * 1000 } ms"
2175+ )
21702176 elif self .backend_type == BackendType .DRAM :
21712177 self .flush (force = True )
21722178
21732179 dtype = self .weights_precision .as_dtype ()
2180+ if self .load_state_dict and self .kv_zch_params :
2181+ # init for checkpointing loading
2182+ assert (
2183+ self ._cached_kvzch_data is not None
2184+ ), "weight id and bucket state are not initialized for load checkpointing"
2185+ return (
2186+ self ._cached_kvzch_data .cached_weight_tensor_per_table ,
2187+ self ._cached_kvzch_data .cached_id_tensor_per_table ,
2188+ self ._cached_kvzch_data .cached_bucket_splits ,
2189+ )
2190+
21742191 pmt_splits = []
21752192 bucket_sorted_id_splits = [] if self .kv_zch_params else None
21762193 active_id_cnt_per_bucket_split = [] if self .kv_zch_params else None
@@ -2179,18 +2196,15 @@ def split_embedding_weights(
21792196 for i , (emb_height , emb_dim ) in enumerate (self .embedding_specs ):
21802197 bucket_ascending_id_tensor = None
21812198 bucket_t = None
2199+ row_offset = table_offset
21822200 if self .kv_zch_params :
21832201 bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
21842202 # pyre-ignore
21852203 bucket_size = self .kv_zch_params .bucket_sizes [i ]
21862204
21872205 # linearize with table offset
2188- table_input_id_start = (
2189- min (bucket_id_start * bucket_size , emb_height ) + table_offset
2190- )
2191- table_input_id_end = (
2192- min (bucket_id_end * bucket_size , emb_height ) + table_offset
2193- )
2206+ table_input_id_start = table_offset
2207+ table_input_id_end = table_offset + emb_height
21942208 # 1. get all keys from backend for one table
21952209 unordered_id_tensor = self ._ssd_db .get_keys_in_range_by_snapshot (
21962210 table_input_id_start ,
@@ -2203,15 +2217,38 @@ def split_embedding_weights(
22032217 torch .ops .fbgemm .get_bucket_sorted_indices_and_bucket_tensor (
22042218 unordered_id_tensor ,
22052219 0 , # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
2206- bucket_id_start ,
2207- bucket_id_end ,
2220+ 0 , # local bucket offset
2221+ bucket_id_end - bucket_id_start , # local bucket num
22082222 bucket_size ,
22092223 )
22102224 )
2211- # pyre-ignore
2225+ # 3. convert local id back to global id
2226+ bucket_ascending_id_tensor .add_ (bucket_id_start * bucket_size )
2227+
2228+ if (
2229+ bucket_ascending_id_tensor .size (0 ) == 0
2230+ and self .local_weight_counts [i ] > 0
2231+ ):
2232+ logging .info (
2233+ f"resetting bucket id tensor with { self .local_weight_counts [i ]} "
2234+ )
2235+ bucket_ascending_id_tensor = torch .zeros (
2236+ (self .local_weight_counts [i ], 1 ),
2237+ device = torch .device ("cpu" ),
2238+ dtype = torch .int64 ,
2239+ )
2240+ # self.local_weight_counts[i] = 0 # Reset the count
2241+
2242+ # pyre-ignore [16] bucket_sorted_id_splits is not None
22122243 bucket_sorted_id_splits .append (bucket_ascending_id_tensor )
22132244 active_id_cnt_per_bucket_split .append (bucket_t )
22142245
2246+ # for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
2247+ # but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
2248+ # first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
2249+ # to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
2250+ row_offset = table_offset - (bucket_id_start * bucket_size )
2251+
22152252 tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
22162253 shape = [
22172254 (
@@ -2222,33 +2259,184 @@ def split_embedding_weights(
22222259 emb_dim ,
22232260 ],
22242261 dtype = dtype ,
2225- row_offset = table_offset ,
2262+ row_offset = row_offset ,
22262263 snapshot_handle = snapshot_handle ,
2264+ # set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
2265+ # embedding weights.
22272266 sorted_indices = (
22282267 bucket_ascending_id_tensor if self .kv_zch_params else None
22292268 ),
22302269 )
2231- # TODO add if else support in the future for dram integration.
2232- tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2270+ (
2271+ tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2272+ if self .backend_type == BackendType .SSD
2273+ else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
2274+ )
22332275 table_offset += emb_height
22342276 pmt_splits .append (
22352277 PartiallyMaterializedTensor (
22362278 tensor_wrapper ,
22372279 True if self .kv_zch_params else False ,
22382280 )
22392281 )
2282+ logging .info (
2283+ f"split_embedding_weights latency: { (time .time () - start_time ) * 1000 } ms"
2284+ )
22402285 return (pmt_splits , bucket_sorted_id_splits , active_id_cnt_per_bucket_split )
22412286
22422287 @torch .jit .ignore
22432288 def apply_state_dict (self ) -> None :
22442289 # After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
22452290 # Caller should call this function to apply the cached states to backend.
2246- pass
2291+ if self .load_state_dict is False :
2292+ return
2293+ self .load_state_dict = False
2294+ assert self .kv_zch_params is not None , "apply_state_dict supports KV ZCH only"
2295+ assert (
2296+ self ._cached_kvzch_data is not None
2297+ and self ._cached_kvzch_data .cached_optimizer_state_per_table is not None
2298+ ), "optimizer state is not initialized for load checkpointing"
2299+ assert (
2300+ self ._cached_kvzch_data .cached_weight_tensor_per_table is not None
2301+ and self ._cached_kvzch_data .cached_id_tensor_per_table is not None
2302+ ), "weight and id state is not initialized for load checkpointing"
2303+
2304+ # Compute the number of elements of cache_dtype needed to store the
2305+ # optimizer state, round to the nearest 4
2306+ # optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
2307+ # apply weight and optimizer state per table
2308+ table_offset = 0
2309+ for i , (emb_height , _ ) in enumerate (self .embedding_specs ):
2310+ # pyre-ignore [16]
2311+ bucket_id_start , _ = self .kv_zch_params .bucket_offsets [i ]
2312+ # pyre-ignore [16]
2313+ bucket_size = self .kv_zch_params .bucket_sizes [i ]
2314+ row_offset = table_offset - bucket_id_start * bucket_size
2315+
2316+ if self .enable_optimizer_offloading :
2317+ # pyre-ignore [16]
2318+ weight_state = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2319+ # pyre-ignore [16]
2320+ opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2321+ self .streaming_write_weight_and_id_per_table (
2322+ weight_state ,
2323+ opt_state ,
2324+ # pyre-ignore [16]
2325+ self ._cached_kvzch_data .cached_id_tensor_per_table [i ],
2326+ row_offset ,
2327+ )
2328+ self ._cached_kvzch_data .cached_weight_tensor_per_table [i ] = None
2329+ self ._cached_kvzch_data .cached_optimizer_state_per_table [i ] = None
2330+ else :
2331+ weight = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2332+ id = self ._cached_kvzch_data .cached_id_tensor_per_table [i ]
2333+ local_id = id + row_offset
2334+ logging .info (
2335+ f"applying sd for table { i } without optimizer offloading, local_id is { local_id } "
2336+ )
2337+ opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2338+ t_device = self .momentum1_dev .device
2339+ self .momentum1_dev .index_put_ (
2340+ indices = (
2341+ local_id .to (t_device ).view (- 1 ),
2342+ ), # expects tuple of tensors
2343+ values = opt_state .to (t_device ),
2344+ )
2345+ self .ssd_db .set_cuda (
2346+ local_id .view (- 1 ),
2347+ weight ,
2348+ torch .as_tensor (local_id .size (0 )),
2349+ 1 ,
2350+ False ,
2351+ )
2352+ table_offset += emb_height
2353+ self .clear_cache ()
2354+
2355+ @torch .jit .ignore
2356+ def streaming_write_weight_and_id_per_table (
2357+ self ,
2358+ weight_state : torch .Tensor ,
2359+ opt_state : torch .Tensor ,
2360+ id_tensor : torch .Tensor ,
2361+ row_offset : int ,
2362+ ) -> None :
2363+ """
2364+ This function is used to write weight, optimizer and id to the backend using kvt wrapper.
2365+ to avoid over use memory, we will write the weight and id to backend in a rolling window manner
2366+
2367+ Args:
2368+ weight_state (torch.tensor): The weight state tensor to be written.
2369+ opt_state (torch.tensor): The optimizer state tensor to be written.
2370+ id_tensor (torch.tensor): The id tensor to be written.
2371+ """
2372+ D_rounded = pad4 (weight_state .size (1 )) # padded to 4 bytes alignment
2373+ dtype = self .weights_precision .as_dtype ()
2374+ kvt = torch .classes .fbgemm .KVTensorWrapper (
2375+ db = self .ssd_db ,
2376+ shape = [weight_state .size (0 ), self .cache_row_dim ],
2377+ dtype = dtype ,
2378+ row_offset = row_offset ,
2379+ snapshot_handle = None ,
2380+ sorted_indices = id_tensor ,
2381+ )
2382+ # TODO: make chunk_size configurable or dynamic
2383+ chunk_size = 10000
2384+ row = weight_state .size (0 )
2385+ optimizer_dim = self .optimizer .state_size_dim (dtype )
2386+ opt_state_2d = opt_state .view (dtype ).view (- 1 , optimizer_dim )
2387+ for i in range (0 , row , chunk_size ):
2388+ length = min (chunk_size , row - i )
2389+ chunk_buffer = torch .empty (
2390+ length ,
2391+ self .cache_row_dim ,
2392+ dtype = dtype ,
2393+ device = "cpu" ,
2394+ )
2395+ chunk_buffer [:, : weight_state .size (1 )] = weight_state [i : i + length , :]
2396+ chunk_buffer [:, D_rounded : D_rounded + optimizer_dim ] = opt_state_2d [
2397+ i : i + length , :
2398+ ]
2399+ kvt .set_weights_and_ids (chunk_buffer , id_tensor [i : i + length , :].view (- 1 ))
22472400
22482401 @torch .jit .ignore
22492402 def enable_load_state_dict_mode (self ) -> None :
22502403 # Enable load state dict mode before loading checkpoint
2251- pass
2404+ if self .load_state_dict :
2405+ return
2406+ self .load_state_dict = True
2407+
2408+ dtype = self .weights_precision .as_dtype ()
2409+ self ._cached_kvzch_data = KVZCHCachedData ([], [], [], [])
2410+ for i , (_ , emb_dim ) in enumerate (self .embedding_specs ):
2411+ # for checkpointing loading, we need to store the weight and id tensor temporarily in memory
2412+ assert (
2413+ self .local_weight_counts [i ] > 0
2414+ ), f"local_weight_counts for table { i } is not set"
2415+ # pyre-ignore [16]
2416+ bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
2417+ rows = self .local_weight_counts [i ]
2418+ weight_state = torch .empty (rows , emb_dim , dtype = dtype , device = "cpu" )
2419+ opt_state = torch .empty (rows , dtype = torch .float32 , device = "cpu" )
2420+ # pyre-ignore [16]
2421+ self ._cached_kvzch_data .cached_weight_tensor_per_table .append (weight_state )
2422+ # pyre-ignore [16]
2423+ self ._cached_kvzch_data .cached_optimizer_state_per_table .append (opt_state )
2424+ logging .info (
2425+ f"for checkpoint loading, table { i } , weight_state shape is { weight_state .shape } , opt_state shape is { opt_state .shape } "
2426+ )
2427+ id_tensor = torch .zeros (
2428+ (self .local_weight_counts [i ], 1 ), dtype = torch .int64 , device = "cpu"
2429+ )
2430+ # pyre-ignore [16]
2431+ self ._cached_kvzch_data .cached_id_tensor_per_table .append (id_tensor )
2432+ # pyre-ignore [16]
2433+ self ._cached_kvzch_data .cached_bucket_splits .append (
2434+ torch .empty (
2435+ (bucket_id_end - bucket_id_start , 1 ),
2436+ dtype = torch .int64 ,
2437+ device = "cpu" ,
2438+ )
2439+ )
22522440
22532441 @torch .jit .export
22542442 def set_learning_rate (self , lr : float ) -> None :
0 commit comments