Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions c/include/cuvs/neighbors/ivf_pq.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <cuvs/core/c_api.h>
#include <cuvs/distance/distance.h>
#include <cuvs/neighbors/common.h>
#include <dlpack/dlpack.h>
#include <stdbool.h>
#include <stdint.h>
Expand Down Expand Up @@ -371,8 +372,9 @@ cuvsError_t cuvsIvfPqBuild(cuvsResources_t res,
* cuvsError_t params_create_status = cuvsIvfPqSearchParamsCreate(&search_params);
*
* // Search the `index` built using `cuvsIvfPqBuild`
* cuvsFilter filter = {.addr = 0, .type = NO_FILTER};
* cuvsError_t search_status = cuvsIvfPqSearch(res, search_params, index, &queries, &neighbors,
* &distances);
* &distances, filter);
*
* // de-allocate `search_params` and `res`
* cuvsError_t params_destroy_status = cuvsIvfPqSearchParamsDestroy(search_params);
Expand All @@ -385,13 +387,15 @@ cuvsError_t cuvsIvfPqBuild(cuvsResources_t res,
* @param[in] queries DLManagedTensor* queries dataset to search
* @param[out] neighbors DLManagedTensor* output `k` neighbors for queries
* @param[out] distances DLManagedTensor* output `k` distances for queries
* @param[in] filter cuvsFilter filter to apply to the search
*/
cuvsError_t cuvsIvfPqSearch(cuvsResources_t res,
cuvsIvfPqSearchParams_t search_params,
cuvsIvfPqIndex_t index,
DLManagedTensor* queries,
DLManagedTensor* neighbors,
DLManagedTensor* distances);
DLManagedTensor* distances,
cuvsFilter filter);
/**
* @}
*/
Expand Down
32 changes: 18 additions & 14 deletions c/src/neighbors/cagra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,22 +203,26 @@ void _search(cuvsResources_t res,
if (filter.type == NO_FILTER) {
cuvs::neighbors::cagra::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
} else if (filter.type == BITMAP) {
using filter_mdspan_type = raft::device_vector_view<std::uint32_t, int64_t>;
using filter_bmp_type = cuvs::core::bitmap_view<std::uint32_t, int64_t>;
auto filter_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr);
auto filter_mds = cuvs::core::from_dlpack<filter_mdspan_type>(filter_tensor);
const auto bitmap_filter_obj = cuvs::neighbors::filtering::bitmap_filter(
filter_bmp_type((std::uint32_t*)filter_mds.data_handle(), queries_mds.extent(0), index_ptr->size()));
cuvs::neighbors::cagra::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds, bitmap_filter_obj);
} else if (filter.type == BITSET) {
using filter_mdspan_type = raft::device_vector_view<std::uint32_t, int64_t, raft::row_major>;
auto removed_indices_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr);
auto removed_indices = cuvs::core::from_dlpack<filter_mdspan_type>(removed_indices_tensor);
cuvs::core::bitset_view<std::uint32_t, int64_t> removed_indices_bitset(
removed_indices, index_ptr->dataset().extent(0));
auto bitset_filter_obj = cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset);
cuvs::neighbors::cagra::search(*res_ptr,
search_params,
*index_ptr,
queries_mds,
neighbors_mds,
distances_mds,
bitset_filter_obj);
using filter_mdspan_type = raft::device_vector_view<std::uint32_t, int64_t>;
using filter_bst_type = cuvs::core::bitset_view<std::uint32_t, int64_t>;
auto filter_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr);
auto filter_mds = cuvs::core::from_dlpack<filter_mdspan_type>(filter_tensor);
const auto bitset_filter_obj = cuvs::neighbors::filtering::bitset_filter(
filter_bst_type((std::uint32_t*)filter_mds.data_handle(), index_ptr->size()));
cuvs::neighbors::cagra::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds, bitset_filter_obj);
} else {
RAFT_FAIL("Unsupported filter type: BITMAP");
RAFT_FAIL("Unsupported filter type");
}
}

Expand Down
33 changes: 18 additions & 15 deletions c/src/neighbors/ivf_flat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,26 @@ void _search(cuvsResources_t res,
if (filter.type == NO_FILTER) {
cuvs::neighbors::ivf_flat::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
} else if (filter.type == BITMAP) {
using filter_mdspan_type = raft::device_vector_view<std::uint32_t, int64_t>;
using filter_bmp_type = cuvs::core::bitmap_view<std::uint32_t, int64_t>;
auto filter_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr);
auto filter_mds = cuvs::core::from_dlpack<filter_mdspan_type>(filter_tensor);
const auto bitmap_filter_obj = cuvs::neighbors::filtering::bitmap_filter(
filter_bmp_type((std::uint32_t*)filter_mds.data_handle(), queries_mds.extent(0), index_ptr->size()));
cuvs::neighbors::ivf_flat::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds, bitmap_filter_obj);
} else if (filter.type == BITSET) {
using filter_mdspan_type = raft::device_vector_view<std::uint32_t, int64_t, raft::row_major>;
auto removed_indices_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr);
auto removed_indices = cuvs::core::from_dlpack<filter_mdspan_type>(removed_indices_tensor);
cuvs::core::bitset_view<std::uint32_t, int64_t> removed_indices_bitset(removed_indices,
index_ptr->size());
auto bitset_filter_obj = cuvs::neighbors::filtering::bitset_filter(removed_indices_bitset);
cuvs::neighbors::ivf_flat::search(*res_ptr,
search_params,
*index_ptr,
queries_mds,
neighbors_mds,
distances_mds,
bitset_filter_obj);

using filter_mdspan_type = raft::device_vector_view<std::uint32_t, int64_t>;
using filter_bst_type = cuvs::core::bitset_view<std::uint32_t, int64_t>;
auto filter_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr);
auto filter_mds = cuvs::core::from_dlpack<filter_mdspan_type>(filter_tensor);
const auto bitset_filter_obj = cuvs::neighbors::filtering::bitset_filter(
filter_bst_type((std::uint32_t*)filter_mds.data_handle(), index_ptr->size()));
cuvs::neighbors::ivf_flat::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds, bitset_filter_obj);
} else {
RAFT_FAIL("Unsupported filter type: BITMAP");
RAFT_FAIL("Unsupported filter type");
}
}

Expand Down
40 changes: 32 additions & 8 deletions c/src/neighbors/ivf_pq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ void _search(cuvsResources_t res,
cuvsIvfPqIndex index,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
DLManagedTensor* distances_tensor,
cuvsFilter filter)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_pq::index<IdxT>*>(index.addr);
Expand All @@ -97,8 +98,30 @@ void _search(cuvsResources_t res,
auto neighbors_mds = cuvs::core::from_dlpack<neighbors_mdspan_type>(neighbors_tensor);
auto distances_mds = cuvs::core::from_dlpack<distances_mdspan_type>(distances_tensor);

cuvs::neighbors::ivf_pq::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
if (filter.type == NO_FILTER) {
cuvs::neighbors::ivf_pq::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
} else if (filter.type == BITMAP) {
using filter_mdspan_type = raft::device_vector_view<std::uint32_t, int64_t>;
using filter_bmp_type = cuvs::core::bitmap_view<std::uint32_t, int64_t>;
auto filter_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr);
auto filter_mds = cuvs::core::from_dlpack<filter_mdspan_type>(filter_tensor);
const auto bitmap_filter_obj = cuvs::neighbors::filtering::bitmap_filter(
filter_bmp_type((std::uint32_t*)filter_mds.data_handle(), queries_mds.extent(0), index_ptr->size()));
cuvs::neighbors::ivf_pq::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds, bitmap_filter_obj);
} else if (filter.type == BITSET) {
using filter_mdspan_type = raft::device_vector_view<std::uint32_t, int64_t>;
using filter_bst_type = cuvs::core::bitset_view<std::uint32_t, int64_t>;
auto filter_tensor = reinterpret_cast<DLManagedTensor*>(filter.addr);
auto filter_mds = cuvs::core::from_dlpack<filter_mdspan_type>(filter_tensor);
const auto bitset_filter_obj = cuvs::neighbors::filtering::bitset_filter(
filter_bst_type((std::uint32_t*)filter_mds.data_handle(), index_ptr->size()));
cuvs::neighbors::ivf_pq::search(
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds, bitset_filter_obj);
} else {
RAFT_FAIL("Unsupported filter type");
}
}

template <typename IdxT>
Expand Down Expand Up @@ -220,7 +243,8 @@ extern "C" cuvsError_t cuvsIvfPqSearch(cuvsResources_t res,
cuvsIvfPqIndex_t index_c_ptr,
DLManagedTensor* queries_tensor,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
DLManagedTensor* distances_tensor,
cuvsFilter filter)
{
return cuvs::core::translate_exceptions([=] {
auto queries = queries_tensor->dl_tensor;
Expand All @@ -242,16 +266,16 @@ extern "C" cuvsError_t cuvsIvfPqSearch(cuvsResources_t res,
auto index = *index_c_ptr;
if (queries.dtype.code == kDLFloat && queries.dtype.bits == 32) {
_search<float, int64_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else if (queries.dtype.code == kDLFloat && queries.dtype.bits == 16) {
_search<half, int64_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else if (queries.dtype.code == kDLInt && queries.dtype.bits == 8) {
_search<int8_t, int64_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else if (queries.dtype.code == kDLUInt && queries.dtype.bits == 8) {
_search<uint8_t, int64_t>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor, filter);
} else {
RAFT_FAIL("Unsupported queries DLtensor dtype: %d and bits: %d",
queries.dtype.code,
Expand Down
138 changes: 137 additions & 1 deletion c/tests/neighbors/ann_cagra_c.cu
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ TEST(CagraC, BuildExtendSearch)
cuvsResourcesDestroy(res);
}

TEST(CagraC, BuildSearchFiltered)
TEST(CagraC, BuildSearchBitsetFiltered)
{
// create cuvsResources_t
cuvsResources_t res;
Expand Down Expand Up @@ -442,6 +442,142 @@ TEST(CagraC, BuildSearchFiltered)
cuvsResourcesDestroy(res);
}

TEST(CagraC, BuildSearchBitmapFiltered)
{
int64_t n_rows = 100;
int64_t n_queries = 10;
int64_t n_dim = 16;
uint32_t n_neighbors = 4;

raft::handle_t handle;
auto stream = raft::resource::get_cuda_stream(handle);

// Generate data
rmm::device_uvector<float> index_data(n_rows * n_dim, stream);
rmm::device_uvector<float> query_data(n_queries * n_dim, stream);
raft::random::RngState r(1234ULL);
raft::random::uniform(
handle, r, index_data.data(), n_rows * n_dim, float(0.1), float(2.0));
raft::random::uniform(
handle, r, query_data.data(), n_queries * n_dim, float(0.1), float(2.0));

// create cuvsResources_t
cuvsResources_t res;
cuvsResourcesCreate(&res);

// create dataset DLTensor
DLManagedTensor dataset_tensor;
dataset_tensor.dl_tensor.data = index_data.data();
dataset_tensor.dl_tensor.device.device_type = kDLCUDA;
dataset_tensor.dl_tensor.ndim = 2;
dataset_tensor.dl_tensor.dtype.code = kDLFloat;
dataset_tensor.dl_tensor.dtype.bits = 32;
dataset_tensor.dl_tensor.dtype.lanes = 1;
int64_t dataset_shape[2] = {n_rows, n_dim};
dataset_tensor.dl_tensor.shape = dataset_shape;
dataset_tensor.dl_tensor.strides = nullptr;

// create index
cuvsCagraIndex_t index;
cuvsCagraIndexCreate(&index);

// build index
cuvsCagraIndexParams_t build_params;
cuvsCagraIndexParamsCreate(&build_params);
cuvsCagraBuild(res, build_params, &dataset_tensor, index);

// create queries DLTensor
DLManagedTensor queries_tensor;
queries_tensor.dl_tensor.data = query_data.data();
queries_tensor.dl_tensor.device.device_type = kDLCUDA;
queries_tensor.dl_tensor.ndim = 2;
queries_tensor.dl_tensor.dtype.code = kDLFloat;
queries_tensor.dl_tensor.dtype.bits = 32;
queries_tensor.dl_tensor.dtype.lanes = 1;
int64_t queries_shape[2] = {n_queries, n_dim};
queries_tensor.dl_tensor.shape = queries_shape;
queries_tensor.dl_tensor.strides = nullptr;

// create neighbors DLTensor
rmm::device_uvector<uint32_t> neighbors_data(n_queries * n_neighbors, stream);
DLManagedTensor neighbors_tensor;
neighbors_tensor.dl_tensor.data = neighbors_data.data();
neighbors_tensor.dl_tensor.device.device_type = kDLCUDA;
neighbors_tensor.dl_tensor.ndim = 2;
neighbors_tensor.dl_tensor.dtype.code = kDLUInt;
neighbors_tensor.dl_tensor.dtype.bits = 32;
neighbors_tensor.dl_tensor.dtype.lanes = 1;
int64_t neighbors_shape[2] = {n_queries, n_neighbors};
neighbors_tensor.dl_tensor.shape = neighbors_shape;
neighbors_tensor.dl_tensor.strides = nullptr;

// create distances DLTensor
rmm::device_uvector<float> distances_data(n_queries * n_neighbors, stream);
DLManagedTensor distances_tensor;
distances_tensor.dl_tensor.data = distances_data.data();
distances_tensor.dl_tensor.device.device_type = kDLCUDA;
distances_tensor.dl_tensor.ndim = 2;
distances_tensor.dl_tensor.dtype.code = kDLFloat;
distances_tensor.dl_tensor.dtype.bits = 32;
distances_tensor.dl_tensor.dtype.lanes = 1;
int64_t distances_shape[2] = {n_queries, n_neighbors};
distances_tensor.dl_tensor.shape = distances_shape;
distances_tensor.dl_tensor.strides = nullptr;

// Create bitmap filter - per query filter
// For each query, remove even indices
auto bitmap_size = n_queries * ((n_rows + 31) / 32); // n_queries x (bits for n_rows)
rmm::device_uvector<uint32_t> filter_bitmap(bitmap_size, stream);
std::vector<uint32_t> filter_bitmap_h(bitmap_size);
for (size_t q = 0; q < n_queries; ++q) {
for (size_t i = 0; i < (n_rows + 31) / 32; ++i) {
filter_bitmap_h[q * ((n_rows + 31) / 32) + i] =
0xAAAAAAAA; // 10101010... pattern - removes even indices
}
}
raft::copy(filter_bitmap.data(), filter_bitmap_h.data(), bitmap_size, stream);

DLManagedTensor filter_tensor;
filter_tensor.dl_tensor.data = filter_bitmap.data();
filter_tensor.dl_tensor.device.device_type = kDLCUDA;
filter_tensor.dl_tensor.ndim = 1;
filter_tensor.dl_tensor.dtype.code = kDLUInt;
filter_tensor.dl_tensor.dtype.bits = 32;
filter_tensor.dl_tensor.dtype.lanes = 1;
int64_t filter_shape[1] = {bitmap_size};
filter_tensor.dl_tensor.shape = filter_shape;
filter_tensor.dl_tensor.strides = nullptr;

cuvsFilter filter;
filter.type = BITMAP;
filter.addr = (uintptr_t)&filter_tensor;

// search index with bitmap filter
cuvsCagraSearchParams_t search_params;
cuvsCagraSearchParamsCreate(&search_params);
cuvsCagraSearch(
res, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor, filter);

// Verify all returned neighbors are odd indices (not filtered out)
std::vector<uint32_t> neighbors_h(n_queries * n_neighbors);
raft::copy(neighbors_h.data(), neighbors_data.data(), n_queries * n_neighbors, stream);
raft::resource::sync_stream(handle);

for (size_t i = 0; i < n_queries * n_neighbors; ++i) {
// All neighbors should be odd indices (since even indices are filtered)
// Note: uint32_t max value indicates no valid neighbor found
ASSERT_TRUE(neighbors_h[i] % 2 == 1 || neighbors_h[i] == std::numeric_limits<uint32_t>::max())
<< "Neighbor at position " << i << " has value " << neighbors_h[i]
<< " which is an even index (should be filtered)";
}

// de-allocate index and res
cuvsCagraSearchParamsDestroy(search_params);
cuvsCagraIndexParamsDestroy(build_params);
cuvsCagraIndexDestroy(index);
cuvsResourcesDestroy(res);
}

TEST(CagraC, BuildMergeSearch)
{
cuvsResources_t res;
Expand Down
Loading
Loading