Skip to content

Commit 1eb1cbf

Browse files
committed
Enable quantization fallback to non-quantized mode
1 parent 74d23d9 commit 1eb1cbf

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

src/VecSim/algorithms/svs/svs_utils.h

+15-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <cstdint>
1111
#include <cstdlib>
1212
#include <string>
13+
#include <utility>
1314

1415
namespace svs_details {
1516
// VecSim->SVS data type conversion
@@ -149,12 +150,23 @@ inline bool check_cpuid() {
149150
}
150151
// clang-format on
151152

152-
inline bool isSVSLVQModeSupported(VecSimSvsQuantBits quant_bits) {
153-
return quant_bits == VecSimSvsQuant_NONE
153+
// Check if the SVS implementation supports Qquantization mode
154+
// @param quant_bits requested SVS quantization mode
155+
// @return pair<fallbackMode, bool>
156+
inline std::pair<VecSimSvsQuantBits, bool> isSVSQuantBitsSupported(VecSimSvsQuantBits quant_bits) {
157+
// If HAVE_SVS_LVQ is not defined, we don't support any quantization mode
158+
// else we check if the CPU supports SVS LVQ
159+
bool supported = quant_bits == VecSimSvsQuant_NONE
154160
#if HAVE_SVS_LVQ
155-
|| check_cpuid() // Check if the CPU supports SVS LVQ
161+
|| check_cpuid() // Check if the CPU supports SVS LVQ
156162
#endif
157163
;
164+
165+
// If the quantization mode is not supported, we fallback to non-quantized mode
166+
auto fallBack = supported ? quant_bits : VecSimSvsQuant_NONE;
167+
168+
// And always return true, as far as non-quantized mode is always supported
169+
return std::make_pair(fallBack, true);
158170
}
159171
} // namespace svs_details
160172

src/VecSim/index_factories/svs_factory.cpp

+10-5
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@ VecSimIndex *NewIndexImpl(const VecSimParams *params, bool is_normalized) {
3737

3838
template <typename MetricType, typename DataType>
3939
VecSimIndex *NewIndexImpl(const VecSimParams *params, bool is_normalized) {
40-
if (!svs_details::isSVSLVQModeSupported(params->algoParams.svsParams.quantBits)) {
41-
return NULL;
42-
}
40+
// Ignore the 'supported' flag because we always fallback at least to the non-quantized mode
41+
// elsewhere we got code coverage failure for the `supported==false` case
42+
auto quantBits =
43+
std::get<0>(svs_details::isSVSQuantBitsSupported(params->algoParams.svsParams.quantBits));
4344

44-
switch (params->algoParams.svsParams.quantBits) {
45+
switch (quantBits) {
4546
case VecSimSvsQuant_NONE:
4647
return NewIndexImpl<MetricType, DataType, 0>(params, is_normalized);
4748
case VecSimSvsQuant_8:
@@ -97,7 +98,11 @@ constexpr size_t QuantizedVectorSize(size_t dims, size_t alignment = 0) {
9798

9899
template <typename DataType>
99100
size_t QuantizedVectorSize(VecSimSvsQuantBits quant_bits, size_t dims, size_t alignment = 0) {
100-
switch (quant_bits) {
101+
// Ignore the 'supported' flag because we always fallback at least to the non-quantized mode
102+
// elsewhere we got code coverage failure for the `supported==false` case
103+
auto quantBits = std::get<0>(svs_details::isSVSQuantBitsSupported(quant_bits));
104+
105+
switch (quantBits) {
101106
case VecSimSvsQuant_NONE:
102107
return QuantizedVectorSize<DataType, 0>(dims, alignment);
103108
case VecSimSvsQuant_8:

tests/unit/test_svs.cpp

+9-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#define ASSERT_INDEX(index) \
1010
if (index == nullptr) { \
11-
if (svs_details::isSVSLVQModeSupported(TypeParam::get_quant_bits())) { \
11+
if (std::get<1>(svs_details::isSVSQuantBitsSupported(TypeParam::get_quant_bits()))) { \
1212
GTEST_FAIL() << "Failed to create SVS index"; \
1313
} else { \
1414
GTEST_SKIP() << "SVS LVQ is not supported."; \
@@ -698,7 +698,10 @@ TYPED_TEST(SVSTest, resizeIndex) {
698698

699699
// Initial capacity is rounded up to the block size.
700700
size_t extra_cap = n % bs == 0 ? 0 : bs - n % bs;
701-
if constexpr (TypeParam::get_quant_bits() > 0) {
701+
auto quantBits = TypeParam::get_quant_bits();
702+
// Get the fallback quantization mode
703+
quantBits = std::get<0>(svs_details::isSVSQuantBitsSupported(quantBits));
704+
if (quantBits != VecSimSvsQuant_NONE) {
702705
// LVQDataset does not provide a capacity method
703706
extra_cap = 0;
704707
}
@@ -1460,7 +1463,10 @@ TYPED_TEST(SVSTest, testSizeEstimation) {
14601463
// converted then to a number of elements.
14611464
// IMHO, would be better to always interpret block size to a number of elements
14621465
// rather than conversion to-from number of bytes
1463-
if (TypeParam::get_quant_bits() > 0) {
1466+
auto quantBits = TypeParam::get_quant_bits();
1467+
// Get the fallback quantization mode
1468+
quantBits = std::get<0>(svs_details::isSVSQuantBitsSupported(quantBits));
1469+
if (quantBits != VecSimSvsQuant_NONE) {
14641470
// Extra data in LVQ vector
14651471
const auto lvq_vector_extra = sizeof(svs::quantization::lvq::ScalarBundle);
14661472
dim -= (lvq_vector_extra * 8) / TypeParam::get_quant_bits();

0 commit comments

Comments
 (0)