Skip to content

Commit e051cad

Browse files
q10facebook-github-bot
authored andcommitted
Append columns to the SSD cache for storing optimizer data, v3 (#4125)
Summary: X-link: facebookresearch/FBGEMM#1206 Pull Request resolved: #4125 - Append columns to the SSD cache for storing optimizer data. For backwards compatibility, this feature is disabled by default, unless the client builds `SSDTableBatchedEmbeddingBags` with `KVZCHParams` that has `enable_optimizer_offloading` set to `true`. - This is a graft of D74051349 to land to main, since D74051349 is based on the ZCH WIP stack Reviewed By: sryap, duduyi2013 Differential Revision: D74748058 fbshipit-source-id: 67f086884f6d4204e551d96a8d52b16527fd3ce2
1 parent d377dd4 commit e051cad

File tree

5 files changed

+111
-51
lines changed

5 files changed

+111
-51
lines changed

fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# pyre-strict
99

1010
import enum
11+
import math
1112
from typing import Any, Dict # noqa: F401
1213

1314
import torch
@@ -40,6 +41,23 @@ class EmbOptimType(enum.Enum):
4041
def __str__(self) -> str:
4142
return self.value
4243

44+
def state_size(self) -> int:
45+
"""
46+
Returns the size of the data (in bytes) required to hold the optimizer
47+
state (per table row), or 0 if none needed
48+
"""
49+
return {
50+
# Only holds the momentum float value per row
51+
EmbOptimType.EXACT_ROWWISE_ADAGRAD: torch.float32.itemsize,
52+
}.get(self, 0)
53+
54+
def state_size_dim(self, dtype: torch.dtype) -> int:
55+
"""
56+
Returns the size of the data (in units of elements of dtype) rquired to
57+
hold optimizer information (per table row)
58+
"""
59+
return int(math.ceil(self.state_size() / dtype.itemsize))
60+
4361

4462
# Base class for quantization configuration (in case other numeric types have
4563
# configs)

fbgemm_gpu/fbgemm_gpu/tbe/ssd/common.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,25 @@
1818
pass
1919

2020
ASSOC = 32
21+
22+
23+
def pad4(value: int) -> int:
24+
"""
25+
Compute the smallest multiple of 4 that is greater than or equal to the given value.
26+
27+
Parameters:
28+
value (int): The integer to align (must be non-negative).
29+
30+
Returns:
31+
int: The aligned value.
32+
33+
Raises:
34+
ValueError: If the input is negative.
35+
TypeError: If the input is not an integer.
36+
"""
37+
if not isinstance(value, int):
38+
raise TypeError("Input must be an integer")
39+
if value < 0:
40+
raise ValueError("Input must be a non-negative integer")
41+
42+
return (value + 3) & ~3

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import tempfile
1717
import threading
1818
import time
19+
from functools import cached_property
1920
from math import floor, log2
2021
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
2122
import torch # usort:skip
@@ -57,7 +58,7 @@
5758

5859
from ..cache import get_unique_indices_v2
5960

60-
from .common import ASSOC
61+
from .common import ASSOC, pad4
6162
from .utils.partially_materialized_tensor import PartiallyMaterializedTensor
6263

6364

@@ -172,6 +173,29 @@ def __init__(
172173
) -> None:
173174
super(SSDTableBatchedEmbeddingBags, self).__init__()
174175

176+
assert optimizer in (
177+
OptimType.EXACT_ROWWISE_ADAGRAD,
178+
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
179+
self.optimizer = optimizer
180+
181+
assert weights_precision in (SparseType.FP32, SparseType.FP16)
182+
self.weights_precision = weights_precision
183+
self.output_dtype: int = output_dtype.as_int()
184+
185+
# Zero collision TBE configurations
186+
self.kv_zch_params = kv_zch_params
187+
self.backend_type = backend_type
188+
self.enable_optimizer_offloading: bool = False
189+
if self.kv_zch_params:
190+
self.kv_zch_params.validate()
191+
self.enable_optimizer_offloading = (
192+
# pyre-ignore [16]
193+
self.kv_zch_params.enable_optimizer_offloading
194+
)
195+
196+
if self.enable_optimizer_offloading:
197+
logging.info("Optimizer state offloading is enabled")
198+
175199
self.pooling_mode = pooling_mode
176200
self.bounds_check_mode_int: int = bounds_check_mode.value
177201
self.embedding_specs = embedding_specs
@@ -207,7 +231,11 @@ def __init__(
207231
feature_dims = [dims[t] for t in self.feature_table_map]
208232
D_offsets = [dims[t] for t in self.feature_table_map]
209233
D_offsets = [0] + list(itertools.accumulate(D_offsets))
234+
235+
# Sum of row length of all tables
210236
self.total_D: int = D_offsets[-1]
237+
238+
# Max number of elements required to store a row in the cache
211239
self.max_D: int = max(dims)
212240
self.register_buffer(
213241
"D_offsets",
@@ -273,15 +301,15 @@ def __init__(
273301
assert (
274302
element_size == 4 or element_size == 2
275303
), f"Invalid element size {element_size}"
276-
cache_size = cache_sets * ASSOC * element_size * self.max_D
304+
cache_size = cache_sets * ASSOC * element_size * self.cache_row_dim
277305
logging.info(
278306
f"Using cache for SSD with admission algorithm "
279307
f"{CacheAlgorithm.LRU}, {cache_sets} sets, stored on {'DEVICE' if ssd_cache_location is EmbeddingLocation.DEVICE else 'MANAGED'} with {ssd_rocksdb_shards} shards, "
280308
f"SSD storage directory: {ssd_storage_directory}, "
281309
f"Memtable Flush Period: {ssd_memtable_flush_period}, "
282310
f"Memtable Flush Offset: {ssd_memtable_flush_offset}, "
283311
f"Desired L0 files per compaction: {ssd_l0_files_per_compact}, "
284-
f"{cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
312+
f"Cache size: {cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
285313
f"weights precision: {weights_precision}, "
286314
f"output dtype: {output_dtype}, "
287315
f"chunk size in bulk init: {bulk_init_chunk_size} bytes, backend_type: {backend_type}, "
@@ -331,10 +359,6 @@ def __init__(
331359
EmbeddingLocation.DEVICE,
332360
)
333361

334-
assert weights_precision in (SparseType.FP32, SparseType.FP16)
335-
self.weights_precision = weights_precision
336-
self.output_dtype: int = output_dtype.as_int()
337-
338362
cache_dtype = weights_precision.as_dtype()
339363
if ssd_cache_location == EmbeddingLocation.MANAGED:
340364
self.register_buffer(
@@ -345,7 +369,7 @@ def __init__(
345369
device=self.current_device,
346370
dtype=cache_dtype,
347371
),
348-
[cache_sets * ASSOC, self.max_D],
372+
[cache_sets * ASSOC, self.cache_row_dim],
349373
is_host_mapped=self.uvm_host_mapped,
350374
),
351375
)
@@ -354,7 +378,7 @@ def __init__(
354378
"lxu_cache_weights",
355379
torch.zeros(
356380
cache_sets * ASSOC,
357-
self.max_D,
381+
self.cache_row_dim,
358382
device=self.current_device,
359383
dtype=cache_dtype,
360384
),
@@ -457,17 +481,6 @@ def __init__(
457481
)
458482
# logging.info("DEBUG: weights_precision {}".format(weights_precision))
459483

460-
# zero collision TBE configurations
461-
self.kv_zch_params = kv_zch_params
462-
self.backend_type = backend_type
463-
self.enable_optimizer_offloading: bool = False
464-
if self.kv_zch_params:
465-
self.kv_zch_params.validate()
466-
self.enable_optimizer_offloading = (
467-
# pyre-ignore [16]
468-
self.kv_zch_params.enable_optimizer_offloading
469-
)
470-
471484
"""
472485
##################### for ZCH v.Next loading checkpoints Short Term Solution #######################
473486
weight_id tensor is the weight and optimizer keys, to load from checkpoint, weight_id tensor
@@ -510,7 +523,7 @@ def __init__(
510523
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, enable_async_update:{enable_async_update}"
511524
f"passed_in_path={ssd_directory}, num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
512525
f"memtable_flush_period={ssd_memtable_flush_period},memtable_flush_offset={ssd_memtable_flush_offset},"
513-
f"l0_files_per_compact={ssd_l0_files_per_compact},max_D={self.max_D},rate_limit_mbps={ssd_rate_limit_mbps},"
526+
f"l0_files_per_compact={ssd_l0_files_per_compact},max_D={self.max_D},cache_row_dim={self.cache_row_dim},rate_limit_mbps={ssd_rate_limit_mbps},"
514527
f"size_ratio={ssd_size_ratio},compaction_trigger={ssd_compaction_trigger}, lazy_bulk_init_enabled={lazy_bulk_init_enabled},"
515528
f"write_buffer_size_per_tbe={ssd_rocksdb_write_buffer_size},max_write_buffer_num_per_db_shard={ssd_max_write_buffer_num},"
516529
f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper},"
@@ -526,7 +539,7 @@ def __init__(
526539
ssd_memtable_flush_period,
527540
ssd_memtable_flush_offset,
528541
ssd_l0_files_per_compact,
529-
self.max_D,
542+
self.cache_row_dim,
530543
ssd_rate_limit_mbps,
531544
ssd_size_ratio,
532545
ssd_compaction_trigger,
@@ -567,7 +580,7 @@ def __init__(
567580
ps_client_thread_num if ps_client_thread_num is not None else 32,
568581
ps_max_key_per_request if ps_max_key_per_request is not None else 500,
569582
l2_cache_size,
570-
self.max_D,
583+
self.cache_row_dim,
571584
)
572585
else:
573586
raise AssertionError(f"Invalid backend type {self.backend_type}")
@@ -707,11 +720,6 @@ def __init__(
707720
self._update_cache_counter_and_pointers
708721
)
709722

710-
assert optimizer in (
711-
OptimType.EXACT_ROWWISE_ADAGRAD,
712-
), f"Optimizer {optimizer} is not supported by SSDTableBatchedEmbeddingBags"
713-
self.optimizer = optimizer
714-
715723
# stats reporter
716724
self.gather_ssd_cache_stats = gather_ssd_cache_stats
717725
self.stats_reporter: Optional[TBEStatsReporter] = (
@@ -798,6 +806,22 @@ def __init__(
798806

799807
self.bounds_check_version: int = get_bounds_check_version_for_platform()
800808

809+
@cached_property
810+
def cache_row_dim(self) -> int:
811+
"""
812+
Compute the effective physical cache row size taking into account
813+
padding to the nearest 4 elements and the optimizer state appended to
814+
the back of the row
815+
"""
816+
if self.enable_optimizer_offloading:
817+
return self.max_D + pad4(
818+
# Compute the number of elements of cache_dtype needed to store the
819+
# optimizer state
820+
self.optimizer.state_size_dim(self.weights_precision.as_dtype())
821+
)
822+
else:
823+
return self.max_D
824+
801825
@property
802826
# pyre-ignore
803827
def ssd_db(self):
@@ -854,7 +878,7 @@ def _insert_all_kv(self) -> None:
854878
row_offset = 0
855879
row_count = floor(
856880
self.bulk_init_chunk_size
857-
/ (self.max_D * self.weights_precision.as_dtype().itemsize)
881+
/ (self.cache_row_dim * self.weights_precision.as_dtype().itemsize)
858882
)
859883
total_dim0 = 0
860884
for dim0, _ in self.embedding_specs:
@@ -863,7 +887,7 @@ def _insert_all_kv(self) -> None:
863887
start_ts = time.time()
864888
chunk_tensor = torch.empty(
865889
row_count,
866-
self.max_D,
890+
self.cache_row_dim,
867891
dtype=self.weights_precision.as_dtype(),
868892
device="cuda",
869893
)
@@ -1397,7 +1421,7 @@ def _prefetch( # noqa C901
13971421

13981422
# Allocation a scratch pad for the current iteration. The scratch
13991423
# pad is a UVA tensor
1400-
inserted_rows_shape = (assigned_cache_slots.numel(), self.max_D)
1424+
inserted_rows_shape = (assigned_cache_slots.numel(), self.cache_row_dim)
14011425
if linear_cache_indices.numel() > 0:
14021426
inserted_rows = torch.ops.fbgemm.new_unified_tensor(
14031427
torch.zeros(
@@ -2093,7 +2117,7 @@ def flush(self, force: bool = False) -> None:
20932117
active_slots_mask = self.lxu_cache_state != -1
20942118

20952119
active_weights_gpu = self.lxu_cache_weights[active_slots_mask.view(-1)].view(
2096-
-1, self.max_D
2120+
-1, self.cache_row_dim
20972121
)
20982122
active_ids_gpu = self.lxu_cache_state.view(-1)[active_slots_mask.view(-1)]
20992123

@@ -2195,7 +2219,7 @@ def _report_ssd_l1_cache_stats(self) -> None:
21952219
data_bytes=int(
21962220
ssd_cache_stats_delta[stat_index.value]
21972221
* element_size
2198-
* self.max_D
2222+
* self.cache_row_dim
21992223
/ passed_steps
22002224
),
22012225
)

fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def generate_fbgemm_ssd_tbe(
5757
mixed: bool,
5858
enable_l2: bool = True,
5959
ssd_rocksdb_shards: int = 1,
60-
) -> Tuple[SSDTableBatchedEmbeddingBags, List[int], List[int], int]:
60+
) -> Tuple[SSDTableBatchedEmbeddingBags, List[int], List[int]]:
6161
E = int(10**log_E)
6262
D = D * 4
6363
if not mixed:
@@ -84,7 +84,7 @@ def generate_fbgemm_ssd_tbe(
8484
l2_cache_size=1 if enable_l2 else 0,
8585
ssd_rocksdb_shards=ssd_rocksdb_shards,
8686
)
87-
return emb, Es, Ds, max(Ds)
87+
return emb, Es, Ds
8888

8989
@given(**default_st, do_flush=st.sampled_from([True, False]))
9090
@settings(**default_settings)
@@ -97,12 +97,10 @@ def test_l2_flush(
9797
weights_precision: SparseType,
9898
do_flush: bool,
9999
) -> None:
100-
emb, Es, Ds, max_D = self.generate_fbgemm_ssd_tbe(
101-
T, D, log_E, weights_precision, mixed
102-
)
100+
emb, Es, _ = self.generate_fbgemm_ssd_tbe(T, D, log_E, weights_precision, mixed)
103101
indices = torch.arange(start=0, end=sum(Es))
104102
weights = torch.randn(
105-
indices.numel(), max_D, dtype=weights_precision.as_dtype()
103+
indices.numel(), emb.cache_row_dim, dtype=weights_precision.as_dtype()
106104
)
107105
weights_from_l2 = torch.empty_like(weights)
108106
count = torch.as_tensor([indices.numel()])
@@ -134,7 +132,7 @@ def test_l2_io(
134132
weights_precision: SparseType,
135133
enable_l2: bool,
136134
) -> None:
137-
emb, _, _, max_D = self.generate_fbgemm_ssd_tbe(
135+
emb, _, _ = self.generate_fbgemm_ssd_tbe(
138136
T, D, log_E, weights_precision, mixed, enable_l2
139137
)
140138
E = int(10**log_E)
@@ -146,7 +144,7 @@ def test_l2_io(
146144
np.random.choice(E, replace=False, size=(N,)), dtype=torch.int64
147145
)
148146
weights = torch.randn(
149-
indices.numel(), max_D, dtype=weights_precision.as_dtype()
147+
indices.numel(), emb.cache_row_dim, dtype=weights_precision.as_dtype()
150148
)
151149
sub_N = N // num_rounds
152150

@@ -191,15 +189,13 @@ def test_l2_prefetch_compatibility(
191189
mixed: bool,
192190
weights_precision: SparseType,
193191
) -> None:
194-
emb, _, _, max_D = self.generate_fbgemm_ssd_tbe(
195-
T, D, log_E, weights_precision, mixed
196-
)
192+
emb, _, _ = self.generate_fbgemm_ssd_tbe(T, D, log_E, weights_precision, mixed)
197193
E = int(10**log_E)
198194
N = E
199195
indices = torch.as_tensor(
200196
np.random.choice(E, replace=False, size=(N,)), dtype=torch.int64
201197
)
202-
weights = torch.randn(N, max_D, dtype=weights_precision.as_dtype())
198+
weights = torch.randn(N, emb.cache_row_dim, dtype=weights_precision.as_dtype())
203199
new_weights = weights + 1
204200
weights_out = torch.empty_like(weights)
205201
count = torch.as_tensor([E])
@@ -242,9 +238,7 @@ def test_l2_multiple_flush_at_same_train_iter(
242238
mixed: bool,
243239
weights_precision: SparseType,
244240
) -> None:
245-
emb, _, _, _ = self.generate_fbgemm_ssd_tbe(
246-
T, D, log_E, weights_precision, mixed
247-
)
241+
emb, _, _ = self.generate_fbgemm_ssd_tbe(T, D, log_E, weights_precision, mixed)
248242

249243
with patch.object(torch.cuda, "synchronize") as mock_calls:
250244
mock_calls.side_effect = None
@@ -269,7 +263,7 @@ def test_rocksdb_get_discrete_ids(
269263
mixed: bool,
270264
weights_precision: SparseType,
271265
) -> None:
272-
emb, Es, Ds, max_D = self.generate_fbgemm_ssd_tbe(
266+
emb, Es, _ = self.generate_fbgemm_ssd_tbe(
273267
T, D, log_E, weights_precision, mixed, False, 8
274268
)
275269
E = int(10**log_E)

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,6 @@ def test_ssd(self, indice_int64_t: bool, weights_precision: SparseType) -> None:
109109
indices = torch.as_tensor(
110110
np.random.choice(E, replace=False, size=(N,)), dtype=torch.int32
111111
)
112-
weights = torch.randn(N, D, dtype=weights_precision.as_dtype())
113-
output_weights = torch.empty_like(weights)
114112
count = torch.tensor([N])
115113

116114
feature_table_map = list(range(1))
@@ -124,6 +122,10 @@ def test_ssd(self, indice_int64_t: bool, weights_precision: SparseType) -> None:
124122
weights_precision=weights_precision,
125123
l2_cache_size=8,
126124
)
125+
126+
weights = torch.randn(N, emb.cache_row_dim, dtype=weights_precision.as_dtype())
127+
output_weights = torch.empty_like(weights)
128+
127129
emb.ssd_db.get_cuda(indices, output_weights, count)
128130
torch.cuda.synchronize()
129131
assert (output_weights <= 0.1).all().item()

0 commit comments

Comments
 (0)