From 09f16c2b36335cb7044d7935054fdb24e71f9263 Mon Sep 17 00:00:00 2001 From: keptsecret Date: Mon, 28 Apr 2025 10:54:49 +0700 Subject: [PATCH 01/25] minor fixes, example --- examples_tests | 2 +- .../builtin/hlsl/workgroup2/arithmetic.hlsl | 36 +++++ .../builtin/hlsl/workgroup2/shared_scan.hlsl | 125 ++++++++++++++++++ 3 files changed, 162 insertions(+), 1 deletion(-) create mode 100644 include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl create mode 100644 include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl diff --git a/examples_tests b/examples_tests index 8c76367c1c..20011f5fdd 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 8c76367c1c226cce3d66f1c60f540e29a501a1cb +Subproject commit 20011f5fdd3e8454bb830ded6f4221ec75036809 diff --git a/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl new file mode 100644 index 0000000000..dcd2a5df5d --- /dev/null +++ b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl @@ -0,0 +1,36 @@ +// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_ARITHMETIC_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP2_ARITHMETIC_INCLUDED_ + + +#include "nbl/builtin/hlsl/functional.hlsl" +#include "nbl/builtin/hlsl/workgroup/ballot.hlsl" +#include "nbl/builtin/hlsl/workgroup/broadcast.hlsl" +#include "nbl/builtin/hlsl/workgroup2/shared_scan.hlsl" + + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup2 +{ + +template +struct reduction +{ + template + static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + impl::reduce fn; + fn.__call(dataAccessor, scratchAccessor); + } +} + +} +} +} + +#endif diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl new file mode 100644 index 0000000000..9c2eb164cf --- /dev/null +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -0,0 +1,125 @@ +// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_SHARED_SCAN_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP2_SHARED_SCAN_INCLUDED_ + +#include "nbl/builtin/hlsl/cpp_compat.hlsl" +#include "nbl/builtin/hlsl/workgroup/broadcast.hlsl" +#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" +#include "nbl/builtin/hlsl/subgroup/ballot.hlsl" +#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup2 +{ + +template +struct Configuration +{ + NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(_WorkgroupSize); + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(_SubgroupSizeLog2); + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation = uint16_t(_ItemsPerInvocation); + + // must have at least enough level 0 outputs to feed a single subgroup + NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = hlsl::max(WorkgroupSize >> SubgroupSizeLog2, SubgroupSize); + NBL_CONSTEXPR_STATIC_INLINE uint32_t VirtualWorkgroupSize = SubgroupsPerVirtualWorkgroup << SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation[2] = { Config::ItemsPerInvocation, SubgroupsPerVirtualWorkgroup >> SubgroupSizeLog2 }; + static_assert(ItemsPerInvocation[1]<=4, "3 level scan would have been needed with this config!"); +}; + +namespace impl +{ + +template +struct reduce +{ + using scalar_t = typename BinOp::type_t; + using vector_lv0_t = vector; // data accessor needs to be this type + using vector_lv1_t = vector; // scratch smem accessor needs to be this type + + template + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) // groupshared vector_lv1_t scratch[Config::SubgroupsPerVirtualWorkgroup] + { + using config_t = subgroup2::Configuration; + using params_lv0_t = subgroup2::ArithmeticParams; + using params_lv1_t = subgroup2::ArithmeticParams; + BinOp binop; + + vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; + const uint32_t invocationIndex = SubgroupContiguousIndex(); + subgroup2::inclusive_scan inclusiveScan0; + // level 0 scan + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + scan_local[idx] = inclusiveScan0(dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex)); + if (subgroup::ElectLast()) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + scratchAccessor.set(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation[0]-1]); // set last element of subgroup scan (reduction) to level 1 scan + } + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + subgroup2::inclusive_scan inclusiveScan1; + // level 1 scan + if (glsl::gl_SubgroupID() == 0) + { + scratchAccessor.set(invocationIndex, inclusiveScan1(scratchAccessor.get(invocationIndex))); + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // set as last element in scan (reduction) + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, scratchAccessor.get(Config::SubgroupsPerVirtualWorkgroup-1)); + } + } +}; + +template +struct scan +{ + using scalar_t = typename BinOp::type_t; + using vector_lv0_t = vector; // data accessor needs to be this type + using vector_lv1_t = vector; // scratch smem accessor needs to be this type + + template + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) // groupshared vector_lv1_t scratch[Config::SubgroupsPerVirtualWorkgroup] + { + // TODO get this working + // same thing for level 0 + + subgroup2::inclusive_scan inclusiveScan1; + // level 1 scan + if (glsl::gl_SubgroupID() == 0) + { + const vector_lv1_t shiftedInput = hlsl::mix(BinOp::identity, scratchAccessor.get(invocationIndex-1), bool(invocationIndex)); + scratchAccessor.set(invocationIndex, inclusiveScan1(shiftedInput)); + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // combine with level 0 + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, binop(scratchAccessor.get(virtualSubgroupID), scan_local[idx])); + } + } +}; + +} + +} +} +} + +#endif From 6f5f8b05bc33cc8ea848d3f003bc7218a2d6bbac Mon Sep 17 00:00:00 2001 From: keptsecret Date: Mon, 28 Apr 2025 17:03:39 +0700 Subject: [PATCH 02/25] bug fixes and example --- .../builtin/hlsl/workgroup2/arithmetic.hlsl | 4 +- .../builtin/hlsl/workgroup2/shared_scan.hlsl | 69 ++++++++++--------- 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl index dcd2a5df5d..2753344e43 100644 --- a/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl @@ -25,9 +25,9 @@ struct reduction static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) { impl::reduce fn; - fn.__call(dataAccessor, scratchAccessor); + fn.template __call(dataAccessor, scratchAccessor); } -} +}; } } diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 9c2eb164cf..7be002e8d3 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -9,6 +9,7 @@ #include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" #include "nbl/builtin/hlsl/subgroup/ballot.hlsl" #include "nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl" +#include "nbl/builtin/hlsl/mpl.hlsl" namespace nbl { @@ -23,13 +24,15 @@ struct Configuration NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(_WorkgroupSize); NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(_SubgroupSizeLog2); NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2; - NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation = uint16_t(_ItemsPerInvocation); + // NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation = uint16_t(_ItemsPerInvocation); // must have at least enough level 0 outputs to feed a single subgroup - NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = hlsl::max(WorkgroupSize >> SubgroupSizeLog2, SubgroupSize); + NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = mpl::max> SubgroupSizeLog2), SubgroupSize>::value; //TODO expression not constant apparently NBL_CONSTEXPR_STATIC_INLINE uint32_t VirtualWorkgroupSize = SubgroupsPerVirtualWorkgroup << SubgroupSizeLog2; - NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation[2] = { Config::ItemsPerInvocation, SubgroupsPerVirtualWorkgroup >> SubgroupSizeLog2 }; - static_assert(ItemsPerInvocation[1]<=4, "3 level scan would have been needed with this config!"); + // NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression + NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = _ItemsPerInvocation; + NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = SubgroupsPerVirtualWorkgroup >> SubgroupSizeLog2; + static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!"); }; namespace impl @@ -39,19 +42,19 @@ template struct reduce { using scalar_t = typename BinOp::type_t; - using vector_lv0_t = vector; // data accessor needs to be this type - using vector_lv1_t = vector; // scratch smem accessor needs to be this type + using vector_lv0_t = vector; // data accessor needs to be this type + using vector_lv1_t = vector; // scratch smem accessor needs to be this type template void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) // groupshared vector_lv1_t scratch[Config::SubgroupsPerVirtualWorkgroup] { using config_t = subgroup2::Configuration; - using params_lv0_t = subgroup2::ArithmeticParams; - using params_lv1_t = subgroup2::ArithmeticParams; + using params_lv0_t = subgroup2::ArithmeticParams; + using params_lv1_t = subgroup2::ArithmeticParams; BinOp binop; vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; - const uint32_t invocationIndex = SubgroupContiguousIndex(); + const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); subgroup2::inclusive_scan inclusiveScan0; // level 0 scan [unroll] @@ -61,7 +64,7 @@ struct reduce if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - scratchAccessor.set(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation[0]-1]); // set last element of subgroup scan (reduction) to level 1 scan + scratchAccessor.set(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -88,31 +91,35 @@ template; // data accessor needs to be this type - using vector_lv1_t = vector; // scratch smem accessor needs to be this type + using vector_lv0_t = vector; // data accessor needs to be this type + using vector_lv1_t = vector; // scratch smem accessor needs to be this type template void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) // groupshared vector_lv1_t scratch[Config::SubgroupsPerVirtualWorkgroup] { - // TODO get this working - // same thing for level 0 - - subgroup2::inclusive_scan inclusiveScan1; - // level 1 scan - if (glsl::gl_SubgroupID() == 0) - { - const vector_lv1_t shiftedInput = hlsl::mix(BinOp::identity, scratchAccessor.get(invocationIndex-1), bool(invocationIndex)); - scratchAccessor.set(invocationIndex, inclusiveScan1(shiftedInput)); - } - scratchAccessor.workgroupExecutionAndMemoryBarrier(); - - // combine with level 0 - [unroll] - for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) - { - const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, binop(scratchAccessor.get(virtualSubgroupID), scan_local[idx])); - } + // // TODO get this working + // // same thing for level 0 + // using config_t = subgroup2::Configuration; + // using params_lv0_t = subgroup2::ArithmeticParams; + // using params_lv1_t = subgroup2::ArithmeticParams; + // BinOp binop; + + // subgroup2::inclusive_scan inclusiveScan1; + // // level 1 scan + // if (glsl::gl_SubgroupID() == 0) + // { + // const vector_lv1_t shiftedInput = hlsl::mix(BinOp::identity, scratchAccessor.get(invocationIndex-1), bool(invocationIndex)); + // scratchAccessor.set(invocationIndex, inclusiveScan1(shiftedInput)); + // } + // scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // // combine with level 0 + // [unroll] + // for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + // { + // const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + // dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, binop(scratchAccessor.get(virtualSubgroupID), scan_local[idx])); + // } } }; From 1bac2478f5f09c05b45fa625c70da6ca44023970 Mon Sep 17 00:00:00 2001 From: keptsecret Date: Tue, 29 Apr 2025 12:05:04 +0700 Subject: [PATCH 03/25] fix to data accessor indexing --- include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 7be002e8d3..3cba3a2d57 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -60,7 +60,7 @@ struct reduce [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - scan_local[idx] = inclusiveScan0(dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex)); + scan_local[idx] = inclusiveScan0(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex)); if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); @@ -70,6 +70,7 @@ struct reduce scratchAccessor.workgroupExecutionAndMemoryBarrier(); subgroup2::inclusive_scan inclusiveScan1; + // subgroup2::reduction reduce1; // level 1 scan if (glsl::gl_SubgroupID() == 0) { @@ -81,8 +82,8 @@ struct reduce [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, scratchAccessor.get(Config::SubgroupsPerVirtualWorkgroup-1)); + // const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scratchAccessor.get(Config::SubgroupSize-1)); } } }; From 305ac7bd3997f7b491ff9adb30a8f9c8e54ab5ca Mon Sep 17 00:00:00 2001 From: keptsecret Date: Tue, 29 Apr 2025 16:58:04 +0700 Subject: [PATCH 04/25] added template spec for vector dim 1 --- include/nbl/builtin/hlsl/vector_utils/vector_traits.hlsl | 1 + 1 file changed, 1 insertion(+) diff --git a/include/nbl/builtin/hlsl/vector_utils/vector_traits.hlsl b/include/nbl/builtin/hlsl/vector_utils/vector_traits.hlsl index 9aefc3b3d8..652cabd7c7 100644 --- a/include/nbl/builtin/hlsl/vector_utils/vector_traits.hlsl +++ b/include/nbl/builtin/hlsl/vector_utils/vector_traits.hlsl @@ -28,6 +28,7 @@ struct vector_traits >\ NBL_CONSTEXPR_STATIC_INLINE bool IsVector = true;\ };\ +DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(1) DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(2) DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(3) DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(4) From c08063da62a3bed85cb4ff9d59668ed7474604f7 Mon Sep 17 00:00:00 2001 From: keptsecret Date: Tue, 29 Apr 2025 17:03:13 +0700 Subject: [PATCH 05/25] added inclusive scan --- .../builtin/hlsl/workgroup2/arithmetic.hlsl | 11 +++ .../builtin/hlsl/workgroup2/shared_scan.hlsl | 77 +++++++++++-------- 2 files changed, 57 insertions(+), 31 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl index 2753344e43..acfa5feba8 100644 --- a/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl @@ -29,6 +29,17 @@ struct reduction } }; +template +struct inclusive_scan +{ + template + static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + impl::scan fn; + fn.template __call(dataAccessor, scratchAccessor); + } +}; + } } } diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 3cba3a2d57..6358bf24ad 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -24,7 +24,6 @@ struct Configuration NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(_WorkgroupSize); NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(_SubgroupSizeLog2); NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2; - // NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation = uint16_t(_ItemsPerInvocation); // must have at least enough level 0 outputs to feed a single subgroup NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = mpl::max> SubgroupSizeLog2), SubgroupSize>::value; //TODO expression not constant apparently @@ -46,7 +45,7 @@ struct reduce using vector_lv1_t = vector; // scratch smem accessor needs to be this type template - void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) // groupshared vector_lv1_t scratch[Config::SubgroupsPerVirtualWorkgroup] + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) { using config_t = subgroup2::Configuration; using params_lv0_t = subgroup2::ArithmeticParams; @@ -55,8 +54,8 @@ struct reduce vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); - subgroup2::inclusive_scan inclusiveScan0; // level 0 scan + subgroup2::inclusive_scan inclusiveScan0; [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { @@ -69,9 +68,8 @@ struct reduce } scratchAccessor.workgroupExecutionAndMemoryBarrier(); - subgroup2::inclusive_scan inclusiveScan1; - // subgroup2::reduction reduce1; // level 1 scan + subgroup2::inclusive_scan inclusiveScan1; if (glsl::gl_SubgroupID() == 0) { scratchAccessor.set(invocationIndex, inclusiveScan1(scratchAccessor.get(invocationIndex))); @@ -82,13 +80,12 @@ struct reduce [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - // const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scratchAccessor.get(Config::SubgroupSize-1)); } } }; -template +template struct scan { using scalar_t = typename BinOp::type_t; @@ -96,31 +93,49 @@ struct scan using vector_lv1_t = vector; // scratch smem accessor needs to be this type template - void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) // groupshared vector_lv1_t scratch[Config::SubgroupsPerVirtualWorkgroup] + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) { - // // TODO get this working - // // same thing for level 0 - // using config_t = subgroup2::Configuration; - // using params_lv0_t = subgroup2::ArithmeticParams; - // using params_lv1_t = subgroup2::ArithmeticParams; - // BinOp binop; - - // subgroup2::inclusive_scan inclusiveScan1; - // // level 1 scan - // if (glsl::gl_SubgroupID() == 0) - // { - // const vector_lv1_t shiftedInput = hlsl::mix(BinOp::identity, scratchAccessor.get(invocationIndex-1), bool(invocationIndex)); - // scratchAccessor.set(invocationIndex, inclusiveScan1(shiftedInput)); - // } - // scratchAccessor.workgroupExecutionAndMemoryBarrier(); - - // // combine with level 0 - // [unroll] - // for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) - // { - // const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - // dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, binop(scratchAccessor.get(virtualSubgroupID), scan_local[idx])); - // } + using config_t = subgroup2::Configuration; + using params_lv0_t = subgroup2::ArithmeticParams; + using params_lv1_t = subgroup2::ArithmeticParams; + BinOp binop; + + vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; + const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); + subgroup2::inclusive_scan inclusiveScan0; + // level 0 scan + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + scan_local[idx] = inclusiveScan0(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex)); + if (subgroup::ElectLast()) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + scratchAccessor.set(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + } + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // level 1 scan + subgroup2::inclusive_scan inclusiveScan1; + if (glsl::gl_SubgroupID() == 0) + { + const vector_lv1_t shiftedInput = hlsl::mix(hlsl::promote(BinOp::identity), scratchAccessor.get(invocationIndex-1), bool(invocationIndex)); + scratchAccessor.set(invocationIndex, inclusiveScan1(shiftedInput)); + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // combine with level 0 + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + const vector_lv1_t lhs = scratchAccessor.get(virtualSubgroupID); + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) + scan_local[idx][i] = binop(lhs, scan_local[idx][i]); + dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + } } }; From b1d804f520eed03d72a1d625bb904e777a34b23a Mon Sep 17 00:00:00 2001 From: keptsecret Date: Wed, 30 Apr 2025 14:08:38 +0700 Subject: [PATCH 06/25] exclusive scan working --- .../builtin/hlsl/workgroup2/arithmetic.hlsl | 11 +++++++++++ .../builtin/hlsl/workgroup2/shared_scan.hlsl | 18 ++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl index acfa5feba8..6824e92afa 100644 --- a/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl @@ -40,6 +40,17 @@ struct inclusive_scan } }; +template +struct exclusive_scan +{ + template + static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + impl::scan fn; + fn.template __call(dataAccessor, scratchAccessor); + } +}; + } } } diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 6358bf24ad..331951d3f3 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -130,10 +130,20 @@ struct scan for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - const vector_lv1_t lhs = scratchAccessor.get(virtualSubgroupID); - [unroll] - for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) - scan_local[idx][i] = binop(lhs, scan_local[idx][i]); + const vector_lv1_t left = scratchAccessor.get(virtualSubgroupID); + if (Exclusive) + { + scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID())); + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) + scan_local[idx][Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(scan_local[idx][Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0))); + } + else + { + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) + scan_local[idx][i] = binop(left, scan_local[idx][i]); + } dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); } } From 3cf98ab4abe77fecd7a779d58c7f85c42d85251e Mon Sep 17 00:00:00 2001 From: keptsecret Date: Wed, 30 Apr 2025 14:12:55 +0700 Subject: [PATCH 07/25] removed outdated comment --- include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 331951d3f3..cd49cb1c1b 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -26,7 +26,7 @@ struct Configuration NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2; // must have at least enough level 0 outputs to feed a single subgroup - NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = mpl::max> SubgroupSizeLog2), SubgroupSize>::value; //TODO expression not constant apparently + NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = mpl::max> SubgroupSizeLog2), SubgroupSize>::value; NBL_CONSTEXPR_STATIC_INLINE uint32_t VirtualWorkgroupSize = SubgroupsPerVirtualWorkgroup << SubgroupSizeLog2; // NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = _ItemsPerInvocation; From 7b310e01f9c4c557dec87555121c3ee7cebed456 Mon Sep 17 00:00:00 2001 From: keptsecret Date: Thu, 1 May 2025 12:18:35 +0700 Subject: [PATCH 08/25] minor changes to config usage --- include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index cd49cb1c1b..c789c8a482 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -18,19 +18,20 @@ namespace hlsl namespace workgroup2 { -template +template struct Configuration { - NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(_WorkgroupSize); + NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << WorkgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(_SubgroupSizeLog2); NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2; + static_assert(WorkgroupSizeLog2>=_SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize"); // must have at least enough level 0 outputs to feed a single subgroup - NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = mpl::max> SubgroupSizeLog2), SubgroupSize>::value; - NBL_CONSTEXPR_STATIC_INLINE uint32_t VirtualWorkgroupSize = SubgroupsPerVirtualWorkgroup << SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max::value - SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint32_t VirtualWorkgroupSize = uint32_t(0x1u) << (SubgroupsPerVirtualWorkgroupLog2 + SubgroupSizeLog2); // NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = _ItemsPerInvocation; - NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = SubgroupsPerVirtualWorkgroup >> SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = uint32_t(0x1u) << (SubgroupsPerVirtualWorkgroupLog2 - SubgroupSizeLog2); static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!"); }; From 4b4e7e8f3685f4a825997ba7a3ea5fc2594883f4 Mon Sep 17 00:00:00 2001 From: keptsecret Date: Thu, 1 May 2025 17:19:13 +0700 Subject: [PATCH 09/25] add 1 level scans --- .../builtin/hlsl/workgroup2/arithmetic.hlsl | 6 +- .../builtin/hlsl/workgroup2/shared_scan.hlsl | 69 ++++++++++++++++++- 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl index 6824e92afa..3b4a028d2c 100644 --- a/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl @@ -24,7 +24,7 @@ struct reduction template static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) { - impl::reduce fn; + impl::reduce fn; fn.template __call(dataAccessor, scratchAccessor); } }; @@ -35,7 +35,7 @@ struct inclusive_scan template static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) { - impl::scan fn; + impl::scan fn; fn.template __call(dataAccessor, scratchAccessor); } }; @@ -46,7 +46,7 @@ struct exclusive_scan template static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) { - impl::scan fn; + impl::scan fn; fn.template __call(dataAccessor, scratchAccessor); } }; diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index c789c8a482..c18c00f83e 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -26,11 +26,13 @@ struct Configuration NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2; static_assert(WorkgroupSizeLog2>=_SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize"); + NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = conditional_value::value; + // must have at least enough level 0 outputs to feed a single subgroup NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max::value - SubgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint32_t VirtualWorkgroupSize = uint32_t(0x1u) << (SubgroupsPerVirtualWorkgroupLog2 + SubgroupSizeLog2); // NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression - NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = _ItemsPerInvocation; + NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = conditional_value::value; NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = uint32_t(0x1u) << (SubgroupsPerVirtualWorkgroupLog2 - SubgroupSizeLog2); static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!"); }; @@ -38,8 +40,69 @@ struct Configuration namespace impl { +template +struct reduce; + +template +struct scan; + +// 1-level scans +template +struct reduce +{ + using scalar_t = typename BinOp::type_t; + using vector_t = vector; // data accessor needs to be this type + // doesn't use scratch smem, need as param? + + template + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + using config_t = subgroup2::Configuration; + using params_t = subgroup2::ArithmeticParams; + + subgroup2::reduction reduction; + if (glsl::gl_SubgroupID() == 0) + { + vector_t value = reduction(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::WorkgroupSize + workgroup::SubgroupContiguousIndex())); + dataAccessor.set(glsl::gl_WorkGroupID().x * Config::WorkgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with top line? + } + } +}; + +template +struct scan +{ + using scalar_t = typename BinOp::type_t; + using vector_t = vector; // data accessor needs to be this type + // doesn't use scratch smem, need as param? + + template + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + using config_t = subgroup2::Configuration; + using params_t = subgroup2::ArithmeticParams; + + if (glsl::gl_SubgroupID() == 0) + { + vector_t value; + if (Exclusive) + { + subgroup2::exclusive_scan excl_scan; + value = excl_scan(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::WorkgroupSize + workgroup::SubgroupContiguousIndex())); + } + else + { + subgroup2::inclusive_scan incl_scan; + value = incl_scan(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::WorkgroupSize + workgroup::SubgroupContiguousIndex())); + } + dataAccessor.set(glsl::gl_WorkGroupID().x * Config::WorkgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with above lines? + } + } +}; + +// 2-level scans template -struct reduce +struct reduce { using scalar_t = typename BinOp::type_t; using vector_lv0_t = vector; // data accessor needs to be this type @@ -87,7 +150,7 @@ struct reduce }; template -struct scan +struct scan { using scalar_t = typename BinOp::type_t; using vector_lv0_t = vector; // data accessor needs to be this type From 2e5f29f10e53f1f8632e8f45099cece1e4b72601 Mon Sep 17 00:00:00 2001 From: keptsecret Date: Fri, 2 May 2025 09:41:52 +0700 Subject: [PATCH 10/25] fixes to 1 level scans --- include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index c18c00f83e..0128c3320d 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -63,8 +63,8 @@ struct reduce subgroup2::reduction reduction; if (glsl::gl_SubgroupID() == 0) { - vector_t value = reduction(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::WorkgroupSize + workgroup::SubgroupContiguousIndex())); - dataAccessor.set(glsl::gl_WorkGroupID().x * Config::WorkgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with top line? + vector_t value = reduction(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex())); + dataAccessor.set(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with top line? } } }; @@ -88,14 +88,14 @@ struct scan if (Exclusive) { subgroup2::exclusive_scan excl_scan; - value = excl_scan(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::WorkgroupSize + workgroup::SubgroupContiguousIndex())); + value = excl_scan(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex())); } else { subgroup2::inclusive_scan incl_scan; - value = incl_scan(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::WorkgroupSize + workgroup::SubgroupContiguousIndex())); + value = incl_scan(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex())); } - dataAccessor.set(glsl::gl_WorkGroupID().x * Config::WorkgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with above lines? + dataAccessor.set(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with above lines? } } }; From 054b26916204d3ece92e474cb87ec74ebdead9bb Mon Sep 17 00:00:00 2001 From: keptsecret Date: Fri, 2 May 2025 10:54:33 +0700 Subject: [PATCH 11/25] added handling >1 vectors on level 1 scan (untested) --- include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 0128c3320d..b32bc3efde 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -127,7 +127,7 @@ struct reduce if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - scratchAccessor.set(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + scratchAccessor.setByComponent(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -144,7 +144,7 @@ struct reduce [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scratchAccessor.get(Config::SubgroupSize-1)); + dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scratchAccessor.getByComponent((1u << Config::SubgroupsPerVirtualWorkgroupLog2)-1)); } } }; @@ -175,7 +175,7 @@ struct scan if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - scratchAccessor.set(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + scratchAccessor.setByComponent(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -194,7 +194,7 @@ struct scan for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - const vector_lv1_t left = scratchAccessor.get(virtualSubgroupID); + const scalar_t left = scratchAccessor.getByComponent(virtualSubgroupID); if (Exclusive) { scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID())); From 1b5282c8b5c37a3d387ec89ce2c2ea12384c41b7 Mon Sep 17 00:00:00 2001 From: keptsecret Date: Mon, 5 May 2025 17:16:12 +0700 Subject: [PATCH 12/25] move load/store smem into scan funcs, setup config for 3 levels --- .../builtin/hlsl/workgroup2/shared_scan.hlsl | 200 +++++++++++++++++- 1 file changed, 191 insertions(+), 9 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index b32bc3efde..c88694d1ac 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -18,6 +18,25 @@ namespace hlsl namespace workgroup2 { +namespace impl +{ +template +struct virtual_wg_size_log2 +{ + NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2+2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value; + NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v+SubgroupSizeLog2; +}; + +template +struct items_per_invocation +{ + NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocationProductLog2 = mpl::max_v; + NBL_CONSTEXPR_STATIC_INLINE uint16_t value0 = conditional_value::value; + NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t(0x1u) << conditional_value, ItemsPerInvocationProductLog2>::value; + NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t(0x1u) << mpl::max_v; +}; +} + template struct Configuration { @@ -26,17 +45,43 @@ struct Configuration NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2; static_assert(WorkgroupSizeLog2>=_SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize"); - NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = conditional_value::value; - // must have at least enough level 0 outputs to feed a single subgroup - NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max::value - SubgroupSizeLog2; - NBL_CONSTEXPR_STATIC_INLINE uint32_t VirtualWorkgroupSize = uint32_t(0x1u) << (SubgroupsPerVirtualWorkgroupLog2 + SubgroupSizeLog2); + NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v; + + using virtual_wg_t = impl::virtual_wg_size_log2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = virtual_wg_t::levels; + NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << virtual_wg_t::value; + using items_per_invoc_t = impl::items_per_invocation; // NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression - NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = conditional_value::value; - NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = uint32_t(0x1u) << (SubgroupsPerVirtualWorkgroupLog2 - SubgroupSizeLog2); + NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = items_per_invoc_t::value0; + NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = items_per_invoc_t::value1; + NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = items_per_invoc_t::value2; static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!"); }; +// special case when workgroup size 2048 and subgroup size 16 needs 3 levels and virtual workgroup size 4096 to get a full subgroup scan each on level 1 and 2 16x16x16=4096 +// specializing with macros because of DXC bug: https://github.com/microsoft/DirectXShaderCom0piler/issues/7007 +#define SPECIALIZE_CONFIG_CASE_2048_16(ITEMS_PER_INVOC) template<>\ +struct Configuration<11, 4, ITEMS_PER_INVOC>\ +{\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << 11u;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(4u);\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2;\ + NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = 128u;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = 3;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << 4096;\ + NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = ITEMS_PER_INVOC;\ + NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = 1u;\ + NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = 1u;\ +};\ + +SPECIALIZE_CONFIG_CASE_2048_16(1) +SPECIALIZE_CONFIG_CASE_2048_16(2) +SPECIALIZE_CONFIG_CASE_2048_16(4) + +#undef SPECIALIZE_CONFIG_CASE_2048_16 + + namespace impl { @@ -127,7 +172,62 @@ struct reduce if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - scratchAccessor.setByComponent(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + scratchAccessor.set(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + } + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // level 1 scan + subgroup2::inclusive_scan inclusiveScan1; + if (glsl::gl_SubgroupID() == 0) + { + vector_lv1_t lv1_val; + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) + scratchAccessor.get(invocationIndex*Config::ItemsPerInvocation_1+i,lv1_val[i]); + lv1_val = inclusiveScan1(lv1_val); + scratchAccessor.set(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1]); + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // set as last element in scan (reduction) + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + scalar_t reduce_val; + scratchAccessor.get(Config::SubgroupSize-1,reduce_val); + dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val); + } + } +}; + +template +struct scan +{ + using scalar_t = typename BinOp::type_t; + using vector_lv0_t = vector; // data accessor needs to be this type + using vector_lv1_t = vector; // scratch smem accessor needs to be this type + + template + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + using config_t = subgroup2::Configuration; + using params_lv0_t = subgroup2::ArithmeticParams; + using params_lv1_t = subgroup2::ArithmeticParams; + BinOp binop; + + vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; + const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); + subgroup2::inclusive_scan inclusiveScan0; + // level 0 scan + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + scan_local[idx] = inclusiveScan0(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex)); + if (subgroup::ElectLast()) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + scratchAccessor.set(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -135,11 +235,93 @@ struct reduce // level 1 scan subgroup2::inclusive_scan inclusiveScan1; if (glsl::gl_SubgroupID() == 0) + { + vector_lv1_t lv1_val; + const uint32_t prevIndex = invocationIndex-1; + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) + scratchAccessor.get(prevIndex*Config::ItemsPerInvocation_1+i,lv1_val[i]); + vector_lv1_t shiftedInput = hlsl::mix(hlsl::promote(BinOp::identity), lv1_val, bool(invocationIndex)); + shiftedInput = inclusiveScan1(shiftedInput); + scratchAccessor.set(invocationIndex, shiftedInput[Config::ItemsPerInvocation_1-1]); + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // combine with level 0 + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + scalar_t left; + scratchAccessor.get(virtualSubgroupID,left); + if (Exclusive) + { + scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID())); + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) + scan_local[idx][Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(scan_local[idx][Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0))); + } + else + { + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) + scan_local[idx][i] = binop(left, scan_local[idx][i]); + } + dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + } + } +}; + +// 2-level scans +/* +template +struct reduce +{ + using scalar_t = typename BinOp::type_t; + using vector_lv0_t = vector; // data accessor needs to be this type + using vector_lv1_t = vector; // scratch smem accessor needs to be this type + + template + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + using config_t = subgroup2::Configuration; + using params_lv0_t = subgroup2::ArithmeticParams; + using params_lv1_t = subgroup2::ArithmeticParams; + BinOp binop; + + vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; + const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); + // level 0 scan + subgroup2::inclusive_scan inclusiveScan0; + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + scan_local[idx] = inclusiveScan0(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex)); + if (subgroup::ElectLast()) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + scratchAccessor.setByComponent(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + } + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // level 1 scan + subgroup2::inclusive_scan inclusiveScan1; + if (glsl::gl_SubgroupID() < Config::SubgroupSizeLog2*Config::ItemsPerInvocation_1) { scratchAccessor.set(invocationIndex, inclusiveScan1(scratchAccessor.get(invocationIndex))); } scratchAccessor.workgroupExecutionAndMemoryBarrier(); + // level 2 scan + // TODO + subgroup2::inclusive_scan inclusiveScan2; + if (glsl::gl_SubgroupID() == 0) + { + scratchAccessor.set(invocationIndex, inclusiveScan2(scratchAccessor.get(invocationIndex))); + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + // set as last element in scan (reduction) [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) @@ -150,7 +332,7 @@ struct reduce }; template -struct scan +struct scan { using scalar_t = typename BinOp::type_t; using vector_lv0_t = vector; // data accessor needs to be this type @@ -212,7 +394,7 @@ struct scan } } }; - +*/ } } From c6dc5bc9579877d03f2e1e5531ef527cdd1b4eda Mon Sep 17 00:00:00 2001 From: keptsecret Date: Tue, 6 May 2025 10:52:05 +0700 Subject: [PATCH 13/25] change to use coalesced indexing for 2-level scans --- .../nbl/builtin/hlsl/workgroup2/shared_scan.hlsl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index c88694d1ac..26fb969ace 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -172,7 +172,8 @@ struct reduce if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - scratchAccessor.set(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroupLog2 + (virtualSubgroupID/Config::ItemsPerInvocation_1); + scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -184,7 +185,7 @@ struct reduce vector_lv1_t lv1_val; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) - scratchAccessor.get(invocationIndex*Config::ItemsPerInvocation_1+i,lv1_val[i]); + scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+invocationIndex,lv1_val[i]); lv1_val = inclusiveScan1(lv1_val); scratchAccessor.set(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1]); } @@ -227,7 +228,8 @@ struct scan if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - scratchAccessor.set(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroupLog2 + (virtualSubgroupID/Config::ItemsPerInvocation_1); + scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -240,7 +242,7 @@ struct scan const uint32_t prevIndex = invocationIndex-1; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) - scratchAccessor.get(prevIndex*Config::ItemsPerInvocation_1+i,lv1_val[i]); + scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+prevIndex,lv1_val[i]); vector_lv1_t shiftedInput = hlsl::mix(hlsl::promote(BinOp::identity), lv1_val, bool(invocationIndex)); shiftedInput = inclusiveScan1(shiftedInput); scratchAccessor.set(invocationIndex, shiftedInput[Config::ItemsPerInvocation_1-1]); @@ -272,8 +274,7 @@ struct scan } }; -// 2-level scans -/* +// 3-level scans template struct reduce { @@ -394,7 +395,7 @@ struct scan } } }; -*/ + } } From aa0c36c8b48f480325c74334fa2fb8400b1fc76e Mon Sep 17 00:00:00 2001 From: keptsecret Date: Tue, 6 May 2025 14:35:02 +0700 Subject: [PATCH 14/25] added 3-level scans --- .../builtin/hlsl/workgroup2/shared_scan.hlsl | 69 +++++++++++++++---- 1 file changed, 56 insertions(+), 13 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 26fb969ace..91596bace0 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -151,7 +151,7 @@ struct reduce { using scalar_t = typename BinOp::type_t; using vector_lv0_t = vector; // data accessor needs to be this type - using vector_lv1_t = vector; // scratch smem accessor needs to be this type + using vector_lv1_t = vector; template void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) @@ -207,7 +207,7 @@ struct scan { using scalar_t = typename BinOp::type_t; using vector_lv0_t = vector; // data accessor needs to be this type - using vector_lv1_t = vector; // scratch smem accessor needs to be this type + using vector_lv1_t = vector; template void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) @@ -280,7 +280,8 @@ struct reduce { using scalar_t = typename BinOp::type_t; using vector_lv0_t = vector; // data accessor needs to be this type - using vector_lv1_t = vector; // scratch smem accessor needs to be this type + using vector_lv1_t = vector; + using vector_lv2_t = vector; template void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) @@ -288,6 +289,7 @@ struct reduce using config_t = subgroup2::Configuration; using params_lv0_t = subgroup2::ArithmeticParams; using params_lv1_t = subgroup2::ArithmeticParams; + using params_lv2_t = subgroup2::ArithmeticParams; BinOp binop; vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; @@ -301,7 +303,8 @@ struct reduce if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - scratchAccessor.setByComponent(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroupLog2 + (virtualSubgroupID/Config::ItemsPerInvocation_1); + scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -310,16 +313,29 @@ struct reduce subgroup2::inclusive_scan inclusiveScan1; if (glsl::gl_SubgroupID() < Config::SubgroupSizeLog2*Config::ItemsPerInvocation_1) { - scratchAccessor.set(invocationIndex, inclusiveScan1(scratchAccessor.get(invocationIndex))); + vector_lv1_t lv1_val; + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) + scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+invocationIndex,lv1_val[i]); + lv1_val = inclusiveScan1(lv1_val); + if (subgroup::ElectLast()) + { + const uint32_t bankedIndex = (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroupLog2 + (invocationIndex/Config::ItemsPerInvocation_2); + scratchAccessor.set(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); + } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); // level 2 scan - // TODO - subgroup2::inclusive_scan inclusiveScan2; + subgroup2::inclusive_scan inclusiveScan2; if (glsl::gl_SubgroupID() == 0) { - scratchAccessor.set(invocationIndex, inclusiveScan2(scratchAccessor.get(invocationIndex))); + vector_lv2_t lv2_val; + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++) + scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+invocationIndex,lv2_val[i]); + lv2_val = inclusiveScan2(lv2_val); + scratchAccessor.set(invocationIndex, lv2_val[Config::ItemsPerInvocation_2-1]); } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -327,7 +343,9 @@ struct reduce [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scratchAccessor.getByComponent((1u << Config::SubgroupsPerVirtualWorkgroupLog2)-1)); + scalar_t reduce_val; + scratchAccessor.get(Config::SubgroupSize-1,reduce_val); + dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val); } } }; @@ -358,17 +376,41 @@ struct scan if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - scratchAccessor.setByComponent(virtualSubgroupID, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroupLog2 + (virtualSubgroupID/Config::ItemsPerInvocation_1); + scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); // level 1 scan subgroup2::inclusive_scan inclusiveScan1; + if (glsl::gl_SubgroupID() < Config::SubgroupSizeLog2*Config::ItemsPerInvocation_1) + { + vector_lv1_t lv1_val; + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) + scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+invocationIndex,lv1_val[i]); + lv1_val = inclusiveScan1(lv1_val); + if (subgroup::ElectLast()) + { + const uint32_t bankedIndex = (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroupLog2 + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2); + scratchAccessor.set(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); + } + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // level 2 scan + subgroup2::inclusive_scan inclusiveScan2; if (glsl::gl_SubgroupID() == 0) { - const vector_lv1_t shiftedInput = hlsl::mix(hlsl::promote(BinOp::identity), scratchAccessor.get(invocationIndex-1), bool(invocationIndex)); - scratchAccessor.set(invocationIndex, inclusiveScan1(shiftedInput)); + vector_lv2_t lv2_val; + const uint32_t prevIndex = invocationIndex-1; + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++) + scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+prevIndex,lv2_val[i]); + vector_lv2_t shiftedInput = hlsl::mix(hlsl::promote(BinOp::identity), lv2_val, bool(invocationIndex)); + shiftedInput = inclusiveScan2(shiftedInput); + scratchAccessor.set(invocationIndex, shiftedInput[Config::ItemsPerInvocation_2-1]); } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -377,7 +419,8 @@ struct scan for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - const scalar_t left = scratchAccessor.getByComponent(virtualSubgroupID); + const scalar_t left; + scratchAccessor.get(virtualSubgroupID, left); if (Exclusive) { scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID())); From 74c359bed10f1a2d3d55b126863f3d962b87826d Mon Sep 17 00:00:00 2001 From: keptsecret Date: Tue, 6 May 2025 14:41:01 +0700 Subject: [PATCH 15/25] minor bug fixes --- include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 91596bace0..141deccb7b 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -355,7 +355,8 @@ struct scan { using scalar_t = typename BinOp::type_t; using vector_lv0_t = vector; // data accessor needs to be this type - using vector_lv1_t = vector; // scratch smem accessor needs to be this type + using vector_lv1_t = vector; + using vector_lv2_t = vector; template void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) @@ -363,6 +364,7 @@ struct scan using config_t = subgroup2::Configuration; using params_lv0_t = subgroup2::ArithmeticParams; using params_lv1_t = subgroup2::ArithmeticParams; + using params_lv2_t = subgroup2::ArithmeticParams; BinOp binop; vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; From ce244e2d24d2da9e79197226799098aaa7675be9 Mon Sep 17 00:00:00 2001 From: keptsecret Date: Wed, 7 May 2025 16:55:34 +0700 Subject: [PATCH 16/25] changes to data accessor usage --- .../builtin/hlsl/workgroup2/shared_scan.hlsl | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 141deccb7b..057e9ebd24 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -108,7 +108,9 @@ struct reduce subgroup2::reduction reduction; if (glsl::gl_SubgroupID() == 0) { - vector_t value = reduction(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex())); + vector_t value; + dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); + value = reduction(value); dataAccessor.set(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with top line? } } @@ -130,15 +132,16 @@ struct scan if (glsl::gl_SubgroupID() == 0) { vector_t value; + dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); if (Exclusive) { subgroup2::exclusive_scan excl_scan; - value = excl_scan(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex())); + value = excl_scan(value); } else { subgroup2::inclusive_scan incl_scan; - value = incl_scan(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex())); + value = incl_scan(value); } dataAccessor.set(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with above lines? } @@ -168,7 +171,8 @@ struct reduce [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - scan_local[idx] = inclusiveScan0(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex)); + dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + scan_local[idx] = inclusiveScan0(scan_local[idx]); if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); @@ -224,7 +228,8 @@ struct scan [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - scan_local[idx] = inclusiveScan0(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex)); + dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + scan_local[idx] = inclusiveScan0(scan_local[idx]); if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); @@ -299,7 +304,8 @@ struct reduce [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - scan_local[idx] = inclusiveScan0(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex)); + dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + scan_local[idx] = inclusiveScan0(scan_local[idx]); if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); @@ -374,7 +380,8 @@ struct scan [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - scan_local[idx] = inclusiveScan0(dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex)); + dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + scan_local[idx] = inclusiveScan0(scan_local[idx]); if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); From 90b19d817b7d5e9651ed755ff503873881e33311 Mon Sep 17 00:00:00 2001 From: keptsecret Date: Thu, 8 May 2025 17:03:47 +0700 Subject: [PATCH 17/25] wg reduction uses reduce instead of scan --- .../builtin/hlsl/workgroup2/shared_scan.hlsl | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 057e9ebd24..7ed16faf09 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -167,12 +167,12 @@ struct reduce vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); // level 0 scan - subgroup2::inclusive_scan inclusiveScan0; + subgroup2::reduction reduction0; [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); - scan_local[idx] = inclusiveScan0(scan_local[idx]); + scan_local[idx] = reduction0(scan_local[idx]); if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); @@ -183,14 +183,14 @@ struct reduce scratchAccessor.workgroupExecutionAndMemoryBarrier(); // level 1 scan - subgroup2::inclusive_scan inclusiveScan1; + subgroup2::reduction reduction1; if (glsl::gl_SubgroupID() == 0) { vector_lv1_t lv1_val; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+invocationIndex,lv1_val[i]); - lv1_val = inclusiveScan1(lv1_val); + lv1_val = reduction1(lv1_val); scratchAccessor.set(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1]); } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -200,7 +200,7 @@ struct reduce for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { scalar_t reduce_val; - scratchAccessor.get(Config::SubgroupSize-1,reduce_val); + scratchAccessor.get(glsl::gl_SubgroupInvocationID(),reduce_val); dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val); } } @@ -300,12 +300,12 @@ struct reduce vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); // level 0 scan - subgroup2::inclusive_scan inclusiveScan0; + subgroup2::reduction reduction0; [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); - scan_local[idx] = inclusiveScan0(scan_local[idx]); + scan_local[idx] = reduction0(scan_local[idx]); if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); @@ -316,14 +316,14 @@ struct reduce scratchAccessor.workgroupExecutionAndMemoryBarrier(); // level 1 scan - subgroup2::inclusive_scan inclusiveScan1; + subgroup2::reduction reduction1; if (glsl::gl_SubgroupID() < Config::SubgroupSizeLog2*Config::ItemsPerInvocation_1) { vector_lv1_t lv1_val; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+invocationIndex,lv1_val[i]); - lv1_val = inclusiveScan1(lv1_val); + lv1_val = reduction1(lv1_val); if (subgroup::ElectLast()) { const uint32_t bankedIndex = (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroupLog2 + (invocationIndex/Config::ItemsPerInvocation_2); @@ -333,14 +333,14 @@ struct reduce scratchAccessor.workgroupExecutionAndMemoryBarrier(); // level 2 scan - subgroup2::inclusive_scan inclusiveScan2; + subgroup2::reduction reduction2; if (glsl::gl_SubgroupID() == 0) { vector_lv2_t lv2_val; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++) scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+invocationIndex,lv2_val[i]); - lv2_val = inclusiveScan2(lv2_val); + lv2_val = reduction2(lv2_val); scratchAccessor.set(invocationIndex, lv2_val[Config::ItemsPerInvocation_2-1]); } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -350,7 +350,7 @@ struct reduce for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { scalar_t reduce_val; - scratchAccessor.get(Config::SubgroupSize-1,reduce_val); + scratchAccessor.get(glsl::gl_SubgroupInvocationID(),reduce_val); dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val); } } From d2a16634dc52ecd1271d9a39cb6bcbe3ada2056c Mon Sep 17 00:00:00 2001 From: keptsecret Date: Fri, 9 May 2025 14:03:47 +0700 Subject: [PATCH 18/25] fixes to calculating levels in config --- .../builtin/hlsl/workgroup2/shared_scan.hlsl | 70 +++++++++---------- 1 file changed, 33 insertions(+), 37 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 7ed16faf09..7ea8d6594b 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -23,7 +23,7 @@ namespace impl template struct virtual_wg_size_log2 { - NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2+2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value; + NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value; NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v+SubgroupSizeLog2; }; @@ -31,7 +31,7 @@ template; - NBL_CONSTEXPR_STATIC_INLINE uint16_t value0 = conditional_value::value; + NBL_CONSTEXPR_STATIC_INLINE uint16_t value0 = BaseItemsPerInvocation; NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t(0x1u) << conditional_value, ItemsPerInvocationProductLog2>::value; NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t(0x1u) << mpl::max_v; }; @@ -47,6 +47,7 @@ struct Configuration // must have at least enough level 0 outputs to feed a single subgroup NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v; + NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = 0x1u << SubgroupsPerVirtualWorkgroupLog2; using virtual_wg_t = impl::virtual_wg_size_log2; NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = virtual_wg_t::levels; @@ -67,8 +68,9 @@ struct Configuration<11, 4, ITEMS_PER_INVOC>\ NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << 11u;\ NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(4u);\ NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2;\ - NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = 128u;\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = 3;\ + NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = 7u;\ + NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = 128u;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = 3u;\ NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << 4096;\ NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = ITEMS_PER_INVOC;\ NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = 1u;\ @@ -106,13 +108,10 @@ struct reduce using params_t = subgroup2::ArithmeticParams; subgroup2::reduction reduction; - if (glsl::gl_SubgroupID() == 0) - { - vector_t value; - dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); - value = reduction(value); - dataAccessor.set(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with top line? - } + vector_t value; + dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); + value = reduction(value); + dataAccessor.set(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with top line? } }; @@ -129,22 +128,19 @@ struct scan using config_t = subgroup2::Configuration; using params_t = subgroup2::ArithmeticParams; - if (glsl::gl_SubgroupID() == 0) + vector_t value; + dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); + if (Exclusive) { - vector_t value; - dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); - if (Exclusive) - { - subgroup2::exclusive_scan excl_scan; - value = excl_scan(value); - } - else - { - subgroup2::inclusive_scan incl_scan; - value = incl_scan(value); - } - dataAccessor.set(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with above lines? + subgroup2::exclusive_scan excl_scan; + value = excl_scan(value); + } + else + { + subgroup2::inclusive_scan incl_scan; + value = incl_scan(value); } + dataAccessor.set(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with above lines? } }; @@ -176,7 +172,7 @@ struct reduce if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroupLog2 + (virtualSubgroupID/Config::ItemsPerInvocation_1); + const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } @@ -189,7 +185,7 @@ struct reduce vector_lv1_t lv1_val; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) - scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+invocationIndex,lv1_val[i]); + scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); lv1_val = reduction1(lv1_val); scratchAccessor.set(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1]); } @@ -233,7 +229,7 @@ struct scan if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroupLog2 + (virtualSubgroupID/Config::ItemsPerInvocation_1); + const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } @@ -247,7 +243,7 @@ struct scan const uint32_t prevIndex = invocationIndex-1; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) - scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+prevIndex,lv1_val[i]); + scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv1_val[i]); vector_lv1_t shiftedInput = hlsl::mix(hlsl::promote(BinOp::identity), lv1_val, bool(invocationIndex)); shiftedInput = inclusiveScan1(shiftedInput); scratchAccessor.set(invocationIndex, shiftedInput[Config::ItemsPerInvocation_1-1]); @@ -309,7 +305,7 @@ struct reduce if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroupLog2 + (virtualSubgroupID/Config::ItemsPerInvocation_1); + const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } @@ -322,11 +318,11 @@ struct reduce vector_lv1_t lv1_val; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) - scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+invocationIndex,lv1_val[i]); + scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); lv1_val = reduction1(lv1_val); if (subgroup::ElectLast()) { - const uint32_t bankedIndex = (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroupLog2 + (invocationIndex/Config::ItemsPerInvocation_2); + const uint32_t bankedIndex = (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroup + (invocationIndex/Config::ItemsPerInvocation_2); scratchAccessor.set(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); } } @@ -339,7 +335,7 @@ struct reduce vector_lv2_t lv2_val; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++) - scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+invocationIndex,lv2_val[i]); + scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv2_val[i]); lv2_val = reduction2(lv2_val); scratchAccessor.set(invocationIndex, lv2_val[Config::ItemsPerInvocation_2-1]); } @@ -385,7 +381,7 @@ struct scan if (subgroup::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); - const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroupLog2 + (virtualSubgroupID/Config::ItemsPerInvocation_1); + const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } @@ -398,11 +394,11 @@ struct scan vector_lv1_t lv1_val; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) - scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+invocationIndex,lv1_val[i]); + scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); lv1_val = inclusiveScan1(lv1_val); if (subgroup::ElectLast()) { - const uint32_t bankedIndex = (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroupLog2 + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2); + const uint32_t bankedIndex = (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroup + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2); scratchAccessor.set(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); } } @@ -416,7 +412,7 @@ struct scan const uint32_t prevIndex = invocationIndex-1; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++) - scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroupLog2+prevIndex,lv2_val[i]); + scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv2_val[i]); vector_lv2_t shiftedInput = hlsl::mix(hlsl::promote(BinOp::identity), lv2_val, bool(invocationIndex)); shiftedInput = inclusiveScan2(shiftedInput); scratchAccessor.set(invocationIndex, shiftedInput[Config::ItemsPerInvocation_2-1]); From ea39d9e698867a97b0d1f75ff356119d11b12302 Mon Sep 17 00:00:00 2001 From: keptsecret Date: Mon, 12 May 2025 16:17:49 +0700 Subject: [PATCH 19/25] fixes to 3-level scan --- .../builtin/hlsl/workgroup2/shared_scan.hlsl | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 7ea8d6594b..1abd9cccd2 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -58,6 +58,8 @@ struct Configuration NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = items_per_invoc_t::value1; NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = items_per_invoc_t::value2; static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!"); + + NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemSize = conditional_value::value + SubgroupsPerVirtualWorkgroup*ItemsPerInvocation_1; }; // special case when workgroup size 2048 and subgroup size 16 needs 3 levels and virtual workgroup size 4096 to get a full subgroup scan each on level 1 and 2 16x16x16=4096 @@ -388,8 +390,9 @@ struct scan scratchAccessor.workgroupExecutionAndMemoryBarrier(); // level 1 scan + const uint32_t lv1_smem_size = Config::SubgroupsPerVirtualWorkgroup*Config::ItemsPerInvocation_1; subgroup2::inclusive_scan inclusiveScan1; - if (glsl::gl_SubgroupID() < Config::SubgroupSizeLog2*Config::ItemsPerInvocation_1) + if (glsl::gl_SubgroupID() < lv1_smem_size) { vector_lv1_t lv1_val; [unroll] @@ -398,8 +401,8 @@ struct scan lv1_val = inclusiveScan1(lv1_val); if (subgroup::ElectLast()) { - const uint32_t bankedIndex = (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroup + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2); - scratchAccessor.set(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); + const uint32_t bankedIndex = (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2); + scratchAccessor.set(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -412,10 +415,20 @@ struct scan const uint32_t prevIndex = invocationIndex-1; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++) - scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv2_val[i]); + scratchAccessor.get(lv1_smem_size+i*Config::SubgroupSize+prevIndex,lv2_val[i]); vector_lv2_t shiftedInput = hlsl::mix(hlsl::promote(BinOp::identity), lv2_val, bool(invocationIndex)); shiftedInput = inclusiveScan2(shiftedInput); - scratchAccessor.set(invocationIndex, shiftedInput[Config::ItemsPerInvocation_2-1]); + + // combine with level 1, only last element of each + [unroll] + for (uint32_t i = 0; i < Config::SubgroupsPerVirtualWorkgroup; i++) + { + scalar_t last_val; + scratchAccessor.get((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i),last_val); + scalar_t val = hlsl::mix(hlsl::promote(BinOp::identity), lv2_val, bool(i)); + val = binop(last_val, shiftedInput[Config::ItemsPerInvocation_2-1]); + scratchAccessor.set((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i), last_val); + } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); From 1c0e72efdf18c17c474e6494a3850f3f132afbcb Mon Sep 17 00:00:00 2001 From: keptsecret Date: Wed, 14 May 2025 15:28:55 +0700 Subject: [PATCH 20/25] split config into new file --- examples_tests | 2 +- .../nbl/builtin/hlsl/subgroup2/ballot.hlsl | 13 +++ .../nbl/builtin/hlsl/workgroup2/config.hlsl | 88 +++++++++++++++++++ .../builtin/hlsl/workgroup2/shared_scan.hlsl | 86 ++---------------- 4 files changed, 111 insertions(+), 78 deletions(-) create mode 100644 include/nbl/builtin/hlsl/workgroup2/config.hlsl diff --git a/examples_tests b/examples_tests index 20011f5fdd..4a951b307b 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 20011f5fdd3e8454bb830ded6f4221ec75036809 +Subproject commit 4a951b307b09ecf4a054f7ac27d4dac01f5e8fb9 diff --git a/include/nbl/builtin/hlsl/subgroup2/ballot.hlsl b/include/nbl/builtin/hlsl/subgroup2/ballot.hlsl index 724887b995..6c7ec4f593 100644 --- a/include/nbl/builtin/hlsl/subgroup2/ballot.hlsl +++ b/include/nbl/builtin/hlsl/subgroup2/ballot.hlsl @@ -11,6 +11,19 @@ namespace hlsl namespace subgroup2 { +uint32_t LastSubgroupInvocation() +{ + // why this code was wrong before: + // - only compute can use SubgroupID + // - but there's no mapping of InvocationID to SubgroupID and Index + return glsl::subgroupBallotFindMSB(glsl::subgroupBallot(true)); +} + +bool ElectLast() +{ + return glsl::gl_SubgroupInvocationID()==LastSubgroupInvocation(); +} + template struct Configuration { diff --git a/include/nbl/builtin/hlsl/workgroup2/config.hlsl b/include/nbl/builtin/hlsl/workgroup2/config.hlsl new file mode 100644 index 0000000000..7855cc1701 --- /dev/null +++ b/include/nbl/builtin/hlsl/workgroup2/config.hlsl @@ -0,0 +1,88 @@ +// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_CONFIG_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP2_CONFIG_INCLUDED_ + +#include "nbl/builtin/hlsl/cpp_compat.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup2 +{ + +namespace impl +{ +template +struct virtual_wg_size_log2 +{ + NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value; + NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v+SubgroupSizeLog2; +}; + +template +struct items_per_invocation +{ + NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocationProductLog2 = mpl::max_v; + NBL_CONSTEXPR_STATIC_INLINE uint16_t value0 = BaseItemsPerInvocation; + NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t(0x1u) << conditional_value, ItemsPerInvocationProductLog2>::value; + NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t(0x1u) << mpl::max_v; +}; +} + +template +struct Configuration +{ + NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << WorkgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2; + static_assert(WorkgroupSizeLog2>=_SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize"); + + // must have at least enough level 0 outputs to feed a single subgroup + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v; + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroup = uint16_t(0x1u) << SubgroupsPerVirtualWorkgroupLog2; + + using virtual_wg_t = impl::virtual_wg_size_log2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = virtual_wg_t::levels; + NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << virtual_wg_t::value; + using items_per_invoc_t = impl::items_per_invocation; + // NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression + NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = items_per_invoc_t::value0; + NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = items_per_invoc_t::value1; + NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = items_per_invoc_t::value2; + static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!"); + + NBL_CONSTEXPR_STATIC_INLINE uint16_t SharedMemSize = conditional_value::value + SubgroupsPerVirtualWorkgroup*ItemsPerInvocation_1; +}; + +// special case when workgroup size 2048 and subgroup size 16 needs 3 levels and virtual workgroup size 4096 to get a full subgroup scan each on level 1 and 2 16x16x16=4096 +// specializing with macros because of DXC bug: https://github.com/microsoft/DirectXShaderCom0piler/issues/7007 +#define SPECIALIZE_CONFIG_CASE_2048_16(ITEMS_PER_INVOC) template<>\ +struct Configuration<11, 4, ITEMS_PER_INVOC>\ +{\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << 11u;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(4u);\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroupLog2 = 7u;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroup = 128u;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = 3u;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << 4096;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = ITEMS_PER_INVOC;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = 1u;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = 1u;\ +};\ + +SPECIALIZE_CONFIG_CASE_2048_16(1) +SPECIALIZE_CONFIG_CASE_2048_16(2) +SPECIALIZE_CONFIG_CASE_2048_16(4) + +} +} +} + +#undef SPECIALIZE_CONFIG_CASE_2048_16 + +#endif diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 1abd9cccd2..b03120b5f6 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -4,88 +4,20 @@ #ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_SHARED_SCAN_INCLUDED_ #define _NBL_BUILTIN_HLSL_WORKGROUP2_SHARED_SCAN_INCLUDED_ -#include "nbl/builtin/hlsl/cpp_compat.hlsl" #include "nbl/builtin/hlsl/workgroup/broadcast.hlsl" #include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" -#include "nbl/builtin/hlsl/subgroup/ballot.hlsl" +#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl" #include "nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl" #include "nbl/builtin/hlsl/mpl.hlsl" +#include "nbl/builtin/hlsl/workgroup2/config.hlsl" -namespace nbl +namespace nbl { namespace hlsl { namespace workgroup2 { -namespace impl -{ -template -struct virtual_wg_size_log2 -{ - NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value; - NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v+SubgroupSizeLog2; -}; - -template -struct items_per_invocation -{ - NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocationProductLog2 = mpl::max_v; - NBL_CONSTEXPR_STATIC_INLINE uint16_t value0 = BaseItemsPerInvocation; - NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t(0x1u) << conditional_value, ItemsPerInvocationProductLog2>::value; - NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t(0x1u) << mpl::max_v; -}; -} - -template -struct Configuration -{ - NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << WorkgroupSizeLog2; - NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(_SubgroupSizeLog2); - NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2; - static_assert(WorkgroupSizeLog2>=_SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize"); - - // must have at least enough level 0 outputs to feed a single subgroup - NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v; - NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = 0x1u << SubgroupsPerVirtualWorkgroupLog2; - - using virtual_wg_t = impl::virtual_wg_size_log2; - NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = virtual_wg_t::levels; - NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << virtual_wg_t::value; - using items_per_invoc_t = impl::items_per_invocation; - // NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression - NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = items_per_invoc_t::value0; - NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = items_per_invoc_t::value1; - NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = items_per_invoc_t::value2; - static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!"); - - NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemSize = conditional_value::value + SubgroupsPerVirtualWorkgroup*ItemsPerInvocation_1; -}; - -// special case when workgroup size 2048 and subgroup size 16 needs 3 levels and virtual workgroup size 4096 to get a full subgroup scan each on level 1 and 2 16x16x16=4096 -// specializing with macros because of DXC bug: https://github.com/microsoft/DirectXShaderCom0piler/issues/7007 -#define SPECIALIZE_CONFIG_CASE_2048_16(ITEMS_PER_INVOC) template<>\ -struct Configuration<11, 4, ITEMS_PER_INVOC>\ -{\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << 11u;\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(4u);\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2;\ - NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroupLog2 = 7u;\ - NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupsPerVirtualWorkgroup = 128u;\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = 3u;\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << 4096;\ - NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_0 = ITEMS_PER_INVOC;\ - NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_1 = 1u;\ - NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation_2 = 1u;\ -};\ - -SPECIALIZE_CONFIG_CASE_2048_16(1) -SPECIALIZE_CONFIG_CASE_2048_16(2) -SPECIALIZE_CONFIG_CASE_2048_16(4) - -#undef SPECIALIZE_CONFIG_CASE_2048_16 - - namespace impl { @@ -171,7 +103,7 @@ struct reduce { dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = reduction0(scan_local[idx]); - if (subgroup::ElectLast()) + if (subgroup2::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); @@ -228,7 +160,7 @@ struct scan { dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = inclusiveScan0(scan_local[idx]); - if (subgroup::ElectLast()) + if (subgroup2::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); @@ -304,7 +236,7 @@ struct reduce { dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = reduction0(scan_local[idx]); - if (subgroup::ElectLast()) + if (subgroup2::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); @@ -322,7 +254,7 @@ struct reduce for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); lv1_val = reduction1(lv1_val); - if (subgroup::ElectLast()) + if (subgroup2::ElectLast()) { const uint32_t bankedIndex = (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroup + (invocationIndex/Config::ItemsPerInvocation_2); scratchAccessor.set(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); @@ -380,7 +312,7 @@ struct scan { dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = inclusiveScan0(scan_local[idx]); - if (subgroup::ElectLast()) + if (subgroup2::ElectLast()) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); @@ -399,7 +331,7 @@ struct scan for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); lv1_val = inclusiveScan1(lv1_val); - if (subgroup::ElectLast()) + if (subgroup2::ElectLast()) { const uint32_t bankedIndex = (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2); scratchAccessor.set(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); From 507904f462c9fe50928b198ca2aabd7fa5c8b460 Mon Sep 17 00:00:00 2001 From: keptsecret Date: Thu, 15 May 2025 10:38:03 +0700 Subject: [PATCH 21/25] minor fixes --- examples_tests | 2 +- include/nbl/builtin/hlsl/subgroup2/ballot.hlsl | 9 +++++---- .../{config.hlsl => arithmetic_config.hlsl} | 8 ++++---- include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl | 12 ++++++------ 4 files changed, 16 insertions(+), 15 deletions(-) rename include/nbl/builtin/hlsl/workgroup2/{config.hlsl => arithmetic_config.hlsl} (95%) diff --git a/examples_tests b/examples_tests index a42a742f36..908abd110c 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit a42a742f363bda827991794053fb93fd803023f1 +Subproject commit 908abd110c387d48110ce8aeb67f0e0f2dd68943 diff --git a/include/nbl/builtin/hlsl/subgroup2/ballot.hlsl b/include/nbl/builtin/hlsl/subgroup2/ballot.hlsl index 6c7ec4f593..52ae6de2d9 100644 --- a/include/nbl/builtin/hlsl/subgroup2/ballot.hlsl +++ b/include/nbl/builtin/hlsl/subgroup2/ballot.hlsl @@ -11,12 +11,13 @@ namespace hlsl namespace subgroup2 { +template uint32_t LastSubgroupInvocation() { - // why this code was wrong before: - // - only compute can use SubgroupID - // - but there's no mapping of InvocationID to SubgroupID and Index - return glsl::subgroupBallotFindMSB(glsl::subgroupBallot(true)); + if (AssumeAllActive) + return glsl::gl_SubgroupSize()-1; + else + return glsl::subgroupBallotFindMSB(glsl::subgroupBallot(true)); } bool ElectLast() diff --git a/include/nbl/builtin/hlsl/workgroup2/config.hlsl b/include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl similarity index 95% rename from include/nbl/builtin/hlsl/workgroup2/config.hlsl rename to include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl index 7855cc1701..2f24c863da 100644 --- a/include/nbl/builtin/hlsl/workgroup2/config.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl @@ -1,8 +1,8 @@ // Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O. // This file is part of the "Nabla Engine". // For conditions of distribution and use, see copyright notice in nabla.h -#ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_CONFIG_INCLUDED_ -#define _NBL_BUILTIN_HLSL_WORKGROUP2_CONFIG_INCLUDED_ +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_ARITHMETIC_CONFIG_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP2_ARITHMETIC_CONFIG_INCLUDED_ #include "nbl/builtin/hlsl/cpp_compat.hlsl" @@ -33,7 +33,7 @@ struct items_per_invocation } template -struct Configuration +struct ArithmeticConfiguration { NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << WorkgroupSizeLog2; @@ -61,7 +61,7 @@ struct Configuration // special case when workgroup size 2048 and subgroup size 16 needs 3 levels and virtual workgroup size 4096 to get a full subgroup scan each on level 1 and 2 16x16x16=4096 // specializing with macros because of DXC bug: https://github.com/microsoft/DirectXShaderCom0piler/issues/7007 #define SPECIALIZE_CONFIG_CASE_2048_16(ITEMS_PER_INVOC) template<>\ -struct Configuration<11, 4, ITEMS_PER_INVOC>\ +struct ArithmeticConfiguration<11, 4, ITEMS_PER_INVOC>\ {\ NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << 11u;\ NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(4u);\ diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index b03120b5f6..681ba39911 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -103,7 +103,7 @@ struct reduce { dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = reduction0(scan_local[idx]); - if (subgroup2::ElectLast()) + if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); @@ -160,7 +160,7 @@ struct scan { dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = inclusiveScan0(scan_local[idx]); - if (subgroup2::ElectLast()) + if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); @@ -236,7 +236,7 @@ struct reduce { dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = reduction0(scan_local[idx]); - if (subgroup2::ElectLast()) + if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); @@ -254,7 +254,7 @@ struct reduce for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); lv1_val = reduction1(lv1_val); - if (subgroup2::ElectLast()) + if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t bankedIndex = (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroup + (invocationIndex/Config::ItemsPerInvocation_2); scratchAccessor.set(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); @@ -312,7 +312,7 @@ struct scan { dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = inclusiveScan0(scan_local[idx]); - if (subgroup2::ElectLast()) + if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); @@ -331,7 +331,7 @@ struct scan for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); lv1_val = inclusiveScan1(lv1_val); - if (subgroup2::ElectLast()) + if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t bankedIndex = (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2); scratchAccessor.set(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); From 542592f7c5926f601351bb1872d65e171b742440 Mon Sep 17 00:00:00 2001 From: keptsecret Date: Thu, 15 May 2025 14:44:10 +0700 Subject: [PATCH 22/25] soome changes to arithmetic config --- examples_tests | 2 +- .../hlsl/workgroup2/arithmetic_config.hlsl | 46 +++++++++---------- .../builtin/hlsl/workgroup2/shared_scan.hlsl | 2 +- 3 files changed, 23 insertions(+), 27 deletions(-) diff --git a/examples_tests b/examples_tests index 908abd110c..81238adaec 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 908abd110c387d48110ce8aeb67f0e0f2dd68943 +Subproject commit 81238adaecbd8d717bdab0dd73e08e2938a794c6 diff --git a/include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl b/include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl index 2f24c863da..d0800d6996 100644 --- a/include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl @@ -18,6 +18,8 @@ namespace impl template struct virtual_wg_size_log2 { + static_assert(WorkgroupSizeLog2>=SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize"); + static_assert(WorkgroupSizeLog2<=SubgroupSizeLog2+4, "WorkgroupSize cannot be larger than SubgroupSize*16"); NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value; NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v+SubgroupSizeLog2; }; @@ -30,6 +32,24 @@ struct items_per_invocation NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t(0x1u) << conditional_value, ItemsPerInvocationProductLog2>::value; NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t(0x1u) << mpl::max_v; }; + +// explicit specializations for cases that don't fit +#define SPECIALIZE_VIRTUAL_WG_SIZE_CASE(WGLOG2, SGLOG2, LEVELS, VALUE) template<>\ +struct virtual_wg_size_log2\ +{\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = LEVELS;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t value = VALUE;\ +};\ + +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(11,4,3,12); +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(7,7,1,7); +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(6,6,1,6); +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(5,5,1,5); +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(4,4,1,4); +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(3,3,1,3); +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(2,2,1,2); + +#undef SPECIALIZE_VIRTUAL_WG_SIZE_CASE } template @@ -39,7 +59,6 @@ struct ArithmeticConfiguration NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << WorkgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2; NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2; - static_assert(WorkgroupSizeLog2>=_SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize"); // must have at least enough level 0 outputs to feed a single subgroup NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v; @@ -55,34 +74,11 @@ struct ArithmeticConfiguration NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = items_per_invoc_t::value2; static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!"); - NBL_CONSTEXPR_STATIC_INLINE uint16_t SharedMemSize = conditional_value::value + SubgroupsPerVirtualWorkgroup*ItemsPerInvocation_1; + NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementCount = conditional_value::value + SubgroupSize*ItemsPerInvocation_1>::value; }; -// special case when workgroup size 2048 and subgroup size 16 needs 3 levels and virtual workgroup size 4096 to get a full subgroup scan each on level 1 and 2 16x16x16=4096 -// specializing with macros because of DXC bug: https://github.com/microsoft/DirectXShaderCom0piler/issues/7007 -#define SPECIALIZE_CONFIG_CASE_2048_16(ITEMS_PER_INVOC) template<>\ -struct ArithmeticConfiguration<11, 4, ITEMS_PER_INVOC>\ -{\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << 11u;\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(4u);\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2;\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroupLog2 = 7u;\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroup = 128u;\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = 3u;\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << 4096;\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = ITEMS_PER_INVOC;\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = 1u;\ - NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = 1u;\ -};\ - -SPECIALIZE_CONFIG_CASE_2048_16(1) -SPECIALIZE_CONFIG_CASE_2048_16(2) -SPECIALIZE_CONFIG_CASE_2048_16(4) - } } } -#undef SPECIALIZE_CONFIG_CASE_2048_16 - #endif diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 681ba39911..461b685c99 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -9,7 +9,7 @@ #include "nbl/builtin/hlsl/subgroup2/ballot.hlsl" #include "nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl" #include "nbl/builtin/hlsl/mpl.hlsl" -#include "nbl/builtin/hlsl/workgroup2/config.hlsl" +#include "nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl" namespace nbl { From a9930a025b4b252c1a08c4abc59cd1652cb666ac Mon Sep 17 00:00:00 2001 From: keptsecret Date: Thu, 15 May 2025 16:00:34 +0700 Subject: [PATCH 23/25] removed referencing workgroupID in scans --- examples_tests | 2 +- .../hlsl/workgroup2/arithmetic_config.hlsl | 10 ++++++++ .../builtin/hlsl/workgroup2/shared_scan.hlsl | 24 +++++++++---------- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/examples_tests b/examples_tests index 81238adaec..1de31ddfd7 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 81238adaecbd8d717bdab0dd73e08e2938a794c6 +Subproject commit 1de31ddfd725009bd650f1fe80f1c4a8c2e6a14a diff --git a/include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl b/include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl index d0800d6996..88ff328e05 100644 --- a/include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl @@ -77,6 +77,16 @@ struct ArithmeticConfiguration NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementCount = conditional_value::value + SubgroupSize*ItemsPerInvocation_1>::value; }; +template +struct is_configuration : bool_constant {}; + +template +struct is_configuration > : bool_constant {}; + +template +NBL_CONSTEXPR bool is_configuration_v = is_configuration::value; + + } } } diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 461b685c99..1043decd73 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -43,9 +43,9 @@ struct reduce subgroup2::reduction reduction; vector_t value; - dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); + dataAccessor.get(workgroup::SubgroupContiguousIndex(), value); value = reduction(value); - dataAccessor.set(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with top line? + dataAccessor.set(workgroup::SubgroupContiguousIndex(), value); } }; @@ -63,7 +63,7 @@ struct scan using params_t = subgroup2::ArithmeticParams; vector_t value; - dataAccessor.get(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); + dataAccessor.get(workgroup::SubgroupContiguousIndex(), value); if (Exclusive) { subgroup2::exclusive_scan excl_scan; @@ -74,7 +74,7 @@ struct scan subgroup2::inclusive_scan incl_scan; value = incl_scan(value); } - dataAccessor.set(glsl::gl_WorkGroupID().x * Config::SubgroupSize + workgroup::SubgroupContiguousIndex(), value); // can be safely merged with above lines? + dataAccessor.set(workgroup::SubgroupContiguousIndex(), value); // can be safely merged with above lines? } }; @@ -101,7 +101,7 @@ struct reduce [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = reduction0(scan_local[idx]); if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { @@ -131,7 +131,7 @@ struct reduce { scalar_t reduce_val; scratchAccessor.get(glsl::gl_SubgroupInvocationID(),reduce_val); - dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val); + dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val); } } }; @@ -158,7 +158,7 @@ struct scan [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = inclusiveScan0(scan_local[idx]); if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { @@ -204,7 +204,7 @@ struct scan for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) scan_local[idx][i] = binop(left, scan_local[idx][i]); } - dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); } } }; @@ -234,7 +234,7 @@ struct reduce [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = reduction0(scan_local[idx]); if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { @@ -281,7 +281,7 @@ struct reduce { scalar_t reduce_val; scratchAccessor.get(glsl::gl_SubgroupInvocationID(),reduce_val); - dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val); + dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val); } } }; @@ -310,7 +310,7 @@ struct scan [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - dataAccessor.get(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = inclusiveScan0(scan_local[idx]); if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { @@ -384,7 +384,7 @@ struct scan for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) scan_local[idx][i] = binop(left, scan_local[idx][i]); } - dataAccessor.set(glsl::gl_WorkGroupID().x * Config::VirtualWorkgroupSize + idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); } } }; From 55d89c5c2e3be03e178af923f0b70dc3420f63d4 Mon Sep 17 00:00:00 2001 From: keptsecret Date: Fri, 16 May 2025 10:09:41 +0700 Subject: [PATCH 24/25] no need to store locals in reduce --- .../nbl/builtin/hlsl/workgroup2/shared_scan.hlsl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index 1043decd73..add3acc687 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -94,20 +94,20 @@ struct reduce using params_lv1_t = subgroup2::ArithmeticParams; BinOp binop; - vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); // level 0 scan subgroup2::reduction reduction0; [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); - scan_local[idx] = reduction0(scan_local[idx]); + vector_lv0_t scan_local; + dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local); + scan_local = reduction0(scan_local); if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); - scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + scratchAccessor.set(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -227,20 +227,20 @@ struct reduce using params_lv2_t = subgroup2::ArithmeticParams; BinOp binop; - vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); // level 0 scan subgroup2::reduction reduction0; [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); - scan_local[idx] = reduction0(scan_local[idx]); + vector_lv0_t scan_local; + dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local); + scan_local = reduction0(scan_local); if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); - scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + scratchAccessor.set(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); From 4e4f26e994a2ca5c5009ba3768b0121b627f50bd Mon Sep 17 00:00:00 2001 From: keptsecret Date: Fri, 16 May 2025 11:18:51 +0700 Subject: [PATCH 25/25] added workgroup accessor concepts, refactor accessor usage --- examples_tests | 2 +- .../accessors/workgroup_arithmetic.hlsl | 57 ++++++++++++++++ .../builtin/hlsl/workgroup2/arithmetic.hlsl | 7 +- .../builtin/hlsl/workgroup2/shared_scan.hlsl | 66 +++++++++---------- src/nbl/builtin/CMakeLists.txt | 9 +++ 5 files changed, 104 insertions(+), 37 deletions(-) create mode 100644 include/nbl/builtin/hlsl/concepts/accessors/workgroup_arithmetic.hlsl diff --git a/examples_tests b/examples_tests index 1de31ddfd7..e828dc49ef 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 1de31ddfd725009bd650f1fe80f1c4a8c2e6a14a +Subproject commit e828dc49ef0a223dcbb8b4af8d722974747f29ee diff --git a/include/nbl/builtin/hlsl/concepts/accessors/workgroup_arithmetic.hlsl b/include/nbl/builtin/hlsl/concepts/accessors/workgroup_arithmetic.hlsl new file mode 100644 index 0000000000..de5e5a3c35 --- /dev/null +++ b/include/nbl/builtin/hlsl/concepts/accessors/workgroup_arithmetic.hlsl @@ -0,0 +1,57 @@ +#ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_WORKGROUP_ARITHMETIC_INCLUDED_ +#define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_WORKGROUP_ARITHMETIC_INCLUDED_ + +#include "nbl/builtin/hlsl/concepts.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup2 +{ + +#define NBL_CONCEPT_NAME ArithmeticSharedMemoryAccessor +#define NBL_CONCEPT_TPLT_PRM_KINDS (typename) +#define NBL_CONCEPT_TPLT_PRM_NAMES (T) +#define NBL_CONCEPT_PARAM_0 (accessor, T) +#define NBL_CONCEPT_PARAM_1 (index, uint32_t) +#define NBL_CONCEPT_PARAM_2 (val, uint32_t) +NBL_CONCEPT_BEGIN(3) +#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0 +#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1 +#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2 +NBL_CONCEPT_END( + ((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set(index, val)), is_same_v, void)) + ((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get(index, val)), is_same_v, void)) + ((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void)) +); +#undef val +#undef index +#undef accessor +#include + +#define NBL_CONCEPT_NAME ArithmeticDataAccessor +#define NBL_CONCEPT_TPLT_PRM_KINDS (typename) +#define NBL_CONCEPT_TPLT_PRM_NAMES (T) +#define NBL_CONCEPT_PARAM_0 (accessor, T) +#define NBL_CONCEPT_PARAM_1 (index, uint32_t) +#define NBL_CONCEPT_PARAM_2 (val, uint32_t) +NBL_CONCEPT_BEGIN(3) +#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0 +#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1 +#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2 +NBL_CONCEPT_END( + ((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set(index, val)), is_same_v, void)) + ((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get(index, val)), is_same_v, void)) + ((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void)) +); +#undef val +#undef index +#undef accessor +#include + +} +} +} + +#endif diff --git a/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl index 3b4a028d2c..d0a26cdf94 100644 --- a/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl @@ -8,6 +8,7 @@ #include "nbl/builtin/hlsl/functional.hlsl" #include "nbl/builtin/hlsl/workgroup/ballot.hlsl" #include "nbl/builtin/hlsl/workgroup/broadcast.hlsl" +#include "nbl/builtin/hlsl/concepts/accessors/workgroup_arithmetic.hlsl" #include "nbl/builtin/hlsl/workgroup2/shared_scan.hlsl" @@ -21,7 +22,7 @@ namespace workgroup2 template struct reduction { - template + template && ArithmeticSharedMemoryAccessor) static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) { impl::reduce fn; @@ -32,7 +33,7 @@ struct reduction template struct inclusive_scan { - template + template && ArithmeticSharedMemoryAccessor) static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) { impl::scan fn; @@ -43,7 +44,7 @@ struct inclusive_scan template struct exclusive_scan { - template + template && ArithmeticSharedMemoryAccessor) static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) { impl::scan fn; diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl index add3acc687..d53bfd6000 100644 --- a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -43,9 +43,9 @@ struct reduce subgroup2::reduction reduction; vector_t value; - dataAccessor.get(workgroup::SubgroupContiguousIndex(), value); + dataAccessor.template get(workgroup::SubgroupContiguousIndex(), value); value = reduction(value); - dataAccessor.set(workgroup::SubgroupContiguousIndex(), value); + dataAccessor.template set(workgroup::SubgroupContiguousIndex(), value); } }; @@ -63,7 +63,7 @@ struct scan using params_t = subgroup2::ArithmeticParams; vector_t value; - dataAccessor.get(workgroup::SubgroupContiguousIndex(), value); + dataAccessor.template get(workgroup::SubgroupContiguousIndex(), value); if (Exclusive) { subgroup2::exclusive_scan excl_scan; @@ -74,7 +74,7 @@ struct scan subgroup2::inclusive_scan incl_scan; value = incl_scan(value); } - dataAccessor.set(workgroup::SubgroupContiguousIndex(), value); // can be safely merged with above lines? + dataAccessor.template set(workgroup::SubgroupContiguousIndex(), value); // can be safely merged with above lines? } }; @@ -101,13 +101,13 @@ struct reduce for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { vector_lv0_t scan_local; - dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local); + dataAccessor.template get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local); scan_local = reduction0(scan_local); if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); - scratchAccessor.set(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + scratchAccessor.template set(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -119,9 +119,9 @@ struct reduce vector_lv1_t lv1_val; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) - scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); + scratchAccessor.template get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); lv1_val = reduction1(lv1_val); - scratchAccessor.set(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1]); + scratchAccessor.template set(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1]); } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -130,8 +130,8 @@ struct reduce for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { scalar_t reduce_val; - scratchAccessor.get(glsl::gl_SubgroupInvocationID(),reduce_val); - dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val); + scratchAccessor.template get(glsl::gl_SubgroupInvocationID(),reduce_val); + dataAccessor.template set(idx * Config::WorkgroupSize + virtualInvocationIndex, hlsl::promote(reduce_val)); } } }; @@ -158,13 +158,13 @@ struct scan [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + dataAccessor.template get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = inclusiveScan0(scan_local[idx]); if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); - scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + scratchAccessor.template set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -177,10 +177,10 @@ struct scan const uint32_t prevIndex = invocationIndex-1; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) - scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv1_val[i]); + scratchAccessor.template get(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv1_val[i]); vector_lv1_t shiftedInput = hlsl::mix(hlsl::promote(BinOp::identity), lv1_val, bool(invocationIndex)); shiftedInput = inclusiveScan1(shiftedInput); - scratchAccessor.set(invocationIndex, shiftedInput[Config::ItemsPerInvocation_1-1]); + scratchAccessor.template set(invocationIndex, shiftedInput[Config::ItemsPerInvocation_1-1]); } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -190,7 +190,7 @@ struct scan { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); scalar_t left; - scratchAccessor.get(virtualSubgroupID,left); + scratchAccessor.template get(virtualSubgroupID,left); if (Exclusive) { scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID())); @@ -204,7 +204,7 @@ struct scan for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) scan_local[idx][i] = binop(left, scan_local[idx][i]); } - dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + dataAccessor.template set(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); } } }; @@ -234,13 +234,13 @@ struct reduce for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { vector_lv0_t scan_local; - dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local); + dataAccessor.template get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local); scan_local = reduction0(scan_local); if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); - scratchAccessor.set(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + scratchAccessor.template set(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -252,12 +252,12 @@ struct reduce vector_lv1_t lv1_val; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) - scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); + scratchAccessor.template get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); lv1_val = reduction1(lv1_val); if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t bankedIndex = (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroup + (invocationIndex/Config::ItemsPerInvocation_2); - scratchAccessor.set(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); + scratchAccessor.template set(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -269,9 +269,9 @@ struct reduce vector_lv2_t lv2_val; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++) - scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv2_val[i]); + scratchAccessor.template get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv2_val[i]); lv2_val = reduction2(lv2_val); - scratchAccessor.set(invocationIndex, lv2_val[Config::ItemsPerInvocation_2-1]); + scratchAccessor.template set(invocationIndex, lv2_val[Config::ItemsPerInvocation_2-1]); } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -280,8 +280,8 @@ struct reduce for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { scalar_t reduce_val; - scratchAccessor.get(glsl::gl_SubgroupInvocationID(),reduce_val); - dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val); + scratchAccessor.template get(glsl::gl_SubgroupInvocationID(),reduce_val); + dataAccessor.template set(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val); } } }; @@ -310,13 +310,13 @@ struct scan [unroll] for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) { - dataAccessor.get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + dataAccessor.template get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); scan_local[idx] = inclusiveScan0(scan_local[idx]); if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); - scratchAccessor.set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + scratchAccessor.template set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -329,12 +329,12 @@ struct scan vector_lv1_t lv1_val; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) - scratchAccessor.get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); + scratchAccessor.template get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); lv1_val = inclusiveScan1(lv1_val); if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) { const uint32_t bankedIndex = (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2); - scratchAccessor.set(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); + scratchAccessor.template set(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -347,7 +347,7 @@ struct scan const uint32_t prevIndex = invocationIndex-1; [unroll] for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++) - scratchAccessor.get(lv1_smem_size+i*Config::SubgroupSize+prevIndex,lv2_val[i]); + scratchAccessor.template get(lv1_smem_size+i*Config::SubgroupSize+prevIndex,lv2_val[i]); vector_lv2_t shiftedInput = hlsl::mix(hlsl::promote(BinOp::identity), lv2_val, bool(invocationIndex)); shiftedInput = inclusiveScan2(shiftedInput); @@ -356,10 +356,10 @@ struct scan for (uint32_t i = 0; i < Config::SubgroupsPerVirtualWorkgroup; i++) { scalar_t last_val; - scratchAccessor.get((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i),last_val); + scratchAccessor.template get((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i),last_val); scalar_t val = hlsl::mix(hlsl::promote(BinOp::identity), lv2_val, bool(i)); val = binop(last_val, shiftedInput[Config::ItemsPerInvocation_2-1]); - scratchAccessor.set((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i), last_val); + scratchAccessor.template set((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i), last_val); } } scratchAccessor.workgroupExecutionAndMemoryBarrier(); @@ -370,7 +370,7 @@ struct scan { const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); const scalar_t left; - scratchAccessor.get(virtualSubgroupID, left); + scratchAccessor.template get(virtualSubgroupID, left); if (Exclusive) { scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID())); @@ -384,7 +384,7 @@ struct scan for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) scan_local[idx][i] = binop(left, scan_local[idx][i]); } - dataAccessor.set(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + dataAccessor.template set(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); } } }; diff --git a/src/nbl/builtin/CMakeLists.txt b/src/nbl/builtin/CMakeLists.txt index 9333a0d3b4..a6405a3c99 100644 --- a/src/nbl/builtin/CMakeLists.txt +++ b/src/nbl/builtin/CMakeLists.txt @@ -330,6 +330,10 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/basic.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/arithmetic_portability.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/arithmetic_portability_impl.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/fft.hlsl") +#subgroup2 +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup2/ballot.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup2/arithmetic_portability.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup2/arithmetic_portability_impl.hlsl") #shared header between C++ and HLSL LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/surface_transform.h") #workgroup @@ -341,6 +345,10 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/fft.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/scratch_size.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/shared_scan.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/shuffle.hlsl") +#workgroup2 +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup2/arithmetic_config.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup2/arithmetic.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup2/shared_scan.hlsl") #Extensions LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/ext/FullScreenTriangle/SVertexAttributes.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/ext/FullScreenTriangle/default.vert.hlsl") @@ -362,6 +370,7 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/loadable_i LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/mip_mapped.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/storable_image.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/fft.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/workgroup_arithmetic.hlsl") #tgmath LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/tgmath.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/tgmath/impl.hlsl")