@@ -317,7 +317,7 @@ def __init__(
317
317
f"weights precision: { weights_precision } , "
318
318
f"output dtype: { output_dtype } , "
319
319
f"chunk size in bulk init: { bulk_init_chunk_size } bytes, backend_type: { backend_type } , "
320
- f"zero_collision_config : { kv_zch_params } "
320
+ f"kv_zch_params : { kv_zch_params } "
321
321
)
322
322
self .register_buffer (
323
323
"lxu_cache_state" ,
@@ -1986,7 +1986,6 @@ def split_optimizer_states(
1986
1986
dtype = dtype ,
1987
1987
row_offset = row_offset ,
1988
1988
snapshot_handle = snapshot_handle ,
1989
- materialized_shape = ([sorted_id_tensor [t ].size (0 ), emb_opt_dim ]),
1990
1989
sorted_indices = sorted_id_tensor [t ],
1991
1990
)
1992
1991
(
@@ -2115,6 +2114,7 @@ def _may_create_snapshot_for_state_dict(
2115
2114
"""
2116
2115
Create a rocksdb snapshot if needed.
2117
2116
"""
2117
+ start_time = time .time ()
2118
2118
# Force device synchronize for now
2119
2119
torch .cuda .synchronize ()
2120
2120
snapshot_handle = None
@@ -2124,7 +2124,13 @@ def _may_create_snapshot_for_state_dict(
2124
2124
if should_flush :
2125
2125
# Flush L1 and L2 caches
2126
2126
self .flush (force = True )
2127
+ logging .info (
2128
+ f"flush latency for weight states: { (time .time () - start_time ) * 1000 } ms"
2129
+ )
2127
2130
snapshot_handle = self .ssd_db .create_snapshot ()
2131
+ logging .info (
2132
+ f"created snapshot for weight states: { snapshot_handle } , latency: { (time .time () - start_time ) * 1000 } ms"
2133
+ )
2128
2134
elif self .backend_type == BackendType .DRAM :
2129
2135
if should_flush :
2130
2136
self .flush (force = True )
@@ -2136,7 +2142,7 @@ def split_embedding_weights(
2136
2142
no_snapshot : bool = True ,
2137
2143
should_flush : bool = False ,
2138
2144
) -> Tuple [ # TODO: make this a NamedTuple for readability
2139
- List [PartiallyMaterializedTensor ],
2145
+ List [PartiallyMaterializedTensor ] | List [ torch . Tensor ] ,
2140
2146
Optional [List [torch .Tensor ]],
2141
2147
Optional [List [torch .Tensor ]],
2142
2148
]:
@@ -2165,6 +2171,17 @@ def split_embedding_weights(
2165
2171
)
2166
2172
2167
2173
dtype = self .weights_precision .as_dtype ()
2174
+ if self .load_state_dict and self .kv_zch_params :
2175
+ # init for checkpointing loading
2176
+ assert (
2177
+ self ._cached_kvzch_data is not None
2178
+ ), "weight id and bucket state are not initialized for load checkpointing"
2179
+ return (
2180
+ self ._cached_kvzch_data .cached_weight_tensor_per_table ,
2181
+ self ._cached_kvzch_data .cached_id_tensor_per_table ,
2182
+ self ._cached_kvzch_data .cached_bucket_splits ,
2183
+ )
2184
+ start_time = time .time ()
2168
2185
pmt_splits = []
2169
2186
bucket_sorted_id_splits = [] if self .kv_zch_params else None
2170
2187
active_id_cnt_per_bucket_split = [] if self .kv_zch_params else None
@@ -2173,18 +2190,15 @@ def split_embedding_weights(
2173
2190
for i , (emb_height , emb_dim ) in enumerate (self .embedding_specs ):
2174
2191
bucket_ascending_id_tensor = None
2175
2192
bucket_t = None
2193
+ row_offset = table_offset
2176
2194
if self .kv_zch_params :
2177
2195
bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
2178
2196
# pyre-ignore
2179
2197
bucket_size = self .kv_zch_params .bucket_sizes [i ]
2180
2198
2181
2199
# linearize with table offset
2182
- table_input_id_start = (
2183
- min (bucket_id_start * bucket_size , emb_height ) + table_offset
2184
- )
2185
- table_input_id_end = (
2186
- min (bucket_id_end * bucket_size , emb_height ) + table_offset
2187
- )
2200
+ table_input_id_start = table_offset
2201
+ table_input_id_end = table_offset + emb_height
2188
2202
# 1. get all keys from backend for one table
2189
2203
unordered_id_tensor = self ._ssd_db .get_keys_in_range_by_snapshot (
2190
2204
table_input_id_start ,
@@ -2197,15 +2211,38 @@ def split_embedding_weights(
2197
2211
torch .ops .fbgemm .get_bucket_sorted_indices_and_bucket_tensor (
2198
2212
unordered_id_tensor ,
2199
2213
0 , # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
2200
- bucket_id_start ,
2201
- bucket_id_end ,
2214
+ 0 , # local bucket offset
2215
+ bucket_id_end - bucket_id_start , # local bucket num
2202
2216
bucket_size ,
2203
2217
)
2204
2218
)
2205
- # pyre-ignore
2219
+ # 3. convert local id back to global id
2220
+ bucket_ascending_id_tensor .add_ (bucket_id_start * bucket_size )
2221
+
2222
+ if (
2223
+ bucket_ascending_id_tensor .size (0 ) == 0
2224
+ and self .local_weight_counts [i ] > 0
2225
+ ):
2226
+ logging .info (
2227
+ f"resetting bucket id tensor with { self .local_weight_counts [i ]} "
2228
+ )
2229
+ bucket_ascending_id_tensor = torch .zeros (
2230
+ (self .local_weight_counts [i ], 1 ),
2231
+ device = torch .device ("cpu" ),
2232
+ dtype = torch .int64 ,
2233
+ )
2234
+ # self.local_weight_counts[i] = 0 # Reset the count
2235
+
2236
+ # pyre-ignore [16] bucket_sorted_id_splits is not None
2206
2237
bucket_sorted_id_splits .append (bucket_ascending_id_tensor )
2207
2238
active_id_cnt_per_bucket_split .append (bucket_t )
2208
2239
2240
+ # for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
2241
+ # but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
2242
+ # first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
2243
+ # to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
2244
+ row_offset = table_offset - (bucket_id_start * bucket_size )
2245
+
2209
2246
tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2210
2247
shape = [
2211
2248
(
@@ -2216,33 +2253,184 @@ def split_embedding_weights(
2216
2253
emb_dim ,
2217
2254
],
2218
2255
dtype = dtype ,
2219
- row_offset = table_offset ,
2256
+ row_offset = row_offset ,
2220
2257
snapshot_handle = snapshot_handle ,
2258
+ # set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
2259
+ # embedding weights.
2221
2260
sorted_indices = (
2222
2261
bucket_ascending_id_tensor if self .kv_zch_params else None
2223
2262
),
2224
2263
)
2225
- # TODO add if else support in the future for dram integration.
2226
- tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2264
+ (
2265
+ tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2266
+ if self .backend_type == BackendType .SSD
2267
+ else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
2268
+ )
2227
2269
table_offset += emb_height
2228
2270
pmt_splits .append (
2229
2271
PartiallyMaterializedTensor (
2230
2272
tensor_wrapper ,
2231
2273
True if self .kv_zch_params else False ,
2232
2274
)
2233
2275
)
2276
+ logging .info (
2277
+ f"split_embedding_weights latency: { (time .time () - start_time ) * 1000 } ms"
2278
+ )
2234
2279
return (pmt_splits , bucket_sorted_id_splits , active_id_cnt_per_bucket_split )
2235
2280
2236
2281
@torch .jit .ignore
2237
2282
def apply_state_dict (self ) -> None :
2238
2283
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
2239
2284
# Caller should call this function to apply the cached states to backend.
2240
- pass
2285
+ if self .load_state_dict is False :
2286
+ return
2287
+ self .load_state_dict = False
2288
+ assert self .kv_zch_params is not None , "apply_state_dict supports KV ZCH only"
2289
+ assert (
2290
+ self ._cached_kvzch_data is not None
2291
+ and self ._cached_kvzch_data .cached_optimizer_state_per_table is not None
2292
+ ), "optimizer state is not initialized for load checkpointing"
2293
+ assert (
2294
+ self ._cached_kvzch_data .cached_weight_tensor_per_table is not None
2295
+ and self ._cached_kvzch_data .cached_id_tensor_per_table is not None
2296
+ ), "weight and id state is not initialized for load checkpointing"
2297
+
2298
+ # Compute the number of elements of cache_dtype needed to store the
2299
+ # optimizer state, round to the nearest 4
2300
+ # optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
2301
+ # apply weight and optimizer state per table
2302
+ table_offset = 0
2303
+ for i , (emb_height , _ ) in enumerate (self .embedding_specs ):
2304
+ # pyre-ignore [16]
2305
+ bucket_id_start , _ = self .kv_zch_params .bucket_offsets [i ]
2306
+ # pyre-ignore [16]
2307
+ bucket_size = self .kv_zch_params .bucket_sizes [i ]
2308
+ row_offset = table_offset - bucket_id_start * bucket_size
2309
+
2310
+ if self .enable_optimizer_offloading :
2311
+ # pyre-ignore [16]
2312
+ weight_state = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2313
+ # pyre-ignore [16]
2314
+ opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2315
+ self .streaming_write_weight_and_id_per_table (
2316
+ weight_state ,
2317
+ opt_state ,
2318
+ # pyre-ignore [16]
2319
+ self ._cached_kvzch_data .cached_id_tensor_per_table [i ],
2320
+ row_offset ,
2321
+ )
2322
+ self ._cached_kvzch_data .cached_weight_tensor_per_table [i ] = None
2323
+ self ._cached_kvzch_data .cached_optimizer_state_per_table [i ] = None
2324
+ else :
2325
+ weight = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2326
+ id = self ._cached_kvzch_data .cached_id_tensor_per_table [i ]
2327
+ local_id = id + row_offset
2328
+ logging .info (
2329
+ f"applying sd for table { i } without optimizer offloading, local_id is { local_id } "
2330
+ )
2331
+ opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2332
+ t_device = self .momentum1_dev .device
2333
+ self .momentum1_dev .index_put_ (
2334
+ indices = (
2335
+ local_id .to (t_device ).view (- 1 ),
2336
+ ), # expects tuple of tensors
2337
+ values = opt_state .to (t_device ),
2338
+ )
2339
+ self .ssd_db .set_cuda (
2340
+ local_id .view (- 1 ),
2341
+ weight ,
2342
+ torch .as_tensor (local_id .size (0 )),
2343
+ 1 ,
2344
+ False ,
2345
+ )
2346
+ table_offset += emb_height
2347
+ self .clear_cache ()
2348
+
2349
+ @torch .jit .ignore
2350
+ def streaming_write_weight_and_id_per_table (
2351
+ self ,
2352
+ weight_state : torch .Tensor ,
2353
+ opt_state : torch .Tensor ,
2354
+ id_tensor : torch .Tensor ,
2355
+ row_offset : int ,
2356
+ ) -> None :
2357
+ """
2358
+ This function is used to write weight, optimizer and id to the backend using kvt wrapper.
2359
+ to avoid over use memory, we will write the weight and id to backend in a rolling window manner
2360
+
2361
+ Args:
2362
+ weight_state (torch.tensor): The weight state tensor to be written.
2363
+ opt_state (torch.tensor): The optimizer state tensor to be written.
2364
+ id_tensor (torch.tensor): The id tensor to be written.
2365
+ """
2366
+ D_rounded = pad4 (weight_state .size (1 )) # padded to 4 bytes alignment
2367
+ dtype = self .weights_precision .as_dtype ()
2368
+ kvt = torch .classes .fbgemm .KVTensorWrapper (
2369
+ db = self .ssd_db ,
2370
+ shape = [weight_state .size (0 ), self .cache_row_dim ],
2371
+ dtype = dtype ,
2372
+ row_offset = row_offset ,
2373
+ snapshot_handle = None ,
2374
+ sorted_indices = id_tensor ,
2375
+ )
2376
+ # TODO: make chunk_size configurable or dynamic
2377
+ chunk_size = 10000
2378
+ row = weight_state .size (0 )
2379
+ optimizer_dim = self .optimizer .state_size_dim (dtype )
2380
+ opt_state_2d = opt_state .view (dtype ).view (- 1 , optimizer_dim )
2381
+ for i in range (0 , row , chunk_size ):
2382
+ length = min (chunk_size , row - i )
2383
+ chunk_buffer = torch .empty (
2384
+ length ,
2385
+ self .cache_row_dim ,
2386
+ dtype = dtype ,
2387
+ device = "cpu" ,
2388
+ )
2389
+ chunk_buffer [:, : weight_state .size (1 )] = weight_state [i : i + length , :]
2390
+ chunk_buffer [:, D_rounded : D_rounded + optimizer_dim ] = opt_state_2d [
2391
+ i : i + length , :
2392
+ ]
2393
+ kvt .set_weights_and_ids (chunk_buffer , id_tensor [i : i + length , :].view (- 1 ))
2241
2394
2242
2395
@torch .jit .ignore
2243
2396
def enable_load_state_dict_mode (self ) -> None :
2244
2397
# Enable load state dict mode before loading checkpoint
2245
- pass
2398
+ if self .load_state_dict :
2399
+ return
2400
+ self .load_state_dict = True
2401
+
2402
+ dtype = self .weights_precision .as_dtype ()
2403
+ self ._cached_kvzch_data = KVZCHCachedData ([], [], [], [])
2404
+ for i , (_ , emb_dim ) in enumerate (self .embedding_specs ):
2405
+ # for checkpointing loading, we need to store the weight and id tensor temporarily in memory
2406
+ assert (
2407
+ self .local_weight_counts [i ] > 0
2408
+ ), f"local_weight_counts for table { i } is not set"
2409
+ # pyre-ignore [16]
2410
+ bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
2411
+ rows = self .local_weight_counts [i ]
2412
+ weight_state = torch .empty (rows , emb_dim , dtype = dtype , device = "cpu" )
2413
+ opt_state = torch .empty (rows , dtype = torch .float32 , device = "cpu" )
2414
+ # pyre-ignore [16]
2415
+ self ._cached_kvzch_data .cached_weight_tensor_per_table .append (weight_state )
2416
+ # pyre-ignore [16]
2417
+ self ._cached_kvzch_data .cached_optimizer_state_per_table .append (opt_state )
2418
+ logging .info (
2419
+ f"for checkpoint loading, table { i } , weight_state shape is { weight_state .shape } , opt_state shape is { opt_state .shape } "
2420
+ )
2421
+ id_tensor = torch .zeros (
2422
+ (self .local_weight_counts [i ], 1 ), dtype = torch .int64 , device = "cpu"
2423
+ )
2424
+ # pyre-ignore [16]
2425
+ self ._cached_kvzch_data .cached_id_tensor_per_table .append (id_tensor )
2426
+ # pyre-ignore [16]
2427
+ self ._cached_kvzch_data .cached_bucket_splits .append (
2428
+ torch .empty (
2429
+ (bucket_id_end - bucket_id_start , 1 ),
2430
+ dtype = torch .int64 ,
2431
+ device = "cpu" ,
2432
+ )
2433
+ )
2246
2434
2247
2435
@torch .jit .export
2248
2436
def set_learning_rate (self , lr : float ) -> None :
0 commit comments