@@ -313,7 +313,7 @@ def __init__(
313
313
f"weights precision: { weights_precision } , "
314
314
f"output dtype: { output_dtype } , "
315
315
f"chunk size in bulk init: { bulk_init_chunk_size } bytes, backend_type: { backend_type } , "
316
- f"zero_collision_config : { kv_zch_params } "
316
+ f"kv_zch_params : { kv_zch_params } "
317
317
)
318
318
self .register_buffer (
319
319
"lxu_cache_state" ,
@@ -2050,7 +2050,6 @@ def split_optimizer_states(
2050
2050
dtype = dtype ,
2051
2051
row_offset = row_offset ,
2052
2052
snapshot_handle = snapshot_handle ,
2053
- materialized_shape = ([sorted_id_tensor [t ].size (0 ), emb_opt_dim ]),
2054
2053
sorted_indices = sorted_id_tensor [t ],
2055
2054
)
2056
2055
(
@@ -2179,6 +2178,7 @@ def _may_create_snapshot_for_state_dict(
2179
2178
"""
2180
2179
Create a rocksdb snapshot if needed.
2181
2180
"""
2181
+ start_time = time .time ()
2182
2182
# Force device synchronize for now
2183
2183
torch .cuda .synchronize ()
2184
2184
snapshot_handle = None
@@ -2187,7 +2187,13 @@ def _may_create_snapshot_for_state_dict(
2187
2187
if not no_snapshot :
2188
2188
# Flush L1 and L2 caches
2189
2189
self .flush (force = should_flush )
2190
+ logging .info (
2191
+ f"flush latency for weight states: { (time .time () - start_time ) * 1000 } ms"
2192
+ )
2190
2193
snapshot_handle = self .ssd_db .create_snapshot ()
2194
+ logging .info (
2195
+ f"created snapshot for weight states: { snapshot_handle } , latency: { (time .time () - start_time ) * 1000 } ms"
2196
+ )
2191
2197
elif self .backend_type == BackendType .DRAM :
2192
2198
self .flush (force = should_flush )
2193
2199
return snapshot_handle
@@ -2198,7 +2204,7 @@ def split_embedding_weights(
2198
2204
no_snapshot : bool = True ,
2199
2205
should_flush : bool = False ,
2200
2206
) -> Tuple [ # TODO: make this a NamedTuple for readability
2201
- List [PartiallyMaterializedTensor ],
2207
+ Union [ List [PartiallyMaterializedTensor ], List [ torch . Tensor ] ],
2202
2208
Optional [List [torch .Tensor ]],
2203
2209
Optional [List [torch .Tensor ]],
2204
2210
]:
@@ -2227,6 +2233,17 @@ def split_embedding_weights(
2227
2233
)
2228
2234
2229
2235
dtype = self .weights_precision .as_dtype ()
2236
+ if self .load_state_dict and self .kv_zch_params :
2237
+ # init for checkpointing loading
2238
+ assert (
2239
+ self ._cached_kvzch_data is not None
2240
+ ), "weight id and bucket state are not initialized for load checkpointing"
2241
+ return (
2242
+ self ._cached_kvzch_data .cached_weight_tensor_per_table ,
2243
+ self ._cached_kvzch_data .cached_id_tensor_per_table ,
2244
+ self ._cached_kvzch_data .cached_bucket_splits ,
2245
+ )
2246
+ start_time = time .time ()
2230
2247
pmt_splits = []
2231
2248
bucket_sorted_id_splits = [] if self .kv_zch_params else None
2232
2249
active_id_cnt_per_bucket_split = [] if self .kv_zch_params else None
@@ -2235,18 +2252,15 @@ def split_embedding_weights(
2235
2252
for i , (emb_height , emb_dim ) in enumerate (self .embedding_specs ):
2236
2253
bucket_ascending_id_tensor = None
2237
2254
bucket_t = None
2255
+ row_offset = table_offset
2238
2256
if self .kv_zch_params :
2239
2257
bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
2240
2258
# pyre-ignore
2241
2259
bucket_size = self .kv_zch_params .bucket_sizes [i ]
2242
2260
2243
2261
# linearize with table offset
2244
- table_input_id_start = (
2245
- min (bucket_id_start * bucket_size , emb_height ) + table_offset
2246
- )
2247
- table_input_id_end = (
2248
- min (bucket_id_end * bucket_size , emb_height ) + table_offset
2249
- )
2262
+ table_input_id_start = table_offset
2263
+ table_input_id_end = table_offset + emb_height
2250
2264
# 1. get all keys from backend for one table
2251
2265
unordered_id_tensor = self ._ssd_db .get_keys_in_range_by_snapshot (
2252
2266
table_input_id_start ,
@@ -2259,15 +2273,38 @@ def split_embedding_weights(
2259
2273
torch .ops .fbgemm .get_bucket_sorted_indices_and_bucket_tensor (
2260
2274
unordered_id_tensor ,
2261
2275
0 , # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
2262
- bucket_id_start ,
2263
- bucket_id_end ,
2276
+ 0 , # local bucket offset
2277
+ bucket_id_end - bucket_id_start , # local bucket num
2264
2278
bucket_size ,
2265
2279
)
2266
2280
)
2267
- # pyre-ignore
2281
+ # 3. convert local id back to global id
2282
+ bucket_ascending_id_tensor .add_ (bucket_id_start * bucket_size )
2283
+
2284
+ if (
2285
+ bucket_ascending_id_tensor .size (0 ) == 0
2286
+ and self .local_weight_counts [i ] > 0
2287
+ ):
2288
+ logging .info (
2289
+ f"resetting bucket id tensor with { self .local_weight_counts [i ]} "
2290
+ )
2291
+ bucket_ascending_id_tensor = torch .zeros (
2292
+ (self .local_weight_counts [i ], 1 ),
2293
+ device = torch .device ("cpu" ),
2294
+ dtype = torch .int64 ,
2295
+ )
2296
+ # self.local_weight_counts[i] = 0 # Reset the count
2297
+
2298
+ # pyre-ignore [16] bucket_sorted_id_splits is not None
2268
2299
bucket_sorted_id_splits .append (bucket_ascending_id_tensor )
2269
2300
active_id_cnt_per_bucket_split .append (bucket_t )
2270
2301
2302
+ # for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
2303
+ # but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
2304
+ # first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
2305
+ # to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
2306
+ row_offset = table_offset - (bucket_id_start * bucket_size )
2307
+
2271
2308
tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2272
2309
shape = [
2273
2310
(
@@ -2278,33 +2315,184 @@ def split_embedding_weights(
2278
2315
pad4 (emb_dim ),
2279
2316
],
2280
2317
dtype = dtype ,
2281
- row_offset = table_offset ,
2318
+ row_offset = row_offset ,
2282
2319
snapshot_handle = snapshot_handle ,
2320
+ # set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
2321
+ # embedding weights.
2283
2322
sorted_indices = (
2284
2323
bucket_ascending_id_tensor if self .kv_zch_params else None
2285
2324
),
2286
2325
)
2287
- # TODO add if else support in the future for dram integration.
2288
- tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2326
+ (
2327
+ tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2328
+ if self .backend_type == BackendType .SSD
2329
+ else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
2330
+ )
2289
2331
table_offset += emb_height
2290
2332
pmt_splits .append (
2291
2333
PartiallyMaterializedTensor (
2292
2334
tensor_wrapper ,
2293
2335
True if self .kv_zch_params else False ,
2294
2336
)
2295
2337
)
2338
+ logging .info (
2339
+ f"split_embedding_weights latency: { (time .time () - start_time ) * 1000 } ms"
2340
+ )
2296
2341
return (pmt_splits , bucket_sorted_id_splits , active_id_cnt_per_bucket_split )
2297
2342
2298
2343
@torch .jit .ignore
2299
2344
def apply_state_dict (self ) -> None :
2300
2345
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
2301
2346
# Caller should call this function to apply the cached states to backend.
2302
- pass
2347
+ if self .load_state_dict is False :
2348
+ return
2349
+ self .load_state_dict = False
2350
+ assert self .kv_zch_params is not None , "apply_state_dict supports KV ZCH only"
2351
+ assert (
2352
+ self ._cached_kvzch_data is not None
2353
+ and self ._cached_kvzch_data .cached_optimizer_state_per_table is not None
2354
+ ), "optimizer state is not initialized for load checkpointing"
2355
+ assert (
2356
+ self ._cached_kvzch_data .cached_weight_tensor_per_table is not None
2357
+ and self ._cached_kvzch_data .cached_id_tensor_per_table is not None
2358
+ ), "weight and id state is not initialized for load checkpointing"
2359
+
2360
+ # Compute the number of elements of cache_dtype needed to store the
2361
+ # optimizer state, round to the nearest 4
2362
+ # optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
2363
+ # apply weight and optimizer state per table
2364
+ table_offset = 0
2365
+ for i , (emb_height , _ ) in enumerate (self .embedding_specs ):
2366
+ # pyre-ignore [16]
2367
+ bucket_id_start , _ = self .kv_zch_params .bucket_offsets [i ]
2368
+ # pyre-ignore [16]
2369
+ bucket_size = self .kv_zch_params .bucket_sizes [i ]
2370
+ row_offset = table_offset - bucket_id_start * bucket_size
2371
+
2372
+ if self .enable_optimizer_offloading :
2373
+ # pyre-ignore [16]
2374
+ weight_state = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2375
+ # pyre-ignore [16]
2376
+ opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2377
+ self .streaming_write_weight_and_id_per_table (
2378
+ weight_state ,
2379
+ opt_state ,
2380
+ # pyre-ignore [16]
2381
+ self ._cached_kvzch_data .cached_id_tensor_per_table [i ],
2382
+ row_offset ,
2383
+ )
2384
+ self ._cached_kvzch_data .cached_weight_tensor_per_table [i ] = None
2385
+ self ._cached_kvzch_data .cached_optimizer_state_per_table [i ] = None
2386
+ else :
2387
+ weight = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2388
+ id = self ._cached_kvzch_data .cached_id_tensor_per_table [i ]
2389
+ local_id = id + row_offset
2390
+ logging .info (
2391
+ f"applying sd for table { i } without optimizer offloading, local_id is { local_id } "
2392
+ )
2393
+ opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2394
+ t_device = self .momentum1_dev .device
2395
+ self .momentum1_dev .index_put_ (
2396
+ indices = (
2397
+ local_id .to (t_device ).view (- 1 ),
2398
+ ), # expects tuple of tensors
2399
+ values = opt_state .to (t_device ),
2400
+ )
2401
+ self .ssd_db .set_cuda (
2402
+ local_id .view (- 1 ),
2403
+ weight ,
2404
+ torch .as_tensor (local_id .size (0 )),
2405
+ 1 ,
2406
+ False ,
2407
+ )
2408
+ table_offset += emb_height
2409
+ self .clear_cache ()
2410
+
2411
+ @torch .jit .ignore
2412
+ def streaming_write_weight_and_id_per_table (
2413
+ self ,
2414
+ weight_state : torch .Tensor ,
2415
+ opt_state : torch .Tensor ,
2416
+ id_tensor : torch .Tensor ,
2417
+ row_offset : int ,
2418
+ ) -> None :
2419
+ """
2420
+ This function is used to write weight, optimizer and id to the backend using kvt wrapper.
2421
+ to avoid over use memory, we will write the weight and id to backend in a rolling window manner
2422
+
2423
+ Args:
2424
+ weight_state (torch.tensor): The weight state tensor to be written.
2425
+ opt_state (torch.tensor): The optimizer state tensor to be written.
2426
+ id_tensor (torch.tensor): The id tensor to be written.
2427
+ """
2428
+ D_rounded = pad4 (weight_state .size (1 )) # padded to 4 bytes alignment
2429
+ dtype = self .weights_precision .as_dtype ()
2430
+ kvt = torch .classes .fbgemm .KVTensorWrapper (
2431
+ db = self .ssd_db ,
2432
+ shape = [weight_state .size (0 ), self .cache_row_dim ],
2433
+ dtype = dtype ,
2434
+ row_offset = row_offset ,
2435
+ snapshot_handle = None ,
2436
+ sorted_indices = id_tensor ,
2437
+ )
2438
+ # TODO: make chunk_size configurable or dynamic
2439
+ chunk_size = 10000
2440
+ row = weight_state .size (0 )
2441
+ optimizer_dim = self .optimizer .state_size_dim (dtype )
2442
+ opt_state_2d = opt_state .view (dtype ).view (- 1 , optimizer_dim )
2443
+ for i in range (0 , row , chunk_size ):
2444
+ length = min (chunk_size , row - i )
2445
+ chunk_buffer = torch .empty (
2446
+ length ,
2447
+ self .cache_row_dim ,
2448
+ dtype = dtype ,
2449
+ device = "cpu" ,
2450
+ )
2451
+ chunk_buffer [:, : weight_state .size (1 )] = weight_state [i : i + length , :]
2452
+ chunk_buffer [:, D_rounded : D_rounded + optimizer_dim ] = opt_state_2d [
2453
+ i : i + length , :
2454
+ ]
2455
+ kvt .set_weights_and_ids (chunk_buffer , id_tensor [i : i + length , :].view (- 1 ))
2303
2456
2304
2457
@torch .jit .ignore
2305
2458
def enable_load_state_dict_mode (self ) -> None :
2306
2459
# Enable load state dict mode before loading checkpoint
2307
- pass
2460
+ if self .load_state_dict :
2461
+ return
2462
+ self .load_state_dict = True
2463
+
2464
+ dtype = self .weights_precision .as_dtype ()
2465
+ self ._cached_kvzch_data = KVZCHCachedData ([], [], [], [])
2466
+ for i , (_ , emb_dim ) in enumerate (self .embedding_specs ):
2467
+ # for checkpointing loading, we need to store the weight and id tensor temporarily in memory
2468
+ assert (
2469
+ self .local_weight_counts [i ] > 0
2470
+ ), f"local_weight_counts for table { i } is not set"
2471
+ # pyre-ignore [16]
2472
+ bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
2473
+ rows = self .local_weight_counts [i ]
2474
+ weight_state = torch .empty (rows , emb_dim , dtype = dtype , device = "cpu" )
2475
+ opt_state = torch .empty (rows , dtype = torch .float32 , device = "cpu" )
2476
+ # pyre-ignore [16]
2477
+ self ._cached_kvzch_data .cached_weight_tensor_per_table .append (weight_state )
2478
+ # pyre-ignore [16]
2479
+ self ._cached_kvzch_data .cached_optimizer_state_per_table .append (opt_state )
2480
+ logging .info (
2481
+ f"for checkpoint loading, table { i } , weight_state shape is { weight_state .shape } , opt_state shape is { opt_state .shape } "
2482
+ )
2483
+ id_tensor = torch .zeros (
2484
+ (self .local_weight_counts [i ], 1 ), dtype = torch .int64 , device = "cpu"
2485
+ )
2486
+ # pyre-ignore [16]
2487
+ self ._cached_kvzch_data .cached_id_tensor_per_table .append (id_tensor )
2488
+ # pyre-ignore [16]
2489
+ self ._cached_kvzch_data .cached_bucket_splits .append (
2490
+ torch .empty (
2491
+ (bucket_id_end - bucket_id_start , 1 ),
2492
+ dtype = torch .int64 ,
2493
+ device = "cpu" ,
2494
+ )
2495
+ )
2308
2496
2309
2497
@torch .jit .export
2310
2498
def set_learning_rate (self , lr : float ) -> None :
0 commit comments