Skip to content

Commit ab1e637

Browse files
duduyi2013facebook-github-bot
authored andcommitted
support filling partial rows from backend (#4158)
Summary: X-link: facebookresearch/FBGEMM#1238 change set 1. After enabling optimizer offloading, when read optimizer out, we want to only read the tailing bytes out instead of the whole value field. This isn't supported in EmbeddingRocksdb before, add it here. 2. Once we have the functionality above, we dont need to call .narrow() or .continuous() on the output tensor anymore, remove those 3. We don't need optimizer chunking anymore(Ideally we can just return PMT to the checkpoint and let it handle KVT the same way as for weights) 4. fix UT bugs for constructing kv zch ssd tbe Differential Revision: D75048595
1 parent e8284e2 commit ab1e637

10 files changed

+409
-129
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ class KVZCHParams(NamedTuple):
5858
bucket_sizes: List[int] = []
5959
# enable optimizer offloading or not
6060
enable_optimizer_offloading: bool = False
61-
# streaming load/save checkpoint chunk size
62-
streaming_ckpt_chunk_size: int = 1000000
6361

6462
def validate(self) -> None:
6563
assert len(self.bucket_offsets) == len(self.bucket_sizes), (

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,8 @@ def _insert_all_kv(self) -> None:
896896
total_dim0 += dim0
897897

898898
start_ts = time.time()
899+
# TODO: do we have case for non-kvzch ssd with bulk init enabled + optimizer offloading? probably not?
900+
# if we have such cases, we should only init the emb dim not the optimizer dim
899901
chunk_tensor = torch.empty(
900902
row_count,
901903
self.cache_row_dim,
@@ -1944,9 +1946,8 @@ def split_optimizer_states(
19441946

19451947
dtype = self.weights_precision.as_dtype()
19461948
optimizer_dim = self.optimizer.state_size_dim(dtype)
1947-
pad4_optimizer_dim = pad4(optimizer_dim)
19481949
logging.info(
1949-
f"split_optimizer_states: {optimizer_dim=} {pad4_optimizer_dim=} {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}"
1950+
f"split_optimizer_states: {optimizer_dim=}, {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}"
19501951
)
19511952

19521953
for t, (emb_height, emb_dim) in enumerate(self.embedding_specs):
@@ -1972,7 +1973,7 @@ def split_optimizer_states(
19721973
self.momentum1_dev.detach().cpu()[local_id_tensor].view(-1),
19731974
)
19741975
else:
1975-
emb_opt_dim = pad4(emb_dim) + pad4_optimizer_dim
1976+
emb_opt_dim = pad4(emb_dim) + optimizer_dim
19761977
row_offset = table_offset - (bucket_id_start * bucket_size)
19771978
# using KVTensorWrapper to query backend to avoid OOM memory, since
19781979
# backend will return both weight and optimizer in one tensor, read the whole tensor
@@ -1984,54 +1985,28 @@ def split_optimizer_states(
19841985
snapshot_handle=snapshot_handle,
19851986
materialized_shape=([sorted_id_tensor[t].size(0), emb_opt_dim]),
19861987
sorted_indices=sorted_id_tensor[t],
1988+
width_offset=pad4(emb_dim),
19871989
)
19881990
(
19891991
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
19901992
if self.backend_type == BackendType.SSD
19911993
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
19921994
)
19931995
opt_list.append(
1994-
self.get_offloaded_optimizer_states(
1995-
tensor_wrapper=tensor_wrapper,
1996-
row=sorted_id_tensor[t].size(
1997-
0
1998-
), # we only need to copy the size of sorted_id_tensor
1999-
optimizer_dim=optimizer_dim,
2000-
start_dim_pos=pad4(emb_dim),
1996+
tensor_wrapper.narrow(
1997+
0,
1998+
0,
1999+
sorted_id_tensor[t].size(0),
20012000
)
2001+
.view(-1)
2002+
.view(self.optimizer.dtype())
20022003
)
20032004
table_offset += emb_height
20042005
logging.info(
20052006
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms"
20062007
)
20072008
return opt_list
20082009

2009-
@torch.jit.export
2010-
def get_offloaded_optimizer_states(
2011-
self,
2012-
# pyre-ignore [2]
2013-
tensor_wrapper,
2014-
row: int,
2015-
optimizer_dim: int,
2016-
start_dim_pos: int,
2017-
) -> torch.Tensor:
2018-
weight_dtype = self.weights_precision.as_dtype()
2019-
opt_state_t = torch.empty(
2020-
row, optimizer_dim, dtype=weight_dtype, device="cpu"
2021-
) # 1D optimizer for OptimType.EXACT_ROWWISE_ADAGRAD
2022-
2023-
# pyre-ignore [16]
2024-
chunk_size = self.kv_zch_params.streaming_ckpt_chunk_size
2025-
for i in range(0, row, chunk_size):
2026-
length = min(chunk_size, row - i)
2027-
opt_state_t.narrow(0, i, length).copy_(
2028-
tensor_wrapper.narrow(0, i, length).narrow(
2029-
1, start_dim_pos, optimizer_dim
2030-
)
2031-
)
2032-
# view optimizer state back to correct dtype
2033-
return opt_state_t.view(-1).view(self.optimizer.dtype())
2034-
20352010
@torch.jit.export
20362011
def get_optimizer_state(
20372012
self,
@@ -2207,7 +2182,7 @@ def split_embedding_weights(
22072182
if bucket_ascending_id_tensor is not None
22082183
else emb_height
22092184
),
2210-
emb_dim,
2185+
pad4(emb_dim),
22112186
],
22122187
dtype=dtype,
22132188
row_offset=table_offset,

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ void EmbeddingKVDB::set(
457457
<< "]skip set_cuda since number evictions is " << num_evictions;
458458
return;
459459
}
460-
460+
CHECK_EQ(max_D_, weights.size(1));
461461
// defer the L2 cache/rocksdb update to the background thread as it could
462462
// be parallelized with other cuda kernels, as long as all updates are
463463
// finished before the next L2 cache lookup

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,15 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
271271
const at::Tensor& weights,
272272
const int64_t start,
273273
const int64_t length,
274-
const ssd::SnapshotHandle* snapshot_handle) {
274+
const ssd::SnapshotHandle* snapshot_handle,
275+
int64_t width_offset = 0,
276+
std::optional<int64_t> width_length = std::nullopt) {
275277
(void)weights;
276278
(void)start;
277279
(void)length;
278280
(void)snapshot_handle;
281+
(void)width_offset;
282+
(void)width_length;
279283
FBEXCEPTION("Not implemented");
280284
}
281285

@@ -287,10 +291,14 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
287291
virtual void get_kv_from_storage_by_snapshot(
288292
const at::Tensor& ids,
289293
const at::Tensor& weights,
290-
const ssd::SnapshotHandle* snapshot_handle) {
294+
const ssd::SnapshotHandle* snapshot_handle,
295+
int64_t width_offset = 0,
296+
std::optional<int64_t> width_length = std::nullopt) {
291297
(void)ids;
292298
(void)weights;
293299
(void)snapshot_handle;
300+
(void)width_offset;
301+
(void)width_length;
294302
FBEXCEPTION("Not implemented");
295303
}
296304

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
4545
int64_t row_offset,
4646
std::optional<c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>
4747
snapshot_handle = std::nullopt,
48-
std::optional<at::Tensor> sorted_indices = std::nullopt);
48+
std::optional<at::Tensor> sorted_indices = std::nullopt,
49+
int64_t width_offset = 0);
4950

5051
at::Tensor narrow(int64_t dim, int64_t start, int64_t length);
5152

@@ -97,6 +98,7 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
9798
std::vector<int64_t> strides_;
9899
int64_t row_offset_;
99100
std::optional<at::Tensor> sorted_indices_ = std::nullopt;
101+
int64_t width_offset_;
100102
};
101103

102104
} // namespace ssd

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ KVTensorWrapper::KVTensorWrapper(
3434
int64_t row_offset,
3535
[[maybe_unused]] const std::optional<
3636
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>> snapshot_handle,
37-
[[maybe_unused]] const std::optional<at::Tensor> sorted_indices)
37+
[[maybe_unused]] const std::optional<at::Tensor> sorted_indices,
38+
[[maybe_unused]] int64_t width_offset)
3839
// @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn
3940
: shape_(std::move(shape)), row_offset_(row_offset) {
4041
FBEXCEPTION("Not implemented");

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,13 @@ KVTensorWrapper::KVTensorWrapper(
308308
int64_t row_offset,
309309
const std::optional<c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>
310310
snapshot_handle,
311-
std::optional<at::Tensor> sorted_indices)
312-
: db_(nullptr), shape_(std::move(shape)), row_offset_(row_offset) {
311+
std::optional<at::Tensor> sorted_indices,
312+
int64_t width_offset_)
313+
: db_(nullptr),
314+
shape_(std::move(shape)),
315+
row_offset_(row_offset),
316+
width_offset_(width_offset_) {
317+
CHECK_GE(width_offset_, 0);
313318
CHECK_EQ(shape_.size(), 2) << "Only 2D emb tensors are supported";
314319
options_ = at::TensorOptions()
315320
.dtype(static_cast<c10::ScalarType>(dtype))
@@ -342,25 +347,28 @@ void KVTensorWrapper::set_dram_db_wrapper(
342347

343348
at::Tensor KVTensorWrapper::narrow(int64_t dim, int64_t start, int64_t length) {
344349
CHECK_EQ(dim, 0) << "Only narrow on dim 0 is supported";
345-
CHECK_GE(db_->get_max_D(), shape_[1]);
346350
CHECK_TRUE(db_ != nullptr);
347-
// Do not force snapshot handle is not nullptr since DRAM DB does not have
348-
// rocksdb snapshot
351+
CHECK_GE(db_->get_max_D(), shape_[1]);
352+
TORCH_CHECK(
353+
(snapshot_handle_ == nullptr) ==
354+
(std::dynamic_pointer_cast<EmbeddingRocksDB>(db_).get() == nullptr),
355+
"snapshot handler must be valid for rocksdb and nullptr for emb kvdb");
349356
if (!sorted_indices_.has_value()) {
350-
auto t = at::empty(c10::IntArrayRef({length, db_->get_max_D()}), options_);
357+
int64_t tensor_width = shape_[1] - width_offset_;
358+
auto t = at::empty(c10::IntArrayRef({length, tensor_width}), options_);
351359
db_->get_range_from_snapshot(
352360
t,
353361
start + row_offset_,
354362
length,
355-
snapshot_handle_ != nullptr ? snapshot_handle_->handle : nullptr);
356-
// TBE may have multiple embeddings in one table padded to max D
357-
// narrow to the actual shape here before returning
358-
return t.narrow(1, 0, shape_[1]).contiguous();
363+
snapshot_handle_ != nullptr ? snapshot_handle_->handle : nullptr,
364+
width_offset_,
365+
tensor_width);
366+
CHECK(t.is_contiguous());
367+
return t;
359368
} else {
360369
at::Tensor sliced_ids =
361370
sorted_indices_.value().slice(0, start, start + length);
362-
auto out_weights = get_weights_by_ids(sliced_ids);
363-
return out_weights.narrow(1, 0, shape_[1]).contiguous();
371+
return get_weights_by_ids(sliced_ids);
364372
}
365373
}
366374

@@ -404,14 +412,23 @@ void KVTensorWrapper::set_weights_and_ids(
404412

405413
at::Tensor KVTensorWrapper::get_weights_by_ids(const at::Tensor& ids) {
406414
CHECK_TRUE(db_ != nullptr);
415+
CHECK_GE(db_->get_max_D(), shape_[1]);
416+
TORCH_CHECK(
417+
(snapshot_handle_ == nullptr) ==
418+
(std::dynamic_pointer_cast<EmbeddingRocksDB>(db_).get() == nullptr),
419+
"snapshot handler must be valid for rocksdb and nullptr for emb kvdb");
420+
int64_t tensor_width = shape_[1] - width_offset_;
407421
auto weights =
408-
at::empty(c10::IntArrayRef({ids.size(0), db_->get_max_D()}), options_);
422+
at::empty(c10::IntArrayRef({ids.size(0), tensor_width}), options_);
409423
auto linearized_ids = ids + row_offset_;
410424
db_->get_kv_from_storage_by_snapshot(
411425
linearized_ids,
412426
weights,
413-
snapshot_handle_ != nullptr ? snapshot_handle_->handle : nullptr);
414-
return weights.narrow(1, 0, shape_[1]);
427+
snapshot_handle_ != nullptr ? snapshot_handle_->handle : nullptr,
428+
width_offset_,
429+
tensor_width);
430+
CHECK(weights.is_contiguous());
431+
return weights;
415432
}
416433

417434
c10::IntArrayRef KVTensorWrapper::sizes() {
@@ -634,15 +651,17 @@ static auto kv_tensor_wrapper =
634651
int64_t,
635652
std::optional<
636653
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>,
637-
std::optional<at::Tensor>>(),
654+
std::optional<at::Tensor>,
655+
int64_t>(),
638656
"",
639657
{torch::arg("shape"),
640658
torch::arg("dtype"),
641659
torch::arg("row_offset"),
642660
// snapshot must be provided for reading
643661
// not needed for writing
644662
torch::arg("snapshot_handle") = std::nullopt,
645-
torch::arg("sorted_indices") = std::nullopt})
663+
torch::arg("sorted_indices") = std::nullopt,
664+
torch::arg("width_offset") = 0})
646665
.def(
647666
"set_embedding_rocks_dp_wrapper",
648667
&KVTensorWrapper::set_embedding_rocks_dp_wrapper,

0 commit comments

Comments
 (0)