Skip to content

Commit 0ed65a3

Browse files
committed
[CPP Runtime] Add get_distance and reconstruct_at to VamanaIndex API
Add get_distance() and reconstruct_at() methods to the runtime library for OpenSearch integration. These expose existing orchestrator-layer functionality through the shared library ABI. - get_distance: computes distance between a stored vector and a query - reconstruct_at: decompresses/reconstructs vectors to float32 by ID - Works with all storage kinds (FP32, FP16, SQI8, LVQ, LeanVec) - Added to both VamanaIndex (static) and DynamicVamanaIndex - Includes tests for both index types
1 parent c563257 commit 0ed65a3

File tree

6 files changed

+201
-0
lines changed

6 files changed

+201
-0
lines changed

bindings/cpp/include/svs/runtime/vamana_index.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ struct SVS_RUNTIME_API VamanaIndex {
7676
IDFilter* filter = nullptr
7777
) const noexcept = 0;
7878

79+
// Compute distance between stored vector `id` and `query` (dim floats).
80+
virtual Status get_distance(double* distance, size_t id, const float* query)
81+
const noexcept = 0;
82+
83+
// Reconstruct `n` vectors by ID into `output` buffer (n * dim floats).
84+
virtual Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept = 0;
85+
7986
// Utility function to check storage kind support
8087
static Status check_storage_kind(StorageKind storage_kind) noexcept;
8188

bindings/cpp/src/dynamic_vamana_index.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,17 @@ struct DynamicVamanaIndexManagerBase : public DynamicVamanaIndex {
118118
Status save(std::ostream& out) const noexcept override {
119119
return runtime_error_wrapper([&] { impl_->save(out); });
120120
}
121+
122+
Status get_distance(double* distance, size_t id, const float* query)
123+
const noexcept override {
124+
return runtime_error_wrapper([&] {
125+
*distance = impl_->get_distance(id, query);
126+
});
127+
}
128+
129+
Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept override {
130+
return runtime_error_wrapper([&] { impl_->reconstruct_at(n, ids, output); });
131+
}
121132
};
122133
} // namespace
123134

bindings/cpp/src/dynamic_vamana_index_impl.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,23 @@ class DynamicVamanaIndexImpl {
305305
return remove(ids_to_delete);
306306
}
307307

308+
double get_distance(size_t id, const float* query) const {
309+
if (!impl_) {
310+
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
311+
}
312+
auto query_span = std::span<const float>(query, dim_);
313+
return impl_->get_distance(id, query_span);
314+
}
315+
316+
void reconstruct_at(size_t n, const size_t* ids, float* output) {
317+
if (!impl_) {
318+
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
319+
}
320+
svs::data::SimpleDataView<float> dst{output, n, dim_};
321+
std::span<const uint64_t> id_span{reinterpret_cast<const uint64_t*>(ids), n};
322+
impl_->reconstruct_at(dst, id_span);
323+
}
324+
308325
void reset() {
309326
impl_.reset();
310327
ntotal_soft_deleted = 0;

bindings/cpp/src/vamana_index.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,17 @@ struct VamanaIndexManagerBase : public VamanaIndex {
8888
Status save(std::ostream& out) const noexcept override {
8989
return runtime_error_wrapper([&] { impl_->save(out); });
9090
}
91+
92+
Status get_distance(double* distance, size_t id, const float* query)
93+
const noexcept override {
94+
return runtime_error_wrapper([&] {
95+
*distance = impl_->get_distance(id, query);
96+
});
97+
}
98+
99+
Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept override {
100+
return runtime_error_wrapper([&] { impl_->reconstruct_at(n, ids, output); });
101+
}
91102
};
92103
} // namespace
93104

bindings/cpp/src/vamana_index_impl.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,23 @@ class VamanaIndexImpl {
269269
}
270270
}
271271

272+
double get_distance(size_t id, const float* query) const {
273+
if (!impl_) {
274+
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
275+
}
276+
auto query_span = std::span<const float>(query, dim_);
277+
return get_impl()->get_distance(id, query_span);
278+
}
279+
280+
void reconstruct_at(size_t n, const size_t* ids, float* output) {
281+
if (!impl_) {
282+
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
283+
}
284+
svs::data::SimpleDataView<float> dst{output, n, dim_};
285+
std::span<const uint64_t> id_span{reinterpret_cast<const uint64_t*>(ids), n};
286+
get_impl()->reconstruct_at(dst, id_span);
287+
}
288+
272289
void reset() { impl_.reset(); }
273290

274291
void save(std::ostream& out) const {

bindings/cpp/tests/runtime_test.cpp

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,3 +881,141 @@ CATCH_TEST_CASE("RangeSearchFunctionalStatic", "[runtime][static_vamana]") {
881881

882882
svs::runtime::v0::VamanaIndex::destroy(index);
883883
}
884+
885+
CATCH_TEST_CASE("GetDistanceDynamic", "[runtime]") {
886+
const auto& test_data = get_test_data();
887+
svs::runtime::v0::DynamicVamanaIndex* index = nullptr;
888+
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
889+
auto status = svs::runtime::v0::DynamicVamanaIndex::build(
890+
&index,
891+
test_d,
892+
svs::runtime::v0::MetricType::L2,
893+
svs::runtime::v0::StorageKind::FP32,
894+
build_params
895+
);
896+
CATCH_REQUIRE(status.ok());
897+
898+
std::vector<size_t> labels(test_n);
899+
std::iota(labels.begin(), labels.end(), 0);
900+
status = index->add(test_n, labels.data(), test_data.data());
901+
CATCH_REQUIRE(status.ok());
902+
903+
// Self-distance should be approximately 0
904+
double dist = -1.0;
905+
const float* vec0 = test_data.data();
906+
status = index->get_distance(&dist, 0, vec0);
907+
CATCH_REQUIRE(status.ok());
908+
CATCH_REQUIRE(dist < 1e-6);
909+
910+
// Distance to a different vector should be positive
911+
const float* vec1 = test_data.data() + test_d;
912+
status = index->get_distance(&dist, 0, vec1);
913+
CATCH_REQUIRE(status.ok());
914+
CATCH_REQUIRE(dist > 0.0);
915+
916+
svs::runtime::v0::DynamicVamanaIndex::destroy(index);
917+
}
918+
919+
CATCH_TEST_CASE("GetDistanceStatic", "[runtime][static_vamana]") {
920+
const auto& test_data = get_test_data();
921+
svs::runtime::v0::VamanaIndex* index = nullptr;
922+
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
923+
auto status = svs::runtime::v0::VamanaIndex::build(
924+
&index,
925+
test_d,
926+
svs::runtime::v0::MetricType::L2,
927+
svs::runtime::v0::StorageKind::FP32,
928+
build_params
929+
);
930+
CATCH_REQUIRE(status.ok());
931+
932+
status = index->add(test_n, test_data.data());
933+
CATCH_REQUIRE(status.ok());
934+
935+
// Self-distance should be approximately 0
936+
double dist = -1.0;
937+
const float* vec0 = test_data.data();
938+
status = index->get_distance(&dist, 0, vec0);
939+
CATCH_REQUIRE(status.ok());
940+
CATCH_REQUIRE(dist < 1e-6);
941+
942+
// Distance to a different vector should be positive
943+
const float* vec1 = test_data.data() + test_d;
944+
status = index->get_distance(&dist, 0, vec1);
945+
CATCH_REQUIRE(status.ok());
946+
CATCH_REQUIRE(dist > 0.0);
947+
948+
svs::runtime::v0::VamanaIndex::destroy(index);
949+
}
950+
951+
CATCH_TEST_CASE("ReconstructAtDynamic", "[runtime]") {
952+
const auto& test_data = get_test_data();
953+
svs::runtime::v0::DynamicVamanaIndex* index = nullptr;
954+
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
955+
auto status = svs::runtime::v0::DynamicVamanaIndex::build(
956+
&index,
957+
test_d,
958+
svs::runtime::v0::MetricType::L2,
959+
svs::runtime::v0::StorageKind::FP32,
960+
build_params
961+
);
962+
CATCH_REQUIRE(status.ok());
963+
964+
std::vector<size_t> labels(test_n);
965+
std::iota(labels.begin(), labels.end(), 0);
966+
status = index->add(test_n, labels.data(), test_data.data());
967+
CATCH_REQUIRE(status.ok());
968+
969+
// Reconstruct first 5 vectors
970+
constexpr size_t nrecon = 5;
971+
std::vector<size_t> ids(nrecon);
972+
std::iota(ids.begin(), ids.end(), 0);
973+
std::vector<float> output(nrecon * test_d, 0.0f);
974+
975+
status = index->reconstruct_at(nrecon, ids.data(), output.data());
976+
CATCH_REQUIRE(status.ok());
977+
978+
// For FP32 storage, reconstructed vectors should match originals exactly
979+
for (size_t i = 0; i < nrecon; ++i) {
980+
for (size_t j = 0; j < test_d; ++j) {
981+
CATCH_REQUIRE(output[i * test_d + j] == test_data[i * test_d + j]);
982+
}
983+
}
984+
985+
svs::runtime::v0::DynamicVamanaIndex::destroy(index);
986+
}
987+
988+
CATCH_TEST_CASE("ReconstructAtStatic", "[runtime][static_vamana]") {
989+
const auto& test_data = get_test_data();
990+
svs::runtime::v0::VamanaIndex* index = nullptr;
991+
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
992+
auto status = svs::runtime::v0::VamanaIndex::build(
993+
&index,
994+
test_d,
995+
svs::runtime::v0::MetricType::L2,
996+
svs::runtime::v0::StorageKind::FP32,
997+
build_params
998+
);
999+
CATCH_REQUIRE(status.ok());
1000+
1001+
status = index->add(test_n, test_data.data());
1002+
CATCH_REQUIRE(status.ok());
1003+
1004+
// Reconstruct first 5 vectors
1005+
constexpr size_t nrecon = 5;
1006+
std::vector<size_t> ids(nrecon);
1007+
std::iota(ids.begin(), ids.end(), 0);
1008+
std::vector<float> output(nrecon * test_d, 0.0f);
1009+
1010+
status = index->reconstruct_at(nrecon, ids.data(), output.data());
1011+
CATCH_REQUIRE(status.ok());
1012+
1013+
// For FP32 storage, reconstructed vectors should match originals exactly
1014+
for (size_t i = 0; i < nrecon; ++i) {
1015+
for (size_t j = 0; j < test_d; ++j) {
1016+
CATCH_REQUIRE(output[i * test_d + j] == test_data[i * test_d + j]);
1017+
}
1018+
}
1019+
1020+
svs::runtime::v0::VamanaIndex::destroy(index);
1021+
}

0 commit comments

Comments
 (0)