Skip to content

Small l opt c++ #4508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
91a4559
fixed find_max_Ls function to return int type
kudomcho Jan 30, 2025
139b529
removed prints on max_ls vars
kudomcho Jan 30, 2025
8fdb98b
Merge branch 'pytorch:main' into main
kudomcho Jun 19, 2025
d23685c
Merge branch 'pytorch:main' into main
kudomcho Jun 20, 2025
cdf3977
optimized inference for L=1 case
kudomcho Jun 20, 2025
f3654d6
fixed packed_bags logic on max_L and added new condition on find_max_ls
kudomcho Jun 24, 2025
43114ad
adapted linter on None condition
kudomcho Jun 24, 2025
30095f7
added args of max_ls on codegen_lookup_func
kudomcho Jun 24, 2025
b93ebac
fixed errors on flake8
kudomcho Jun 24, 2025
c889646
fixed the relative path on common on test scripts
kudomcho Jun 24, 2025
2412366
formatted ufmt and fixed flake8 on nbit_split_embeddings_test.py
kudomcho Jun 24, 2025
65c77da
formatted ufmt and fixed flake8 on nbit_split_embeddings_test.py
kudomcho Jun 24, 2025
8e9b7d9
changed datatype to optional int64 with default values
kudomcho Jul 15, 2025
a26a48f
removed Ls arg for benchmark scripts
kudomcho Jul 15, 2025
59708f9
supported nan weight handling
kudomcho Jul 15, 2025
971e8bc
removed Ls args
kudomcho Jul 16, 2025
fb0d6b3
removed Ls args on test script
kudomcho Jul 16, 2025
600df74
fixed nan weight test and ',' on host.cpp
kudomcho Jul 16, 2025
5f6cbd2
updated cpu_template and adjusted args position on host_cpu
kudomcho Jul 16, 2025
49d9373
fixed 1 space on all lines
kudomcho Jul 16, 2025
70ae79a
Merge branch 'main' into small-l-opt-c++
kudomcho Jul 16, 2025
94aaf4e
added F401 on nvfp4 import on quantize_ops
kudomcho Jul 17, 2025
86ad794
remain the fbgemm_kernel_launch
kudomcho Jul 17, 2025
dda9e4d
remove unused import
kudomcho Jul 17, 2025
a599f18
Merge branch 'pytorch:main' into small-l-opt-c++
kudomcho Jul 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 75 additions & 18 deletions fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
Expand Down Expand Up @@ -209,7 +209,13 @@ Tensor int_nbit_split_embedding_codegen_forward_unweighted_cuda(
Tensor lxu_cache_locations,
int64_t max_float8_D,
int64_t fp8_exponent_bits,
int64_t fp8_exponent_bias);
int64_t fp8_exponent_bias,
int64_t INT2_max_ls,
int64_t INT4_max_ls,
int64_t INT8_max_ls,
int64_t FP8_max_ls,
int64_t FP16_max_ls,
int64_t FP32_max_ls);

Tensor int_nbit_split_embedding_codegen_forward_weighted_cuda(
Tensor dev_weights,
Expand All @@ -234,7 +240,13 @@ Tensor int_nbit_split_embedding_codegen_forward_weighted_cuda(
Tensor lxu_cache_locations,
int64_t max_float8_D,
int64_t fp8_exponent_bits,
int64_t fp8_exponent_bias);
int64_t fp8_exponent_bias,
int64_t INT2_max_ls,
int64_t INT4_max_ls,
int64_t INT8_max_ls,
int64_t FP8_max_ls,
int64_t FP16_max_ls,
int64_t FP32_max_ls);

Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda(
Tensor dev_weights,
Expand All @@ -256,7 +268,13 @@ Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda(
Tensor lxu_cache_locations,
int64_t max_float8_D,
int64_t fp8_exponent_bits,
int64_t fp8_exponent_bias);
int64_t fp8_exponent_bias,
int64_t INT2_max_ls,
int64_t INT4_max_ls,
int64_t INT8_max_ls,
int64_t FP8_max_ls,
int64_t FP16_max_ls,
int64_t FP32_max_ls);

///@ingroup embedding-cuda
Tensor int_nbit_split_embedding_codegen_lookup_function(
Expand All @@ -282,7 +300,13 @@ Tensor int_nbit_split_embedding_codegen_lookup_function(
std::optional<int64_t> row_alignment,
std::optional<int64_t> max_float8_D,
std::optional<int64_t> fp8_exponent_bits,
std::optional<int64_t> fp8_exponent_bias) {
std::optional<int64_t> fp8_exponent_bias,
std::optional<int64_t> INT2_max_ls,
std::optional<int64_t> INT4_max_ls,
std::optional<int64_t> INT8_max_ls,
std::optional<int64_t> FP8_max_ls,
std::optional<int64_t> FP16_max_ls,
std::optional<int64_t> FP32_max_ls) {
if (offsets.scalar_type() != indices.scalar_type()) {
offsets = offsets.toType(indices.scalar_type());
}
Expand Down Expand Up @@ -316,7 +340,14 @@ Tensor int_nbit_split_embedding_codegen_lookup_function(
lxu_cache_locations.value_or(at::empty({0}, at::kInt)),
max_float8_D ? *max_float8_D : 0,
fp8_exponent_bits ? *fp8_exponent_bits : -1,
fp8_exponent_bias ? *fp8_exponent_bias : -1);
fp8_exponent_bias ? *fp8_exponent_bias : -1,
INT2_max_ls.value_or(0),
INT4_max_ls.value_or(0),
INT8_max_ls.value_or(0),
FP8_max_ls.value_or(0),
FP16_max_ls.value_or(0),
FP32_max_ls.value_or(0)
);
}
if (!indice_weights || indice_weights->numel() == 0) {
return int_nbit_split_embedding_codegen_forward_unweighted_cuda(
Expand All @@ -341,7 +372,14 @@ Tensor int_nbit_split_embedding_codegen_lookup_function(
lxu_cache_locations.value_or(at::empty({0}, at::kInt)),
max_float8_D ? *max_float8_D : 0,
fp8_exponent_bits ? *fp8_exponent_bits : -1,
fp8_exponent_bias ? *fp8_exponent_bias : -1);
fp8_exponent_bias ? *fp8_exponent_bias : -1,
INT2_max_ls.value_or(0),
INT4_max_ls.value_or(0),
INT8_max_ls.value_or(0),
FP8_max_ls.value_or(0),
FP16_max_ls.value_or(0),
FP32_max_ls.value_or(0)
);
}
// Force casting indice_weights to float (doing this in the backend to avoid
// JIT issue)
Expand Down Expand Up @@ -369,15 +407,21 @@ Tensor int_nbit_split_embedding_codegen_lookup_function(
lxu_cache_locations.value_or(at::empty({0}, at::kInt)),
max_float8_D ? *max_float8_D : 0,
fp8_exponent_bits ? *fp8_exponent_bits : -1,
fp8_exponent_bias ? *fp8_exponent_bias : -1);
fp8_exponent_bias ? *fp8_exponent_bias : -1,
INT2_max_ls.value_or(0),
INT4_max_ls.value_or(0),
INT8_max_ls.value_or(0),
FP8_max_ls.value_or(0),
FP16_max_ls.value_or(0),
FP32_max_ls.value_or(0)
);
}

///@ingroup embedding-cuda
/// Simlar to int_nbit_split_embedding_codegen_lookup_function, but it does
/// UVM_CACHING lookup.
Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
// First args should be the same to those of
// int_nbit_split_embedding_codegen_lookup_function.

Tensor dev_weights,
Tensor uvm_weights,
Tensor weights_placements,
Expand Down Expand Up @@ -415,7 +459,13 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
std::optional<Tensor> lxu_cache_state,
// lxu_state: meta info for replacement (time stamp for LRU).
// 2D tensor: # sets x assoc. dtype=int64.
std::optional<Tensor> lxu_state) {
std::optional<Tensor> lxu_state,
std::optional<int64_t> INT2_max_ls,
std::optional<int64_t> INT4_max_ls,
std::optional<int64_t> INT8_max_ls,
std::optional<int64_t> FP8_max_ls,
std::optional<int64_t> FP16_max_ls,
std::optional<int64_t> FP32_max_ls) {
// This function does prefetch() and foward() methods in
// IntNBitTableBatchedEmbeddingBagsCodegen, but run them in sequence.
// Prefetching of multiple batches of requests is not yet supported.
Expand Down Expand Up @@ -557,7 +607,14 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
row_alignment,
max_float8_D,
fp8_exponent_bits,
fp8_exponent_bias);
fp8_exponent_bias,
INT2_max_ls,
INT4_max_ls,
INT8_max_ls,
FP8_max_ls,
FP16_max_ls,
FP32_max_ls
);
}

///@ingroup embedding-cuda
Expand All @@ -583,4 +640,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
int_nbit_split_embedding_uvm_caching_codegen_lookup_function);
DISPATCH_TO_CUDA("pruned_hashmap_lookup", pruned_hashmap_lookup_cuda);
DISPATCH_TO_CUDA("pruned_array_lookup", pruned_array_lookup_cuda);
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
Expand Down Expand Up @@ -106,7 +106,13 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu_impl(
std::optional<int64_t> max_float8_D,
std::optional<int64_t> fp8_exponent_bits,
std::optional<int64_t> fp8_exponent_bias,
std::optional<bool> scale_bias_last) {
std::optional<bool> scale_bias_last,
std::optional<int64_t> INT2_max_ls,
std::optional<int64_t> INT4_max_ls,
std::optional<int64_t> INT8_max_ls,
std::optional<int64_t> FP8_max_ls,
std::optional<int64_t> FP16_max_ls,
std::optional<int64_t> FP32_max_ls) {
if (offsets.scalar_type() != indices.scalar_type()) {
offsets = offsets.toType(indices.scalar_type());
}
Expand Down Expand Up @@ -199,7 +205,14 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu(
std::optional<int64_t> row_alignment,
std::optional<int64_t> max_float8_D,
std::optional<int64_t> fp8_exponent_bits,
std::optional<int64_t> fp8_exponent_bias) {
std::optional<int64_t> fp8_exponent_bias,
std::optional<int64_t> INT2_max_ls,
std::optional<int64_t> INT4_max_ls,
std::optional<int64_t> INT8_max_ls,
std::optional<int64_t> FP8_max_ls,
std::optional<int64_t> FP16_max_ls,
std::optional<int64_t> FP32_max_ls
) {
return int_nbit_split_embedding_codegen_lookup_function_cpu_impl(
std::move(dev_weights),
std::move(uvm_weights),
Expand All @@ -224,7 +237,15 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu(
std::move(max_float8_D),
std::move(fp8_exponent_bits),
std::move(fp8_exponent_bias),
false);
false,
INT2_max_ls.value_or(0),
INT4_max_ls.value_or(0),
INT8_max_ls.value_or(0),
FP8_max_ls.value_or(0),
FP16_max_ls.value_or(0),
FP32_max_ls.value_or(0)
);

}

///@ingroup embedding-cpu
Expand Down Expand Up @@ -257,7 +278,13 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu(
std::optional<int64_t> total_cache_hash_size [[maybe_unused]],
std::optional<Tensor> cache_index_table_map [[maybe_unused]],
std::optional<Tensor> lxu_cache_state [[maybe_unused]],
std::optional<Tensor> lxu_state [[maybe_unused]]) {
std::optional<Tensor> lxu_state [[maybe_unused]],
std::optional<int64_t> INT2_max_ls,
std::optional<int64_t> INT4_max_ls,
std::optional<int64_t> INT8_max_ls,
std::optional<int64_t> FP8_max_ls,
std::optional<int64_t> FP16_max_ls,
std::optional<int64_t> FP32_max_ls) {
LOG(WARNING)
<< "int_nbit_split_embedding_uvm_caching_codegen_lookup_function shouldn't be called for CPU; it is only for GPU.";
return int_nbit_split_embedding_codegen_lookup_function_cpu(
Expand All @@ -283,7 +310,13 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu(
row_alignment,
max_float8_D,
fp8_exponent_bits,
fp8_exponent_bias);
fp8_exponent_bias,
INT2_max_ls,
INT4_max_ls,
INT8_max_ls,
FP8_max_ls,
FP16_max_ls,
FP32_max_ls);
}

///@ingroup embedding-cpu
Expand Down Expand Up @@ -315,14 +348,14 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py");
#endif
m.def(
"int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1) -> Tensor",
"int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1, int ? INT2_max_ls = 0, int ? INT4_max_ls = 0, int ? INT8_max_ls = 0, int ? FP8_max_ls = 0, int ? FP16_max_ls = 0, int ? FP32_max_ls = 0) -> Tensor",
{PT2_COMPLIANT_TAG});
DISPATCH_TO_CPU(
"int_nbit_split_embedding_codegen_lookup_function",
int_nbit_split_embedding_codegen_lookup_function_cpu);

m.def(
"int_nbit_split_embedding_uvm_caching_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights=None, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment=-1, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1, Tensor? cache_hash_size_cumsum=None, int? total_cache_hash_size=-1, Tensor? cache_index_table_map=None, Tensor? lxu_cache_state=None, Tensor? lxu_state=None) -> Tensor");
"int_nbit_split_embedding_uvm_caching_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights=None, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment=-1, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1, Tensor? cache_hash_size_cumsum=None, int? total_cache_hash_size=-1, Tensor? cache_index_table_map=None, Tensor? lxu_cache_state=None, Tensor? lxu_state=None, int ? INT2_max_ls = 0, int ? INT4_max_ls = 0, int ? INT8_max_ls = 0, int ? FP8_max_ls = 0, int ? FP16_max_ls = 0, int ? FP32_max_ls = 0) -> Tensor");
DISPATCH_TO_CPU(
"int_nbit_split_embedding_uvm_caching_codegen_lookup_function",
int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu);
Expand All @@ -348,7 +381,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
}

class PrunedMapCPU : public torch::jit::CustomClassHolder {
public:
public:
PrunedMapCPU() {}
explicit PrunedMapCPU(std::string serialized) {
torch::serialize::InputArchive archive;
Expand Down Expand Up @@ -470,7 +503,7 @@ class PrunedMapCPU : public torch::jit::CustomClassHolder {
return dense_indices;
}

private:
private:
#ifdef FBCODE_CAFFE2
std::vector<folly::F14FastMap<int32_t, int32_t>> maps_;
#else
Expand All @@ -494,7 +527,7 @@ static auto PrunedMapCPURegistry =
});

class AtomicCounter : public torch::jit::CustomClassHolder {
public:
public:
AtomicCounter() {
counter_ = 0;
}
Expand Down Expand Up @@ -531,7 +564,7 @@ class AtomicCounter : public torch::jit::CustomClassHolder {
return oss.str();
}

private:
private:
std::atomic<int64_t> counter_{0};
};

Expand Down Expand Up @@ -631,7 +664,7 @@ struct TensorQueue : torch::CustomClassHolder {
std::make_tuple("queue", queue_vec));
}

private:
private:
std::deque<Tensor> queue_;
std::mutex mutex_;
Tensor init_tensor_;
Expand Down
Loading
Loading