Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions cpp/src/distance/detail/sparse/coo_spmv_kernel.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
#pragma once

#include <raft/core/detail/macros.hpp>
#include <raft/util/cuda_dev_essentials.cuh>

#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
Expand Down Expand Up @@ -131,10 +132,10 @@ RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy,

extern __shared__ char smem[];

typename strategy_t::smem_type A = (typename strategy_t::smem_type)(smem);
typename warp_reduce::TempStorage* temp_storage = (typename warp_reduce::TempStorage*)(A + dim);
void* A = smem;
typename warp_reduce::TempStorage* temp_storage = (typename warp_reduce::TempStorage*)((char*)A + dim);

auto inserter = strategy.init_insert(A, dim);
auto map_ref = strategy.init_map(A, dim);

__syncthreads();

Expand All @@ -145,13 +146,11 @@ RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy,

// Convert current row vector in A to dense
for (int i = tid; i <= (stop_offset_a - start_offset_a); i += blockDim.x) {
strategy.insert(inserter, indicesA[start_offset_a + i], dataA[start_offset_a + i]);
strategy.insert(map_ref, indicesA[start_offset_a + i], dataA[start_offset_a + i]);
}

__syncthreads();

auto finder = strategy.init_find(A, dim);

if (cur_row_a > m || cur_chunk_offset > n_blocks_per_row) return;
if (ind >= nnz_b) return;

Expand All @@ -177,7 +176,7 @@ RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy,
auto in_bounds = indptrA.check_indices_bounds(start_index_a, stop_index_a, index_b);

if (in_bounds) {
value_t a_col = strategy.find(finder, index_b);
value_t a_col = strategy.find(map_ref, index_b);
if (!rev || a_col == 0.0) { c = product_func(a_col, dataB[ind]); }
}
}
Expand Down Expand Up @@ -215,7 +214,7 @@ RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy,
auto index_b = indicesB[ind];
auto in_bounds = indptrA.check_indices_bounds(start_index_a, stop_index_a, index_b);
if (in_bounds) {
value_t a_col = strategy.find(finder, index_b);
value_t a_col = strategy.find(map_ref, index_b);

if (!rev || a_col == 0.0) { c = accum_func(c, product_func(a_col, dataB[ind])); }
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@

#include "base_strategy.cuh"

#include <raft/util/cuda_dev_essentials.cuh> // raft::ceildiv
#include <raft/util/cuda_dev_essentials.cuh>

namespace cuvs {
namespace distance {
Expand All @@ -28,9 +28,7 @@ namespace sparse {
template <typename value_idx, typename value_t, int tpb>
class dense_smem_strategy : public coo_spmv_strategy<value_idx, value_t, tpb> {
public:
using smem_type = value_t*;
using insert_type = smem_type;
using find_type = smem_type;
using map_type = value_t*;

dense_smem_strategy(const distances_config_t<value_idx, value_t>& config_)
: coo_spmv_strategy<value_idx, value_t, tpb>(config_)
Expand Down Expand Up @@ -94,25 +92,21 @@ class dense_smem_strategy : public coo_spmv_strategy<value_idx, value_t, tpb> {
n_blocks_per_row);
}

__device__ inline insert_type init_insert(smem_type cache, const value_idx& cache_size)
__device__ inline map_type init_map(void* storage, const value_idx& cache_size)
{
auto cache = static_cast<value_t*>(storage);
for (int k = threadIdx.x; k < cache_size; k += blockDim.x) {
cache[k] = 0.0;
}
return cache;
}

__device__ inline void insert(insert_type cache, const value_idx& key, const value_t& value)
__device__ inline void insert(map_type& cache, const value_idx& key, const value_t& value)
{
cache[key] = value;
}

__device__ inline find_type init_find(smem_type cache, const value_idx& cache_size)
{
return cache;
}

__device__ inline value_t find(find_type cache, const value_idx& key) { return cache[key]; }
__device__ inline value_t find(map_type& cache, const value_idx& key) { return cache[key]; }
};

} // namespace sparse
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,11 +20,15 @@

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/util/cuda_dev_essentials.cuh>

#include <cuco/static_map.cuh>
#include <thrust/copy.h>
#include <thrust/iterator/counting_iterator.h>

#include <cooperative_groups.h>
#include <rmm/device_uvector.hpp>

// this is needed by cuco as key, value must be bitwise comparable.
// compilers don't declare float/double as bitwise comparable
// but that is too strict
Expand All @@ -43,11 +47,19 @@ namespace sparse {
template <typename value_idx, typename value_t, int tpb>
class hash_strategy : public coo_spmv_strategy<value_idx, value_t, tpb> {
public:
using insert_type = typename cuco::legacy::
static_map<value_idx, value_t, cuda::thread_scope_block>::device_mutable_view;
using smem_type = typename insert_type::slot_type*;
using find_type =
typename cuco::legacy::static_map<value_idx, value_t, cuda::thread_scope_block>::device_view;
static constexpr value_idx empty_key_sentinel = value_idx{-1};
static constexpr value_t empty_value_sentinel = value_t{0};
using probing_scheme_type = cuco::linear_probing<1, cuco::murmurhash3_32<value_idx>>;
using storage_ref_type = cuco::bucket_storage_ref<cuco::pair<value_idx, value_t>, 1, cuco::extent<int>>;
using map_type = cuco::static_map_ref<
value_idx,
value_t,
cuda::thread_scope_block,
cuda::std::equal_to<value_idx>,
probing_scheme_type,
storage_ref_type,
cuco::op::insert_tag,
cuco::op::find_tag>;

hash_strategy(const distances_config_t<value_idx, value_t>& config_,
float capacity_threshold_ = 0.5,
Expand Down Expand Up @@ -231,32 +243,33 @@ class hash_strategy : public coo_spmv_strategy<value_idx, value_t, tpb> {
}
}

__device__ inline insert_type init_insert(smem_type cache, const value_idx& cache_size)
__device__ inline map_type init_map(void* storage, const value_idx& cache_size)
{
return insert_type::make_from_uninitialized_slots(cooperative_groups::this_thread_block(),
cache,
cache_size,
cuco::empty_key{value_idx{-1}},
cuco::empty_value{value_t{0}});
auto map_ref = map_type{
cuco::empty_key<value_idx>{empty_key_sentinel},
cuco::empty_value<value_t>{empty_value_sentinel},
cuda::std::equal_to<value_idx>{},
probing_scheme_type{},
cuco::cuda_thread_scope<cuda::thread_scope_block>{},
storage_ref_type{cuco::extent<int>{cache_size}, static_cast<typename storage_ref_type::value_type*>(storage)}};
map_ref.initialize(cooperative_groups::this_thread_block());

return map_ref;
}

__device__ inline void insert(insert_type cache, const value_idx& key, const value_t& value)
__device__ inline void insert(map_type& map_ref, const value_idx& key, const value_t& value)
{
auto success = cache.insert(cuco::pair<value_idx, value_t>(key, value));
map_ref.insert(cuco::pair{key, value});
}

__device__ inline find_type init_find(smem_type cache, const value_idx& cache_size)
{
return find_type(
cache, cache_size, cuco::empty_key{value_idx{-1}}, cuco::empty_value{value_t{0}});
}
// Note: init_find is now merged with init_map since the new API uses the same ref for both operations

__device__ inline value_t find(find_type cache, const value_idx& key)
__device__ inline value_t find(map_type& map_ref, const value_idx& key)
{
auto a_pair = cache.find(key);
auto a_pair = map_ref.find(key);

value_t a_col = 0.0;
if (a_pair != cache.end()) { a_col = a_pair->second; }
if (a_pair != map_ref.end()) { a_col = a_pair->second; }
return a_col;
}

Expand All @@ -282,7 +295,7 @@ class hash_strategy : public coo_spmv_strategy<value_idx, value_t, tpb> {
inline static int get_map_size()
{
return (raft::getSharedMemPerBlock() - ((tpb / raft::warp_size()) * sizeof(value_t))) /
sizeof(typename insert_type::slot_type);
sizeof(cuco::pair<value_idx, value_t>);
}

private:
Expand Down