Skip to content

Commit e58d119

Browse files
Implement the new tuning API for DispatchHistogram (#8212)
Fixes: #7637 Fixes: #7475
1 parent 0599ec3 commit e58d119

File tree

6 files changed

+1086
-1104
lines changed

6 files changed

+1086
-1104
lines changed

c/parallel/src/histogram.cu

Lines changed: 32 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,14 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include <cub/detail/launcher/cuda_driver.cuh>
12-
#include <cub/detail/ptx-json-parser.cuh>
1312
#include <cub/device/device_histogram.cuh>
1413

15-
#include <cuda/std/algorithm>
16-
1714
#include <format>
1815
#include <limits>
16+
#include <sstream>
1917
#include <vector>
2018

2119
#include "cccl/c/types.h"
22-
#include "cub/util_type.cuh"
2320
#include "kernels/iterators.h"
2421
#include "util/context.h"
2522
#include "util/indirect_arg.h"
@@ -39,35 +36,6 @@ struct samples_iterator_t;
3936

4037
namespace histogram
4138
{
42-
struct histogram_runtime_tuning_policy
43-
{
44-
cub::detail::RuntimeHistogramAgentPolicy histogram;
45-
46-
auto Histogram() const
47-
{
48-
return histogram;
49-
}
50-
51-
CUB_RUNTIME_FUNCTION int BlockThreads() const
52-
{
53-
return histogram.BlockThreads();
54-
}
55-
56-
CUB_RUNTIME_FUNCTION int PixelsPerThread() const
57-
{
58-
return histogram.PixelsPerThread();
59-
}
60-
61-
using HistogramPolicy = cub::detail::RuntimeHistogramAgentPolicy;
62-
using MaxPolicy = histogram_runtime_tuning_policy;
63-
64-
template <typename F>
65-
cudaError_t Invoke(int, F& op)
66-
{
67-
return op.template Invoke<histogram_runtime_tuning_policy>(*this);
68-
}
69-
};
70-
7139
struct histogram_kernel_source
7240
{
7341
cccl_device_histogram_build_result_t& build;
@@ -292,8 +260,25 @@ try
292260
const std::string samples_iterator_src =
293261
make_kernel_input_iterator(offset_cpp, samples_iterator_name, sample_cpp, d_samples);
294262

295-
std::string policy_hub_expr = std::format(
296-
"cub::detail::histogram::policy_hub<{}, {}, {}, {}, {}>",
263+
const bool sample_is_primitive = d_samples.value_type.type != CCCL_STORAGE; // TODO(bgruber): how to check if sample
264+
// is primitive?
265+
const auto policy_sel = cub::detail::histogram::policy_selector{
266+
sample_is_primitive,
267+
static_cast<int>(d_samples.value_type.size),
268+
static_cast<int>(d_output_histograms.value_type.size),
269+
static_cast<int>(d_samples.value_type.size),
270+
num_channels,
271+
num_active_channels,
272+
is_evenly_segmented};
273+
274+
const auto arch_id = cuda::to_arch_id(cuda::compute_capability{cc_major, cc_minor});
275+
const auto active_policy = policy_sel(arch_id);
276+
277+
std::stringstream policy_sel_str;
278+
policy_sel_str << active_policy;
279+
280+
std::string policy_selector_expr = std::format(
281+
"cub::detail::histogram::policy_selector_from_types<{}, {}, {}, {}, {}>",
297282
sample_cpp,
298283
counter_cpp,
299284
num_channels,
@@ -311,18 +296,16 @@ struct __align__({1}) storage_t {{
311296
char data[{0}];
312297
}};
313298
{2}
314-
using device_histogram_policy = {3}::MaxPolicy;
315-
316-
#include <cub/detail/ptx-json/json.cuh>
317-
__device__ consteval auto& policy_generator() {{
318-
return ptx_json::id<ptx_json::string("device_histogram_policy")>()
319-
= cub::detail::histogram::HistogramPolicyWrapper<device_histogram_policy::ActivePolicy>::EncodedPolicy();
320-
}}
299+
using device_histogram_policy = {3};
300+
using namespace cub;
301+
using namespace cub::detail::histogram;
302+
static_assert(device_histogram_policy()(::cuda::arch_id{{CUB_PTX_ARCH / 10}}) == {4}, "Host generated and JIT compiled policy mismatch");
321303
)XXX",
322304
d_samples.value_type.size, // 0
323305
d_samples.value_type.alignment, // 1
324306
samples_iterator_src, // 2
325-
policy_hub_expr // 3
307+
policy_selector_expr, // 3
308+
policy_sel_str.view() // 4
326309
);
327310

328311
#if false // CCCL_DEBUGGING_SWITCH
@@ -368,7 +351,6 @@ __device__ consteval auto& policy_generator() {{
368351
"-dlto",
369352
"-default-device",
370353
"-DCUB_DISABLE_CDP",
371-
"-DCUB_ENABLE_POLICY_PTX_JSON",
372354
"-std=c++20"};
373355

374356
cccl::detail::extend_args_with_build_config(args, config);
@@ -398,12 +380,6 @@ __device__ consteval auto& policy_generator() {{
398380
check(cuLibraryGetKernel(&build_ptr->init_kernel, build_ptr->library, init_kernel_lowered_name.c_str()));
399381
check(cuLibraryGetKernel(&build_ptr->sweep_kernel, build_ptr->library, sweep_kernel_lowered_name.c_str()));
400382

401-
nlohmann::json runtime_policy =
402-
cub::detail::ptx_json::parse("device_histogram_policy", {result.data.get(), result.size});
403-
404-
using cub::detail::RuntimeHistogramAgentPolicy;
405-
auto histogram_policy = RuntimeHistogramAgentPolicy::from_json(runtime_policy, "HistogramPolicy");
406-
407383
build_ptr->cc = cc;
408384
build_ptr->cubin = (void*) result.data.release();
409385
build_ptr->cubin_size = result.size;
@@ -413,7 +389,7 @@ __device__ consteval auto& policy_generator() {{
413389
build_ptr->num_active_channels = num_active_channels;
414390
build_ptr->may_overflow = false; // This is set in cccl_device_histogram_even_impl so that kernel source can access
415391
// it later.
416-
build_ptr->runtime_policy = new histogram::histogram_runtime_tuning_policy{histogram_policy};
392+
build_ptr->runtime_policy = new cub::detail::histogram::policy_selector{policy_sel};
417393

418394
return CUDA_SUCCESS;
419395
}
@@ -477,7 +453,7 @@ CUresult cccl_device_histogram_even_impl(
477453
indirect_arg_t, // CounterT
478454
indirect_arg_t, // LevelT
479455
OffsetT, // OffsetT
480-
histogram::histogram_runtime_tuning_policy, // PolicyHub
456+
cub::detail::histogram::policy_selector, // PolicySelector
481457
indirect_arg_t, // SampleT
482458
histogram::histogram_kernel_source, // KernelSource
483459
cub::detail::CudaDriverLauncherFactory // KernelLauncherFactory
@@ -493,9 +469,9 @@ CUresult cccl_device_histogram_even_impl(
493469
row_stride_samples,
494470
stream,
495471
is_byte_sample{},
472+
*reinterpret_cast<cub::detail::histogram::policy_selector*>(build.runtime_policy),
496473
{build},
497-
cub::detail::CudaDriverLauncherFactory{cu_device, build.cc},
498-
*reinterpret_cast<histogram::histogram_runtime_tuning_policy*>(build.runtime_policy));
474+
cub::detail::CudaDriverLauncherFactory{cu_device, build.cc});
499475

500476
error = static_cast<CUresult>(exec_status);
501477
}
@@ -595,7 +571,8 @@ try
595571
}
596572

597573
std::unique_ptr<char[]> cubin(reinterpret_cast<char*>(build_ptr->cubin));
598-
std::unique_ptr<char[]> policy(reinterpret_cast<char*>(build_ptr->runtime_policy));
574+
std::unique_ptr<cub::detail::histogram::policy_selector> policy(
575+
static_cast<cub::detail::histogram::policy_selector*>(build_ptr->runtime_policy));
599576
check(cuLibraryUnload(build_ptr->library));
600577

601578
return CUDA_SUCCESS;

cub/cub/agent/agent_histogram.cuh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
#include <cuda/std/__type_traits/integral_constant.h>
2828
#include <cuda/std/__type_traits/is_pointer.h>
2929

30+
#if !_CCCL_COMPILER(NVRTC)
31+
# include <ostream>
32+
#endif // !_CCCL_COMPILER(NVRTC)
33+
3034
CUB_NAMESPACE_BEGIN
3135

3236
enum BlockHistogramMemoryPreference
@@ -36,6 +40,23 @@ enum BlockHistogramMemoryPreference
3640
BLEND
3741
};
3842

43+
#if !_CCCL_COMPILER(NVRTC)
44+
inline ::std::ostream& operator<<(::std::ostream& os, BlockHistogramMemoryPreference mempref)
45+
{
46+
switch (mempref)
47+
{
48+
case GMEM:
49+
return os << "GMEM";
50+
case SMEM:
51+
return os << "SMEM";
52+
case BLEND:
53+
return os << "BLEND";
54+
default:
55+
return os << "<unknown BlockHistogramMemoryPreference: " << static_cast<int>(mempref) << ">";
56+
}
57+
}
58+
#endif // !_CCCL_COMPILER(NVRTC)
59+
3960
//! Parameterizable tuning policy type for AgentHistogram
4061
//!
4162
//! @tparam BlockThreads

0 commit comments

Comments
 (0)