Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bc8a196
[No Miss Sync] Remove a miss-sync in batched_memset with duplicating …
JigaoLuo May 22, 2025
9bb5399
[No Miss Sync] Remove a miss-sync in hostdevice_vector::element funct…
JigaoLuo May 23, 2025
14721d0
[No Miss Sync] Remove a miss-sync in hostdevice_vector::element funct…
JigaoLuo May 23, 2025
02cf20e
[No Miss Sync] Remove a miss-sync due to rmm::device_scalar::value fu…
JigaoLuo May 23, 2025
a040bb8
[No Miss Sync] Update comments
JigaoLuo May 23, 2025
6bb4263
[No Miss Sync] Remove a miss-sync inside cudf::reduction::detail::reduce
JigaoLuo May 23, 2025
7a75b63
[No Miss Sync] Add a TODO due to BITMASK null_mask_partition_bulk test
JigaoLuo May 24, 2025
40c35f0
[No Miss Sync] Remove miss-sync from some thrust::reduce causing miss…
JigaoLuo May 24, 2025
bfee8ce
[No Miss Sync] Remove miss-sync from some thrust::reduce causing miss…
JigaoLuo May 24, 2025
de2c750
[No Miss Sync] Remove miss-sync from thrust logical via thrust::trans…
JigaoLuo May 25, 2025
5316923
[No Miss Sync] Remove miss-sync from thrust::reduce_by_key via CUB re…
JigaoLuo May 25, 2025
8c0ca98
[No Miss Sync] Remove miss-sync from thrust::reduce_by_key via CUB re…
JigaoLuo May 25, 2025
ad2bd94
[No Miss Sync] Remove miss-sync from thrust::reduce_by_key via CUB re…
JigaoLuo May 25, 2025
c11a00e
[No Miss Sync] Remove miss-sync from thrust::copy_if via CUB rewriting
JigaoLuo May 25, 2025
c4bfa96
[No Miss Sync] Remove miss-sync: dummy comments
JigaoLuo May 25, 2025
c9ac9fb
[No Miss Sync] Remove miss-sync due to rmm::device_uvector::element()
JigaoLuo May 25, 2025
75bd459
[No Miss Sync] Remove miss-sync in WriteFinalOffsets with duplicating…
JigaoLuo May 25, 2025
59983c8
[No Miss Sync] Dummy update
JigaoLuo May 25, 2025
bdc814f
[No Miss Sync] Refactor with cudf::detail::device_scalar
JigaoLuo May 25, 2025
87a55b2
[No Miss Sync] Refactor with batched_memset.hpp (and vector_factories…
JigaoLuo May 26, 2025
b109204
[No Miss Sync] Refactor with page_data.cu (and vector_factories.hpp l…
JigaoLuo May 26, 2025
24f02c6
[No Miss Sync] Refactor with page_string_decode.cu (and vector_factor…
JigaoLuo May 26, 2025
9b3996f
[No Miss Sync] Refactor with page_data.cu (and vector_factories.hpp l…
JigaoLuo May 26, 2025
e05f53e
[No Miss Sync] Refactor with page_string_decode.cu (and vector_factor…
JigaoLuo May 26, 2025
993d16e
[No Miss Sync] Refactor with reader_impl_chunking.cu (and vector_fact…
JigaoLuo May 26, 2025
506739e
[No Miss Sync] Refactor with page_data.cu (and vector_factories.hpp l…
JigaoLuo May 26, 2025
a99869f
[No Miss Sync] Refactor with reader_impl_preprocess.cu (and vector_fa…
JigaoLuo May 26, 2025
1ec74c9
[No Miss Sync] Refactor with hostdevice_vector.hpp (and vector_factor…
JigaoLuo May 26, 2025
ca71572
[No Miss Sync] Refactor with cuda_memcpy.hpp and vector_factories.hpp
JigaoLuo May 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cpp/include/cudf/detail/device_scalar.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,15 +54,15 @@ class device_scalar : public rmm::device_scalar<T> {
explicit device_scalar(
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref())
: rmm::device_scalar<T>(stream, mr), bounce_buffer{make_host_vector<T>(1, stream)}
: rmm::device_scalar<T>(stream, mr), bounce_buffer{make_pinned_vector<T>(1, stream)}
{
}

explicit device_scalar(
T const& initial_value,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref())
: rmm::device_scalar<T>(stream, mr), bounce_buffer{make_host_vector<T>(1, stream)}
: rmm::device_scalar<T>(stream, mr), bounce_buffer{make_pinned_vector<T>(1, stream)}
{
bounce_buffer[0] = initial_value;
cuda_memcpy_async<T>(device_span<T>{this->data(), 1}, bounce_buffer, stream);
Expand All @@ -71,7 +71,7 @@ class device_scalar : public rmm::device_scalar<T> {
device_scalar(device_scalar const& other,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref())
: rmm::device_scalar<T>(other, stream, mr), bounce_buffer{make_host_vector<T>(1, stream)}
: rmm::device_scalar<T>(other, stream, mr), bounce_buffer{make_pinned_vector<T>(1, stream)}
{
}

Expand Down
6 changes: 3 additions & 3 deletions cpp/include/cudf/detail/sizes_to_offsets_iterator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
#pragma once

#include <cudf/column/column_factories.hpp>
#include <cudf/detail/device_scalar.hpp>
#include <cudf/detail/iterator.cuh>
#include <cudf/types.hpp>
#include <cudf/utilities/memory_resource.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/exec_policy.hpp>

#include <cuda/functional>
Expand Down Expand Up @@ -203,7 +203,7 @@ struct sizes_to_offsets_iterator {
* auto begin = // begin input iterator
* auto end = // end input iterator
* auto result = rmm::device_uvector(std::distance(begin,end), stream);
* auto last = rmm::device_scalar<int64_t>(0, stream);
* auto last = cudf::detail::device_scalar<int64_t>(0, stream);
* auto itr = make_sizes_to_offsets_iterator(result.begin(),
* result.end(),
* last.data());
Expand Down Expand Up @@ -270,7 +270,7 @@ auto sizes_to_offsets(SizesIterator begin,
"Only numeric types are supported by sizes_to_offsets");

using LastType = std::conditional_t<std::is_signed_v<SizeType>, int64_t, uint64_t>;
auto last_element = rmm::device_scalar<LastType>(0, stream);
auto last_element = cudf::detail::device_scalar<LastType>(0, stream);
auto output_itr =
make_sizes_to_offsets_iterator(result, result + std::distance(begin, end), last_element.data());
// This function uses the type of the initialization parameter as the accumulator type
Expand Down
9 changes: 7 additions & 2 deletions cpp/include/cudf/detail/utilities/batched_memset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ void batched_memset(std::vector<cudf::device_span<T>> const& bufs,
// define task and bytes parameters
auto const num_bufs = bufs.size();

auto host_pinned_bufs = cudf::detail::make_pinned_vector_async<cudf::device_span<T>>(
bufs, stream); // host pageble -> host pinned memory

// copy bufs into device memory and then get sizes
auto gpu_bufs =
cudf::detail::make_device_uvector_async(bufs, stream, cudf::get_current_device_resource_ref());
auto gpu_bufs = cudf::detail::make_device_uvector_async(
host_pinned_bufs,
stream,
cudf::get_current_device_resource_ref()); // host pinned -> device memory

// get a vector with the sizes of all buffers
auto sizes = thrust::make_transform_iterator(
Expand Down
38 changes: 37 additions & 1 deletion cpp/include/cudf/detail/utilities/cuda_memcpy.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -71,6 +71,26 @@ void cuda_memcpy_async(host_span<T> dst, device_span<T const> src, rmm::cuda_str
stream);
}

/**
* @brief Asynchronously copies data from host memory to host memory.
*
* Implementation may use different strategies depending on the size and type of host data.
*
* @param dst Destination host memory
* @param src Source device memory
* @param stream CUDA stream used for the copy
*/
template <typename T>
void cuda_memcpy_async(host_span<T> dst, host_span<T const> src, rmm::cuda_stream_view stream)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a cuda_memcpy_async named utility for host-to-host copies seems super fishy to me 😄

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It kind of depends, right? Consider if the src span is pinned memory and was populated by a cudaMemcpyAsync(..., DtoH, stream), that copy is async, stream-ordered. Now we copy from src to dst. We either need to sync the stream, or else use cudaMemcpyAsync(..., HtoH, stream) to ensure that things continue in stream order.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wence- Thanks. It is exactly what I have in mind. In my opinion, cudaMemcpyHostToHost should be the only option to ensure CUDA stream semantics.
But we can discuss where to put this function, if this file is not a fit.

{
CUDF_EXPECTS(dst.size() == src.size(), "Mismatched sizes in cuda_memcpy_async");
cuda_memcpy_async_impl(dst.data(),
src.data(),
src.size_bytes(),
host_memory_kind::PAGEABLE, // use copy_pageable for host-to-host copy
stream);
}

/**
* @brief Synchronously copies data from host to device memory.
*
Expand Down Expand Up @@ -103,5 +123,21 @@ void cuda_memcpy(host_span<T> dst, device_span<T const> src, rmm::cuda_stream_vi
stream.synchronize();
}

/**
* @brief Synchronously copies data from host memory to host memory.
*
* Implementation may use different strategies depending on the size and type of host data.
*
* @param dst Destination host memory
* @param src Source device memory
* @param stream CUDA stream used for the copy
*/
template <typename T>
void cuda_memcpy(host_span<T> dst, host_span<T const> src, rmm::cuda_stream_view stream)
{
cuda_memcpy_async(dst, src, stream);
stream.synchronize();
}

} // namespace detail
} // namespace CUDF_EXPORT cudf
115 changes: 111 additions & 4 deletions cpp/include/cudf/detail/utilities/vector_factories.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ make_device_uvector_sync(Container const& c,
* @note This function does not synchronize `stream` after the copy.
*
* @tparam T The type of the data to copy
* @param source_data The device data to copy
* @param source_data The device_span of data to copy
* @param stream The stream on which to perform the copy
* @return The data copied to the host
*/
Expand Down Expand Up @@ -396,7 +396,7 @@ std::vector<typename Container::value_type> make_std_vector_async(Container cons
* @note This function does a synchronize on `stream` after the copy.
*
* @tparam T The type of the data to copy
* @param source_data The device data to copy
* @param source_data The device_span of data to copy
* @param stream The stream on which to perform the copy
* @return The data copied to the host
*/
Expand Down Expand Up @@ -499,7 +499,7 @@ host_vector<T> make_empty_host_vector(size_t capacity, rmm::cuda_stream_view str
* using a pinned memory resource.
*
* @tparam T The type of the data to copy
* @param source_data The device data to copy
* @param source_data The device_span of data to copy
* @param stream The stream on which to perform the copy
* @return The data copied to the host
*/
Expand Down Expand Up @@ -542,7 +542,7 @@ host_vector<typename Container::value_type> make_host_vector_async(Container con
* using a pinned memory resource.
*
* @tparam T The type of the data to copy
* @param source_data The device data to copy
* @param source_data The device_span of data to copy
* @param stream The stream on which to perform the copy
* @return The data copied to the host
*/
Expand Down Expand Up @@ -637,6 +637,113 @@ host_vector<T> make_pinned_vector(size_t size, rmm::cuda_stream_view stream)
return result;
}

/**
* @brief Asynchronously construct a `cudf::detail::host_vector` containing a copy of data from a
* `device_span`
*
* @note This function does not synchronize `stream` after the copy. The returned vector use
* a pinned memory resource.
*
* @tparam T The type of the data to copy
* @param source_data The device_span of data to copy
* @param stream The stream on which to perform the copy
* @return The data copied to the host
*/
template <typename T>
host_vector<T> make_pinned_vector_async(device_span<T const> v, rmm::cuda_stream_view stream)
{
auto result = make_pinned_vector<T>(v.size(), stream);
cuda_memcpy_async<T>(result, v, stream);
return result;
}

/**
*
* @brief Asynchronously construct a `cudf::detail::host_vector` containing a copy of data from a
* device container
*
* @note This function does not synchronize `stream` after the copy. The returned vector uses
* a pinned memory resource.
*
* @tparam Container The type of the container to copy from
* @tparam T The type of the data to copy
* @param c The input device container from which to copy
* @param stream The stream on which to perform the copy
* @return The data copied to the host
*/
template <
typename Container,
std::enable_if_t<
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use CUDF_ENABLE_IF instead of this if possible

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: We have moved to using requires clause now

std::is_convertible_v<Container, device_span<typename Container::value_type const>>>* = nullptr>
host_vector<typename Container::value_type> make_pinned_vector_async(Container const& c,
rmm::cuda_stream_view stream)
{
return make_pinned_vector_async(device_span<typename Container::value_type const>{c}, stream);
}

/**
* @brief Asynchronously construct a `cudf::detail::host_vector` containing a copy of data from a
* `host_span`
*
* @note This function does not synchronize `stream` after the copy. The returned vector use
* a pinned memory resource.
*
* @tparam T The type of the data to copy
* @param source_data The host_span of data to copy
* @param stream The stream on which to perform the copy
* @return The data copied to the host
*/
template <typename T>
host_vector<T> make_pinned_vector_async(host_span<T const> v, rmm::cuda_stream_view stream)
{
auto result = make_pinned_vector<T>(v.size(), stream);
cuda_memcpy_async<T>(result, v, stream);
return result;
}

/**
* @brief Synchronously construct a `cudf::detail::host_vector` containing a copy of data from a
* `device_span`
*
* @note This function does a synchronize on `stream` after the copy. The returned vector used
* a pinned memory resource.
*
* @tparam T The type of the data to copy
* @param source_data The device_span of data to copy
* @param stream The stream on which to perform the copy
* @return The data copied to the host
*/
template <typename T>
host_vector<T> make_pinned_vector(device_span<T const> v, rmm::cuda_stream_view stream)
{
auto result = make_pinned_vector_async(v, stream);
stream.synchronize();
return result;
}

/**
* @brief Synchronously construct a `cudf::detail::host_vector` containing a copy of data from a
* device container
*
* @note This function synchronizes `stream` after the copy. The returned vector used
* a pinned memory resource.
*
* @tparam Container The type of the container to copy from
* @tparam T The type of the data to copy
* @param c The input device container from which to copy
* @param stream The stream on which to perform the copy
* @return The data copied to the host
*/
template <
typename Container,
std::enable_if_t<
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

std::is_convertible_v<Container, device_span<typename Container::value_type const>>>* = nullptr>
host_vector<typename Container::value_type> make_pinned_vector(Container const& c,
rmm::cuda_stream_view stream)
{
return make_pinned_vector(device_span<typename Container::value_type const>{c}, stream);
}

/**
* @copydoc cudf::detail::make_pinned_vector(size_t size, rmm::cuda_stream_view stream)
*
Expand Down
81 changes: 80 additions & 1 deletion cpp/include/cudf/reduction/detail/reduction.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,16 @@ std::unique_ptr<scalar> reduce(InputIterator d_in,
{
auto const binary_op = cudf::detail::cast_functor<OutputType>(op.get_binary_op());
auto const initial_value = init.value_or(op.template get_identity<OutputType>());
auto dev_result = rmm::device_scalar<OutputType>{initial_value, stream, mr};
auto host_scalar =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is just one value being copied to device, is there any performance advantage from first creating a pinned vector of size 1 and then using that to copy to device?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main goal of the code changes is to get rid of any cudaMemcpy using pageable memory.

Here, the old code did a copy from the stack (pageable) to the device, and it can cause “miss-sync” by messing with other running CUDA streams. Sorry, I didn’t explain this better up front!

cudf::detail::make_pinned_vector_async<OutputType>(1, stream); // as host pinned memory
CUDF_CUDA_TRY(cudaMemcpyAsync(
host_scalar.data(), &initial_value, sizeof(OutputType), cudaMemcpyHostToHost, stream.value()));
Comment on lines +72 to +73
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Host to Host copy

rmm::device_scalar<OutputType> dev_result{stream, mr};
CUDF_CUDA_TRY(cudaMemcpyAsync(dev_result.data(),
host_scalar.data(),
sizeof(OutputType),
cudaMemcpyHostToDevice,
stream.value())); // device <- host pinned

// Allocate temporary storage
rmm::device_buffer d_temp_storage;
Expand Down Expand Up @@ -231,6 +240,76 @@ std::unique_ptr<scalar> reduce(InputIterator d_in,
return std::unique_ptr<scalar>(result);
}

/**
* @brief Compute the specified by-key reduction over the input range of elements.
*
* @param[in] d_keys_in the begin iterator of input keys
* @param[out] d_unique_out the begin iterator of output keys (one key per run)
* @param[in] d_values_in the begin iterator of input values
* @param[out] d_aggregates_out the begin iterator of output aggregated values (one aggregate
* per run)
* @param[out] d_num_runs_out the pointer of total number of runs encountered (i.e., the length
* of d_unique_out)
* @param[in] op the reduction operator
* @param[in] num_items the number of key+value pairs (i.e., the length of d_in_keys and
* d_in_values)
* @param[in] stream CUDA stream used for device memory operations and kernel launches
* @param[in] mr Device memory resource used to allocate the returned scalar's
* device memory
*
* @tparam Op the reduction operator with device binary operator
* @tparam KeysInputIteratorT the input keys iterator
* @tparam UniqueOutputIteratorT the output keys iterator
* @tparam ValuesInputIteratorT the input values iterator
* @tparam AggregatesOutputIteratorT the output values iterator
* @tparam OutputType the output type of reduction
*/
template <typename Op,
typename KeysInputIteratorT,
typename UniqueOutputIteratorT,
typename ValuesInputIteratorT,
typename AggregatesOutputIteratorT,
typename OutputType = cuda::std::iter_value_t<KeysInputIteratorT>,
std::enable_if_t<is_fixed_width<OutputType>() &&
not cudf::is_fixed_point<OutputType>()>* = nullptr>
void reduce_by_key(KeysInputIteratorT d_keys_in,
UniqueOutputIteratorT d_unique_out,
ValuesInputIteratorT d_values_in,
AggregatesOutputIteratorT d_aggregates_out,
cudf::size_type* d_num_runs_out,
op::simple_op<Op> op,
cudf::size_type num_items,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto const binary_op = cudf::detail::cast_functor<OutputType>(op.get_binary_op());
// Allocate temporary storage
rmm::device_buffer d_temp_storage;
size_t temp_storage_bytes = 0;
cub::DeviceReduce::ReduceByKey(d_temp_storage.data(),
temp_storage_bytes,
d_keys_in,
d_unique_out,
d_values_in,
d_aggregates_out,
d_num_runs_out,
binary_op,
num_items,
stream.value());
d_temp_storage = rmm::device_buffer{temp_storage_bytes, stream, mr};

// Run reduction
cub::DeviceReduce::ReduceByKey(d_temp_storage.data(),
temp_storage_bytes,
d_keys_in,
d_unique_out,
d_values_in,
d_aggregates_out,
d_num_runs_out,
binary_op,
num_items,
stream.value());
}
} // namespace detail
} // namespace reduction
} // namespace cudf
Loading