16
16
import tempfile
17
17
import threading
18
18
import time
19
+ from functools import cached_property
19
20
from math import floor , log2
20
21
from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Union
21
22
import torch # usort:skip
57
58
58
59
from ..cache import get_unique_indices_v2
59
60
60
- from .common import ASSOC
61
+ from .common import ASSOC , pad4
61
62
from .utils .partially_materialized_tensor import PartiallyMaterializedTensor
62
63
63
64
@@ -172,6 +173,29 @@ def __init__(
172
173
) -> None :
173
174
super (SSDTableBatchedEmbeddingBags , self ).__init__ ()
174
175
176
+ assert optimizer in (
177
+ OptimType .EXACT_ROWWISE_ADAGRAD ,
178
+ ), f"Optimizer { optimizer } is not supported by SSDTableBatchedEmbeddingBags"
179
+ self .optimizer = optimizer
180
+
181
+ assert weights_precision in (SparseType .FP32 , SparseType .FP16 )
182
+ self .weights_precision = weights_precision
183
+ self .output_dtype : int = output_dtype .as_int ()
184
+
185
+ # Zero collision TBE configurations
186
+ self .kv_zch_params = kv_zch_params
187
+ self .backend_type = backend_type
188
+ self .enable_optimizer_offloading : bool = False
189
+ if self .kv_zch_params :
190
+ self .kv_zch_params .validate ()
191
+ self .enable_optimizer_offloading = (
192
+ # pyre-ignore [16]
193
+ self .kv_zch_params .enable_optimizer_offloading
194
+ )
195
+
196
+ if self .enable_optimizer_offloading :
197
+ logging .info ("Optimizer state offloading is enabled" )
198
+
175
199
self .pooling_mode = pooling_mode
176
200
self .bounds_check_mode_int : int = bounds_check_mode .value
177
201
self .embedding_specs = embedding_specs
@@ -207,7 +231,11 @@ def __init__(
207
231
feature_dims = [dims [t ] for t in self .feature_table_map ]
208
232
D_offsets = [dims [t ] for t in self .feature_table_map ]
209
233
D_offsets = [0 ] + list (itertools .accumulate (D_offsets ))
234
+
235
+ # Sum of row length of all tables
210
236
self .total_D : int = D_offsets [- 1 ]
237
+
238
+ # Max number of elements required to store a row in the cache
211
239
self .max_D : int = max (dims )
212
240
self .register_buffer (
213
241
"D_offsets" ,
@@ -273,15 +301,15 @@ def __init__(
273
301
assert (
274
302
element_size == 4 or element_size == 2
275
303
), f"Invalid element size { element_size } "
276
- cache_size = cache_sets * ASSOC * element_size * self .max_D
304
+ cache_size = cache_sets * ASSOC * element_size * self .cache_row_dim
277
305
logging .info (
278
306
f"Using cache for SSD with admission algorithm "
279
307
f"{ CacheAlgorithm .LRU } , { cache_sets } sets, stored on { 'DEVICE' if ssd_cache_location is EmbeddingLocation .DEVICE else 'MANAGED' } with { ssd_rocksdb_shards } shards, "
280
308
f"SSD storage directory: { ssd_storage_directory } , "
281
309
f"Memtable Flush Period: { ssd_memtable_flush_period } , "
282
310
f"Memtable Flush Offset: { ssd_memtable_flush_offset } , "
283
311
f"Desired L0 files per compaction: { ssd_l0_files_per_compact } , "
284
- f"{ cache_size / 1024.0 / 1024.0 / 1024.0 : .2f} GB, "
312
+ f"Cache size: { cache_size / 1024.0 / 1024.0 / 1024.0 : .2f} GB, "
285
313
f"weights precision: { weights_precision } , "
286
314
f"output dtype: { output_dtype } , "
287
315
f"chunk size in bulk init: { bulk_init_chunk_size } bytes, backend_type: { backend_type } , "
@@ -331,10 +359,6 @@ def __init__(
331
359
EmbeddingLocation .DEVICE ,
332
360
)
333
361
334
- assert weights_precision in (SparseType .FP32 , SparseType .FP16 )
335
- self .weights_precision = weights_precision
336
- self .output_dtype : int = output_dtype .as_int ()
337
-
338
362
cache_dtype = weights_precision .as_dtype ()
339
363
if ssd_cache_location == EmbeddingLocation .MANAGED :
340
364
self .register_buffer (
@@ -345,7 +369,7 @@ def __init__(
345
369
device = self .current_device ,
346
370
dtype = cache_dtype ,
347
371
),
348
- [cache_sets * ASSOC , self .max_D ],
372
+ [cache_sets * ASSOC , self .cache_row_dim ],
349
373
is_host_mapped = self .uvm_host_mapped ,
350
374
),
351
375
)
@@ -354,7 +378,7 @@ def __init__(
354
378
"lxu_cache_weights" ,
355
379
torch .zeros (
356
380
cache_sets * ASSOC ,
357
- self .max_D ,
381
+ self .cache_row_dim ,
358
382
device = self .current_device ,
359
383
dtype = cache_dtype ,
360
384
),
@@ -457,17 +481,6 @@ def __init__(
457
481
)
458
482
# logging.info("DEBUG: weights_precision {}".format(weights_precision))
459
483
460
- # zero collision TBE configurations
461
- self .kv_zch_params = kv_zch_params
462
- self .backend_type = backend_type
463
- self .enable_optimizer_offloading : bool = False
464
- if self .kv_zch_params :
465
- self .kv_zch_params .validate ()
466
- self .enable_optimizer_offloading = (
467
- # pyre-ignore [16]
468
- self .kv_zch_params .enable_optimizer_offloading
469
- )
470
-
471
484
"""
472
485
##################### for ZCH v.Next loading checkpoints Short Term Solution #######################
473
486
weight_id tensor is the weight and optimizer keys, to load from checkpoint, weight_id tensor
@@ -510,7 +523,7 @@ def __init__(
510
523
f"Logging SSD offloading setup, tbe_unique_id:{ tbe_unique_id } , l2_cache_size:{ l2_cache_size } GB, enable_async_update:{ enable_async_update } "
511
524
f"passed_in_path={ ssd_directory } , num_shards={ ssd_rocksdb_shards } ,num_threads={ ssd_rocksdb_shards } ,"
512
525
f"memtable_flush_period={ ssd_memtable_flush_period } ,memtable_flush_offset={ ssd_memtable_flush_offset } ,"
513
- f"l0_files_per_compact={ ssd_l0_files_per_compact } ,max_D={ self .max_D } ,rate_limit_mbps={ ssd_rate_limit_mbps } ,"
526
+ f"l0_files_per_compact={ ssd_l0_files_per_compact } ,max_D={ self .max_D } ,cache_row_dim= { self . cache_row_dim } , rate_limit_mbps={ ssd_rate_limit_mbps } ,"
514
527
f"size_ratio={ ssd_size_ratio } ,compaction_trigger={ ssd_compaction_trigger } , lazy_bulk_init_enabled={ lazy_bulk_init_enabled } ,"
515
528
f"write_buffer_size_per_tbe={ ssd_rocksdb_write_buffer_size } ,max_write_buffer_num_per_db_shard={ ssd_max_write_buffer_num } ,"
516
529
f"uniform_init_lower={ ssd_uniform_init_lower } ,uniform_init_upper={ ssd_uniform_init_upper } ,"
@@ -526,7 +539,7 @@ def __init__(
526
539
ssd_memtable_flush_period ,
527
540
ssd_memtable_flush_offset ,
528
541
ssd_l0_files_per_compact ,
529
- self .max_D ,
542
+ self .cache_row_dim ,
530
543
ssd_rate_limit_mbps ,
531
544
ssd_size_ratio ,
532
545
ssd_compaction_trigger ,
@@ -567,7 +580,7 @@ def __init__(
567
580
ps_client_thread_num if ps_client_thread_num is not None else 32 ,
568
581
ps_max_key_per_request if ps_max_key_per_request is not None else 500 ,
569
582
l2_cache_size ,
570
- self .max_D ,
583
+ self .cache_row_dim ,
571
584
)
572
585
else :
573
586
raise AssertionError (f"Invalid backend type { self .backend_type } " )
@@ -707,11 +720,6 @@ def __init__(
707
720
self ._update_cache_counter_and_pointers
708
721
)
709
722
710
- assert optimizer in (
711
- OptimType .EXACT_ROWWISE_ADAGRAD ,
712
- ), f"Optimizer { optimizer } is not supported by SSDTableBatchedEmbeddingBags"
713
- self .optimizer = optimizer
714
-
715
723
# stats reporter
716
724
self .gather_ssd_cache_stats = gather_ssd_cache_stats
717
725
self .stats_reporter : Optional [TBEStatsReporter ] = (
@@ -798,6 +806,22 @@ def __init__(
798
806
799
807
self .bounds_check_version : int = get_bounds_check_version_for_platform ()
800
808
809
+ @cached_property
810
+ def cache_row_dim (self ) -> int :
811
+ """
812
+ Compute the effective physical cache row size taking into account
813
+ padding to the nearest 4 elements and the optimizer state appended to
814
+ the back of the row
815
+ """
816
+ if self .enable_optimizer_offloading :
817
+ return self .max_D + pad4 (
818
+ # Compute the number of elements of cache_dtype needed to store the
819
+ # optimizer state
820
+ self .optimizer .state_size_dim (self .weights_precision .as_dtype ())
821
+ )
822
+ else :
823
+ return self .max_D
824
+
801
825
@property
802
826
# pyre-ignore
803
827
def ssd_db (self ):
@@ -854,7 +878,7 @@ def _insert_all_kv(self) -> None:
854
878
row_offset = 0
855
879
row_count = floor (
856
880
self .bulk_init_chunk_size
857
- / (self .max_D * self .weights_precision .as_dtype ().itemsize )
881
+ / (self .cache_row_dim * self .weights_precision .as_dtype ().itemsize )
858
882
)
859
883
total_dim0 = 0
860
884
for dim0 , _ in self .embedding_specs :
@@ -863,7 +887,7 @@ def _insert_all_kv(self) -> None:
863
887
start_ts = time .time ()
864
888
chunk_tensor = torch .empty (
865
889
row_count ,
866
- self .max_D ,
890
+ self .cache_row_dim ,
867
891
dtype = self .weights_precision .as_dtype (),
868
892
device = "cuda" ,
869
893
)
@@ -1397,7 +1421,7 @@ def _prefetch( # noqa C901
1397
1421
1398
1422
# Allocation a scratch pad for the current iteration. The scratch
1399
1423
# pad is a UVA tensor
1400
- inserted_rows_shape = (assigned_cache_slots .numel (), self .max_D )
1424
+ inserted_rows_shape = (assigned_cache_slots .numel (), self .cache_row_dim )
1401
1425
if linear_cache_indices .numel () > 0 :
1402
1426
inserted_rows = torch .ops .fbgemm .new_unified_tensor (
1403
1427
torch .zeros (
@@ -2093,7 +2117,7 @@ def flush(self, force: bool = False) -> None:
2093
2117
active_slots_mask = self .lxu_cache_state != - 1
2094
2118
2095
2119
active_weights_gpu = self .lxu_cache_weights [active_slots_mask .view (- 1 )].view (
2096
- - 1 , self .max_D
2120
+ - 1 , self .cache_row_dim
2097
2121
)
2098
2122
active_ids_gpu = self .lxu_cache_state .view (- 1 )[active_slots_mask .view (- 1 )]
2099
2123
@@ -2195,7 +2219,7 @@ def _report_ssd_l1_cache_stats(self) -> None:
2195
2219
data_bytes = int (
2196
2220
ssd_cache_stats_delta [stat_index .value ]
2197
2221
* element_size
2198
- * self .max_D
2222
+ * self .cache_row_dim
2199
2223
/ passed_steps
2200
2224
),
2201
2225
)
0 commit comments