diff --git a/src/EmbeddingSpMDMAutovec.cc b/src/EmbeddingSpMDMAutovec.cc index 87d27116b7..e9f519a9bc 100644 --- a/src/EmbeddingSpMDMAutovec.cc +++ b/src/EmbeddingSpMDMAutovec.cc @@ -1165,118 +1165,209 @@ typename EmbeddingSpMDMKernelSignature:: }; \ } -#define SPECIALIZE_BLOCK_SIZE( \ - HAS_WEIGHT, \ - NORMALIZE_BY_LENGTHS, \ - PREFETCH, \ - IS_WEIGHT_POSITIONAL, \ - USE_OFFSETS, \ - NO_BAG, \ - IS_BF16_OUT, \ - IS_BF16_IN) \ - SPECIALIZE( \ - /*BLOCK_SIZE*/ fixed(int64_t{32}), \ - HAS_WEIGHT, \ - NORMALIZE_BY_LENGTHS, \ - PREFETCH, \ - IS_WEIGHT_POSITIONAL, \ - USE_OFFSETS, \ - /*OUTPUT_STRIDE*/ var, \ - /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(32, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ - NO_BAG, \ - IS_BF16_OUT, \ - IS_BF16_IN) \ - SPECIALIZE( \ - /*BLOCK_SIZE*/ fixed(int64_t{64}), \ - HAS_WEIGHT, \ - NORMALIZE_BY_LENGTHS, \ - PREFETCH, \ - IS_WEIGHT_POSITIONAL, \ - USE_OFFSETS, \ - /*OUTPUT_STRIDE*/ var, \ - /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(64, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ - NO_BAG, \ - IS_BF16_OUT, \ - IS_BF16_IN) \ - SPECIALIZE( \ - /*BLOCK_SIZE*/ fixed(int64_t{124}), \ - HAS_WEIGHT, \ - NORMALIZE_BY_LENGTHS, \ - PREFETCH, \ - IS_WEIGHT_POSITIONAL, \ - USE_OFFSETS, \ - /*OUTPUT_STRIDE*/ var, \ - /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(124, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ - NO_BAG, \ - IS_BF16_OUT, \ - IS_BF16_IN) \ - SPECIALIZE( \ - /*BLOCK_SIZE*/ fixed(int64_t{128}), \ - HAS_WEIGHT, \ - NORMALIZE_BY_LENGTHS, \ - PREFETCH, \ - IS_WEIGHT_POSITIONAL, \ - USE_OFFSETS, \ - /*OUTPUT_STRIDE*/ var, \ - /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(128, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ - NO_BAG, \ - IS_BF16_OUT, \ - IS_BF16_IN) \ - SPECIALIZE( \ - /*BLOCK_SIZE*/ fixed(int64_t{252}), \ - HAS_WEIGHT, \ - NORMALIZE_BY_LENGTHS, \ - PREFETCH, \ - IS_WEIGHT_POSITIONAL, \ - USE_OFFSETS, \ - /*OUTPUT_STRIDE*/ var, \ - /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(252, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ - NO_BAG, \ - IS_BF16_OUT, \ - IS_BF16_IN) \ - SPECIALIZE( \ - /*BLOCK_SIZE*/ fixed(int64_t{256}), \ - HAS_WEIGHT, \ - NORMALIZE_BY_LENGTHS, \ - PREFETCH, \ - IS_WEIGHT_POSITIONAL, \ - USE_OFFSETS, \ - /*OUTPUT_STRIDE*/ var, \ - /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(256, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ - NO_BAG, \ - IS_BF16_OUT, \ - IS_BF16_IN) \ - SPECIALIZE( \ - /*BLOCK_SIZE*/ fixed(int64_t{508}), \ - HAS_WEIGHT, \ - NORMALIZE_BY_LENGTHS, \ - PREFETCH, \ - IS_WEIGHT_POSITIONAL, \ - USE_OFFSETS, \ - /*OUTPUT_STRIDE*/ var, \ - /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(508, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ - NO_BAG, \ - IS_BF16_OUT, \ - IS_BF16_IN) \ - SPECIALIZE( \ - /*BLOCK_SIZE*/ fixed(int64_t{512}), \ - HAS_WEIGHT, \ - NORMALIZE_BY_LENGTHS, \ - PREFETCH, \ - IS_WEIGHT_POSITIONAL, \ - USE_OFFSETS, \ - /*OUTPUT_STRIDE*/ var, \ - /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(512, false)), \ - /*SCALE_BIAS_LAST*/ fixed(false), \ - NO_BAG, \ - IS_BF16_OUT, \ +#define SPECIALIZE_BLOCK_SIZE( \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{4}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(4, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{24}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(24, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{32}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(32, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{64}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(64, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{96}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(96, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{124}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(124, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{128}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(128, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{252}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(252, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{256}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(256, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{320}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(320, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{384}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(384, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{508}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(508, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{512}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(512, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{768}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(768, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ + IS_BF16_IN) \ + SPECIALIZE( \ + /*BLOCK_SIZE*/ fixed(int64_t{1024}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + PREFETCH, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides(1024, false)), \ + /*SCALE_BIAS_LAST*/ fixed(false), \ + NO_BAG, \ + IS_BF16_OUT, \ IS_BF16_IN) #ifdef FBGEMM_MORE_SPECIALIZATION @@ -1473,6 +1564,19 @@ GenerateEmbeddingSpMDMNBitWithStrides_autovec( IS_BF16_OUT, \ NO_BAG, \ OUTPUT_BIT_RATE) \ + SPECIALIZE( \ + INPUT_BIT_RATE, \ + /*BLOCK_SIZE*/ fixed(int64_t{96}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 96)), \ + SCALE_BIAS_LAST, \ + IS_BF16_OUT, \ + NO_BAG, \ + OUTPUT_BIT_RATE) \ SPECIALIZE( \ INPUT_BIT_RATE, \ /*BLOCK_SIZE*/ fixed(int64_t{120}), \ @@ -1524,6 +1628,84 @@ GenerateEmbeddingSpMDMNBitWithStrides_autovec( SCALE_BIAS_LAST, \ IS_BF16_OUT, \ NO_BAG, \ + OUTPUT_BIT_RATE) \ + SPECIALIZE( \ + INPUT_BIT_RATE, \ + /*BLOCK_SIZE*/ fixed(int64_t{320}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 320)), \ + SCALE_BIAS_LAST, \ + IS_BF16_OUT, \ + NO_BAG, \ + OUTPUT_BIT_RATE) \ + SPECIALIZE( \ + INPUT_BIT_RATE, \ + /*BLOCK_SIZE*/ fixed(int64_t{384}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 384)), \ + SCALE_BIAS_LAST, \ + IS_BF16_OUT, \ + NO_BAG, \ + OUTPUT_BIT_RATE) \ + SPECIALIZE( \ + INPUT_BIT_RATE, \ + /*BLOCK_SIZE*/ fixed(int64_t{512}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 512)), \ + SCALE_BIAS_LAST, \ + IS_BF16_OUT, \ + NO_BAG, \ + OUTPUT_BIT_RATE) \ + SPECIALIZE( \ + INPUT_BIT_RATE, \ + /*BLOCK_SIZE*/ fixed(int64_t{576}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 576)), \ + SCALE_BIAS_LAST, \ + IS_BF16_OUT, \ + NO_BAG, \ + OUTPUT_BIT_RATE) \ + SPECIALIZE( \ + INPUT_BIT_RATE, \ + /*BLOCK_SIZE*/ fixed(int64_t{768}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, \ + /*INPUT_STRIDE*/ fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 768)), \ + SCALE_BIAS_LAST, \ + IS_BF16_OUT, \ + NO_BAG, \ + OUTPUT_BIT_RATE) \ + SPECIALIZE( \ + INPUT_BIT_RATE, \ + /*BLOCK_SIZE*/ fixed(int64_t{1024}), \ + HAS_WEIGHT, \ + NORMALIZE_BY_LENGTHS, \ + IS_WEIGHT_POSITIONAL, \ + USE_OFFSETS, \ + /*OUTPUT_STRIDE*/ var, /*INPUT_STRIDE*/ \ + fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 1024)), \ + SCALE_BIAS_LAST, \ + IS_BF16_OUT, \ + NO_BAG, \ OUTPUT_BIT_RATE) #define SPECIALIZE_INPUT_RATE( \ @@ -1562,6 +1744,14 @@ GenerateEmbeddingSpMDMNBitWithStrides_autovec( /*SCALE_BIAS_LAST*/ fixed(false), /*IS_BF16_OUT*/ var, /*NO_BAG*/ fixed(false)) + SPECIALIZE_INPUT_RATE( + /*HAS_WEIGHT*/ fixed(false), + /*NORMALIZE_BY_LENGTHS*/ fixed(false), + /*IS_WEIGHT_POSITIONAL*/ fixed(false), + /*USE_OFFSETS*/ fixed(true), + /*SCALE_BIAS_LAST*/ fixed(true), + /*IS_BF16_OUT*/ var, + /*NO_BAG*/ fixed(false)) WARN_ONCE( "fbgemm warning: " "using non-specialized EmbeddingSpMDMNBit_autovec (may be slow)\n" diff --git a/src/EmbeddingSpMDMNBit.cc b/src/EmbeddingSpMDMNBit.cc index 04f8314e88..7bf60f5ad7 100644 --- a/src/EmbeddingSpMDMNBit.cc +++ b/src/EmbeddingSpMDMNBit.cc @@ -1036,7 +1036,7 @@ typename EmbeddingSpMDMKernelSignature:: const bool no_bag /*=false*/, int output_bit_rate /*=-1*/) { if (output_bit_rate == -1) { - output_bit_rate = input_bit_rate; + output_bit_rate = sizeof(outType) * 8; } assert( (input_bit_rate == 2 || input_bit_rate == 4) &&