diff --git a/cpp/src/neighbors/detail/cagra/bitonic.hpp b/cpp/src/neighbors/detail/cagra/bitonic.hpp index 66c726c71a..7fd81f0a11 100644 --- a/cpp/src/neighbors/detail/cagra/bitonic.hpp +++ b/cpp/src/neighbors/detail/cagra/bitonic.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 45328377be..53f54a0429 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -9,7 +9,6 @@ #include "factory.cuh" #include "sample_filter_utils.cuh" #include "search_plan.cuh" -#include "search_single_cta_inst.cuh" #include #include diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index 1720841c91..44aa4c9601 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -114,14 +114,14 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num INDEX_T* indices, // [num_elements] const uint32_t num_elements) { - const unsigned warp_id = threadIdx.x / 32; + const unsigned warp_id = threadIdx.x / raft::warp_size(); if (warp_id > 0) { return; } - const unsigned lane_id = threadIdx.x % 32; - constexpr unsigned N = (MAX_ELEMENTS + 31) / 32; + const unsigned lane_id = threadIdx.x % raft::warp_size(); + constexpr unsigned N = (MAX_ELEMENTS + (raft::warp_size() - 1)) / raft::warp_size(); float key[N]; INDEX_T val[N]; for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i); + unsigned j = lane_id + (raft::warp_size() * i); if (j < num_elements) { key[i] = distances[j]; val[i] = indices[j]; @@ -142,13 +142,34 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num } } +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_64( + float* distances, // [num_elements] + uint32_t* indices, // [num_elements] + const uint32_t num_elements) +{ + topk_by_bitonic_sort<64, uint32_t>(distances, indices, num_elements); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_128( + float* distances, // [num_elements] + uint32_t* indices, // [num_elements] + const uint32_t num_elements) +{ + topk_by_bitonic_sort<128, uint32_t>(distances, indices, num_elements); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_256( + float* distances, // [num_elements] + uint32_t* indices, // [num_elements] + const uint32_t num_elements) +{ + topk_by_bitonic_sort<256, uint32_t>(distances, indices, num_elements); +} + // // multiple CTAs per single query // -template +template RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] @@ -157,6 +178,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( const DATASET_DESCRIPTOR_T* dataset_desc, const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree] + const uint32_t max_elements, const uint32_t graph_degree, const SourceIndexT* source_indices_ptr, // [num_queries, search_width] const unsigned num_distilation, @@ -211,7 +233,7 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( // |<--- result_buffer_size_32 --->| const auto result_buffer_size = itopk_size + graph_degree; const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); - assert(result_buffer_size_32 <= MAX_ELEMENTS); + assert(result_buffer_size_32 <= max_elements); // Set smem working buffer for the distance calculation dataset_desc = dataset_desc->setup_workspace(smem, queries_ptr, query_id); @@ -268,8 +290,33 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( _CLK_START(); if (threadIdx.x < 32) { // [1st warp] Topk with bitonic sort - topk_by_bitonic_sort( - result_distances_buffer, result_indices_buffer, result_buffer_size_32); + if constexpr (std::is_same_v) { + // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort + // function (vs post-inlining, this impacts register pressure) + if (max_elements <= 64) { + topk_by_bitonic_sort_wrapper_64( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else if (max_elements <= 128) { + topk_by_bitonic_sort_wrapper_128( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else { + assert(max_elements <= 256); + topk_by_bitonic_sort_wrapper_256( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } + } else { + if (max_elements <= 64) { + topk_by_bitonic_sort<64, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else if (max_elements <= 128) { + topk_by_bitonic_sort<128, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else { + assert(max_elements <= 256); + topk_by_bitonic_sort<256, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } + } } __syncthreads(); _CLK_REC(clk_topk); @@ -487,17 +534,12 @@ struct search_kernel_config { // Search kernel function type. Note that the actual values for the template value // parameters do not matter, because they are not part of the function signature. The // second to fourth value parameters will be selected by the choose_* functions below. - using kernel_t = - decltype(&search_kernel<128, DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>); + using kernel_t = decltype(&search_kernel); static auto choose_buffer_size(unsigned result_buffer_size, unsigned block_size) -> kernel_t { - if (result_buffer_size <= 64) { - return search_kernel<64, DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>; - } else if (result_buffer_size <= 128) { - return search_kernel<128, DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>; - } else if (result_buffer_size <= 256) { - return search_kernel<256, DATASET_DESCRIPTOR_T, SourceIndexT, SAMPLE_FILTER_T>; + if (result_buffer_size <= 256) { + return search_kernel; } THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256); } @@ -536,6 +578,17 @@ void select_and_run(const dataset_descriptor_host& dat SourceIndexT, SampleFilterT>::choose_buffer_size(result_buffer_size, block_size); + uint32_t max_elements{}; + if (result_buffer_size <= 64) { + max_elements = 64; + } else if (result_buffer_size <= 128) { + max_elements = 128; + } else if (result_buffer_size <= 256) { + max_elements = 256; + } else { + THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256); + } + RAFT_CUDA_TRY( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // Initialize hash table @@ -560,6 +613,7 @@ void select_and_run(const dataset_descriptor_host& dat dataset_desc.dev_ptr(stream), queries_ptr, graph.data_handle(), + max_elements, graph.extent(1), source_indices_ptr, ps.num_random_samplings, diff --git a/cpp/src/neighbors/detail/cagra/search_plan.cuh b/cpp/src/neighbors/detail/cagra/search_plan.cuh index 7b15196bea..f7035a23bb 100644 --- a/cpp/src/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/src/neighbors/detail/cagra/search_plan.cuh @@ -10,7 +10,6 @@ #include "compute_distance-ext.cuh" #include #include -// #include "search_single_cta_inst.cuh" // #include "topk_for_cagra/topk.h" #include diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh index 948c1b889e..b2b4670964 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -117,7 +117,7 @@ struct search // // Determine the thread block size // - constexpr unsigned min_block_size = 64; // 32 or 64 + constexpr unsigned min_block_size = 64; constexpr unsigned min_block_size_radix = 256; constexpr unsigned max_block_size = 1024; // @@ -129,13 +129,17 @@ struct search sizeof(std::uint32_t) * topk_ws_size + sizeof(std::uint32_t); std::uint32_t additional_smem_size = 0; - if (num_itopk_candidates > 256) { - // Tentatively calculate the required share memory size when radix - // sort based topk is used, assuming the block size is the maximum. + if (num_itopk_candidates > 256) { // radix sort + // Tentatively calculate the required shared memory size when radix sort based topk is used, + // assuming the block size is the maximum. if (itopk_size <= 256) { - additional_smem_size += topk_by_radix_sort<256, INDEX_T>::smem_size * sizeof(std::uint32_t); + constexpr unsigned MAX_ITOPK = 256; + additional_smem_size += + topk_by_radix_sort::smem_size(MAX_ITOPK) * sizeof(std::uint32_t); } else { - additional_smem_size += topk_by_radix_sort<512, INDEX_T>::smem_size * sizeof(std::uint32_t); + constexpr unsigned MAX_ITOPK = 512; + additional_smem_size += + topk_by_radix_sort::smem_size(MAX_ITOPK) * sizeof(std::uint32_t); } } @@ -152,7 +156,7 @@ struct search if (block_size == 0) { block_size = min_block_size; - if (num_itopk_candidates > 256) { + if (num_itopk_candidates > 256) { // radix sort // radix-based topk is used. block_size = min_block_size_radix; @@ -190,19 +194,6 @@ struct search max_block_size); thread_block_size = block_size; - if (num_itopk_candidates <= 256) { - RAFT_LOG_DEBUG("# bitonic-sort based topk routine is used"); - } else { - RAFT_LOG_DEBUG("# radix-sort based topk routine is used"); - smem_size = base_smem_size; - if (itopk_size <= 256) { - constexpr unsigned MAX_ITOPK = 256; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); - } else { - constexpr unsigned MAX_ITOPK = 512; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); - } - } RAFT_LOG_DEBUG("# smem_size: %u", smem_size); hashmap_size = 0; if (small_hash_bitlen == 0 && !this->persistent) { diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 96e0c419f2..76a45dff81 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -57,7 +57,7 @@ namespace single_cta_search { // #define _CLK_BREAKDOWN -template +template RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents(std::uint32_t* const terminate_flag, INDEX_T* const next_parent_indices, INDEX_T* const internal_topk_indices, @@ -99,24 +99,24 @@ RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents(std::uint32_t* const termin if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } } -template +template RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full( float* candidate_distances, // [num_candidates] IdxT* candidate_indices, // [num_candidates] const std::uint32_t num_candidates, - const std::uint32_t num_itopk, - unsigned MULTI_WARPS = 0) + const std::uint32_t num_itopk) { - const unsigned lane_id = threadIdx.x % 32; - const unsigned warp_id = threadIdx.x / 32; - if (MULTI_WARPS == 0) { + const unsigned lane_id = threadIdx.x % raft::warp_size(); + const unsigned warp_id = threadIdx.x / raft::warp_size(); + static_assert(MAX_CANDIDATES <= 256); + if constexpr (!MULTI_WARPS) { if (warp_id > 0) { return; } - constexpr unsigned N = (MAX_CANDIDATES + 31) / 32; + constexpr unsigned N = (MAX_CANDIDATES + (raft::warp_size() - 1)) / raft::warp_size(); float key[N]; IdxT val[N]; /* Candidates -> Reg */ for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i); + unsigned j = lane_id + (raft::warp_size() * i); if (j < num_candidates) { key[i] = candidate_distances[j]; val[i] = candidate_indices[j]; @@ -136,15 +136,17 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full( } } } else { + assert(blockDim.x >= 64); // Use two warps (64 threads) constexpr unsigned max_candidates_per_warp = (MAX_CANDIDATES + 1) / 2; - constexpr unsigned N = (max_candidates_per_warp + 31) / 32; + static_assert(max_candidates_per_warp <= 128); + constexpr unsigned N = (max_candidates_per_warp + (raft::warp_size() - 1)) / raft::warp_size(); float key[N]; IdxT val[N]; if (warp_id < 2) { /* Candidates -> Reg */ for (unsigned i = 0; i < N; i++) { - unsigned jl = lane_id + (32 * i); + unsigned jl = lane_id + (raft::warp_size() * i); unsigned j = jl + (max_candidates_per_warp * warp_id); if (j < num_candidates) { key[i] = candidate_distances[j]; @@ -188,7 +190,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full( if (num_warps_used > 1) { __syncthreads(); } if (warp_id < num_warps_used) { /* Merge */ - bitonic::warp_merge(key, val, 32); + bitonic::warp_merge(key, val, raft::warp_size()); /* Reg -> Temp_itopk */ for (unsigned i = 0; i < N; i++) { unsigned jl = (N * lane_id) + i; @@ -203,7 +205,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full( } } -template +template RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( float* itopk_distances, // [num_itopk] IdxT* itopk_indices, // [num_itopk] @@ -212,20 +214,22 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( IdxT* candidate_indices, // [num_candidates] const std::uint32_t num_candidates, std::uint32_t* work_buf, - const bool first, - unsigned MULTI_WARPS = 0) + const bool first) { - const unsigned lane_id = threadIdx.x % 32; - const unsigned warp_id = threadIdx.x / 32; - if (MULTI_WARPS == 0) { + const unsigned lane_id = threadIdx.x % raft::warp_size(); + const unsigned warp_id = threadIdx.x / raft::warp_size(); + + static_assert(MAX_ITOPK <= 512); + if constexpr (!MULTI_WARPS) { + static_assert(MAX_ITOPK <= 256); if (warp_id > 0) { return; } - constexpr unsigned N = (MAX_ITOPK + 31) / 32; + constexpr unsigned N = (MAX_ITOPK + (raft::warp_size() - 1)) / raft::warp_size(); float key[N]; IdxT val[N]; if (first) { /* Load itopk results */ for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i); + unsigned j = lane_id + (raft::warp_size() * i); if (j < num_itopk) { key[i] = itopk_distances[j]; val[i] = itopk_indices[j]; @@ -251,7 +255,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( } /* Merge candidates */ for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; // [0:MAX_ITOPK-1] + unsigned j = (N * lane_id) + i; // [0:max_itopk-1] unsigned k = MAX_ITOPK - 1 - j; if (k >= num_itopk || k >= num_candidates) continue; float candidate_key = candidate_distances[device::swizzling(k)]; @@ -261,7 +265,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( } } /* Warp Merge */ - bitonic::warp_merge(key, val, 32); + bitonic::warp_merge(key, val, raft::warp_size()); /* Store new itopk results */ for (unsigned i = 0; i < N; i++) { unsigned j = (N * lane_id) + i; @@ -271,16 +275,18 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( } } } else { + static_assert(MAX_ITOPK == 512); + assert(blockDim.x >= 64); // Use two warps (64 threads) or more constexpr unsigned max_itopk_per_warp = (MAX_ITOPK + 1) / 2; - constexpr unsigned N = (max_itopk_per_warp + 31) / 32; + constexpr unsigned N = (max_itopk_per_warp + (raft::warp_size() - 1)) / raft::warp_size(); float key[N]; IdxT val[N]; if (first) { /* Load itop results (not sorted) */ if (warp_id < 2) { for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i) + (max_itopk_per_warp * warp_id); + unsigned j = lane_id + (raft::warp_size() * i) + (max_itopk_per_warp * warp_id); if (j < num_itopk) { key[i] = itopk_distances[j]; val[i] = itopk_indices[j]; @@ -314,7 +320,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( } } /* Warp Merge */ - bitonic::warp_merge(key, val, 32); + bitonic::warp_merge(key, val, raft::warp_size()); } __syncthreads(); /* Store itopk results (sorted) */ @@ -396,7 +402,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( } } /* Warp Merge */ - bitonic::warp_merge(key, val, 32); + bitonic::warp_merge(key, val, raft::warp_size()); /* Store new itopk results */ for (unsigned i = 0; i < N; i++) { const unsigned j = (N * lane_id) + i; @@ -410,36 +416,170 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( } } -template +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_64_false( + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + topk_by_bitonic_sort_and_full<64, false, uint32_t>( + candidate_distances, candidate_indices, num_candidates, num_itopk); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_128_false( + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + topk_by_bitonic_sort_and_full<128, false, uint32_t>( + candidate_distances, candidate_indices, num_candidates, num_itopk); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_256_false( + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + topk_by_bitonic_sort_and_full<256, false, uint32_t>( + candidate_distances, candidate_indices, num_candidates, num_itopk); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_64_false( + float* itopk_distances, // [num_itopk] + uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + topk_by_bitonic_sort_and_merge<64, false, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_128_false( + float* itopk_distances, // [num_itopk] + uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + topk_by_bitonic_sort_and_merge<128, false, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_256_false( + float* itopk_distances, // [num_itopk] + uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + topk_by_bitonic_sort_and_merge<256, false, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +template RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( float* itopk_distances, // [num_itopk] IdxT* itopk_indices, // [num_itopk] + const std::uint32_t max_itopk, const std::uint32_t num_itopk, float* candidate_distances, // [num_candidates] IdxT* candidate_indices, // [num_candidates] + const std::uint32_t max_candidates, const std::uint32_t num_candidates, std::uint32_t* work_buf, - const bool first, - const unsigned MULTI_WARPS_1, - const unsigned MULTI_WARPS_2) + const bool first) { - // The results in candidate_distances/indices are sorted by bitonic sort. - topk_by_bitonic_sort_and_full( - candidate_distances, candidate_indices, num_candidates, num_itopk, MULTI_WARPS_1); - - // The results sorted above are merged with the internal intermediate top-k - // results so far using bitonic merge. - topk_by_bitonic_sort_and_merge(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first, - MULTI_WARPS_2); + static_assert(std::is_same_v); + assert(max_itopk <= 512); + assert(max_candidates <= 256); + assert(!MULTI_WARPS || blockDim.x >= 64); + + // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort_and_full + // function (vs post-inlining, this impacts register pressure) + if (max_candidates <= 64) { + topk_by_bitonic_sort_and_full_wrapper_64_false( + candidate_distances, candidate_indices, num_candidates, num_itopk); + } else if (max_candidates <= 128) { + topk_by_bitonic_sort_and_full_wrapper_128_false( + candidate_distances, candidate_indices, num_candidates, num_itopk); + } else { + topk_by_bitonic_sort_and_full_wrapper_256_false( + candidate_distances, candidate_indices, num_candidates, num_itopk); + } + + if constexpr (!MULTI_WARPS) { + assert(max_itopk <= 256); + // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort_and_merge + // function (vs post-inlining, this impacts register pressure) + if (max_itopk <= 64) { + topk_by_bitonic_sort_and_merge_wrapper_64_false(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } else if (max_itopk <= 128) { + topk_by_bitonic_sort_and_merge_wrapper_128_false(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } else { + topk_by_bitonic_sort_and_merge_wrapper_256_false(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } + } else { + assert(max_itopk > 256); + topk_by_bitonic_sort_and_merge<512, MULTI_WARPS, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } } // This function move the invalid index element to the end of the itopk list. @@ -502,8 +642,6 @@ RAFT_DEVICE_INLINE_FUNCTION void hashmap_restore(INDEX_T* const hashmap_ptr, /** * @brief Search operation for a single query using a single thread block. * * - * @tparam MAX_ITOPK Maximum for the internal_topk argument. - * @tparam MAX_CANDIDATES * @tparam TOPK_BY_BITONIC_SORT * @tparam DATASET_DESCRIPTOR_T * @tparam SAMPLE_FILTER_T @@ -533,13 +671,12 @@ RAFT_DEVICE_INLINE_FUNCTION void hashmap_restore(INDEX_T* const hashmap_ptr, * @param small_hash_reset_interval Interval for resetting the small hash. * @param query_id sequential id of the query in the batch */ -template -__device__ void search_core( +RAFT_DEVICE_INLINE_FUNCTION void search_core( uintptr_t result_indices_ptr, // [num_queries, top_k] typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] const std::uint32_t top_k, @@ -554,6 +691,8 @@ __device__ void search_core( const uint32_t num_seeds, typename DATASET_DESCRIPTOR_T::INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, const std::uint32_t internal_topk, const std::uint32_t search_width, const std::uint32_t min_iteration, @@ -665,10 +804,10 @@ __device__ void search_core( // batch size is small (short-latency), but it might not be always good // when batch size is large (high-throughput). // topk_by_bitonic_sort_and_merge() consists of two operations: - // if MAX_CANDIDATES is greater than 128, the first operation uses two warps; - // if MAX_ITOPK is greater than 256, the second operation used two warps. - const unsigned multi_warps_1 = ((blockDim.x >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0; - const unsigned multi_warps_2 = ((blockDim.x >= 64) && (MAX_ITOPK > 256)) ? 1 : 0; + // if max_candidates is greater than 128, the first operation uses two warps; + // if max_itopk is greater than 256, the second operation used two warps. + assert(blockDim.x >= 64); + const bool bitonic_sort_and_full_multi_warps = (max_candidates > 128) ? true : false; // reset small-hash table. if ((iter + 1) % small_hash_reset_interval == 0) { @@ -681,13 +820,13 @@ __device__ void search_core( if (blockDim.x == 32) { hash_start_tid = 0; } else if (blockDim.x == 64) { - if (multi_warps_1 || multi_warps_2) { + if (bitonic_sort_and_full_multi_warps || BITONIC_SORT_AND_MERGE_MULTI_WARPS) { hash_start_tid = 0; } else { hash_start_tid = 32; } } else { - if (multi_warps_1 || multi_warps_2) { + if (bitonic_sort_and_full_multi_warps || BITONIC_SORT_AND_MERGE_MULTI_WARPS) { hash_start_tid = 64; } else { hash_start_tid = 32; @@ -709,34 +848,33 @@ __device__ void search_core( if (threadIdx.x == 0) { *terminate_flag = 0; } } - topk_by_bitonic_sort_and_merge( + topk_by_bitonic_sort_and_merge( result_distances_buffer, result_indices_buffer, + max_itopk, internal_topk, result_distances_buffer + internal_topk, result_indices_buffer + internal_topk, + max_candidates, search_width * graph_degree, topk_ws, - (iter == 0), - multi_warps_1, - multi_warps_2); + (iter == 0)); __syncthreads(); _CLK_REC(clk_topk); } else { _CLK_START(); // topk with radix block sort - topk_by_radix_sort{}( - internal_topk, - gridDim.x, - result_buffer_size, - reinterpret_cast(result_distances_buffer), - result_indices_buffer, - reinterpret_cast(result_distances_buffer), - result_indices_buffer, - nullptr, - topk_ws, - true, - smem_work_ptr); + topk_by_radix_sort{}(max_itopk, + internal_topk, + result_buffer_size, + reinterpret_cast(result_distances_buffer), + result_indices_buffer, + reinterpret_cast(result_distances_buffer), + result_indices_buffer, + nullptr, + topk_ws, + true, + smem_work_ptr); _CLK_REC(clk_topk); // reset small-hash table @@ -876,19 +1014,17 @@ __device__ void search_core( // candidate list. if (top_k > internal_topk || result_indices_buffer[top_k - 1] == invalid_index) { __syncthreads(); - const unsigned multi_warps_1 = ((blockDim.x >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0; - const unsigned multi_warps_2 = ((blockDim.x >= 64) && (MAX_ITOPK > 256)) ? 1 : 0; - topk_by_bitonic_sort_and_merge( + topk_by_bitonic_sort_and_merge( result_distances_buffer, result_indices_buffer, + max_itopk, internal_topk, result_distances_buffer + internal_topk, result_indices_buffer + internal_topk, + max_candidates, search_width * graph_degree, topk_ws, - (iter == 0), - multi_warps_1, - multi_warps_2); + (iter == 0)); } __syncthreads(); } @@ -918,7 +1054,7 @@ __device__ void search_core( for (std::uint32_t i = threadIdx.x; i < top_k; i += blockDim.x) { unsigned j = i + (top_k * query_id); unsigned ii = i; - if (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } + if constexpr (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; } constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; @@ -958,9 +1094,8 @@ __device__ void search_core( #endif } -template @@ -979,6 +1114,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( const uint32_t num_seeds, typename DATASET_DESCRIPTOR_T::INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, const std::uint32_t internal_topk, const std::uint32_t search_width, const std::uint32_t min_iteration, @@ -990,9 +1127,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( SAMPLE_FILTER_T sample_filter) { const auto query_id = blockIdx.y; - search_core(result_indices_ptr, @@ -1008,6 +1144,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( seed_ptr, num_seeds, visited_hashmap_ptr, + max_candidates, + max_itopk, internal_topk, search_width, min_iteration, @@ -1079,9 +1217,8 @@ constexpr auto is_worker_busy(worker_handle_t::handle_t h) -> bool return (h != kWaitForWork) && (h != kNoMoreWork); } -template @@ -1099,6 +1236,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel_p( const uint32_t num_seeds, typename DATASET_DESCRIPTOR_T::INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, const std::uint32_t internal_topk, const std::uint32_t search_width, const std::uint32_t min_iteration, @@ -1150,9 +1289,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel_p( auto query_id = worker_data.value.query_id; // work phase - search_core(result_indices_ptr, @@ -1168,6 +1306,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel_p( seed_ptr, num_seeds, visited_hashmap_ptr, + max_candidates, + max_itopk, internal_topk, search_width, min_iteration, @@ -1195,24 +1335,22 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel_p( } template auto dispatch_kernel = []() { + static_assert(TOPK_BY_BITONIC_SORT || !BITONIC_SORT_AND_MERGE_MULTI_WARPS); if constexpr (Persistent) { - return search_kernel_p; } else { - return search_kernel; @@ -1225,87 +1363,43 @@ template struct search_kernel_config { using kernel_t = decltype(dispatch_kernel); - template - static auto choose_search_kernel(unsigned itopk_size) -> kernel_t - { - if (itopk_size <= 64) { - return dispatch_kernel; - } else if (itopk_size <= 128) { - return dispatch_kernel; - } else if (itopk_size <= 256) { - return dispatch_kernel; - } else if (itopk_size <= 512) { - return dispatch_kernel; - } - THROW("No kernel for parametels itopk_size %u, max_candidates %u", itopk_size, MAX_CANDIDATES); - } - static auto choose_itopk_and_mx_candidates(unsigned itopk_size, unsigned num_itopk_candidates, unsigned block_size) -> kernel_t { - if (num_itopk_candidates <= 64) { - // use bitonic sort based topk - return choose_search_kernel<64, 1>(itopk_size); - } else if (num_itopk_candidates <= 128) { - return choose_search_kernel<128, 1>(itopk_size); - } else if (num_itopk_candidates <= 256) { - return choose_search_kernel<256, 1>(itopk_size); - } else { - // Radix-based topk is used - constexpr unsigned max_candidates = 32; // to avoid build failure + assert(itopk_size <= 512); + if (num_itopk_candidates <= 256) { if (itopk_size <= 256) { return dispatch_kernel; - } else if (itopk_size <= 512) { + } else { + assert(block_size >= 64); return dispatch_kernel; } + } else { + // Radix-based topk is used + return dispatch_kernel; } - THROW("No kernel for parametels itopk_size %u, num_itopk_candidates %u", - itopk_size, - num_itopk_candidates); } }; @@ -1803,6 +1897,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b std::reference_wrapper> dataset_desc, raft::device_matrix_view graph, const SourceIndexT* source_indices_ptr, + uint32_t max_candidates, uint32_t num_itopk_candidates, uint32_t block_size, // uint32_t smem_size, @@ -1812,6 +1907,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b uint32_t num_random_samplings, uint64_t rand_xor_mask, uint32_t num_seeds, + uint32_t max_itopk, size_t itopk_size, size_t search_width, size_t min_iterations, @@ -1831,6 +1927,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b std::reference_wrapper> dataset_desc, raft::device_matrix_view graph, const SourceIndexT* source_indices_ptr, + uint32_t max_candidates, uint32_t num_itopk_candidates, uint32_t block_size, // uint32_t smem_size, @@ -1840,6 +1937,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b uint32_t num_random_samplings, uint64_t rand_xor_mask, uint32_t num_seeds, + uint32_t max_itopk, size_t itopk_size, size_t search_width, size_t min_iterations, @@ -1859,6 +1957,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b param_hash(calculate_parameter_hash(dd_host, graph, source_indices_ptr, + max_candidates, num_itopk_candidates, block_size, smem_size, @@ -1868,6 +1967,7 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b num_random_samplings, rand_xor_mask, num_seeds, + max_itopk, itopk_size, search_width, min_iterations, @@ -1938,6 +2038,8 @@ struct alignas(kCacheLineBytes) persistent_runner_t : public persistent_runner_b &dev_seed_ptr, &num_seeds, &hashmap_ptr, // visited_hashmap_ptr: [num_queries, 1 << hash_bitlen] + &max_candidates, + &max_itopk, &itopk_size, &search_width, &min_iterations, @@ -2139,6 +2241,38 @@ void select_and_run( const SourceIndexT* source_indices_ptr = source_indices.has_value() ? source_indices->data_handle() : nullptr; + uint32_t max_candidates{}; + if (num_itopk_candidates <= 64) { + max_candidates = 64; + } else if (num_itopk_candidates <= 128) { + max_candidates = 128; + } else if (num_itopk_candidates <= 256) { + max_candidates = 256; + } else { + max_candidates = + 32; // irrelevant, radix based topk is used (see choose_itopk_and_max_candidates) + } + + uint32_t max_itopk{}; + assert(ps.itopk_size <= 512); + if (num_itopk_candidates <= 256) { // bitonic sort + if (ps.itopk_size <= 64) { + max_itopk = 64; + } else if (ps.itopk_size <= 128) { + max_itopk = 128; + } else if (ps.itopk_size <= 256) { + max_itopk = 256; + } else { + max_itopk = 512; + } + } else { // radix sort + if (ps.itopk_size <= 256) { + max_itopk = 256; + } else { + max_itopk = 512; + } + } + if (ps.persistent) { using runner_type = persistent_runner_t; @@ -2159,6 +2293,8 @@ control is returned in this thread (in persistent_runner_t constructor), so we'r ps.num_random_samplings, ps.rand_xor_mask, num_seeds, + max_candidates, + max_itopk, ps.itopk_size, ps.search_width, ps.min_iterations, @@ -2190,6 +2326,8 @@ control is returned in this thread (in persistent_runner_t constructor), so we'r dev_seed_ptr, num_seeds, hashmap_ptr, + max_candidates, + max_itopk, ps.itopk_size, ps.search_width, ps.min_iterations, diff --git a/cpp/src/neighbors/detail/cagra/topk_by_radix.cuh b/cpp/src/neighbors/detail/cagra/topk_by_radix.cuh index 68aab54053..7b0e3dbcad 100644 --- a/cpp/src/neighbors/detail/cagra/topk_by_radix.cuh +++ b/cpp/src/neighbors/detail/cagra/topk_by_radix.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -9,20 +9,53 @@ namespace cuvs::neighbors::cagra::detail { namespace single_cta_search { -template struct topk_by_radix_sort_base { - static constexpr std::uint32_t smem_size = MAX_INTERNAL_TOPK * 2 + 2048 + 8; - static constexpr std::uint32_t state_bit_lenght = 0; + static constexpr std::uint32_t state_bit_length = 0; static constexpr std::uint32_t vecLen = 2; // TODO + + static constexpr uint32_t smem_size(uint32_t max_itopk) { return max_itopk * 2 + 2048 + 8; } }; -template -struct topk_by_radix_sort : topk_by_radix_sort_base {}; -template -struct topk_by_radix_sort> - : topk_by_radix_sort_base { - RAFT_DEVICE_INLINE_FUNCTION void operator()(uint32_t topk, - uint32_t batch_size, +RAFT_DEVICE_INLINE_FUNCTION void topk_cta_11_core_wrapper_256(uint32_t topk, + uint32_t len_x, + const uint32_t* _x, + const uint32_t* _in_vals, + uint32_t* _y, + uint32_t* _out_vals, + uint8_t* state, + uint32_t* _hints, + bool sort, + uint32_t* _smem) +{ + topk_cta_11_core(topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_cta_11_core_wrapper_512(uint32_t topk, + uint32_t len_x, + const uint32_t* _x, + const uint32_t* _in_vals, + uint32_t* _y, + uint32_t* _out_vals, + uint8_t* state, + uint32_t* _hints, + bool sort, + uint32_t* _smem) +{ + topk_cta_11_core(topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); +} + +template +struct topk_by_radix_sort : topk_by_radix_sort_base { + RAFT_DEVICE_INLINE_FUNCTION void operator()(uint32_t max_itopk, + uint32_t topk, uint32_t len_x, const uint32_t* _x, const IdxT* _in_vals, @@ -33,48 +66,40 @@ struct topk_by_radix_sort(work); - topk_cta_11_core::state_bit_lenght, - topk_by_radix_sort_base::vecLen, - 64, - 32, - IdxT>(topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); + if constexpr (std::is_same_v) { // use a non-template wrapper function to avoid + // pre-inlining the topk_cta_11_core function (vs + // post-inlining, this impacts register pressure) + std::uint8_t* const state = reinterpret_cast(work); + if (max_itopk <= 256) { + topk_cta_11_core_wrapper_256( + topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); + } else { + assert(max_itopk <= 512); + topk_cta_11_core_wrapper_512( + topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); + } + } else { // currently, unused + std::uint8_t* const state = reinterpret_cast(work); + if (max_itopk <= 256) { + topk_cta_11_core( + topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); + } else { + assert(max_itopk <= 512); + topk_cta_11_core( + topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); + } + } } }; -#define TOP_FUNC_PARTIAL_SPECIALIZATION(V) \ - template \ - struct topk_by_radix_sort< \ - MAX_INTERNAL_TOPK, \ - IdxT, \ - std::enable_if_t<((MAX_INTERNAL_TOPK <= V) && (2 * MAX_INTERNAL_TOPK > V))>> \ - : topk_by_radix_sort_base { \ - RAFT_DEVICE_INLINE_FUNCTION void operator()(uint32_t topk, \ - uint32_t batch_size, \ - uint32_t len_x, \ - const uint32_t* _x, \ - const IdxT* _in_vals, \ - uint32_t* _y, \ - IdxT* _out_vals, \ - uint32_t* work, \ - uint32_t* _hints, \ - bool sort, \ - uint32_t* _smem) \ - { \ - assert(blockDim.x >= V / 4); \ - std::uint8_t* state = (std::uint8_t*)work; \ - topk_cta_11_core::state_bit_lenght, \ - topk_by_radix_sort_base::vecLen, \ - V, \ - V / 4, \ - IdxT>( \ - topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); \ - } \ - }; -TOP_FUNC_PARTIAL_SPECIALIZATION(128); -TOP_FUNC_PARTIAL_SPECIALIZATION(256); -TOP_FUNC_PARTIAL_SPECIALIZATION(512); -TOP_FUNC_PARTIAL_SPECIALIZATION(1024); - } // namespace single_cta_search } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh b/cpp/src/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh index 7eb6a8a1a0..97ea7dc236 100644 --- a/cpp/src/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh +++ b/cpp/src/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -749,7 +749,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_cta_11_core(uint32_t topk, // Sorting by thread if (thread_id < numSortThreads) { const bool ascending = ((thread_id & mask) == 0); - if (numTopkPerThread == 3) { + if constexpr (numTopkPerThread == 3) { swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); @@ -812,7 +812,7 @@ RAFT_DEVICE_INLINE_FUNCTION void topk_cta_11_core(uint32_t topk, if (thread_id < numSortThreads) { const bool ascending = ((thread_id & next_mask) == 0); - if (numTopkPerThread == 3) { + if constexpr (numTopkPerThread == 3) { swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending);