Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 110 additions & 13 deletions src/quant/int16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,62 @@ namespace ndd {
}
#endif

#if defined(USE_AVX2)
inline std::vector<uint8_t>
quantize_vector_fp32_to_int16_buffer_avx2(const std::vector<float>& input) {
if(input.empty()) {
return std::vector<uint8_t>();
}

size_t dimension = input.size();
size_t buffer_size = get_storage_size(dimension);
std::vector<uint8_t> buffer(buffer_size);

float abs_max = math::find_abs_max(input.data(), dimension);
if(abs_max == 0.0f) {
abs_max = 1.0f;
}
float scale = abs_max / INT16_SCALE;
float inv_scale = 1.0f / scale;

int16_t* data_ptr = reinterpret_cast<int16_t*>(buffer.data());
const __m256 scale_vec = _mm256_set1_ps(inv_scale);

size_t i = 0;
size_t vec_size = (dimension / 16) * 16;

for(; i < vec_size; i += 16) {
__m256 vec0 = _mm256_loadu_ps(&input[i]);
__m256 vec1 = _mm256_loadu_ps(&input[i + 8]);

vec0 = _mm256_mul_ps(vec0, scale_vec);
vec1 = _mm256_mul_ps(vec1, scale_vec);

__m256i int32_0 = _mm256_cvtps_epi32(vec0);
__m256i int32_1 = _mm256_cvtps_epi32(vec1);

__m128i packed0 = _mm_packs_epi32(_mm256_castsi256_si128(int32_0),
_mm256_extracti128_si256(int32_0, 1));
__m128i packed1 = _mm_packs_epi32(_mm256_castsi256_si128(int32_1),
_mm256_extracti128_si256(int32_1, 1));

_mm_storeu_si128((__m128i*)&data_ptr[i], packed0);
_mm_storeu_si128((__m128i*)&data_ptr[i + 8], packed1);
}

for(; i < dimension; ++i) {
float scaled = input[i] * inv_scale;
data_ptr[i] = static_cast<int16_t>(std::round(scaled));
}

float* scale_ptr =
reinterpret_cast<float*>(buffer.data() + (dimension * sizeof(int16_t)));
*scale_ptr = scale;

return buffer;
}
#endif

#if defined(USE_NEON)
// NEON optimized quantization FP32 -> INT16 buffer
inline std::vector<uint8_t>
Expand Down Expand Up @@ -317,6 +373,8 @@ namespace ndd {
quantize_vector_fp32_to_int16_buffer_auto(const std::vector<float>& input) {
#if defined(USE_AVX512)
return quantize_vector_fp32_to_int16_buffer_avx512(input);
#elif defined(USE_AVX2)
return quantize_vector_fp32_to_int16_buffer_avx2(input);
#elif defined(USE_SVE2)
return quantize_vector_fp32_to_int16_buffer_sve(input);
#elif defined(USE_NEON)
Expand Down Expand Up @@ -434,6 +492,44 @@ namespace ndd {
}
#endif

#if defined(USE_AVX2)
inline std::vector<float> dequantize_int16_buffer_to_fp32_avx2(const uint8_t* buffer,
size_t dimension) {
std::vector<float> output(dimension);
const int16_t* data_ptr = reinterpret_cast<const int16_t*>(buffer);
float scale = extract_scale(buffer, dimension);

const __m256 scale_vec = _mm256_set1_ps(scale);

size_t i = 0;
size_t vec_size = (dimension / 16) * 16;

for(; i < vec_size; i += 16) {
__m256i int16_vec = _mm256_loadu_si256((const __m256i*)&data_ptr[i]);

__m256i int32_lo =
_mm256_cvtepi16_epi32(_mm256_castsi256_si128(int16_vec));
__m256i int32_hi =
_mm256_cvtepi16_epi32(_mm256_extracti128_si256(int16_vec, 1));

__m256 float_lo = _mm256_cvtepi32_ps(int32_lo);
__m256 float_hi = _mm256_cvtepi32_ps(int32_hi);

float_lo = _mm256_mul_ps(float_lo, scale_vec);
float_hi = _mm256_mul_ps(float_hi, scale_vec);

_mm256_storeu_ps(&output[i], float_lo);
_mm256_storeu_ps(&output[i + 8], float_hi);
}

for(; i < dimension; ++i) {
output[i] = static_cast<float>(data_ptr[i]) * scale;
}

return output;
}
#endif

#if defined(USE_SVE2)
inline std::vector<float> dequantize_int16_buffer_to_fp32_sve(const uint8_t* buffer,
size_t dimension) {
Expand Down Expand Up @@ -470,6 +566,8 @@ namespace ndd {
size_t dimension) {
#if defined(USE_AVX512)
return dequantize_int16_buffer_to_fp32_avx512(buffer, dimension);
#elif defined(USE_AVX2)
return dequantize_int16_buffer_to_fp32_avx2(buffer, dimension);
#elif defined(USE_SVE2)
return dequantize_int16_buffer_to_fp32_sve(buffer, dimension);
#elif defined(USE_NEON)
Expand Down Expand Up @@ -1131,22 +1229,21 @@ namespace ndd {
__m256i dot_vec_hi = _mm256_setzero_si256();
__m256i sq_vec_lo = _mm256_setzero_si256();
__m256i sq_vec_hi = _mm256_setzero_si256();
for(; d + 8 <= block_len; d += 8) {
__m128i q_i16 = _mm_loadu_si128(
reinterpret_cast<const __m128i*>(query_vec + block_start + d));
__m128i v_i16 = _mm_loadu_si128(
reinterpret_cast<const __m128i*>(vec + block_start + d));

__m256i q_i32 = _mm256_cvtepi16_epi32(q_i16);
__m256i v_i32 = _mm256_cvtepi16_epi32(v_i16);
__m256i dot_i32 = _mm256_mullo_epi32(q_i32, v_i32);
__m256i dot_i64_lo = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(dot_i32));
__m256i dot_i64_hi =
_mm256_cvtepi32_epi64(_mm256_extracti128_si256(dot_i32, 1));
for(; d + 16 <= block_len; d += 16) {
__m256i q_i16 = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(query_vec + block_start + d));
__m256i v_i16 = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(vec + block_start + d));

__m256i dot_i32 = _mm256_madd_epi16(q_i16, v_i16);
__m256i dot_i64_lo =
_mm256_cvtepi32_epi64(_mm256_castsi256_si128(dot_i32));
__m256i dot_i64_hi =
_mm256_cvtepi32_epi64(_mm256_extracti128_si256(dot_i32, 1));
dot_vec_lo = _mm256_add_epi64(dot_vec_lo, dot_i64_lo);
dot_vec_hi = _mm256_add_epi64(dot_vec_hi, dot_i64_hi);
if(l2_metric) {
__m256i sq_i32 = _mm256_mullo_epi32(v_i32, v_i32);
__m256i sq_i32 = _mm256_madd_epi16(v_i16, v_i16);
__m256i sq_i64_lo =
_mm256_cvtepi32_epi64(_mm256_castsi256_si128(sq_i32));
__m256i sq_i64_hi =
Expand Down
Loading