99// ===----------------------------------------------------------------------===//
1010
1111#include < cub/detail/launcher/cuda_driver.cuh>
12- #include < cub/detail/ptx-json-parser.cuh>
1312#include < cub/device/device_histogram.cuh>
1413
15- #include < cuda/std/algorithm>
16-
1714#include < format>
1815#include < limits>
16+ #include < sstream>
1917#include < vector>
2018
2119#include " cccl/c/types.h"
22- #include " cub/util_type.cuh"
2320#include " kernels/iterators.h"
2421#include " util/context.h"
2522#include " util/indirect_arg.h"
@@ -39,35 +36,6 @@ struct samples_iterator_t;
3936
4037namespace histogram
4138{
42- struct histogram_runtime_tuning_policy
43- {
44- cub::detail::RuntimeHistogramAgentPolicy histogram;
45-
46- auto Histogram () const
47- {
48- return histogram;
49- }
50-
51- CUB_RUNTIME_FUNCTION int BlockThreads () const
52- {
53- return histogram.BlockThreads ();
54- }
55-
56- CUB_RUNTIME_FUNCTION int PixelsPerThread () const
57- {
58- return histogram.PixelsPerThread ();
59- }
60-
61- using HistogramPolicy = cub::detail::RuntimeHistogramAgentPolicy;
62- using MaxPolicy = histogram_runtime_tuning_policy;
63-
64- template <typename F>
65- cudaError_t Invoke (int , F& op)
66- {
67- return op.template Invoke <histogram_runtime_tuning_policy>(*this );
68- }
69- };
70-
7139struct histogram_kernel_source
7240{
7341 cccl_device_histogram_build_result_t & build;
292260 const std::string samples_iterator_src =
293261 make_kernel_input_iterator (offset_cpp, samples_iterator_name, sample_cpp, d_samples);
294262
295- std::string policy_hub_expr = std::format (
296- " cub::detail::histogram::policy_hub<{}, {}, {}, {}, {}>" ,
263+ const bool sample_is_primitive = d_samples.value_type .type != CCCL_STORAGE; // TODO(bgruber): how to check if sample
264+ // is primitive?
265+ const auto policy_sel = cub::detail::histogram::policy_selector{
266+ sample_is_primitive,
267+ static_cast <int >(d_samples.value_type .size ),
268+ static_cast <int >(d_output_histograms.value_type .size ),
269+ static_cast <int >(d_samples.value_type .size ),
270+ num_channels,
271+ num_active_channels,
272+ is_evenly_segmented};
273+
274+ const auto arch_id = cuda::to_arch_id (cuda::compute_capability{cc_major, cc_minor});
275+ const auto active_policy = policy_sel (arch_id);
276+
277+ std::stringstream policy_sel_str;
278+ policy_sel_str << active_policy;
279+
280+ std::string policy_selector_expr = std::format (
281+ " cub::detail::histogram::policy_selector_from_types<{}, {}, {}, {}, {}>" ,
297282 sample_cpp,
298283 counter_cpp,
299284 num_channels,
@@ -311,18 +296,16 @@ struct __align__({1}) storage_t {{
311296 char data[{0}];
312297}};
313298{2}
314- using device_histogram_policy = {3}::MaxPolicy;
315-
316- #include <cub/detail/ptx-json/json.cuh>
317- __device__ consteval auto& policy_generator() {{
318- return ptx_json::id<ptx_json::string("device_histogram_policy")>()
319- = cub::detail::histogram::HistogramPolicyWrapper<device_histogram_policy::ActivePolicy>::EncodedPolicy();
320- }}
299+ using device_histogram_policy = {3};
300+ using namespace cub;
301+ using namespace cub::detail::histogram;
302+ static_assert(device_histogram_policy()(::cuda::arch_id{{CUB_PTX_ARCH / 10}}) == {4}, "Host generated and JIT compiled policy mismatch");
321303)XXX" ,
322304 d_samples.value_type .size , // 0
323305 d_samples.value_type .alignment , // 1
324306 samples_iterator_src, // 2
325- policy_hub_expr // 3
307+ policy_selector_expr, // 3
308+ policy_sel_str.view () // 4
326309 );
327310
328311#if false // CCCL_DEBUGGING_SWITCH
@@ -368,7 +351,6 @@ __device__ consteval auto& policy_generator() {{
368351 " -dlto" ,
369352 " -default-device" ,
370353 " -DCUB_DISABLE_CDP" ,
371- " -DCUB_ENABLE_POLICY_PTX_JSON" ,
372354 " -std=c++20" };
373355
374356 cccl::detail::extend_args_with_build_config (args, config);
@@ -398,12 +380,6 @@ __device__ consteval auto& policy_generator() {{
398380 check (cuLibraryGetKernel (&build_ptr->init_kernel , build_ptr->library , init_kernel_lowered_name.c_str ()));
399381 check (cuLibraryGetKernel (&build_ptr->sweep_kernel , build_ptr->library , sweep_kernel_lowered_name.c_str ()));
400382
401- nlohmann::json runtime_policy =
402- cub::detail::ptx_json::parse (" device_histogram_policy" , {result.data .get (), result.size });
403-
404- using cub::detail::RuntimeHistogramAgentPolicy;
405- auto histogram_policy = RuntimeHistogramAgentPolicy::from_json (runtime_policy, " HistogramPolicy" );
406-
407383 build_ptr->cc = cc;
408384 build_ptr->cubin = (void *) result.data .release ();
409385 build_ptr->cubin_size = result.size ;
@@ -413,7 +389,7 @@ __device__ consteval auto& policy_generator() {{
413389 build_ptr->num_active_channels = num_active_channels;
414390 build_ptr->may_overflow = false ; // This is set in cccl_device_histogram_even_impl so that kernel source can access
415391 // it later.
416- build_ptr->runtime_policy = new histogram::histogram_runtime_tuning_policy{histogram_policy };
392+ build_ptr->runtime_policy = new cub::detail:: histogram::policy_selector{policy_sel };
417393
418394 return CUDA_SUCCESS;
419395}
@@ -477,7 +453,7 @@ CUresult cccl_device_histogram_even_impl(
477453 indirect_arg_t , // CounterT
478454 indirect_arg_t , // LevelT
479455 OffsetT, // OffsetT
480- histogram::histogram_runtime_tuning_policy , // PolicyHub
456+ cub::detail:: histogram::policy_selector , // PolicySelector
481457 indirect_arg_t , // SampleT
482458 histogram::histogram_kernel_source, // KernelSource
483459 cub::detail::CudaDriverLauncherFactory // KernelLauncherFactory
@@ -493,9 +469,9 @@ CUresult cccl_device_histogram_even_impl(
493469 row_stride_samples,
494470 stream,
495471 is_byte_sample{},
472+ *reinterpret_cast <cub::detail::histogram::policy_selector*>(build.runtime_policy ),
496473 {build},
497- cub::detail::CudaDriverLauncherFactory{cu_device, build.cc },
498- *reinterpret_cast <histogram::histogram_runtime_tuning_policy*>(build.runtime_policy ));
474+ cub::detail::CudaDriverLauncherFactory{cu_device, build.cc });
499475
500476 error = static_cast <CUresult>(exec_status);
501477 }
595571 }
596572
597573 std::unique_ptr<char []> cubin (reinterpret_cast <char *>(build_ptr->cubin ));
598- std::unique_ptr<char []> policy (reinterpret_cast <char *>(build_ptr->runtime_policy ));
574+ std::unique_ptr<cub::detail::histogram::policy_selector> policy (
575+ static_cast <cub::detail::histogram::policy_selector*>(build_ptr->runtime_policy ));
599576 check (cuLibraryUnload (build_ptr->library ));
600577
601578 return CUDA_SUCCESS;
0 commit comments