@@ -313,7 +313,7 @@ def __init__(
313
313
f"weights precision: { weights_precision } , "
314
314
f"output dtype: { output_dtype } , "
315
315
f"chunk size in bulk init: { bulk_init_chunk_size } bytes, backend_type: { backend_type } , "
316
- f"zero_collision_config : { kv_zch_params } "
316
+ f"kv_zch_params : { kv_zch_params } "
317
317
)
318
318
self .register_buffer (
319
319
"lxu_cache_state" ,
@@ -1983,9 +1983,7 @@ def split_optimizer_states(
1983
1983
dtype = dtype ,
1984
1984
row_offset = row_offset ,
1985
1985
snapshot_handle = snapshot_handle ,
1986
- materialized_shape = ([sorted_id_tensor [t ].size (0 ), emb_opt_dim ]),
1987
1986
sorted_indices = sorted_id_tensor [t ],
1988
- width_offset = pad4 (emb_dim ),
1989
1987
)
1990
1988
(
1991
1989
tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
@@ -2086,6 +2084,7 @@ def _may_create_snapshot_for_state_dict(
2086
2084
"""
2087
2085
Create a rocksdb snapshot if needed.
2088
2086
"""
2087
+ start_time = time .time ()
2089
2088
# Force device synchronize for now
2090
2089
torch .cuda .synchronize ()
2091
2090
snapshot_handle = None
@@ -2094,7 +2093,13 @@ def _may_create_snapshot_for_state_dict(
2094
2093
if not no_snapshot :
2095
2094
# Flush L1 and L2 caches
2096
2095
self .flush (force = should_flush )
2096
+ logging .info (
2097
+ f"flush latency for weight states: { (time .time () - start_time ) * 1000 } ms"
2098
+ )
2097
2099
snapshot_handle = self .ssd_db .create_snapshot ()
2100
+ logging .info (
2101
+ f"created snapshot for weight states: { snapshot_handle } , latency: { (time .time () - start_time ) * 1000 } ms"
2102
+ )
2098
2103
elif self .backend_type == BackendType .DRAM :
2099
2104
self .flush (force = should_flush )
2100
2105
return snapshot_handle
@@ -2105,7 +2110,7 @@ def split_embedding_weights(
2105
2110
no_snapshot : bool = True ,
2106
2111
should_flush : bool = False ,
2107
2112
) -> Tuple [ # TODO: make this a NamedTuple for readability
2108
- List [PartiallyMaterializedTensor ],
2113
+ Union [ List [PartiallyMaterializedTensor ], List [ torch . Tensor ] ],
2109
2114
Optional [List [torch .Tensor ]],
2110
2115
Optional [List [torch .Tensor ]],
2111
2116
]:
@@ -2134,6 +2139,17 @@ def split_embedding_weights(
2134
2139
)
2135
2140
2136
2141
dtype = self .weights_precision .as_dtype ()
2142
+ if self .load_state_dict and self .kv_zch_params :
2143
+ # init for checkpointing loading
2144
+ assert (
2145
+ self ._cached_kvzch_data is not None
2146
+ ), "weight id and bucket state are not initialized for load checkpointing"
2147
+ return (
2148
+ self ._cached_kvzch_data .cached_weight_tensor_per_table ,
2149
+ self ._cached_kvzch_data .cached_id_tensor_per_table ,
2150
+ self ._cached_kvzch_data .cached_bucket_splits ,
2151
+ )
2152
+ start_time = time .time ()
2137
2153
pmt_splits = []
2138
2154
bucket_sorted_id_splits = [] if self .kv_zch_params else None
2139
2155
active_id_cnt_per_bucket_split = [] if self .kv_zch_params else None
@@ -2142,18 +2158,15 @@ def split_embedding_weights(
2142
2158
for i , (emb_height , emb_dim ) in enumerate (self .embedding_specs ):
2143
2159
bucket_ascending_id_tensor = None
2144
2160
bucket_t = None
2161
+ row_offset = table_offset
2145
2162
if self .kv_zch_params :
2146
2163
bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
2147
2164
# pyre-ignore
2148
2165
bucket_size = self .kv_zch_params .bucket_sizes [i ]
2149
2166
2150
2167
# linearize with table offset
2151
- table_input_id_start = (
2152
- min (bucket_id_start * bucket_size , emb_height ) + table_offset
2153
- )
2154
- table_input_id_end = (
2155
- min (bucket_id_end * bucket_size , emb_height ) + table_offset
2156
- )
2168
+ table_input_id_start = table_offset
2169
+ table_input_id_end = table_offset + emb_height
2157
2170
# 1. get all keys from backend for one table
2158
2171
unordered_id_tensor = self ._ssd_db .get_keys_in_range_by_snapshot (
2159
2172
table_input_id_start ,
@@ -2166,15 +2179,38 @@ def split_embedding_weights(
2166
2179
torch .ops .fbgemm .get_bucket_sorted_indices_and_bucket_tensor (
2167
2180
unordered_id_tensor ,
2168
2181
0 , # id--bucket hashing mode, 0 for chunk-based hashing, 1 for interleave-based hashing
2169
- bucket_id_start ,
2170
- bucket_id_end ,
2182
+ 0 , # local bucket offset
2183
+ bucket_id_end - bucket_id_start , # local bucket num
2171
2184
bucket_size ,
2172
2185
)
2173
2186
)
2174
- # pyre-ignore
2187
+ # 3. convert local id back to global id
2188
+ bucket_ascending_id_tensor .add_ (bucket_id_start * bucket_size )
2189
+
2190
+ if (
2191
+ bucket_ascending_id_tensor .size (0 ) == 0
2192
+ and self .local_weight_counts [i ] > 0
2193
+ ):
2194
+ logging .info (
2195
+ f"resetting bucket id tensor with { self .local_weight_counts [i ]} "
2196
+ )
2197
+ bucket_ascending_id_tensor = torch .zeros (
2198
+ (self .local_weight_counts [i ], 1 ),
2199
+ device = torch .device ("cpu" ),
2200
+ dtype = torch .int64 ,
2201
+ )
2202
+ # self.local_weight_counts[i] = 0 # Reset the count
2203
+
2204
+ # pyre-ignore [16] bucket_sorted_id_splits is not None
2175
2205
bucket_sorted_id_splits .append (bucket_ascending_id_tensor )
2176
2206
active_id_cnt_per_bucket_split .append (bucket_t )
2177
2207
2208
+ # for KV ZCH tbe, the sorted_indices is global id for checkpointing and publishing
2209
+ # but in backend, local id is used during training, so the KVTensorWrapper need to convert global id to local id
2210
+ # first, then linearize the local id with table offset, the formulat is x + table_offset - local_shard_offset
2211
+ # to achieve this, the row_offset will be set to (table_offset - local_shard_offset)
2212
+ row_offset = table_offset - (bucket_id_start * bucket_size )
2213
+
2178
2214
tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2179
2215
shape = [
2180
2216
(
@@ -2185,33 +2221,184 @@ def split_embedding_weights(
2185
2221
pad4 (emb_dim ),
2186
2222
],
2187
2223
dtype = dtype ,
2188
- row_offset = table_offset ,
2224
+ row_offset = row_offset ,
2189
2225
snapshot_handle = snapshot_handle ,
2226
+ # set bucket_ascending_id_tensor to kvt wrapper, so narrow will follow the id order to return
2227
+ # embedding weights.
2190
2228
sorted_indices = (
2191
2229
bucket_ascending_id_tensor if self .kv_zch_params else None
2192
2230
),
2193
2231
)
2194
- # TODO add if else support in the future for dram integration.
2195
- tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2232
+ (
2233
+ tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2234
+ if self .backend_type == BackendType .SSD
2235
+ else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
2236
+ )
2196
2237
table_offset += emb_height
2197
2238
pmt_splits .append (
2198
2239
PartiallyMaterializedTensor (
2199
2240
tensor_wrapper ,
2200
2241
True if self .kv_zch_params else False ,
2201
2242
)
2202
2243
)
2244
+ logging .info (
2245
+ f"split_embedding_weights latency: { (time .time () - start_time ) * 1000 } ms"
2246
+ )
2203
2247
return (pmt_splits , bucket_sorted_id_splits , active_id_cnt_per_bucket_split )
2204
2248
2205
2249
@torch .jit .ignore
2206
2250
def apply_state_dict (self ) -> None :
2207
2251
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
2208
2252
# Caller should call this function to apply the cached states to backend.
2209
- pass
2253
+ if self .load_state_dict is False :
2254
+ return
2255
+ self .load_state_dict = False
2256
+ assert self .kv_zch_params is not None , "apply_state_dict supports KV ZCH only"
2257
+ assert (
2258
+ self ._cached_kvzch_data is not None
2259
+ and self ._cached_kvzch_data .cached_optimizer_state_per_table is not None
2260
+ ), "optimizer state is not initialized for load checkpointing"
2261
+ assert (
2262
+ self ._cached_kvzch_data .cached_weight_tensor_per_table is not None
2263
+ and self ._cached_kvzch_data .cached_id_tensor_per_table is not None
2264
+ ), "weight and id state is not initialized for load checkpointing"
2265
+
2266
+ # Compute the number of elements of cache_dtype needed to store the
2267
+ # optimizer state, round to the nearest 4
2268
+ # optimizer_dim = self.optimizer.optimizer_state_size_dim(dtype)
2269
+ # apply weight and optimizer state per table
2270
+ table_offset = 0
2271
+ for i , (emb_height , _ ) in enumerate (self .embedding_specs ):
2272
+ # pyre-ignore [16]
2273
+ bucket_id_start , _ = self .kv_zch_params .bucket_offsets [i ]
2274
+ # pyre-ignore [16]
2275
+ bucket_size = self .kv_zch_params .bucket_sizes [i ]
2276
+ row_offset = table_offset - bucket_id_start * bucket_size
2277
+
2278
+ if self .enable_optimizer_offloading :
2279
+ # pyre-ignore [16]
2280
+ weight_state = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2281
+ # pyre-ignore [16]
2282
+ opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2283
+ self .streaming_write_weight_and_id_per_table (
2284
+ weight_state ,
2285
+ opt_state ,
2286
+ # pyre-ignore [16]
2287
+ self ._cached_kvzch_data .cached_id_tensor_per_table [i ],
2288
+ row_offset ,
2289
+ )
2290
+ self ._cached_kvzch_data .cached_weight_tensor_per_table [i ] = None
2291
+ self ._cached_kvzch_data .cached_optimizer_state_per_table [i ] = None
2292
+ else :
2293
+ weight = self ._cached_kvzch_data .cached_weight_tensor_per_table [i ]
2294
+ id = self ._cached_kvzch_data .cached_id_tensor_per_table [i ]
2295
+ local_id = id + row_offset
2296
+ logging .info (
2297
+ f"applying sd for table { i } without optimizer offloading, local_id is { local_id } "
2298
+ )
2299
+ opt_state = self ._cached_kvzch_data .cached_optimizer_state_per_table [i ]
2300
+ t_device = self .momentum1_dev .device
2301
+ self .momentum1_dev .index_put_ (
2302
+ indices = (
2303
+ local_id .to (t_device ).view (- 1 ),
2304
+ ), # expects tuple of tensors
2305
+ values = opt_state .to (t_device ),
2306
+ )
2307
+ self .ssd_db .set_cuda (
2308
+ local_id .view (- 1 ),
2309
+ weight ,
2310
+ torch .as_tensor (local_id .size (0 )),
2311
+ 1 ,
2312
+ False ,
2313
+ )
2314
+ table_offset += emb_height
2315
+ self .clear_cache ()
2316
+
2317
+ @torch .jit .ignore
2318
+ def streaming_write_weight_and_id_per_table (
2319
+ self ,
2320
+ weight_state : torch .Tensor ,
2321
+ opt_state : torch .Tensor ,
2322
+ id_tensor : torch .Tensor ,
2323
+ row_offset : int ,
2324
+ ) -> None :
2325
+ """
2326
+ This function is used to write weight, optimizer and id to the backend using kvt wrapper.
2327
+ to avoid over use memory, we will write the weight and id to backend in a rolling window manner
2328
+
2329
+ Args:
2330
+ weight_state (torch.tensor): The weight state tensor to be written.
2331
+ opt_state (torch.tensor): The optimizer state tensor to be written.
2332
+ id_tensor (torch.tensor): The id tensor to be written.
2333
+ """
2334
+ D_rounded = pad4 (weight_state .size (1 )) # padded to 4 bytes alignment
2335
+ dtype = self .weights_precision .as_dtype ()
2336
+ kvt = torch .classes .fbgemm .KVTensorWrapper (
2337
+ db = self .ssd_db ,
2338
+ shape = [weight_state .size (0 ), self .cache_row_dim ],
2339
+ dtype = dtype ,
2340
+ row_offset = row_offset ,
2341
+ snapshot_handle = None ,
2342
+ sorted_indices = id_tensor ,
2343
+ )
2344
+ # TODO: make chunk_size configurable or dynamic
2345
+ chunk_size = 10000
2346
+ row = weight_state .size (0 )
2347
+ optimizer_dim = self .optimizer .state_size_dim (dtype )
2348
+ opt_state_2d = opt_state .view (dtype ).view (- 1 , optimizer_dim )
2349
+ for i in range (0 , row , chunk_size ):
2350
+ length = min (chunk_size , row - i )
2351
+ chunk_buffer = torch .empty (
2352
+ length ,
2353
+ self .cache_row_dim ,
2354
+ dtype = dtype ,
2355
+ device = "cpu" ,
2356
+ )
2357
+ chunk_buffer [:, : weight_state .size (1 )] = weight_state [i : i + length , :]
2358
+ chunk_buffer [:, D_rounded : D_rounded + optimizer_dim ] = opt_state_2d [
2359
+ i : i + length , :
2360
+ ]
2361
+ kvt .set_weights_and_ids (chunk_buffer , id_tensor [i : i + length , :].view (- 1 ))
2210
2362
2211
2363
@torch .jit .ignore
2212
2364
def enable_load_state_dict_mode (self ) -> None :
2213
2365
# Enable load state dict mode before loading checkpoint
2214
- pass
2366
+ if self .load_state_dict :
2367
+ return
2368
+ self .load_state_dict = True
2369
+
2370
+ dtype = self .weights_precision .as_dtype ()
2371
+ self ._cached_kvzch_data = KVZCHCachedData ([], [], [], [])
2372
+ for i , (_ , emb_dim ) in enumerate (self .embedding_specs ):
2373
+ # for checkpointing loading, we need to store the weight and id tensor temporarily in memory
2374
+ assert (
2375
+ self .local_weight_counts [i ] > 0
2376
+ ), f"local_weight_counts for table { i } is not set"
2377
+ # pyre-ignore [16]
2378
+ bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
2379
+ rows = self .local_weight_counts [i ]
2380
+ weight_state = torch .empty (rows , emb_dim , dtype = dtype , device = "cpu" )
2381
+ opt_state = torch .empty (rows , dtype = torch .float32 , device = "cpu" )
2382
+ # pyre-ignore [16]
2383
+ self ._cached_kvzch_data .cached_weight_tensor_per_table .append (weight_state )
2384
+ # pyre-ignore [16]
2385
+ self ._cached_kvzch_data .cached_optimizer_state_per_table .append (opt_state )
2386
+ logging .info (
2387
+ f"for checkpoint loading, table { i } , weight_state shape is { weight_state .shape } , opt_state shape is { opt_state .shape } "
2388
+ )
2389
+ id_tensor = torch .zeros (
2390
+ (self .local_weight_counts [i ], 1 ), dtype = torch .int64 , device = "cpu"
2391
+ )
2392
+ # pyre-ignore [16]
2393
+ self ._cached_kvzch_data .cached_id_tensor_per_table .append (id_tensor )
2394
+ # pyre-ignore [16]
2395
+ self ._cached_kvzch_data .cached_bucket_splits .append (
2396
+ torch .empty (
2397
+ (bucket_id_end - bucket_id_start , 1 ),
2398
+ dtype = torch .int64 ,
2399
+ device = "cpu" ,
2400
+ )
2401
+ )
2215
2402
2216
2403
@torch .jit .export
2217
2404
def set_learning_rate (self , lr : float ) -> None :
0 commit comments