diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp index 5fcc3a0176..b167ea39a2 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp @@ -1,10 +1,10 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include #include @@ -209,7 +209,13 @@ Tensor int_nbit_split_embedding_codegen_forward_unweighted_cuda( Tensor lxu_cache_locations, int64_t max_float8_D, int64_t fp8_exponent_bits, - int64_t fp8_exponent_bias); + int64_t fp8_exponent_bias, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls); Tensor int_nbit_split_embedding_codegen_forward_weighted_cuda( Tensor dev_weights, @@ -234,7 +240,13 @@ Tensor int_nbit_split_embedding_codegen_forward_weighted_cuda( Tensor lxu_cache_locations, int64_t max_float8_D, int64_t fp8_exponent_bits, - int64_t fp8_exponent_bias); + int64_t fp8_exponent_bias, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls); Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda( Tensor dev_weights, @@ -256,7 +268,13 @@ Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda( Tensor lxu_cache_locations, int64_t max_float8_D, int64_t fp8_exponent_bits, - int64_t fp8_exponent_bias); + int64_t fp8_exponent_bias, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls); ///@ingroup embedding-cuda Tensor int_nbit_split_embedding_codegen_lookup_function( @@ -282,7 +300,13 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( std::optional row_alignment, std::optional max_float8_D, std::optional fp8_exponent_bits, - std::optional fp8_exponent_bias) { + std::optional fp8_exponent_bias, + std::optional INT2_max_ls, + std::optional INT4_max_ls, + std::optional INT8_max_ls, + std::optional FP8_max_ls, + std::optional FP16_max_ls, + std::optional FP32_max_ls) { if (offsets.scalar_type() != indices.scalar_type()) { offsets = offsets.toType(indices.scalar_type()); } @@ -316,7 +340,14 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( lxu_cache_locations.value_or(at::empty({0}, at::kInt)), max_float8_D ? *max_float8_D : 0, fp8_exponent_bits ? *fp8_exponent_bits : -1, - fp8_exponent_bias ? *fp8_exponent_bias : -1); + fp8_exponent_bias ? *fp8_exponent_bias : -1, + INT2_max_ls.value_or(0), + INT4_max_ls.value_or(0), + INT8_max_ls.value_or(0), + FP8_max_ls.value_or(0), + FP16_max_ls.value_or(0), + FP32_max_ls.value_or(0) + ); } if (!indice_weights || indice_weights->numel() == 0) { return int_nbit_split_embedding_codegen_forward_unweighted_cuda( @@ -341,7 +372,14 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( lxu_cache_locations.value_or(at::empty({0}, at::kInt)), max_float8_D ? *max_float8_D : 0, fp8_exponent_bits ? *fp8_exponent_bits : -1, - fp8_exponent_bias ? *fp8_exponent_bias : -1); + fp8_exponent_bias ? *fp8_exponent_bias : -1, + INT2_max_ls.value_or(0), + INT4_max_ls.value_or(0), + INT8_max_ls.value_or(0), + FP8_max_ls.value_or(0), + FP16_max_ls.value_or(0), + FP32_max_ls.value_or(0) + ); } // Force casting indice_weights to float (doing this in the backend to avoid // JIT issue) @@ -369,15 +407,21 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( lxu_cache_locations.value_or(at::empty({0}, at::kInt)), max_float8_D ? *max_float8_D : 0, fp8_exponent_bits ? *fp8_exponent_bits : -1, - fp8_exponent_bias ? *fp8_exponent_bias : -1); + fp8_exponent_bias ? *fp8_exponent_bias : -1, + INT2_max_ls.value_or(0), + INT4_max_ls.value_or(0), + INT8_max_ls.value_or(0), + FP8_max_ls.value_or(0), + FP16_max_ls.value_or(0), + FP32_max_ls.value_or(0) + ); } ///@ingroup embedding-cuda -/// Simlar to int_nbit_split_embedding_codegen_lookup_function, but it does /// UVM_CACHING lookup. Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( // First args should be the same to those of - // int_nbit_split_embedding_codegen_lookup_function. + Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, @@ -415,7 +459,13 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( std::optional lxu_cache_state, // lxu_state: meta info for replacement (time stamp for LRU). // 2D tensor: # sets x assoc. dtype=int64. - std::optional lxu_state) { + std::optional lxu_state, + std::optional INT2_max_ls, + std::optional INT4_max_ls, + std::optional INT8_max_ls, + std::optional FP8_max_ls, + std::optional FP16_max_ls, + std::optional FP32_max_ls) { // This function does prefetch() and foward() methods in // IntNBitTableBatchedEmbeddingBagsCodegen, but run them in sequence. // Prefetching of multiple batches of requests is not yet supported. @@ -557,7 +607,14 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( row_alignment, max_float8_D, fp8_exponent_bits, - fp8_exponent_bias); + fp8_exponent_bias, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls + ); } ///@ingroup embedding-cuda @@ -583,4 +640,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { int_nbit_split_embedding_uvm_caching_codegen_lookup_function); DISPATCH_TO_CUDA("pruned_hashmap_lookup", pruned_hashmap_lookup_cuda); DISPATCH_TO_CUDA("pruned_array_lookup", pruned_array_lookup_cuda); -} +} \ No newline at end of file diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp index 98f0235ac0..74154b027c 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp @@ -1,10 +1,10 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include #include @@ -106,7 +106,13 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu_impl( std::optional max_float8_D, std::optional fp8_exponent_bits, std::optional fp8_exponent_bias, - std::optional scale_bias_last) { + std::optional scale_bias_last, + std::optional INT2_max_ls, + std::optional INT4_max_ls, + std::optional INT8_max_ls, + std::optional FP8_max_ls, + std::optional FP16_max_ls, + std::optional FP32_max_ls) { if (offsets.scalar_type() != indices.scalar_type()) { offsets = offsets.toType(indices.scalar_type()); } @@ -199,7 +205,14 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu( std::optional row_alignment, std::optional max_float8_D, std::optional fp8_exponent_bits, - std::optional fp8_exponent_bias) { + std::optional fp8_exponent_bias, + std::optional INT2_max_ls, + std::optional INT4_max_ls, + std::optional INT8_max_ls, + std::optional FP8_max_ls, + std::optional FP16_max_ls, + std::optional FP32_max_ls + ) { return int_nbit_split_embedding_codegen_lookup_function_cpu_impl( std::move(dev_weights), std::move(uvm_weights), @@ -224,7 +237,15 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu( std::move(max_float8_D), std::move(fp8_exponent_bits), std::move(fp8_exponent_bias), - false); + false, + INT2_max_ls.value_or(0), + INT4_max_ls.value_or(0), + INT8_max_ls.value_or(0), + FP8_max_ls.value_or(0), + FP16_max_ls.value_or(0), + FP32_max_ls.value_or(0) + ); + } ///@ingroup embedding-cpu @@ -257,7 +278,13 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu( std::optional total_cache_hash_size [[maybe_unused]], std::optional cache_index_table_map [[maybe_unused]], std::optional lxu_cache_state [[maybe_unused]], - std::optional lxu_state [[maybe_unused]]) { + std::optional lxu_state [[maybe_unused]], + std::optional INT2_max_ls, + std::optional INT4_max_ls, + std::optional INT8_max_ls, + std::optional FP8_max_ls, + std::optional FP16_max_ls, + std::optional FP32_max_ls) { LOG(WARNING) << "int_nbit_split_embedding_uvm_caching_codegen_lookup_function shouldn't be called for CPU; it is only for GPU."; return int_nbit_split_embedding_codegen_lookup_function_cpu( @@ -283,7 +310,13 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu( row_alignment, max_float8_D, fp8_exponent_bits, - fp8_exponent_bias); + fp8_exponent_bias, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls); } ///@ingroup embedding-cpu @@ -315,14 +348,14 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); #endif m.def( - "int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1) -> Tensor", + "int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1, int ? INT2_max_ls = 0, int ? INT4_max_ls = 0, int ? INT8_max_ls = 0, int ? FP8_max_ls = 0, int ? FP16_max_ls = 0, int ? FP32_max_ls = 0) -> Tensor", {PT2_COMPLIANT_TAG}); DISPATCH_TO_CPU( "int_nbit_split_embedding_codegen_lookup_function", int_nbit_split_embedding_codegen_lookup_function_cpu); m.def( - "int_nbit_split_embedding_uvm_caching_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights=None, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment=-1, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1, Tensor? cache_hash_size_cumsum=None, int? total_cache_hash_size=-1, Tensor? cache_index_table_map=None, Tensor? lxu_cache_state=None, Tensor? lxu_state=None) -> Tensor"); + "int_nbit_split_embedding_uvm_caching_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights=None, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment=-1, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1, Tensor? cache_hash_size_cumsum=None, int? total_cache_hash_size=-1, Tensor? cache_index_table_map=None, Tensor? lxu_cache_state=None, Tensor? lxu_state=None, int ? INT2_max_ls = 0, int ? INT4_max_ls = 0, int ? INT8_max_ls = 0, int ? FP8_max_ls = 0, int ? FP16_max_ls = 0, int ? FP32_max_ls = 0) -> Tensor"); DISPATCH_TO_CPU( "int_nbit_split_embedding_uvm_caching_codegen_lookup_function", int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu); @@ -348,7 +381,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { } class PrunedMapCPU : public torch::jit::CustomClassHolder { - public: +public: PrunedMapCPU() {} explicit PrunedMapCPU(std::string serialized) { torch::serialize::InputArchive archive; @@ -470,7 +503,7 @@ class PrunedMapCPU : public torch::jit::CustomClassHolder { return dense_indices; } - private: +private: #ifdef FBCODE_CAFFE2 std::vector> maps_; #else @@ -494,7 +527,7 @@ static auto PrunedMapCPURegistry = }); class AtomicCounter : public torch::jit::CustomClassHolder { - public: +public: AtomicCounter() { counter_ = 0; } @@ -531,7 +564,7 @@ class AtomicCounter : public torch::jit::CustomClassHolder { return oss.str(); } - private: +private: std::atomic counter_{0}; }; @@ -631,7 +664,7 @@ struct TensorQueue : torch::CustomClassHolder { std::make_tuple("queue", queue_vec)); } - private: +private: std::deque queue_; std::mutex mutex_; Tensor init_tensor_; diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index 0edd97a0c6..45e3b62b0e 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -25,7 +25,7 @@ namespace nbit { same generated source file. */ {%- for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} -template +template __launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( const pta::PackedTensorAccessor64 dev_weights, @@ -53,6 +53,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no const int fp8_exponent_bias, {%- endif %} const int32_t num_packed_bags, + const int32_t num_packed_bags_L, pta::PackedTensorAccessor32 output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations @@ -68,9 +69,9 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no #undef X #endif - #define X(DeviceOnly, PackedMode, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + #define X(DeviceOnly, PackedMode, PackedModeL, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ FBGEMM_LAUNCH_KERNEL( \ - ({{ func_name }}), \ + ({{ func_name }}), \ nbit::div_round_up(T * nbit::div_round_up(B, num_packed_bags * OutputRowsPerThread), kWarpsPerBlock), \ dim3(kWarpSize, kWarpsPerBlock), \ 0, \ @@ -100,6 +101,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no fp8_exponent_bias, \ {%- endif %} num_packed_bags, \ + num_packed_bags_L, \ PTA_B(output, output_t, 2, 32), \ PTA_B(lxu_cache_weights, uint8_t, 2, 64), \ PTA_B(lxu_cache_locations, int32_t, 1, 32) \ @@ -200,7 +202,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ Tensor lxu_cache_locations, const int64_t max_float8_D, const int64_t fp8_exponent_bits, - const int64_t fp8_exponent_bias + const int64_t fp8_exponent_bias, + const int64_t INT2_max_ls, + const int64_t INT4_max_ls, + const int64_t INT8_max_ls, + const int64_t FP8_max_ls, + const int64_t FP16_max_ls, + const int64_t FP32_max_ls ) { TENSOR_ON_CUDA_GPU(dev_weights); TENSORS_ON_SAME_DEVICE(uvm_weights, dev_weights); @@ -236,6 +244,9 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ */ #define PACKED_MODE_SWITCH(dev_only, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ int32_t num_packed_bags = 1; \ + int32_t num_packed_bags_D = 1; \ + int32_t num_packed_bags_L = 1; \ + const int64_t max_L = max_Ls; \ {%-if is_rocm and not nobag %} const static bool use_packed_bag_mode = fbgemm_gpu::config::is_feature_enabled( \ fbgemm_gpu::config::FeatureGateName::TBE_ROCM_INFERENCE_PACKED_BAGS); \ @@ -243,14 +254,21 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ /* The actual maximum number of uint4 reads per row w.r.t. row size, type and alignment */ \ const int32_t num_uint4_loads_per_row = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), sizeof(uint4)); \ constexpr int32_t NumUint4LoadsPerRow = MaxNum128BRows * 128 / sizeof(uint4); \ + constexpr int32_t max_indices_per_warp = kWarpSize / NumUint4LoadsPerRow; \ + num_packed_bags_L = max_L > 0 && max_indices_per_warp > max_L && !std::is_same_v && sparse_type != SparseType::FP32? max_indices_per_warp / max_L : 1; \ + num_packed_bags_D = NumUint4LoadsPerRow > num_uint4_loads_per_row && !std::is_same_v && sparse_type != SparseType::FP32 ? NumUint4LoadsPerRow / num_uint4_loads_per_row : 1; \ /* Number of bags that might be fitted to shared memory. */ \ - num_packed_bags = NumUint4LoadsPerRow > num_uint4_loads_per_row && !std::is_same_v && sparse_type != SparseType::FP32 ? NumUint4LoadsPerRow / num_uint4_loads_per_row : 1; \ + num_packed_bags = max_L==1 ? num_packed_bags_L * num_packed_bags_D : num_packed_bags_D; \ } \ {%- endif %} if (num_packed_bags > 1) { \ - X(dev_only, true, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + if (max_L==1){ \ + X(dev_only, true, true, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + } else{ \ + X(dev_only, true, false, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + } \ } else { \ - X(dev_only, false, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + X(dev_only, false, false, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ }; #define Y(...) \ @@ -270,6 +288,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "int2_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int2_D > 0) { const auto max_D = max_int2_D; + const auto max_Ls = INT2_max_ls; constexpr auto sparse_type = SparseType::INT2; auto max_int2_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_int2_128b_rows <= 8); @@ -299,6 +318,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "int4_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int4_D > 0) { const auto max_D = max_int4_D; + const auto max_Ls = INT4_max_ls; constexpr auto sparse_type = SparseType::INT4; auto max_int4_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_int4_128b_rows <= 16); @@ -345,6 +365,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "int8_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int8_D > 0) { const auto max_D = max_int8_D; + const auto max_Ls = INT8_max_ls; constexpr auto sparse_type = SparseType::INT8; auto max_int8_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_int8_128b_rows <= 32); @@ -402,6 +423,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "fp8_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_float8_D > 0) { const auto max_D = max_float8_D; + const auto max_Ls = FP8_max_ls; constexpr auto sparse_type = SparseType::FP8; auto max_fp8_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_fp8_128b_rows <= 32); @@ -437,6 +459,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "fp16_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_float16_D > 0) { const auto max_D = max_float16_D; + const auto max_Ls = FP16_max_ls; constexpr auto sparse_type = SparseType::FP16; auto max_fp16_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_fp16_128b_rows <= 64); @@ -472,6 +495,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "fp32_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_float32_D > 0) { const auto max_D = max_float32_D; + const auto max_Ls = FP32_max_ls; constexpr auto sparse_type = SparseType::FP32; auto max_fp32_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_fp32_128b_rows <= 64); // 128 doesn't fit in 48KB SM, so FP32 TBE supports a smaller dimension than others @@ -525,7 +549,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ Tensor lxu_cache_locations, const int64_t max_float8_D, const int64_t fp8_exponent_bits, - const int64_t fp8_exponent_bias + const int64_t fp8_exponent_bias, + const int64_t INT2_max_ls, + const int64_t INT4_max_ls, + const int64_t INT8_max_ls, + const int64_t FP8_max_ls, + const int64_t FP16_max_ls, + const int64_t FP32_max_ls ) { // All argument tensors need to be on the same CUDA device TENSOR_ON_CUDA_GPU(dev_weights); @@ -586,7 +616,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ lxu_cache_locations, max_float8_D, fp8_exponent_bits, - fp8_exponent_bias); + fp8_exponent_bias, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls); }); return output; diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index c17fe9fb0f..828527dcb1 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -9,7 +9,7 @@ // clang-format off {% set wdesc = "weighted" if weighted else "unweighted" %} #include "fbgemm_gpu/embedding_forward_template_helpers.cuh" -#include "fbgemm_gpu/utils/tensor_accessor_builder.h" +#include "fbgemm_gpu/utils/tensor_accessor.h" using namespace fbgemm_gpu; using Tensor = at::Tensor; @@ -17,7 +17,7 @@ using Tensor = at::Tensor; namespace nbit { // TODO: increase code sharing (templates for accumulator_ty, accumulation, outputs per thread, etc?) -template +template __launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( const pta::PackedTensorAccessor64 dev_weights, @@ -46,40 +46,40 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no {% endif %} // The number of bags that one warp/wave is able to process in one go. (NumUint4LoadsPerRow / uint4_loads_per_row) const int32_t num_packed_bags, + const int32_t num_packed_bags_L, pta::PackedTensorAccessor32 output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations ) { - const int32_t T = weights_offsets.size(0); - {% if not nobag %} - const bool mean_pooling = static_cast(pooling_mode) == PoolingMode::MEAN; - const int32_t B = output.size(0); - {% else %} - const int32_t B = (offsets.size(0) - 1) / T; - {% endif %} - const auto bb_t = blockIdx.x * blockDim.y + threadIdx.y; - if (bb_t >= fd_B.D() * T) { - return; - } - static_assert( - std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v, - "output_t can only be float or half or bytes now" - ); - - int32_t t; - int32_t bb; - fd_B.DivMod(bb_t, &t, &bb); - - {% if not nobag %} - const int32_t D_start = D_offsets[t]; - const int32_t D_end = D_offsets[t + 1]; - const int32_t D = D_end - D_start; - {% endif %} - SparseType weight_ty = static_cast(weights_tys[t]); - if (weight_ty != SparseType::{{ emb_weight_type.enum_name }}) { + const int32_t T = weights_offsets.size(0); + {% if not nobag %} + const bool mean_pooling = static_cast(pooling_mode) == PoolingMode::MEAN; + const int32_t B = output.size(0); + {% else %} + const int32_t B = (offsets.size(0) - 1) / T; + {% endif %} + const auto bb_t = blockIdx.x * blockDim.y + threadIdx.y; + if (bb_t >= fd_B.D() * T) { return; - } - + } + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v, + "output_t can only be float or half or bytes now" + ); + + int32_t t; + int32_t bb; + fd_B.DivMod(bb_t, &t, &bb); + + {% if not nobag %} + const int32_t D_start = D_offsets[t]; + const int32_t D_end = D_offsets[t + 1]; + const int32_t D = D_end - D_start; + {% endif %} + SparseType weight_ty = static_cast(weights_tys[t]); + if (weight_ty != SparseType::{{ emb_weight_type.enum_name }}) { + return; + } // default to 16 byte alignment for GPU TBE const int32_t D_bytes = padded_row_size_in_bytes(D, weight_ty, row_alignment); @@ -100,7 +100,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no constexpr uint32_t NumUint4LoadsPerRow = MaxNum128BRows * 128 / sizeof(uint4); const uint32_t uint4_loads_per_row = div_round_up(D_bytes, sizeof(uint4)); - + const int32_t bag_size_offset = num_packed_bags_L > 1 ? kWarpSize/(num_packed_bags_L * NumUint4LoadsPerRow) : 1; // Index of packed bag during load stage in current warp/wave. Should fit into NumUint4LoadsPerRow (3rd) shared // memory buffer's dimension w.r.t. the actual size of the row in the bag. const uint32_t packed_bag_load_idx = PackedMode ? (threadIdx.x % NumUint4LoadsPerRow) / uint4_loads_per_row : 0; @@ -109,7 +109,10 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no // Packed bag accumulation index in current warp/wave. Each thread/lane process 1 uint instead of // 4 uints during load stage, so the index should be recomputed accordingly. const int32_t packed_bag_acc_idx = PackedMode ? (threadIdx.x / uints_per_row) % num_packed_bags : 0; + const uint32_t packed_bag_idx_L = num_packed_bags_L > 1 ? (threadIdx.x / NumUint4LoadsPerRow) / bag_size_offset : 0; + const uint32_t packed_bag_idx = (packed_bag_idx_L * num_packed_bags) + packed_bag_load_idx; + // const int32_t bag_d = kWarpSize/num_packed_bags_L; // num_packed_bags_L can be {1, 2, 4, 8} for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { // In case of PackedMode, b should be offseted with num_packed_bags and indexed with packed_bag_load_idx // to take into account reduced grid size in host kernel call and that the warp/wave may contain several @@ -134,315 +137,573 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no {% if not nobag %} VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> accumulators[OutputRowsPerThread][AccumulateStoreRequests]; - {% endif %} - - for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { - uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); - - typedef uint4 AllBuffers[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4LoadsPerRow]; - __shared__ AllBuffers buffers; - - {% if weighted %} - // In case of PackedMode, overallocate indice weights buffer to store additional per-row weights for - // packed bags. - typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][PackedMode ? NumUint4LoadsPerRow : 1]; - __shared__ AllIndiceWeights buffers_indice_weights; - {% endif %} - for (uint32_t load_idx = threadIdx.x; load_idx < input_rows_in_flight * NumUint4LoadsPerRow; load_idx += kWarpSize) { - uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow; - if constexpr (PackedMode) { - // The actual row index in packed bag w.r.t. the required uint4 loads. - row_load_idx %= uint4_loads_per_row; - } - uint32_t input_row_idx = (load_idx / NumUint4LoadsPerRow); - // In case of PackedMode, packed_bag_load_idx already takes into account uint4_loads_per_row, - // so only the packed_bag index should be evaluated against total number of packed bags. - bool load_idx_valid = PackedMode ? packed_bag_load_idx < num_packed_bags : row_load_idx < uint4_loads_per_row; - {%- if is_rocm %} - constexpr uint32_t kMaxRowUnroll = 4; - constexpr uint32_t kRowUnroll = OutputRowsPerThread < kMaxRowUnroll ? OutputRowsPerThread : kMaxRowUnroll; - - #pragma unroll - for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) { - uint4 row_data_v[kRowUnroll]; - const uint4* row_v[kRowUnroll]; - int32_t idx_v[kRowUnroll]; - int32_t cache_idx_v[kRowUnroll]; + {% endif %} + typedef uint4 AllBuffers[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4LoadsPerRow]; + __shared__ AllBuffers buffers; + {% if weighted %} + // In case of PackedMode, overallocate indice weights buffer to store additional per-row weights for + // packed bags. + typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][PackedMode ? NumUint4LoadsPerRow : 1]; + __shared__ AllIndiceWeights buffers_indice_weights; + {% endif %} + {%- if is_rocm %} + constexpr uint32_t kMaxRowUnroll = 4; + constexpr uint32_t kRowUnroll = OutputRowsPerThread < kMaxRowUnroll ? OutputRowsPerThread : kMaxRowUnroll; + {% endif %} + if constexpr (PackedModeL){ + for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { + uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); + for (uint32_t load_idx = threadIdx.x; load_idx < input_rows_in_flight * num_packed_bags_L * NumUint4LoadsPerRow; load_idx += kWarpSize) { + uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow % uint4_loads_per_row; + uint32_t input_row_idx = num_packed_bags_L>1? (load_idx / NumUint4LoadsPerRow) % bag_size_offset: (load_idx / NumUint4LoadsPerRow); + bool load_idx_valid = packed_bag_load_idx < num_packed_bags && packed_bag_idx_L < num_packed_bags_L; + {%- if is_rocm %} #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { - uint32_t i = outer_i + inner_i; - bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; - bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); - idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; - cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) { + uint4 row_data_v[kRowUnroll]; + const uint4* row_v[kRowUnroll]; + int32_t idx_v[kRowUnroll]; + int32_t cache_idx_v[kRowUnroll]; + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; + cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + } + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + valid = valid && (idx_v[inner_i] != -1); + if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) { + row_v[inner_i] = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx_v[inner_i])][0]); + } else + if (valid) { + row_v[inner_i] = reinterpret_cast(&weights[static_cast(idx_v[inner_i]) * D_bytes]); + } else { + row_v[inner_i] = reinterpret_cast(&weights[0]); + } + } + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + row_data_v[inner_i] = row_v[inner_i][row_load_idx]; + } + uint4 zeros = {0, 0, 0, 0}; + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1); + uint4 data = valid ? row_data_v[inner_i] : zeros; + buffers[warp_idx][i][input_row_idx + bag_size_offset *packed_bag_idx_L][row_load_idx + uint4_loads_per_row * packed_bag_load_idx] = data; + {% if weighted %} + buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + {% endif %} + } } - - + {%- endif %} + + {%- if is_rocm %} + if constexpr (OutputRowsPerThread % kRowUnroll) + { #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { - uint32_t i = outer_i + inner_i; + for (uint32_t i = OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; i < OutputRowsPerThread; ++i) { + {%- else %} + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + {%- endif %} bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); - valid = valid && (idx_v[inner_i] != -1); - if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) { - row_v[inner_i] = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx_v[inner_i])][0]); - } else - if (valid) { - row_v[inner_i] = reinterpret_cast(&weights[static_cast(idx_v[inner_i]) * D_bytes]); + int32_t idx = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; + int32_t cache_idx = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + valid = valid && (idx != -1); + const uint4* row; + if (!DeviceOnly && cache_valid && cache_idx != kCacheLocationMissing) { + row = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx)][0]); + } else if (valid) { + row = reinterpret_cast(&weights[static_cast(idx) * D_bytes]); } else { - row_v[inner_i] = reinterpret_cast(&weights[0]); + row = reinterpret_cast(&weights[0]); } + cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx + bag_size_offset * packed_bag_idx_L][row_load_idx + uint4_loads_per_row * packed_bag_load_idx] , &row[row_load_idx], valid); + {% if weighted %} + buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + {% endif %} } - #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { - uint32_t i = outer_i + inner_i; - row_data_v[inner_i] = row_v[inner_i][row_load_idx]; + {%- if is_rocm %} + } // constexpr if (OutputRowsPerThread % kRowUnroll) + {%- endif %} + } + // equivalent to fence + wait. + cp_async_wait<0>(); + syncwarp(); + const int32_t packed_bag_load_idx = (threadIdx.x / uints_per_row) % num_packed_bags; + input_rows_in_flight = shfl_sync(input_rows_in_flight, packed_bag_load_idx * uint4_loads_per_row); + constexpr int32_t max_indices_per_warp = kWarpSize / (MaxNum128BRows * 128 / sizeof(uint4)); + int32_t Ls_shfl[kWarpSize]; + for(uint32_t k = 0; k < num_packed_bags_L ; ++k){ + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + Ls_shfl[k*OutputRowsPerThread+i] = shfl_sync(Ls[i], k * bag_size_offset * NumUint4LoadsPerRow + packed_bag_load_idx * uint4_loads_per_row); } - uint4 zeros = {0, 0, 0, 0}; - #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { - uint32_t i = outer_i + inner_i; - bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1); - uint4 data = valid ? row_data_v[inner_i] : zeros; - if constexpr (PackedMode) { - // Store row data with uint4_loads_per_row offset - buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx] = data; - } else { - buffers[warp_idx][i][input_row_idx][row_load_idx] = data; + } + for(uint32_t k = 0; k < num_packed_bags_L ; ++k){ + for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { + + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + bool valid = L_start + input_row_idx < Ls_shfl[k*OutputRowsPerThread+i]; + if (!valid) { + continue; + } + const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx + bag_size_offset *k][0]); + // scale and bias are at the beginning of each row. + // rationale: have scale/shift at start since these get loaded first + // and then broadcasted around so it might speed up the first cache miss. + {% if emb_weight_type.primitive_type == "INT" %} + half2 shift_scale = reinterpret_cast(row)[(packed_bag_load_idx * uints_per_row)]; + {% endif %} + + {% if weighted %} + float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][PackedMode ? packed_bag_acc_idx : 0]; + {% endif %} + + using scalar_t = {{ emb_weight_type.cpp_type_name }}; + + {% if not nobag %} + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + {% if weighted %} + accumulators[i][j].fma(v, {% if emb_weight_type.primitive_type == "INT" %} shift_scale, {% elif emb_weight_type.enum_name == "FP8" %} exponent_bits, exponent_bias, {% endif %} row_weight); + {% else %} + accumulators[i][j].add(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + + {% endif %} + } + + {% else %} + const int32_t output_j = indices_starts[i] + L_start + input_row_idx; + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: + // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to + // the scale/shift handling). + // Reason: to avoid divergence the first thread in the warp computes garbage. + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], num_valid_outputs); + } + } + } else if constexpr (std::is_same_v) { + // INT8: + // apply per feature row-wise int8 + auto thread_local_min = std::numeric_limits::max(); + auto thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + if (output_d >= 0 && output_d < D) { + thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); + thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); + } + } + qparams = warp_find_qparams(thread_local_min, thread_local_max); + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], qparams, num_valid_outputs); + } + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[output_j][D], qparams); + } + } + {% endif %} + } } - {% if weighted %} - if (row_load_idx == 0) { - // Use only one thread to load the index weight to prevent a race - // condition when writing to the shared memory - buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = - valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + {% if not nobag %} + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + const int32_t num_stores_with_padding_per_row = 4 * uint4_loads_per_row; + const int32_t packed_bag_load_idx = threadIdx.x / num_stores_with_padding_per_row; + uint32_t b = min(static_cast(bb * num_packed_bags * OutputRowsPerThread + i * num_packed_bags + k*num_packed_bags + packed_bag_load_idx), static_cast(B - 1)); + const float inv_L = (mean_pooling &&Ls_shfl[k*OutputRowsPerThread+i] != 0) ? static_cast(1.0) / Ls_shfl[k*OutputRowsPerThread+i] : static_cast(1.0); + + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding \ + - packed_bag_load_idx * kOutputsPerThread * num_stores_with_padding_per_row; + accumulators[i][j].mul(inv_L); + + if (output_d >= 0 && output_d < D && packed_bag_load_idx < num_packed_bags) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); + } + + } + + } else if constexpr (std::is_same_v) { + // INT8: + // apply per feature row-wise int8 + float thread_local_min = std::numeric_limits::max(); + float thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + accumulators[i][j].mul(inv_L); + if (output_d >= 0 && output_d < D) { + thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(accumulators[i][j].acc)); + thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(accumulators[i][j].acc)); + } + } + + qparams = warp_find_qparams(thread_local_min, thread_local_max); + const int output_D_start = D_start + t * 8; + const int output_D_end = output_D_start + D; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + accumulators[i][j].store(&output[b][output_D_start + output_d], qparams, num_valid_outputs); + } + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[b][output_D_end], qparams); + } + } else { + // INT4: not implemented yet } - {% endif %} } - } - {%- endif %} - {%- if is_rocm %} - if constexpr (OutputRowsPerThread % kRowUnroll) - { - #pragma unroll - for (uint32_t i = OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; i < OutputRowsPerThread; ++i) { - {%- else %} - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - {%- endif %} - bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; - bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); - int32_t idx = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; - int32_t cache_idx = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; - valid = valid && (idx != -1); - const uint4* row; - if (!DeviceOnly && cache_valid && cache_idx != kCacheLocationMissing) { - row = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx)][0]); - } else if (valid) { - row = reinterpret_cast(&weights[static_cast(idx) * D_bytes]); - } else { - row = reinterpret_cast(&weights[0]); + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + accumulators[i][j].mul(0.0); // Use a dedicated clear method + } } + + + {% endif %} + } + } + } + else{ + for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { + uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); + + for (uint32_t load_idx = threadIdx.x; load_idx < input_rows_in_flight * NumUint4LoadsPerRow; load_idx += kWarpSize) { + uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow; + if constexpr (PackedMode) { + // The actual row index in packed bag w.r.t. the required uint4 loads. + row_load_idx %= uint4_loads_per_row; + } + uint32_t input_row_idx = (load_idx / NumUint4LoadsPerRow); + // In case of PackedMode, packed_bag_load_idx already takes into account uint4_loads_per_row, + // so only the packed_bag index should be evaluated against total number of packed bags. + bool load_idx_valid = PackedMode ? packed_bag_load_idx < num_packed_bags : row_load_idx < uint4_loads_per_row; + {%- if is_rocm %} + constexpr uint32_t kMaxRowUnroll = 4; + constexpr uint32_t kRowUnroll = OutputRowsPerThread < kMaxRowUnroll ? OutputRowsPerThread : kMaxRowUnroll; + + #pragma unroll + for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) { + uint4 row_data_v[kRowUnroll]; + const uint4* row_v[kRowUnroll]; + int32_t idx_v[kRowUnroll]; + int32_t cache_idx_v[kRowUnroll]; + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; + cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + } + + + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + valid = valid && (idx_v[inner_i] != -1); + if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) { + row_v[inner_i] = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx_v[inner_i])][0]); + } else + if (valid) { + row_v[inner_i] = reinterpret_cast(&weights[static_cast(idx_v[inner_i]) * D_bytes]); + } else { + row_v[inner_i] = reinterpret_cast(&weights[0]); + } + } + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + row_data_v[inner_i] = row_v[inner_i][row_load_idx]; + } + uint4 zeros = {0, 0, 0, 0}; + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1); + uint4 data = valid ? row_data_v[inner_i] : zeros; + if constexpr (PackedMode) { + // Store row data with uint4_loads_per_row offset + buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx] = data; + } else { + buffers[warp_idx][i][input_row_idx][row_load_idx] = data; + } + {% if weighted %} + if (row_load_idx == 0) { + // Use only one thread to load the index weight to prevent a race + // condition when writing to the shared memory + buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = + valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + } + {% endif %} + } + } + {%- endif %} + + {%- if is_rocm %} + if constexpr (OutputRowsPerThread % kRowUnroll) + { + #pragma unroll + for (uint32_t i = OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; i < OutputRowsPerThread; ++i) { + {%- else %} + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + {%- endif %} + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + int32_t idx = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; + int32_t cache_idx = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + valid = valid && (idx != -1); + const uint4* row; + if (!DeviceOnly && cache_valid && cache_idx != kCacheLocationMissing) { + row = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx)][0]); + } else if (valid) { + row = reinterpret_cast(&weights[static_cast(idx) * D_bytes]); + } else { + row = reinterpret_cast(&weights[0]); + } + if constexpr (PackedMode) { + // Load valid packed row data w.r.t. packed_bag offset + cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx], &row[row_load_idx], valid); + } else { + cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx], &row[row_load_idx], valid); + } + {% if weighted %} + if (row_load_idx == 0) { + // Use only one thread to load the index weight to prevent a race + // condition when writing to the shared memory + buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = + valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + } + {% endif %} + } + {%- if is_rocm %} + } // constexpr if (OutputRowsPerThread % kRowUnroll) + {%- endif %} + } + // equivalent to fence + wait. + cp_async_wait<0>(); + syncwarp(); + if constexpr (PackedMode) { - // Load valid packed row data w.r.t. packed_bag offset - cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx], &row[row_load_idx], valid); - } else { - cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx], &row[row_load_idx], valid); + // Since in PackedMode one warp/wave may contain different bags with different sizes, + // the permutation should be done after switching from uint4 processing during load stage + // to uint processing during accumulate and store. + input_rows_in_flight = shfl_sync(input_rows_in_flight, packed_bag_acc_idx * uint4_loads_per_row); + + #pragma unroll OutputRowsPerThread + for(uint32_t i = 0; i < OutputRowsPerThread; ++i) + { + Ls[i] = shfl_sync(Ls[i], packed_bag_acc_idx * uint4_loads_per_row); + } } - {% if weighted %} - if (row_load_idx == 0) { - // Use only one thread to load the index weight to prevent a race - // condition when writing to the shared memory - buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = - valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + + for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + bool valid = L_start + input_row_idx < Ls[i]; + if (!valid) { + continue; + } + const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); + // scale and bias are at the beginning of each row. + // rationale: have scale/shift at start since these get loaded first + // and then broadcasted around so it might speed up the first cache miss. + {% if emb_weight_type.primitive_type == "INT" %} + // In PackedMode, row pointer may contain several rows from different bags, so each thread/lane should + // read the certain shift_scale related to the row in the packed_bag. + half2 shift_scale = reinterpret_cast(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; + {% endif %} + + {% if weighted %} + float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][PackedMode ? packed_bag_acc_idx : 0]; + {% endif %} + + using scalar_t = {{ emb_weight_type.cpp_type_name }}; + + {% if not nobag %} + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + {% if weighted %} + accumulators[i][j].fma(v, {% if emb_weight_type.primitive_type == "INT" %} shift_scale, {% elif emb_weight_type.enum_name == "FP8" %} exponent_bits, exponent_bias, {% endif %} row_weight); + {% else %} + accumulators[i][j].add(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + {% endif %} + } + {% else %} + const int32_t output_j = indices_starts[i] + L_start + input_row_idx; + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: + // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to + // the scale/shift handling). + // Reason: to avoid divergence the first thread in the warp computes garbage. + const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], num_valid_outputs); + } + } + } else if constexpr (std::is_same_v) { + // INT8: + // apply per feature row-wise int8 + auto thread_local_min = std::numeric_limits::max(); + auto thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + if (output_d >= 0 && output_d < D) { + thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); + thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); + } + } + qparams = warp_find_qparams(thread_local_min, thread_local_max); + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], qparams, num_valid_outputs); + } + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[output_j][D], qparams); + } + } + {% endif %} + } } - {% endif %} } - {%- if is_rocm %} - } // constexpr if (OutputRowsPerThread % kRowUnroll) - {%- endif %} - } - // equivalent to fence + wait. - cp_async_wait<0>(); - syncwarp(); - - if constexpr (PackedMode) { - // Since in PackedMode one warp/wave may contain different bags with different sizes, - // the permutation should be done after switching from uint4 processing during load stage - // to uint processing during accumulate and store. - input_rows_in_flight = shfl_sync(input_rows_in_flight, packed_bag_acc_idx * uint4_loads_per_row); - - #pragma unroll OutputRowsPerThread - for(uint32_t i = 0; i < OutputRowsPerThread; ++i) - { - Ls[i] = shfl_sync(Ls[i], packed_bag_acc_idx * uint4_loads_per_row); - } - } - - for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { + + {% if not nobag %} + // In case of PackedMode, computes the packed bag index during store stage w.r.t. + // the real number of uints in the rows. + const auto packed_bag_store_idx = PackedMode ? threadIdx.x / uints_per_row : 0; + #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - bool valid = L_start + input_row_idx < Ls[i]; - if (!valid) { - continue; - } - const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); - // scale and bias are at the beginning of each row. - // rationale: have scale/shift at start since these get loaded first - // and then broadcasted around so it might speed up the first cache miss. - {% if emb_weight_type.primitive_type == "INT" %} - // In PackedMode, row pointer may contain several rows from different bags, so each thread/lane should - // read the certain shift_scale related to the row in the packed_bag. - half2 shift_scale = reinterpret_cast(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; - {% endif %} - - {% if weighted %} - float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][PackedMode ? packed_bag_acc_idx : 0]; - {% endif %} - - using scalar_t = {{ emb_weight_type.cpp_type_name }}; - - {% if not nobag %} - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - {% if weighted %} - accumulators[i][j].fma(v, {% if emb_weight_type.primitive_type == "INT" %} shift_scale, {% elif emb_weight_type.enum_name == "FP8" %} exponent_bits, exponent_bias, {% endif %} row_weight); - {% else %} - accumulators[i][j].add(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - {% endif %} - } - {% else %} - const int32_t output_j = indices_starts[i] + L_start + input_row_idx; + const uint32_t b = min(static_cast(bb * num_packed_bags * OutputRowsPerThread + i * num_packed_bags + packed_bag_store_idx), static_cast(B - 1)); + const float inv_L = (mean_pooling && Ls[i] != 0) ? static_cast(1.0) / Ls[i] : static_cast(1.0); + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { #pragma unroll AccumulateStoreRequests for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: - // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to - // the scale/shift handling). - // Reason: to avoid divergence the first thread in the warp computes garbage. - const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); - VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - acc.store(&output[output_j][output_d], num_valid_outputs); + auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + if constexpr (PackedMode) { + // Offset global output_d index with the size of outputs per bag w.r.t. current + // packed bag index + output_d -= packed_bag_store_idx * kOutputsPerThread * uints_per_row; } + accumulators[i][j].mul(inv_L); + + if constexpr (PackedMode) { + // Take into account the packed bag index overflow + if (output_d >= 0 && output_d < D && packed_bag_store_idx < num_packed_bags) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); + } + } else { + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); + } + } + } } else if constexpr (std::is_same_v) { // INT8: // apply per feature row-wise int8 - auto thread_local_min = std::numeric_limits::max(); - auto thread_local_max = std::numeric_limits::lowest(); + float thread_local_min = std::numeric_limits::max(); + float thread_local_max = std::numeric_limits::lowest(); float2 qparams; #pragma unroll AccumulateStoreRequests for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + accumulators[i][j].mul(inv_L); if (output_d >= 0 && output_d < D) { - thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); - thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); + thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(accumulators[i][j].acc)); + thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(accumulators[i][j].acc)); } } + qparams = warp_find_qparams(thread_local_min, thread_local_max); + const int output_D_start = D_start + t * 8; + const int output_D_end = output_D_start + D; #pragma unroll AccumulateStoreRequests for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; if (output_d >= 0 && output_d < D) { const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); - VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - acc.store(&output[output_j][output_d], qparams, num_valid_outputs); + accumulators[i][j].store(&output[b][output_D_start + output_d], qparams, num_valid_outputs); } } if (threadIdx.x == 0) { - store_qparams_to_row(&output[output_j][D], qparams); - } - } - {% endif %} - } - } - } - - {% if not nobag %} - // In case of PackedMode, computes the packed bag index during store stage w.r.t. - // the real number of uints in the rows. - const auto packed_bag_store_idx = PackedMode ? threadIdx.x / uints_per_row : 0; - - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - const uint32_t b = min(static_cast(bb * num_packed_bags * OutputRowsPerThread + i * num_packed_bags + packed_bag_store_idx), static_cast(B - 1)); - const float inv_L = (mean_pooling && Ls[i] != 0) ? static_cast(1.0) / Ls[i] : static_cast(1.0); - - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - if constexpr (PackedMode) { - // Offset global output_d index with the size of outputs per bag w.r.t. current - // packed bag index - output_d -= packed_bag_store_idx * kOutputsPerThread * uints_per_row; - } - accumulators[i][j].mul(inv_L); - - if constexpr (PackedMode) { - // Take into account the packed bag index overflow - if (output_d >= 0 && output_d < D && packed_bag_store_idx < num_packed_bags) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); - accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); + store_qparams_to_row(&output[b][output_D_end], qparams); } } else { - if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); - accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); - } - } - - } - } else if constexpr (std::is_same_v) { - // INT8: - // apply per feature row-wise int8 - float thread_local_min = std::numeric_limits::max(); - float thread_local_max = std::numeric_limits::lowest(); - float2 qparams; - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - accumulators[i][j].mul(inv_L); - if (output_d >= 0 && output_d < D) { - thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(accumulators[i][j].acc)); - thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(accumulators[i][j].acc)); - } - } - - qparams = warp_find_qparams(thread_local_min, thread_local_max); - const int output_D_start = D_start + t * 8; - const int output_D_end = output_D_start + D; - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); - accumulators[i][j].store(&output[b][output_D_start + output_d], qparams, num_valid_outputs); + // INT4: not implemented yet } - } - if (threadIdx.x == 0) { - store_qparams_to_row(&output[b][output_D_end], qparams); - } - } else { - // INT4: not implemented yet } - } - {% endif %} + {% endif %} + } } // kWarpsPerBlock is defined in embedding_forward_quantized_split_nbit_host_template.cu {% set warps_per_block = '4' %} {% for packed_mode in ['true', 'false'] %} +{% for packed_mode_L in ['true', 'false'] %} {% for device_only in ['true', 'false'] %} {% for output_type in ['at::Half', 'at::BFloat16', 'float', 'uint8_t'] %} {% for index_type in ['int32_t', 'int64_t'] %} @@ -464,7 +725,8 @@ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else "" {{ params.min_128b_rows }}, {{ params.max_128b_rows }}, {{ device_only }}, - {{ packed_mode }} > ( + {{ packed_mode }}, + {{ packed_mode_L }} > ( const pta::PackedTensorAccessor64 dev_weights, const pta::PackedTensorAccessor64 uvm_weights, const pta::PackedTensorAccessor32 weights_placements, @@ -490,6 +752,7 @@ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else "" const int exponent_bias, {% endif %} const int32_t num_packed_bags, + const int32_t num_packed_bags_L, pta::PackedTensorAccessor32<{{ output_type }}, 2, at::RestrictPtrTraits> output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations @@ -504,7 +767,8 @@ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else "" {% endfor %} // for output_type in [True, False] {% endfor %} // device_only in [True, False] {% endfor %} // packed_bags in ['true', 'false'] +{% endfor %} // packed_bags in ['true', 'false'] } - // clang-format on + // clang-format on \ No newline at end of file