diff --git a/c/include/cuvs/neighbors/nn_descent.h b/c/include/cuvs/neighbors/nn_descent.h index 6a1e67d92c..0c7102e3e1 100644 --- a/c/include/cuvs/neighbors/nn_descent.h +++ b/c/include/cuvs/neighbors/nn_descent.h @@ -15,6 +15,18 @@ extern "C" { #endif +/** + * @brief Dtype to use for distance computation + * - `NND_DIST_COMP_AUTO`: Automatically determine the best dtype for distance computation based on the dataset dimensions. + * - `NND_DIST_COMP_FP32`: Use fp32 distance computation for better precision at the cost of performance and memory usage. + * - `NND_DIST_COMP_FP16`: Use fp16 distance computation. + */ +typedef enum { + NND_DIST_COMP_AUTO = 0, + NND_DIST_COMP_FP32 = 1, + NND_DIST_COMP_FP16 = 2 +} cuvsNNDescentDistCompDtype; + /** * @defgroup nn_descent_c_index_params The nn-descent algorithm parameters. * @{ @@ -34,6 +46,8 @@ extern "C" { * `max_iterations`: The number of iterations that nn-descent will refine * the graph for. More iterations produce a better quality graph at cost of performance * `termination_threshold`: The delta at which nn-descent will terminate its iterations + * `return_distances`: Boolean to decide whether to return distances array + * `dist_comp_dtype`: dtype to use for distance computation. Defaults to `NND_DIST_COMP_AUTO` which automatically determines the best dtype for distance computation based on the dataset dimensions. Use `NND_DIST_COMP_FP32` for better precision at the cost of performance and memory usage. This option is only valid when data type is fp32. Use `NND_DIST_COMP_FP16` for better performance and memory usage at the cost of precision. */ struct cuvsNNDescentIndexParams { cuvsDistanceType metric; @@ -43,6 +57,7 @@ struct cuvsNNDescentIndexParams { size_t max_iterations; float termination_threshold; bool return_distances; + cuvsNNDescentDistCompDtype dist_comp_dtype; }; typedef struct cuvsNNDescentIndexParams* cuvsNNDescentIndexParams_t; diff --git a/c/src/neighbors/nn_descent.cpp b/c/src/neighbors/nn_descent.cpp index 62079bb274..708056144a 100644 --- a/c/src/neighbors/nn_descent.cpp +++ b/c/src/neighbors/nn_descent.cpp @@ -43,6 +43,7 @@ void* _build(cuvsResources_t res, build_params.max_iterations = params.max_iterations; build_params.termination_threshold = params.termination_threshold; build_params.return_distances = params.return_distances; + build_params.dist_comp_dtype = static_cast(static_cast(params.dist_comp_dtype)); using graph_type = raft::host_matrix_view; std::optional graph; @@ -177,7 +178,8 @@ extern "C" cuvsError_t cuvsNNDescentIndexParamsCreate(cuvsNNDescentIndexParams_t .intermediate_graph_degree = cpp_params.intermediate_graph_degree, .max_iterations = cpp_params.max_iterations, .termination_threshold = cpp_params.termination_threshold, - .return_distances = cpp_params.return_distances}; + .return_distances = cpp_params.return_distances, + .dist_comp_dtype = static_cast(static_cast(cpp_params.dist_comp_dtype))}; }); } diff --git a/cpp/include/cuvs/neighbors/nn_descent.hpp b/cpp/include/cuvs/neighbors/nn_descent.hpp index f04dda90e3..c2f6121303 100644 --- a/cpp/include/cuvs/neighbors/nn_descent.hpp +++ b/cpp/include/cuvs/neighbors/nn_descent.hpp @@ -24,6 +24,16 @@ namespace cuvs::neighbors::nn_descent { * @{ */ +/** + * @brief Dtype to use for distance computation + * - `AUTO`: Automatically determine the best dtype for distance computation based on the dataset + * dimensions. + * - `FP32`: Use fp32 distance computation for better precision at the cost of performance and + * memory usage. + * - `FP16`: Use fp16 distance computation. + */ +enum class DIST_COMP_DTYPE { AUTO = 0, FP32 = 1, FP16 = 2 }; + /** * @brief Parameters used to build an nn-descent index * - `graph_degree`: For an input dataset of dimensions (N, D), @@ -37,6 +47,11 @@ namespace cuvs::neighbors::nn_descent { * the graph for. More iterations produce a better quality graph at cost of performance * - `termination_threshold`: The delta at which nn-descent will terminate its iterations * - `return_distances`: Boolean to decide whether to return distances array + * - `dist_comp_dtype`: dtype to use for distance computation. Defaults to `AUTO` which + * automatically determines the best dtype for distance computation based on the dataset dimensions. + * Use `FP32` for better precision at the cost of performance and memory usage. This option is only + * valid when data type is fp32. Use `FP16` for better performance and memory usage at the cost of + * precision. */ struct index_params : cuvs::neighbors::index_params { size_t graph_degree = 64; @@ -44,6 +59,7 @@ struct index_params : cuvs::neighbors::index_params { size_t max_iterations = 20; float termination_threshold = 0.0001; bool return_distances = true; + DIST_COMP_DTYPE dist_comp_dtype = DIST_COMP_DTYPE::AUTO; /** @brief Construct NN descent parameters for a specific kNN graph degree * diff --git a/cpp/src/neighbors/detail/nn_descent.cuh b/cpp/src/neighbors/detail/nn_descent.cuh index 184cbc72cd..24ccdd5027 100644 --- a/cpp/src/neighbors/detail/nn_descent.cuh +++ b/cpp/src/neighbors/detail/nn_descent.cuh @@ -168,9 +168,6 @@ __device__ __forceinline__ void warp_bitonic_sort(T* element_ptr, const int lane return; } -constexpr int TILE_ROW_WIDTH = 64; -constexpr int TILE_COL_WIDTH = 128; - constexpr int NUM_SAMPLES = 32; // For now, the max. number of samples is 32, so the sample cache size is fixed // to 64 (32 * 2). @@ -228,11 +225,11 @@ __device__ __forceinline__ void load_vec(Data_t* vec_buffer, } // TODO: Replace with RAFT utilities https://github.com/rapidsai/raft/issues/1827 -/** Calculate L2 norm, and cast data to __half */ -template +/** Calculate L2 norm, and cast data to Output_t */ +template RAFT_KERNEL preprocess_data_kernel( const Data_t* input_data, - __half* output_data, + Output_t* output_data, int dim, DistData_t* l2_norms, size_t list_offset = 0, @@ -302,7 +299,12 @@ RAFT_KERNEL add_rev_edges_kernel(const Index_t* graph, for (int idx = threadIdx.x; idx < list_size.x; idx += blockDim.x) { // each node has same number (num_samples) of forward and reverse edges - size_t rev_list_id = graph[list_id * num_samples + idx]; + Index_t rev_list_id = graph[list_id * num_samples + idx]; + if (rev_list_id == std::numeric_limits::max()) { + // sentinel value + continue; + } + // there are already num_samples forward edges int idx_in_rev_list = atomicAdd(&list_sizes[rev_list_id].y, 1); if (idx_in_rev_list >= num_samples) { @@ -480,6 +482,326 @@ __device__ __forceinline__ void remove_duplicates( } } +template +__device__ __forceinline__ void calculate_metric(float* s_distances, + Index_t* row_neighbors, + int list_row_size, + Index_t* col_neighbors, + int list_col_size, + const Data_t* data, + const int data_dim, + DistData_t* l2_norms, + cuvs::distance::DistanceType metric, + DistEpilogue_t dist_epilogue) +{ + // if we have a distance epilogue, distances need to be fully calculated instead of postprocessing + // them. + bool can_postprocess_dist = std::is_same_v; + + for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { + int row_id = i / SKEWED_MAX_NUM_BI_SAMPLES; + int col_id = i % SKEWED_MAX_NUM_BI_SAMPLES; + + if (row_id < list_row_size && col_id < list_col_size) { + if (metric == cuvs::distance::DistanceType::InnerProduct && can_postprocess_dist) { + s_distances[i] = -s_distances[i]; + } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { + s_distances[i] = 1.0 - s_distances[i]; + } else if (metric == cuvs::distance::DistanceType::BitwiseHamming) { + s_distances[i] = 0.0; + int n1 = row_neighbors[row_id]; + int n2 = col_neighbors[col_id]; + // TODO: https://github.com/rapidsai/cuvs/issues/1127 + const uint8_t* data_n1 = reinterpret_cast(data) + n1 * data_dim; + const uint8_t* data_n2 = reinterpret_cast(data) + n2 * data_dim; + for (int d = 0; d < data_dim; d++) { + s_distances[i] += __popc(static_cast(data_n1[d] ^ data_n2[d]) & 0xff); + } + } else { // L2Expanded or L2SqrtExpanded + s_distances[i] = + l2_norms[row_neighbors[row_id]] + l2_norms[col_neighbors[col_id]] - 2.0 * s_distances[i]; + // for fp32 vs fp16 precision differences resulting in negative distances when distance + // should be 0 related issue: https://github.com/rapidsai/cuvs/issues/991 + s_distances[i] = s_distances[i] < 0.0f ? 0.0f : s_distances[i]; + if (!can_postprocess_dist && metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + s_distances[i] = sqrtf(s_distances[i]); + } + } + s_distances[i] = dist_epilogue(s_distances[i], row_neighbors[row_id], col_neighbors[col_id]); + } else { + s_distances[i] = std::numeric_limits::max(); + } + } +} + +// launch_bounds here denote BLOCK_SIZE = 512 and MIN_BLOCKS_PER_SM = 4 +// Per +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications, +// MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048 +// For architectures 750 and 860 (890), the values for MAX_RESIDENT_THREAD_PER_SM +// is 1024 and 1536 respectively, which means the bounds don't work anymore +template , typename DistEpilogue_t> +RAFT_KERNEL +#ifdef __CUDA_ARCH__ +// Use minBlocksPerMultiprocessor = 4 on specific arches +#if (__CUDA_ARCH__) == 700 || (__CUDA_ARCH__) == 800 || (__CUDA_ARCH__) == 900 || \ + (__CUDA_ARCH__) == 1000 +__launch_bounds__(BLOCK_SIZE, 4) +#else +__launch_bounds__(BLOCK_SIZE) +#endif +#endif + local_join_kernel(const Index_t* graph_new, + const Index_t* rev_graph_new, + const int2* sizes_new, + const Index_t* graph_old, + const Index_t* rev_graph_old, + const int2* sizes_old, + const int width, + const float* data, + const int data_dim, + ID_t* graph, + DistData_t* dists, + int graph_width, + int* locks, + DistData_t* l2_norms, + cuvs::distance::DistanceType metric, + DistEpilogue_t dist_epilogue) +{ +#if (__CUDA_ARCH__ >= 700) + using namespace nvcuda; + __shared__ int s_list[MAX_NUM_BI_SAMPLES * 2]; + + constexpr int APAD = 4; + constexpr int BPAD = 4; + constexpr int TILE_COL_WIDTH = 32; + __shared__ float s_nv[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + APAD]; + __shared__ float s_ov[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + BPAD]; + __shared__ float s_distances[MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES]; + + // s_distances: MAX_NUM_BI_SAMPLES x SKEWED_MAX_NUM_BI_SAMPLES, reuse the space of s_ov + int* s_unique_counter = (int*)&s_ov[0][0]; + + if (threadIdx.x == 0) { + s_unique_counter[0] = 0; + s_unique_counter[1] = 0; + } + + Index_t* new_neighbors = s_list; + Index_t* old_neighbors = s_list + MAX_NUM_BI_SAMPLES; + + size_t list_id = blockIdx.x; + int2 list_new_size2 = sizes_new[list_id]; + int list_new_size = list_new_size2.x + list_new_size2.y; + int2 list_old_size2 = sizes_old[list_id]; + int list_old_size = list_old_size2.x + list_old_size2.y; + + if (!list_new_size) return; + int tx = threadIdx.x; + + if (tx < list_new_size2.x) { + new_neighbors[tx] = graph_new[list_id * width + tx]; + } else if (tx >= list_new_size2.x && tx < list_new_size) { + new_neighbors[tx] = rev_graph_new[list_id * width + tx - list_new_size2.x]; + } + + if (tx < list_old_size2.x) { + old_neighbors[tx] = graph_old[list_id * width + tx]; + } else if (tx >= list_old_size2.x && tx < list_old_size) { + old_neighbors[tx] = rev_graph_old[list_id * width + tx - list_old_size2.x]; + } + + __syncthreads(); + + remove_duplicates(new_neighbors, + list_new_size2.x, + new_neighbors + list_new_size2.x, + list_new_size2.y, + s_unique_counter[0], + 0); + + remove_duplicates(old_neighbors, + list_old_size2.x, + old_neighbors + list_old_size2.x, + list_old_size2.y, + s_unique_counter[1], + 1); + __syncthreads(); + list_new_size = list_new_size2.x + s_unique_counter[0]; + list_old_size = list_old_size2.x + s_unique_counter[1]; + + int warp_id = threadIdx.x / raft::warp_size(); + int lane_id = threadIdx.x % raft::warp_size(); + constexpr int num_warps = BLOCK_SIZE / raft::warp_size(); + + if (metric != cuvs::distance::DistanceType::BitwiseHamming) { + int tid = threadIdx.x; + for (int i = tid; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) + s_distances[i] = 0.0f; + + __syncthreads(); + + for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) { + int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1) + ? data_dim - step * TILE_COL_WIDTH + : TILE_COL_WIDTH; +#pragma unroll + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_new_size) { + size_t neighbor_id = new_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_nv[idx], + data + idx_in_data + step * TILE_COL_WIDTH, + num_load_elems, + TILE_COL_WIDTH, + lane_id); + } + } + __syncthreads(); + + // this is much faster than a warp-collaborative multiplication because MAX_NUM_BI_SAMPLES is + // fixed and small (64) + for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; + i += blockDim.x) { + int tmp_row = i / SKEWED_MAX_NUM_BI_SAMPLES; + int tmp_col = i % SKEWED_MAX_NUM_BI_SAMPLES; + if (tmp_row < list_new_size && tmp_col < list_new_size) { + float acc = 0.0f; + for (int d = 0; d < num_load_elems; d++) { + acc += s_nv[tmp_row][d] * s_nv[tmp_col][d]; + } + s_distances[i] += acc; + } + } + __syncthreads(); + } + } + __syncthreads(); + + calculate_metric(s_distances, + new_neighbors, + list_new_size, + new_neighbors, + list_new_size, + data, + data_dim, + l2_norms, + metric, + dist_epilogue); + + __syncthreads(); + + for (int step = 0; step < raft::ceildiv(list_new_size, num_warps); step++) { + int idx_in_list = step * num_warps + tx / raft::warp_size(); + if (idx_in_list >= list_new_size) continue; + auto min_elem = get_min_item(s_list[idx_in_list], idx_in_list, new_neighbors, s_distances); + if (min_elem.id() < gridDim.x) { + insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); + } + } + + if (!list_old_size) return; + + __syncthreads(); + + if (metric != cuvs::distance::DistanceType::BitwiseHamming) { + int tid = threadIdx.x; + for (int i = tid; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) + s_distances[i] = 0.0f; + + __syncthreads(); + + for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) { + int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1) + ? data_dim - step * TILE_COL_WIDTH + : TILE_COL_WIDTH; + if (TILE_COL_WIDTH < data_dim) { +#pragma unroll + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_new_size) { + size_t neighbor_id = new_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_nv[idx], + data + idx_in_data + step * TILE_COL_WIDTH, + num_load_elems, + TILE_COL_WIDTH, + lane_id); + } + } + } +#pragma unroll + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_old_size) { + size_t neighbor_id = old_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_ov[idx], + data + idx_in_data + step * TILE_COL_WIDTH, + num_load_elems, + TILE_COL_WIDTH, + lane_id); + } + } + __syncthreads(); + + // this is much faster than a warp-collaborative multiplication because MAX_NUM_BI_SAMPLES is + // fixed and small (64) + for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; + i += blockDim.x) { + int tmp_row = i / SKEWED_MAX_NUM_BI_SAMPLES; + int tmp_col = i % SKEWED_MAX_NUM_BI_SAMPLES; + if (tmp_row < list_new_size && tmp_col < list_old_size) { + float acc = 0.0f; + for (int d = 0; d < num_load_elems; d++) { + acc += s_nv[tmp_row][d] * s_ov[tmp_col][d]; + } + s_distances[i] += acc; + } + } + __syncthreads(); + } + } + __syncthreads(); + + calculate_metric(s_distances, + new_neighbors, + list_new_size, + old_neighbors, + list_old_size, + data, + data_dim, + l2_norms, + metric, + dist_epilogue); + + __syncthreads(); + + for (int step = 0; step < raft::ceildiv(MAX_NUM_BI_SAMPLES * 2, num_warps); step++) { + int idx_in_list = step * num_warps + tx / raft::warp_size(); + if (idx_in_list >= list_new_size && idx_in_list < MAX_NUM_BI_SAMPLES) continue; + if (idx_in_list >= MAX_NUM_BI_SAMPLES + list_old_size && idx_in_list < MAX_NUM_BI_SAMPLES * 2) + continue; + ResultItem min_elem{std::numeric_limits::max(), + std::numeric_limits::max()}; + if (idx_in_list < MAX_NUM_BI_SAMPLES) { + auto temp_min_item = + get_min_item(s_list[idx_in_list], idx_in_list, old_neighbors, s_distances); + if (temp_min_item.dist() < min_elem.dist()) { min_elem = temp_min_item; } + } else { + auto temp_min_item = get_min_item( + s_list[idx_in_list], idx_in_list - MAX_NUM_BI_SAMPLES, new_neighbors, s_distances, false); + if (temp_min_item.dist() < min_elem.dist()) { min_elem = temp_min_item; } + } + + if (min_elem.id() < gridDim.x) { + insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); + } + } +#endif +} + // launch_bounds here denote BLOCK_SIZE = 512 and MIN_BLOCKS_PER_SM = 4 // Per // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications, @@ -518,8 +840,9 @@ __launch_bounds__(BLOCK_SIZE) using namespace nvcuda; __shared__ int s_list[MAX_NUM_BI_SAMPLES * 2]; - constexpr int APAD = 8; - constexpr int BPAD = 8; + constexpr int APAD = 8; + constexpr int BPAD = 8; + constexpr int TILE_COL_WIDTH = 128; __shared__ __half s_nv[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + APAD]; // New vectors __shared__ __half s_ov[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + BPAD]; // Old vectors static_assert(sizeof(float) * MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES <= @@ -559,10 +882,6 @@ __launch_bounds__(BLOCK_SIZE) __syncthreads(); - // if we have a distance epilogue, distances need to be fully calculated instead of postprocessing - // them. - bool can_postprocess_dist = std::is_same_v; - remove_duplicates(new_neighbors, list_new_size2.x, new_neighbors + list_new_size2.x, @@ -630,40 +949,16 @@ __launch_bounds__(BLOCK_SIZE) } __syncthreads(); - for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { - int row_id = i % SKEWED_MAX_NUM_BI_SAMPLES; - int col_id = i / SKEWED_MAX_NUM_BI_SAMPLES; - - if (row_id < list_new_size && col_id < list_new_size) { - if (metric == cuvs::distance::DistanceType::InnerProduct && can_postprocess_dist) { - s_distances[i] = -s_distances[i]; - } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { - s_distances[i] = 1.0 - s_distances[i]; - } else if (metric == cuvs::distance::DistanceType::BitwiseHamming) { - s_distances[i] = 0.0; - int n1 = new_neighbors[row_id]; - int n2 = new_neighbors[col_id]; - // TODO: https://github.com/rapidsai/cuvs/issues/1127 - const uint8_t* data_n1 = reinterpret_cast(data) + n1 * data_dim; - const uint8_t* data_n2 = reinterpret_cast(data) + n2 * data_dim; - for (int d = 0; d < data_dim; d++) { - s_distances[i] += __popc(static_cast(data_n1[d] ^ data_n2[d]) & 0xff); - } - } else { // L2Expanded or L2SqrtExpanded - s_distances[i] = - l2_norms[new_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i]; - // for fp32 vs fp16 precision differences resulting in negative distances when distance - // should be 0 related issue: https://github.com/rapidsai/cuvs/issues/991 - s_distances[i] = s_distances[i] < 0.0f ? 0.0f : s_distances[i]; - if (!can_postprocess_dist && metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - s_distances[i] = sqrtf(s_distances[i]); - } - } - s_distances[i] = dist_epilogue(s_distances[i], new_neighbors[row_id], new_neighbors[col_id]); - } else { - s_distances[i] = std::numeric_limits::max(); - } - } + calculate_metric(s_distances, + new_neighbors, + list_new_size, + new_neighbors, + list_new_size, + data, + data_dim, + l2_norms, + metric, + dist_epilogue); __syncthreads(); for (int step = 0; step < raft::ceildiv(list_new_size, num_warps); step++) { @@ -733,39 +1028,17 @@ __launch_bounds__(BLOCK_SIZE) __syncthreads(); } - for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { - int row_id = i % SKEWED_MAX_NUM_BI_SAMPLES; - int col_id = i / SKEWED_MAX_NUM_BI_SAMPLES; - if (row_id < list_old_size && col_id < list_new_size) { - if (metric == cuvs::distance::DistanceType::InnerProduct && can_postprocess_dist) { - s_distances[i] = -s_distances[i]; - } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { - s_distances[i] = 1.0 - s_distances[i]; - } else if (metric == cuvs::distance::DistanceType::BitwiseHamming) { - s_distances[i] = 0.0; - int n1 = old_neighbors[row_id]; - int n2 = new_neighbors[col_id]; - // TODO: https://github.com/rapidsai/cuvs/issues/1127 - const uint8_t* data_n1 = reinterpret_cast(data) + n1 * data_dim; - const uint8_t* data_n2 = reinterpret_cast(data) + n2 * data_dim; - for (int d = 0; d < data_dim; d++) { - s_distances[i] += __popc(static_cast(data_n1[d] ^ data_n2[d]) & 0xff); - } - } else { // L2Expanded or L2SqrtExpanded - s_distances[i] = - l2_norms[old_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i]; - // for fp32 vs fp16 precision differences resulting in negative distances when distance - // should be 0 related issue: https://github.com/rapidsai/cuvs/issues/991 - s_distances[i] = s_distances[i] < 0.0f ? 0.0f : s_distances[i]; - if (!can_postprocess_dist && metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - s_distances[i] = sqrtf(s_distances[i]); - } - } - s_distances[i] = dist_epilogue(s_distances[i], old_neighbors[row_id], new_neighbors[col_id]); - } else { - s_distances[i] = std::numeric_limits::max(); - } - } + calculate_metric(s_distances, + new_neighbors, + list_new_size, + old_neighbors, + list_old_size, + data, + data_dim, + l2_norms, + metric, + dist_epilogue); + __syncthreads(); for (int step = 0; step < raft::ceildiv(MAX_NUM_BI_SAMPLES * 2, num_warps); step++) { @@ -856,6 +1129,7 @@ GnndGraph::GnndGraph(raft::resources const& res, template void GnndGraph::sample_graph_new(InternalID_t* new_neighbors, const size_t width) { + std::fill_n(h_graph_new.data_handle(), nrow * num_samples, std::numeric_limits::max()); #pragma omp parallel for for (size_t i = 0; i < nrow; i++) { auto list_new = h_graph_new.data_handle() + i * num_samples; @@ -914,6 +1188,11 @@ void GnndGraph::init_random_graph() template void GnndGraph::sample_graph(bool sample_new) { + std::fill_n(h_graph_old.data_handle(), nrow * num_samples, std::numeric_limits::max()); + if (sample_new) { + std::fill_n(h_graph_new.data_handle(), nrow * num_samples, std::numeric_limits::max()); + } + #pragma omp parallel for for (size_t i = 0; i < nrow; i++) { h_list_sizes_old.data_handle()[i].x = 0; @@ -1016,12 +1295,6 @@ GNND::GNND(raft::resources const& res, const BuildConfig& build NUM_SAMPLES), nrow_(build_config.max_dataset_size), ndim_(build_config.dataset_dim), - d_data_{raft::make_device_matrix<__half, size_t, raft::row_major>( - res, - nrow_, - build_config.metric == cuvs::distance::DistanceType::BitwiseHamming - ? (build_config.dataset_dim + 1) / 2 - : build_config.dataset_dim)}, l2_norms_{raft::make_device_vector(res, 0)}, graph_buffer_{ raft::make_device_matrix(res, nrow_, DEGREE_ON_DEVICE)}, @@ -1043,6 +1316,23 @@ GNND::GNND(raft::resources const& res, const BuildConfig& build { static_assert(NUM_SAMPLES <= 32); + using input_t = typename std::remove_const::type; + if (std::is_same_v && + (build_config.dist_comp_dtype == cuvs::neighbors::nn_descent::DIST_COMP_DTYPE::FP32 || + (build_config.dist_comp_dtype == cuvs::neighbors::nn_descent::DIST_COMP_DTYPE::AUTO && + build_config.dataset_dim <= 16))) { + // use fp32 distance computation for better precision with smaller dimension + d_data_float_.emplace( + raft::make_device_matrix(res, nrow_, ndim_)); + } else { + d_data_half_.emplace(raft::make_device_matrix( + res, + nrow_, + build_config.metric == cuvs::distance::DistanceType::BitwiseHamming + ? (build_config.dataset_dim + 1) / 2 + : build_config.dataset_dim)); + } + raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits::max()); auto graph_buffer_view = raft::make_device_matrix_view( reinterpret_cast(graph_buffer_.data_handle()), nrow_, DEGREE_ON_DEVICE); @@ -1072,10 +1362,13 @@ void GNND::add_reverse_edges(Index_t* graph_ptr, int2* list_sizes, cudaStream_t stream) { + raft::matrix::fill( + res, + raft::make_device_matrix_view(d_rev_graph_ptr, nrow_, DEGREE_ON_DEVICE), + std::numeric_limits::max()); add_rev_edges_kernel<<>>( graph_ptr, d_rev_graph_ptr, NUM_SAMPLES, list_sizes); - raft::copy( - h_rev_graph_ptr, d_rev_graph_ptr, nrow_ * NUM_SAMPLES, raft::resource::get_cuda_stream(res)); + raft::copy(h_rev_graph_ptr, d_rev_graph_ptr, nrow_ * NUM_SAMPLES, stream); } template @@ -1083,22 +1376,41 @@ template void GNND::local_join(cudaStream_t stream, DistEpilogue_t dist_epilogue) { raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits::max()); - local_join_kernel<<>>(graph_.h_graph_new.data_handle(), - h_rev_graph_new_.data_handle(), - d_list_sizes_new_.data_handle(), - h_graph_old_.data_handle(), - h_rev_graph_old_.data_handle(), - d_list_sizes_old_.data_handle(), - NUM_SAMPLES, - d_data_.data_handle(), - ndim_, - graph_buffer_.data_handle(), - dists_buffer_.data_handle(), - DEGREE_ON_DEVICE, - d_locks_.data_handle(), - l2_norms_.data_handle(), - build_config_.metric, - dist_epilogue); + if (d_data_float_.has_value()) { + local_join_kernel<<>>(graph_.h_graph_new.data_handle(), + h_rev_graph_new_.data_handle(), + d_list_sizes_new_.data_handle(), + h_graph_old_.data_handle(), + h_rev_graph_old_.data_handle(), + d_list_sizes_old_.data_handle(), + NUM_SAMPLES, + d_data_float_.value().data_handle(), + ndim_, + graph_buffer_.data_handle(), + dists_buffer_.data_handle(), + DEGREE_ON_DEVICE, + d_locks_.data_handle(), + l2_norms_.data_handle(), + build_config_.metric, + dist_epilogue); + } else { + local_join_kernel<<>>(graph_.h_graph_new.data_handle(), + h_rev_graph_new_.data_handle(), + d_list_sizes_new_.data_handle(), + h_graph_old_.data_handle(), + h_rev_graph_old_.data_handle(), + d_list_sizes_old_.data_handle(), + NUM_SAMPLES, + d_data_half_.value().data_handle(), + ndim_, + graph_buffer_.data_handle(), + dists_buffer_.data_handle(), + DEGREE_ON_DEVICE, + d_locks_.data_handle(), + l2_norms_.data_handle(), + build_config_.metric, + dist_epilogue); + } } template @@ -1124,7 +1436,12 @@ void GNND::build(Data_t* data, graph_.bloom_filter.set_nrow(nrow); update_counter_ = 0; graph_.h_graph = (InternalID_t*)output_graph; - raft::matrix::fill(res, d_data_.view(), static_cast<__half>(0)); + + if (d_data_float_.has_value()) { + raft::matrix::fill(res, d_data_float_.value().view(), static_cast(0)); + } else { + raft::matrix::fill(res, d_data_half_.value().view(), static_cast(0)); + } cudaPointerAttributes data_ptr_attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); @@ -1133,17 +1450,33 @@ void GNND::build(Data_t* data, cuvs::spatial::knn::detail::utils::batch_load_iterator vec_batches{ data, static_cast(nrow_), build_config_.dataset_dim, batch_size, stream}; for (auto const& batch : vec_batches) { - preprocess_data_kernel<<< - batch.size(), - raft::warp_size(), - sizeof(Data_t) * ceildiv(build_config_.dataset_dim, static_cast(raft::warp_size())) * - raft::warp_size(), - stream>>>(batch.data(), - d_data_.data_handle(), - build_config_.dataset_dim, - l2_norms_.data_handle(), - batch.offset(), - build_config_.metric); + if (d_data_float_.has_value()) { + preprocess_data_kernel<<(raft::warp_size())) * + raft::warp_size(), + stream>>>(batch.data(), + d_data_float_.value().data_handle(), + build_config_.dataset_dim, + l2_norms_.data_handle(), + batch.offset(), + build_config_.metric); + } else { + preprocess_data_kernel<<(raft::warp_size())) * + raft::warp_size(), + stream>>>(batch.data(), + d_data_half_.value().data_handle(), + build_config_.dataset_dim, + l2_norms_.data_handle(), + batch.offset(), + build_config_.metric); + } } graph_.clear(); @@ -1222,11 +1555,11 @@ void GNND::build(Data_t* data, graph_buffer_.data_handle(), nrow_ * DEGREE_ON_DEVICE, raft::resource::get_cuda_stream(res)); - raft::resource::sync_stream(res); raft::copy(dists_host_buffer_.data_handle(), dists_buffer_.data_handle(), nrow_ * DEGREE_ON_DEVICE, raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); graph_.sample_graph_new(graph_host_buffer_.data_handle(), DEGREE_ON_DEVICE); } diff --git a/cpp/src/neighbors/detail/nn_descent_gnnd.hpp b/cpp/src/neighbors/detail/nn_descent_gnnd.hpp index cc453b83de..b0799505f4 100644 --- a/cpp/src/neighbors/detail/nn_descent_gnnd.hpp +++ b/cpp/src/neighbors/detail/nn_descent_gnnd.hpp @@ -64,6 +64,8 @@ struct BuildConfig { float termination_threshold{0.0001}; size_t output_graph_degree{32}; cuvs::distance::DistanceType metric{cuvs::distance::DistanceType::L2Expanded}; + cuvs::neighbors::nn_descent::DIST_COMP_DTYPE dist_comp_dtype{ + cuvs::neighbors::nn_descent::DIST_COMP_DTYPE::AUTO}; }; template @@ -226,7 +228,8 @@ class GNND { size_t nrow_; size_t ndim_; - raft::device_matrix<__half, size_t, raft::row_major> d_data_; + std::optional> d_data_float_; + std::optional> d_data_half_; raft::device_vector l2_norms_; raft::device_matrix graph_buffer_; @@ -302,7 +305,8 @@ inline BuildConfig get_build_config(raft::resources const& res, .max_iterations = params.max_iterations, .termination_threshold = params.termination_threshold, .output_graph_degree = params.graph_degree, - .metric = params.metric}; + .metric = params.metric, + .dist_comp_dtype = params.dist_comp_dtype}; return build_config; } diff --git a/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pxd b/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pxd index 2595c9ff20..9568c88082 100644 --- a/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pxd +++ b/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pxd @@ -13,6 +13,10 @@ from cuvs.distance_type cimport cuvsDistanceType cdef extern from "cuvs/neighbors/nn_descent.h" nogil: + enum cuvsNNDescentDistCompDtype: + NND_DIST_COMP_AUTO = 0, + NND_DIST_COMP_FP32 = 1, + NND_DIST_COMP_FP16 = 2 ctypedef struct cuvsNNDescentIndexParams: cuvsDistanceType metric @@ -22,6 +26,7 @@ cdef extern from "cuvs/neighbors/nn_descent.h" nogil: size_t max_iterations float termination_threshold bool return_distances + cuvsNNDescentDistCompDtype dist_comp_dtype ctypedef cuvsNNDescentIndexParams* cuvsNNDescentIndexParams_t diff --git a/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx b/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx index afae653f35..7cb7f59b60 100644 --- a/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx +++ b/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx @@ -63,6 +63,12 @@ cdef class IndexParams: The delta at which nn-descent will terminate its iterations return_distances : bool Whether to return distances array + dist_comp_dtype : str, default = "auto" + Dtype to use for distance computation. + Supported dtypes are `auto`, `fp32`, and `fp16` + `auto` automatically determines the best dtype for distance computation based on the dataset dimensions. + `fp32` uses fp32 distance computation for better precision at the cost of performance and memory usage. This option is only valid when data type is fp32. + `fp16` uses fp16 distance computation for better performance and memory usage at the cost of precision. """ cdef cuvsNNDescentIndexParams* params @@ -81,7 +87,8 @@ cdef class IndexParams: intermediate_graph_degree=None, max_iterations=None, termination_threshold=None, - return_distances=None + return_distances=None, + dist_comp_dtype="auto" ): if metric is not None: self.params.metric = DISTANCE_TYPES[metric] @@ -96,6 +103,15 @@ cdef class IndexParams: if return_distances is not None: self.params.return_distances = return_distances + if dist_comp_dtype is "auto": + self.params.dist_comp_dtype = cuvsNNDescentDistCompDtype.NND_DIST_COMP_AUTO + elif dist_comp_dtype is "fp32": + self.params.dist_comp_dtype = cuvsNNDescentDistCompDtype.NND_DIST_COMP_FP32 + elif dist_comp_dtype is "fp16": + self.params.dist_comp_dtype = cuvsNNDescentDistCompDtype.NND_DIST_COMP_FP16 + else: + raise ValueError(f"Invalid dist_comp_dtype: {dist_comp_dtype}. Supported options are 'auto', 'fp32', and 'fp16'.") + @property def metric(self): return DISTANCE_NAMES[self.params.metric] diff --git a/python/cuvs/cuvs/tests/test_nn_descent.py b/python/cuvs/cuvs/tests/test_nn_descent.py index 862f82dc12..4463142fcb 100644 --- a/python/cuvs/cuvs/tests/test_nn_descent.py +++ b/python/cuvs/cuvs/tests/test_nn_descent.py @@ -4,7 +4,9 @@ import numpy as np import pytest +import cupy as cp from pylibraft.common import device_ndarray +from sklearn.datasets import make_blobs from cuvs.neighbors import brute_force, nn_descent from cuvs.tests.ann_utils import calc_recall @@ -56,3 +58,38 @@ def test_nn_descent( assert distances.shape == graph.shape assert calc_recall(graph, bfknn_graph) > 0.9 + + +@pytest.mark.parametrize("n_cols", [2, 17, 32]) +@pytest.mark.parametrize("dist_comp_dtype", ["auto", "fp32", "fp16"]) +@pytest.mark.parametrize("dtype", [np.float32, np.float16]) +def test_nn_descent_dist_comp_dtype(n_cols, dist_comp_dtype, dtype): + metric = "sqeuclidean" + graph_degree = 32 + n_rows = 100_000 + + X, _ = make_blobs( + n_samples=n_rows, n_features=n_cols, centers=10, random_state=42 + ) + X = X.astype(dtype) + + params = nn_descent.IndexParams( + metric=metric, + graph_degree=graph_degree, + return_distances=True, + dist_comp_dtype=dist_comp_dtype, + ) + + index = nn_descent.build(params, X) + nnd_indices = index.graph + + gpu_X = cp.asarray(X) + index = brute_force.build(gpu_X, metric=metric) + _, bf_indices = brute_force.search(index, gpu_X, k=graph_degree) + bf_indices = bf_indices.copy_to_host() + + if n_cols <= 16 and dist_comp_dtype == "fp16" and dtype == np.float32: + # for small dim, if data is fp32 but dist_comp_dtype is fp16, the recall will be low + assert calc_recall(nnd_indices, bf_indices) < 0.7 + else: + assert calc_recall(nnd_indices, bf_indices) > 0.9