diff --git a/src/quant/int16.hpp b/src/quant/int16.hpp index a80cf8342f..16ecb72711 100644 --- a/src/quant/int16.hpp +++ b/src/quant/int16.hpp @@ -161,6 +161,62 @@ namespace ndd { } #endif +#if defined(USE_AVX2) + inline std::vector + quantize_vector_fp32_to_int16_buffer_avx2(const std::vector& input) { + if(input.empty()) { + return std::vector(); + } + + size_t dimension = input.size(); + size_t buffer_size = get_storage_size(dimension); + std::vector buffer(buffer_size); + + float abs_max = math::find_abs_max(input.data(), dimension); + if(abs_max == 0.0f) { + abs_max = 1.0f; + } + float scale = abs_max / INT16_SCALE; + float inv_scale = 1.0f / scale; + + int16_t* data_ptr = reinterpret_cast(buffer.data()); + const __m256 scale_vec = _mm256_set1_ps(inv_scale); + + size_t i = 0; + size_t vec_size = (dimension / 16) * 16; + + for(; i < vec_size; i += 16) { + __m256 vec0 = _mm256_loadu_ps(&input[i]); + __m256 vec1 = _mm256_loadu_ps(&input[i + 8]); + + vec0 = _mm256_mul_ps(vec0, scale_vec); + vec1 = _mm256_mul_ps(vec1, scale_vec); + + __m256i int32_0 = _mm256_cvtps_epi32(vec0); + __m256i int32_1 = _mm256_cvtps_epi32(vec1); + + __m128i packed0 = _mm_packs_epi32(_mm256_castsi256_si128(int32_0), + _mm256_extracti128_si256(int32_0, 1)); + __m128i packed1 = _mm_packs_epi32(_mm256_castsi256_si128(int32_1), + _mm256_extracti128_si256(int32_1, 1)); + + _mm_storeu_si128((__m128i*)&data_ptr[i], packed0); + _mm_storeu_si128((__m128i*)&data_ptr[i + 8], packed1); + } + + for(; i < dimension; ++i) { + float scaled = input[i] * inv_scale; + data_ptr[i] = static_cast(std::round(scaled)); + } + + float* scale_ptr = + reinterpret_cast(buffer.data() + (dimension * sizeof(int16_t))); + *scale_ptr = scale; + + return buffer; + } +#endif + #if defined(USE_NEON) // NEON optimized quantization FP32 -> INT16 buffer inline std::vector @@ -317,6 +373,8 @@ namespace ndd { quantize_vector_fp32_to_int16_buffer_auto(const std::vector& input) { #if defined(USE_AVX512) return quantize_vector_fp32_to_int16_buffer_avx512(input); +#elif defined(USE_AVX2) + return quantize_vector_fp32_to_int16_buffer_avx2(input); #elif defined(USE_SVE2) return quantize_vector_fp32_to_int16_buffer_sve(input); #elif defined(USE_NEON) @@ -434,6 +492,44 @@ namespace ndd { } #endif +#if defined(USE_AVX2) + inline std::vector dequantize_int16_buffer_to_fp32_avx2(const uint8_t* buffer, + size_t dimension) { + std::vector output(dimension); + const int16_t* data_ptr = reinterpret_cast(buffer); + float scale = extract_scale(buffer, dimension); + + const __m256 scale_vec = _mm256_set1_ps(scale); + + size_t i = 0; + size_t vec_size = (dimension / 16) * 16; + + for(; i < vec_size; i += 16) { + __m256i int16_vec = _mm256_loadu_si256((const __m256i*)&data_ptr[i]); + + __m256i int32_lo = + _mm256_cvtepi16_epi32(_mm256_castsi256_si128(int16_vec)); + __m256i int32_hi = + _mm256_cvtepi16_epi32(_mm256_extracti128_si256(int16_vec, 1)); + + __m256 float_lo = _mm256_cvtepi32_ps(int32_lo); + __m256 float_hi = _mm256_cvtepi32_ps(int32_hi); + + float_lo = _mm256_mul_ps(float_lo, scale_vec); + float_hi = _mm256_mul_ps(float_hi, scale_vec); + + _mm256_storeu_ps(&output[i], float_lo); + _mm256_storeu_ps(&output[i + 8], float_hi); + } + + for(; i < dimension; ++i) { + output[i] = static_cast(data_ptr[i]) * scale; + } + + return output; + } +#endif + #if defined(USE_SVE2) inline std::vector dequantize_int16_buffer_to_fp32_sve(const uint8_t* buffer, size_t dimension) { @@ -470,6 +566,8 @@ namespace ndd { size_t dimension) { #if defined(USE_AVX512) return dequantize_int16_buffer_to_fp32_avx512(buffer, dimension); +#elif defined(USE_AVX2) + return dequantize_int16_buffer_to_fp32_avx2(buffer, dimension); #elif defined(USE_SVE2) return dequantize_int16_buffer_to_fp32_sve(buffer, dimension); #elif defined(USE_NEON) @@ -1131,22 +1229,21 @@ namespace ndd { __m256i dot_vec_hi = _mm256_setzero_si256(); __m256i sq_vec_lo = _mm256_setzero_si256(); __m256i sq_vec_hi = _mm256_setzero_si256(); - for(; d + 8 <= block_len; d += 8) { - __m128i q_i16 = _mm_loadu_si128( - reinterpret_cast(query_vec + block_start + d)); - __m128i v_i16 = _mm_loadu_si128( - reinterpret_cast(vec + block_start + d)); - - __m256i q_i32 = _mm256_cvtepi16_epi32(q_i16); - __m256i v_i32 = _mm256_cvtepi16_epi32(v_i16); - __m256i dot_i32 = _mm256_mullo_epi32(q_i32, v_i32); - __m256i dot_i64_lo = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(dot_i32)); - __m256i dot_i64_hi = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(dot_i32, 1)); + for(; d + 16 <= block_len; d += 16) { + __m256i q_i16 = _mm256_loadu_si256( + reinterpret_cast(query_vec + block_start + d)); + __m256i v_i16 = _mm256_loadu_si256( + reinterpret_cast(vec + block_start + d)); + + __m256i dot_i32 = _mm256_madd_epi16(q_i16, v_i16); + __m256i dot_i64_lo = + _mm256_cvtepi32_epi64(_mm256_castsi256_si128(dot_i32)); + __m256i dot_i64_hi = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(dot_i32, 1)); dot_vec_lo = _mm256_add_epi64(dot_vec_lo, dot_i64_lo); dot_vec_hi = _mm256_add_epi64(dot_vec_hi, dot_i64_hi); if(l2_metric) { - __m256i sq_i32 = _mm256_mullo_epi32(v_i32, v_i32); + __m256i sq_i32 = _mm256_madd_epi16(v_i16, v_i16); __m256i sq_i64_lo = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sq_i32)); __m256i sq_i64_hi =