Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
0e9f61c
covnert max_topk to a runtime parameter
seunghwak Nov 4, 2025
444b946
convert max_candidates to a runtime parameter
seunghwak Nov 5, 2025
54a1805
convert max_elements to a runtime parameter
seunghwak Nov 5, 2025
6fc87a2
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 5, 2025
25a0b8d
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 6, 2025
a13112f
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 8, 2025
aa5c116
tighter bound on array size
seunghwak Nov 11, 2025
178459a
undo most of the changes in bitonic.hpp except for the less unrolling…
seunghwak Nov 12, 2025
6aae61d
remove __inline__ from topk_by_radix (to lower register pressure)
seunghwak Nov 12, 2025
8aaf11b
branch outside topk_by_bitonic_sort
seunghwak Nov 12, 2025
0735129
use shared memory when I need to create large stack arrays for bitoni…
seunghwak Nov 12, 2025
3c9ee5f
branch before calling topky_by_bitonic_sort_and_merge
seunghwak Nov 13, 2025
68b191b
undo changes in topk_cta_11_core (branch in the caller site)
seunghwak Nov 13, 2025
89d9698
fix build error
seunghwak Nov 13, 2025
a9bfab7
update max_itopk setting in single-CTA radix sort based search to mat…
seunghwak Nov 14, 2025
a0be265
create non-template wrapper functions to prevent high register pressu…
seunghwak Nov 14, 2025
52141c3
remove unnecessary include statements
seunghwak Nov 15, 2025
15bdc84
use smem to reduce register pressure
seunghwak Nov 17, 2025
c41d548
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 17, 2025
2a8eac0
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 17, 2025
5a117a5
fix build error after pulling new updates
seunghwak Nov 17, 2025
4cd6f8f
undo using shared memory to store key, value pairs
seunghwak Nov 18, 2025
c8a9673
delete dead code
seunghwak Nov 18, 2025
77d657c
fix an error
seunghwak Nov 18, 2025
efcf489
tweak register pressure
seunghwak Nov 18, 2025
55b80f4
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 18, 2025
bde5ac5
copyright year
seunghwak Nov 18, 2025
49c1c64
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 19, 2025
901bad6
undo conditional unrolling in bitonic.hpp (this significantly slows d…
seunghwak Nov 23, 2025
4f7ef48
final performance tweak
seunghwak Nov 24, 2025
a596c9f
Merge branch 'main' of https://github.com/rapidsai/cuvs into enh_sear…
seunghwak Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 53 additions & 134 deletions cpp/src/neighbors/detail/cagra/bitonic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,139 +41,52 @@ RAFT_DEVICE_INLINE_FUNCTION void swap_if_needed(K& k0,
}
}

template <class K, class V, unsigned N, unsigned warp_size = 32>
struct warp_merge_core {
RAFT_DEVICE_INLINE_FUNCTION void operator()(K k[N],
V v[N],
const std::uint32_t range,
const bool asc)
template <class K, class V, unsigned warp_size = 32>
struct warp_merge_core_n {
RAFT_DEVICE_INLINE_FUNCTION void operator()(
K* ks, V* vs, unsigned n, const std::uint32_t range, const bool asc)
{
const auto lane_id = threadIdx.x % warp_size;

if (range == 1) {
for (std::uint32_t b = 2; b <= N; b <<= 1) {
for (std::uint32_t b = 2; b <= n; b <<= 1) {
for (std::uint32_t c = b / 2; c >= 1; c >>= 1) {
#pragma unroll
for (std::uint32_t i = 0; i < N; i++) {
for (std::uint32_t i = 0; i < n; i++) {
std::uint32_t j = i ^ c;
if (i >= j) continue;
const auto line_id = i + (N * lane_id);
const auto line_id = i + (n * lane_id);
const auto p = static_cast<bool>(line_id & b) == static_cast<bool>(line_id & c);
swap_if_needed(k[i], v[i], k[j], v[j], p);
swap_if_needed(ks[i], vs[i], ks[j], vs[j], p);
}
}
}
return;
}

const std::uint32_t b = range;
for (std::uint32_t c = b / 2; c >= 1; c >>= 1) {
const auto p = static_cast<bool>(lane_id & b) == static_cast<bool>(lane_id & c);
#pragma unroll
for (std::uint32_t i = 0; i < N; i++) {
swap_if_needed(k[i], v[i], c, p);
}
}
const auto p = ((lane_id & b) == 0);
for (std::uint32_t c = N / 2; c >= 1; c >>= 1) {
} else {
const std::uint32_t b = range;
for (std::uint32_t c = b / 2; c >= 1; c >>= 1) {
const auto p = static_cast<bool>(lane_id & b) == static_cast<bool>(lane_id & c);
#pragma unroll
for (std::uint32_t i = 0; i < N; i++) {
std::uint32_t j = i ^ c;
if (i >= j) continue;
swap_if_needed(k[i], v[i], k[j], v[j], p);
}
}
}
};

template <class K, class V, unsigned warp_size>
struct warp_merge_core<K, V, 6, warp_size> {
RAFT_DEVICE_INLINE_FUNCTION void operator()(K k[6],
V v[6],
const std::uint32_t range,
const bool asc)
{
constexpr unsigned N = 6;
const auto lane_id = threadIdx.x % warp_size;

if (range == 1) {
for (std::uint32_t i = 0; i < N; i += 3) {
const auto p = (i == 0);
swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p);
swap_if_needed(k[1 + i], v[1 + i], k[2 + i], v[2 + i], p);
swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p);
}
const auto p = ((lane_id & 1) == 0);
for (std::uint32_t i = 0; i < 3; i++) {
std::uint32_t j = i + 3;
swap_if_needed(k[i], v[i], k[j], v[j], p);
}
for (std::uint32_t i = 0; i < N; i += 3) {
swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p);
swap_if_needed(k[1 + i], v[1 + i], k[2 + i], v[2 + i], p);
swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p);
}
return;
}

const std::uint32_t b = range;
for (std::uint32_t c = b / 2; c >= 1; c >>= 1) {
const auto p = static_cast<bool>(lane_id & b) == static_cast<bool>(lane_id & c);
#pragma unroll
for (std::uint32_t i = 0; i < N; i++) {
swap_if_needed(k[i], v[i], c, p);
for (std::uint32_t i = 0; i < n; i++) {
swap_if_needed(ks[i], vs[i], c, p);
}
Copy link
Contributor

@achirkin achirkin Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change effectively removes all of the loop unrolling, because n is not known at compile time (you can safely remove #pragra unroll as it does nothing now btw). In particular, this means the input arrays k and v cannot be passed and accessed via registers. This will likely have a huge impact on performance.
Please run a few benchmarks using ANN_BENCH first to see the impact on the throughput. From there, we can decide whether (a) performance is acceptable, (b) we need to profile the the kernel using NCU and try to improve performance, or (c) the perf state is hopeless and cannot be recovered without manual loop unrolling / restoring the template parameter.

For the benchmarks, I'd suggest the following parameter sweep:

./build.sh -n libcuvs bench-ann --limit-bench-ann=CUVS_CAGRA_ANN_BENCH
./cpp/build/bench/ann/CUVS_CAGRA_ANN_BENCH \
  --search \
  --benchmark_min_time=10s \
  --benchmark_min_warmup_time=0.001 \
  --benchmark_counters_tabular=true \
  --benchmark_out=cagra-search-`git rev-parse --abbrev-ref HEAD`.csv \
  --benchmark_out_format=csv \
  --data_prefix=<data folder> \
  --index_prefix=<index folder> \
  --override_kv=algo:\"single_cta\" \
  --override_kv=k:10:100 \
  --override_kv=itopk:32:64:128:256:512 \
  --override_kv=max_iterations:20 \
  --override_kv=n_queries:10000 \
  <config file>

}
}
const auto p = ((lane_id & b) == 0);
for (std::uint32_t i = 0; i < 3; i++) {
std::uint32_t j = i + 3;
swap_if_needed(k[i], v[i], k[j], v[j], p);
}
for (std::uint32_t i = 0; i < N; i += N / 2) {
swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p);
swap_if_needed(k[1 + i], v[1 + i], k[2 + i], v[2 + i], p);
swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p);
}
}
};

template <class K, class V, unsigned warp_size>
struct warp_merge_core<K, V, 3, warp_size> {
RAFT_DEVICE_INLINE_FUNCTION void operator()(K k[3],
V v[3],
const std::uint32_t range,
const bool asc)
{
constexpr unsigned N = 3;
const auto lane_id = threadIdx.x % warp_size;

if (range == 1) {
const auto p = ((lane_id & 1) == 0);
swap_if_needed(k[0], v[0], k[1], v[1], p);
swap_if_needed(k[1], v[1], k[2], v[2], p);
swap_if_needed(k[0], v[0], k[1], v[1], p);
return;
}

const std::uint32_t b = range;
for (std::uint32_t c = b / 2; c >= 1; c >>= 1) {
const auto p = static_cast<bool>(lane_id & b) == static_cast<bool>(lane_id & c);
const auto p = ((lane_id & b) == 0);
for (std::uint32_t c = n / 2; c >= 1; c >>= 1) {
#pragma unroll
for (std::uint32_t i = 0; i < N; i++) {
swap_if_needed(k[i], v[i], c, p);
for (std::uint32_t i = 0; i < n; i++) {
std::uint32_t j = i ^ c;
if (i >= j) continue;
swap_if_needed(ks[i], vs[i], ks[j], vs[j], p);
}
}
}
const auto p = ((lane_id & b) == 0);
swap_if_needed(k[0], v[0], k[1], v[1], p);
swap_if_needed(k[1], v[1], k[2], v[2], p);
swap_if_needed(k[0], v[0], k[1], v[1], p);
}
};

template <class K, class V, unsigned warp_size>
struct warp_merge_core<K, V, 2, warp_size> {
RAFT_DEVICE_INLINE_FUNCTION void operator()(K k[2],
V v[2],
struct warp_merge_core_2 {
RAFT_DEVICE_INLINE_FUNCTION void operator()(K* ks,
V* vs,
const std::uint32_t range,
const bool asc)
{
Expand All @@ -182,53 +95,59 @@ struct warp_merge_core<K, V, 2, warp_size> {

if (range == 1) {
const auto p = ((lane_id & 1) == 0);
swap_if_needed(k[0], v[0], k[1], v[1], p);
return;
}

const std::uint32_t b = range;
for (std::uint32_t c = b / 2; c >= 1; c >>= 1) {
const auto p = static_cast<bool>(lane_id & b) == static_cast<bool>(lane_id & c);
swap_if_needed(ks[0], vs[0], ks[1], vs[1], p);
} else {
const std::uint32_t b = range;
for (std::uint32_t c = b / 2; c >= 1; c >>= 1) {
const auto p = static_cast<bool>(lane_id & b) == static_cast<bool>(lane_id & c);
#pragma unroll
for (std::uint32_t i = 0; i < N; i++) {
swap_if_needed(k[i], v[i], c, p);
for (std::uint32_t i = 0; i < N; i++) {
swap_if_needed(ks[i], vs[i], c, p);
}
}
const auto p = ((lane_id & b) == 0);
swap_if_needed(ks[0], vs[0], ks[1], vs[1], p);
}
const auto p = ((lane_id & b) == 0);
swap_if_needed(k[0], v[0], k[1], v[1], p);
}
};

template <class K, class V, unsigned warp_size>
struct warp_merge_core<K, V, 1, warp_size> {
RAFT_DEVICE_INLINE_FUNCTION void operator()(K k[1],
V v[1],
struct warp_merge_core_1 {
RAFT_DEVICE_INLINE_FUNCTION void operator()(K* ks,
V* vs,
const std::uint32_t range,
const bool asc)
{
const auto lane_id = threadIdx.x % warp_size;
const std::uint32_t b = range;
for (std::uint32_t c = b / 2; c >= 1; c >>= 1) {
const auto p = static_cast<bool>(lane_id & b) == static_cast<bool>(lane_id & c);
swap_if_needed(k[0], v[0], c, p);
swap_if_needed(ks[0], vs[0], c, p);
}
}
};

} // namespace detail

template <class K, class V, unsigned N, unsigned warp_size = 32>
RAFT_DEVICE_INLINE_FUNCTION void warp_merge(K k[N], V v[N], unsigned range, const bool asc = true)
template <class K, class V, unsigned warp_size = 32>
RAFT_DEVICE_INLINE_FUNCTION void warp_merge(
K* ks, V* vs, unsigned n, unsigned range, const bool asc = true)
{
detail::warp_merge_core<K, V, N, warp_size>{}(k, v, range, asc);
if (n == 1) {
detail::warp_merge_core_1<K, V, warp_size>{}(ks, vs, range, asc);
} else if (n == 2) {
detail::warp_merge_core_2<K, V, warp_size>{}(ks, vs, range, asc);
} else {
detail::warp_merge_core_n<K, V, warp_size>{}(ks, vs, n, range, asc);
}
}

template <class K, class V, unsigned N, unsigned warp_size = 32>
RAFT_DEVICE_INLINE_FUNCTION void warp_sort(K k[N], V v[N], const bool asc = true)
template <class K, class V, unsigned warp_size = 32>
RAFT_DEVICE_INLINE_FUNCTION void warp_sort(K* ks, V* vs, unsigned n, const bool asc = true)
{
#pragma unroll
#pragma unroll 1
for (std::uint32_t range = 1; range <= warp_size; range <<= 1) {
warp_merge<K, V, N, warp_size>(k, v, range, asc);
warp_merge<K, V, warp_size>(ks, vs, n, range, asc);
}
}

Expand Down
57 changes: 34 additions & 23 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,25 @@ RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parent(
}
}

template <unsigned MAX_ELEMENTS, class INDEX_T>
template <class INDEX_T>
RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num_elements]
INDEX_T* indices, // [num_elements]
const uint32_t max_elements,
const uint32_t num_elements)
{
const unsigned warp_id = threadIdx.x / 32;
const unsigned warp_id = threadIdx.x / raft::warp_size();
if (warp_id > 0) { return; }
const unsigned lane_id = threadIdx.x % 32;
constexpr unsigned N = (MAX_ELEMENTS + 31) / 32;
float key[N];
INDEX_T val[N];
const unsigned lane_id = threadIdx.x % raft::warp_size();
assert(max_elements <= 256);
constexpr unsigned MAX_N =
8; // if MAX_N >> N, we may have negative performance impact, if this is significant, we may
// get memory space from dynamically sized shared memory
const unsigned N = (max_elements + (raft::warp_size() - 1)) / raft::warp_size();
assert(N <= MAX_N);
float key[MAX_N];
INDEX_T val[MAX_N];
for (unsigned i = 0; i < N; i++) {
unsigned j = lane_id + (32 * i);
unsigned j = lane_id + (raft::warp_size() * i);
if (j < num_elements) {
key[i] = distances[j];
val[i] = indices[j];
Expand All @@ -131,7 +137,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num
}
}
/* Warp Sort */
bitonic::warp_sort<float, INDEX_T, N>(key, val);
bitonic::warp_sort<float, INDEX_T>(key, val, N);
/* Store sorted results */
for (unsigned i = 0; i < N; i++) {
unsigned j = (N * lane_id) + i;
Expand All @@ -145,10 +151,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num
//
// multiple CTAs per single query
//
template <std::uint32_t MAX_ELEMENTS,
class DATASET_DESCRIPTOR_T,
class SourceIndexT,
class SAMPLE_FILTER_T>
template <class DATASET_DESCRIPTOR_T, class SourceIndexT, class SAMPLE_FILTER_T>
RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
typename DATASET_DESCRIPTOR_T::INDEX_T* const
result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size]
Expand All @@ -157,6 +160,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
const DATASET_DESCRIPTOR_T* dataset_desc,
const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim]
const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree]
const uint32_t max_elements,
const uint32_t graph_degree,
const SourceIndexT* source_indices_ptr, // [num_queries, search_width]
const unsigned num_distilation,
Expand Down Expand Up @@ -211,7 +215,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
// |<--- result_buffer_size_32 --->|
const auto result_buffer_size = itopk_size + graph_degree;
const auto result_buffer_size_32 = raft::round_up_safe<uint32_t>(result_buffer_size, 32);
assert(result_buffer_size_32 <= MAX_ELEMENTS);
assert(result_buffer_size_32 <= max_elements);

// Set smem working buffer for the distance calculation
dataset_desc = dataset_desc->setup_workspace(smem, queries_ptr, query_id);
Expand Down Expand Up @@ -268,8 +272,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
_CLK_START();
if (threadIdx.x < 32) {
// [1st warp] Topk with bitonic sort
topk_by_bitonic_sort<MAX_ELEMENTS, INDEX_T>(
result_distances_buffer, result_indices_buffer, result_buffer_size_32);
topk_by_bitonic_sort<INDEX_T>(
result_distances_buffer, result_indices_buffer, max_elements, result_buffer_size_32);
}
__syncthreads();
_CLK_REC(clk_topk);
Expand Down Expand Up @@ -487,17 +491,12 @@ struct search_kernel_config {
// Search kernel function type. Note that the actual values for the template value
// parameters do not matter, because they are not part of the function signature. The
// second to fourth value parameters will be selected by the choose_* functions below.
using kernel_t =
decltype(&search_kernel<128, DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>);
using kernel_t = decltype(&search_kernel<DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>);

static auto choose_buffer_size(unsigned result_buffer_size, unsigned block_size) -> kernel_t
{
if (result_buffer_size <= 64) {
return search_kernel<64, DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>;
} else if (result_buffer_size <= 128) {
return search_kernel<128, DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>;
} else if (result_buffer_size <= 256) {
return search_kernel<256, DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>;
if (result_buffer_size <= 256) {
return search_kernel<DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>;
}
THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256);
}
Expand Down Expand Up @@ -536,6 +535,17 @@ void select_and_run(const dataset_descriptor_host<DataT, IndexT, DistanceT>& dat
SourceIndexT,
SampleFilterT>::choose_buffer_size(result_buffer_size, block_size);

uint32_t max_elements{};
if (result_buffer_size <= 64) {
max_elements = 64;
} else if (result_buffer_size <= 128) {
max_elements = 128;
} else if (result_buffer_size <= 256) {
max_elements = 256;
} else {
THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256);
}

RAFT_CUDA_TRY(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Initialize hash table
Expand All @@ -560,6 +570,7 @@ void select_and_run(const dataset_descriptor_host<DataT, IndexT, DistanceT>& dat
dataset_desc.dev_ptr(stream),
queries_ptr,
graph.data_handle(),
max_elements,
graph.extent(1),
source_indices_ptr,
ps.num_random_samplings,
Expand Down
Loading