diff --git a/examples_tests b/examples_tests index 825c73d5d8..e828dc49ef 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 825c73d5d8307efef2488f0b6ce82b69c32855ea +Subproject commit e828dc49ef0a223dcbb8b4af8d722974747f29ee diff --git a/include/nbl/builtin/hlsl/concepts/accessors/workgroup_arithmetic.hlsl b/include/nbl/builtin/hlsl/concepts/accessors/workgroup_arithmetic.hlsl new file mode 100644 index 0000000000..de5e5a3c35 --- /dev/null +++ b/include/nbl/builtin/hlsl/concepts/accessors/workgroup_arithmetic.hlsl @@ -0,0 +1,57 @@ +#ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_WORKGROUP_ARITHMETIC_INCLUDED_ +#define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_WORKGROUP_ARITHMETIC_INCLUDED_ + +#include "nbl/builtin/hlsl/concepts.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup2 +{ + +#define NBL_CONCEPT_NAME ArithmeticSharedMemoryAccessor +#define NBL_CONCEPT_TPLT_PRM_KINDS (typename) +#define NBL_CONCEPT_TPLT_PRM_NAMES (T) +#define NBL_CONCEPT_PARAM_0 (accessor, T) +#define NBL_CONCEPT_PARAM_1 (index, uint32_t) +#define NBL_CONCEPT_PARAM_2 (val, uint32_t) +NBL_CONCEPT_BEGIN(3) +#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0 +#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1 +#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2 +NBL_CONCEPT_END( + ((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set(index, val)), is_same_v, void)) + ((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get(index, val)), is_same_v, void)) + ((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void)) +); +#undef val +#undef index +#undef accessor +#include + +#define NBL_CONCEPT_NAME ArithmeticDataAccessor +#define NBL_CONCEPT_TPLT_PRM_KINDS (typename) +#define NBL_CONCEPT_TPLT_PRM_NAMES (T) +#define NBL_CONCEPT_PARAM_0 (accessor, T) +#define NBL_CONCEPT_PARAM_1 (index, uint32_t) +#define NBL_CONCEPT_PARAM_2 (val, uint32_t) +NBL_CONCEPT_BEGIN(3) +#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0 +#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1 +#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2 +NBL_CONCEPT_END( + ((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set(index, val)), is_same_v, void)) + ((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get(index, val)), is_same_v, void)) + ((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void)) +); +#undef val +#undef index +#undef accessor +#include + +} +} +} + +#endif diff --git a/include/nbl/builtin/hlsl/subgroup2/ballot.hlsl b/include/nbl/builtin/hlsl/subgroup2/ballot.hlsl index 724887b995..52ae6de2d9 100644 --- a/include/nbl/builtin/hlsl/subgroup2/ballot.hlsl +++ b/include/nbl/builtin/hlsl/subgroup2/ballot.hlsl @@ -11,6 +11,20 @@ namespace hlsl namespace subgroup2 { +template +uint32_t LastSubgroupInvocation() +{ + if (AssumeAllActive) + return glsl::gl_SubgroupSize()-1; + else + return glsl::subgroupBallotFindMSB(glsl::subgroupBallot(true)); +} + +bool ElectLast() +{ + return glsl::gl_SubgroupInvocationID()==LastSubgroupInvocation(); +} + template struct Configuration { diff --git a/include/nbl/builtin/hlsl/vector_utils/vector_traits.hlsl b/include/nbl/builtin/hlsl/vector_utils/vector_traits.hlsl index 9aefc3b3d8..652cabd7c7 100644 --- a/include/nbl/builtin/hlsl/vector_utils/vector_traits.hlsl +++ b/include/nbl/builtin/hlsl/vector_utils/vector_traits.hlsl @@ -28,6 +28,7 @@ struct vector_traits >\ NBL_CONSTEXPR_STATIC_INLINE bool IsVector = true;\ };\ +DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(1) DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(2) DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(3) DEFINE_VECTOR_TRAITS_TEMPLATE_SPECIALIZATION(4) diff --git a/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl new file mode 100644 index 0000000000..d0a26cdf94 --- /dev/null +++ b/include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl @@ -0,0 +1,59 @@ +// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_ARITHMETIC_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP2_ARITHMETIC_INCLUDED_ + + +#include "nbl/builtin/hlsl/functional.hlsl" +#include "nbl/builtin/hlsl/workgroup/ballot.hlsl" +#include "nbl/builtin/hlsl/workgroup/broadcast.hlsl" +#include "nbl/builtin/hlsl/concepts/accessors/workgroup_arithmetic.hlsl" +#include "nbl/builtin/hlsl/workgroup2/shared_scan.hlsl" + + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup2 +{ + +template +struct reduction +{ + template && ArithmeticSharedMemoryAccessor) + static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + impl::reduce fn; + fn.template __call(dataAccessor, scratchAccessor); + } +}; + +template +struct inclusive_scan +{ + template && ArithmeticSharedMemoryAccessor) + static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + impl::scan fn; + fn.template __call(dataAccessor, scratchAccessor); + } +}; + +template +struct exclusive_scan +{ + template && ArithmeticSharedMemoryAccessor) + static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + impl::scan fn; + fn.template __call(dataAccessor, scratchAccessor); + } +}; + +} +} +} + +#endif diff --git a/include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl b/include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl new file mode 100644 index 0000000000..88ff328e05 --- /dev/null +++ b/include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl @@ -0,0 +1,94 @@ +// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_ARITHMETIC_CONFIG_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP2_ARITHMETIC_CONFIG_INCLUDED_ + +#include "nbl/builtin/hlsl/cpp_compat.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup2 +{ + +namespace impl +{ +template +struct virtual_wg_size_log2 +{ + static_assert(WorkgroupSizeLog2>=SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize"); + static_assert(WorkgroupSizeLog2<=SubgroupSizeLog2+4, "WorkgroupSize cannot be larger than SubgroupSize*16"); + NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value; + NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v+SubgroupSizeLog2; +}; + +template +struct items_per_invocation +{ + NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocationProductLog2 = mpl::max_v; + NBL_CONSTEXPR_STATIC_INLINE uint16_t value0 = BaseItemsPerInvocation; + NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t(0x1u) << conditional_value, ItemsPerInvocationProductLog2>::value; + NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t(0x1u) << mpl::max_v; +}; + +// explicit specializations for cases that don't fit +#define SPECIALIZE_VIRTUAL_WG_SIZE_CASE(WGLOG2, SGLOG2, LEVELS, VALUE) template<>\ +struct virtual_wg_size_log2\ +{\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = LEVELS;\ + NBL_CONSTEXPR_STATIC_INLINE uint16_t value = VALUE;\ +};\ + +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(11,4,3,12); +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(7,7,1,7); +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(6,6,1,6); +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(5,5,1,5); +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(4,4,1,4); +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(3,3,1,3); +SPECIALIZE_VIRTUAL_WG_SIZE_CASE(2,2,1,2); + +#undef SPECIALIZE_VIRTUAL_WG_SIZE_CASE +} + +template +struct ArithmeticConfiguration +{ + NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << WorkgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2; + + // must have at least enough level 0 outputs to feed a single subgroup + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v; + NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroup = uint16_t(0x1u) << SubgroupsPerVirtualWorkgroupLog2; + + using virtual_wg_t = impl::virtual_wg_size_log2; + NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = virtual_wg_t::levels; + NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << virtual_wg_t::value; + using items_per_invoc_t = impl::items_per_invocation; + // NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression + NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = items_per_invoc_t::value0; + NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = items_per_invoc_t::value1; + NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = items_per_invoc_t::value2; + static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!"); + + NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementCount = conditional_value::value + SubgroupSize*ItemsPerInvocation_1>::value; +}; + +template +struct is_configuration : bool_constant {}; + +template +struct is_configuration > : bool_constant {}; + +template +NBL_CONSTEXPR bool is_configuration_v = is_configuration::value; + + +} +} +} + +#endif diff --git a/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl new file mode 100644 index 0000000000..d53bfd6000 --- /dev/null +++ b/include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl @@ -0,0 +1,398 @@ +// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_WORKGROUP2_SHARED_SCAN_INCLUDED_ +#define _NBL_BUILTIN_HLSL_WORKGROUP2_SHARED_SCAN_INCLUDED_ + +#include "nbl/builtin/hlsl/workgroup/broadcast.hlsl" +#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl" +#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl" +#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl" +#include "nbl/builtin/hlsl/mpl.hlsl" +#include "nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace workgroup2 +{ + +namespace impl +{ + +template +struct reduce; + +template +struct scan; + +// 1-level scans +template +struct reduce +{ + using scalar_t = typename BinOp::type_t; + using vector_t = vector; // data accessor needs to be this type + // doesn't use scratch smem, need as param? + + template + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + using config_t = subgroup2::Configuration; + using params_t = subgroup2::ArithmeticParams; + + subgroup2::reduction reduction; + vector_t value; + dataAccessor.template get(workgroup::SubgroupContiguousIndex(), value); + value = reduction(value); + dataAccessor.template set(workgroup::SubgroupContiguousIndex(), value); + } +}; + +template +struct scan +{ + using scalar_t = typename BinOp::type_t; + using vector_t = vector; // data accessor needs to be this type + // doesn't use scratch smem, need as param? + + template + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + using config_t = subgroup2::Configuration; + using params_t = subgroup2::ArithmeticParams; + + vector_t value; + dataAccessor.template get(workgroup::SubgroupContiguousIndex(), value); + if (Exclusive) + { + subgroup2::exclusive_scan excl_scan; + value = excl_scan(value); + } + else + { + subgroup2::inclusive_scan incl_scan; + value = incl_scan(value); + } + dataAccessor.template set(workgroup::SubgroupContiguousIndex(), value); // can be safely merged with above lines? + } +}; + +// 2-level scans +template +struct reduce +{ + using scalar_t = typename BinOp::type_t; + using vector_lv0_t = vector; // data accessor needs to be this type + using vector_lv1_t = vector; + + template + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + using config_t = subgroup2::Configuration; + using params_lv0_t = subgroup2::ArithmeticParams; + using params_lv1_t = subgroup2::ArithmeticParams; + BinOp binop; + + const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); + // level 0 scan + subgroup2::reduction reduction0; + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + vector_lv0_t scan_local; + dataAccessor.template get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local); + scan_local = reduction0(scan_local); + if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); + scratchAccessor.template set(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + } + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // level 1 scan + subgroup2::reduction reduction1; + if (glsl::gl_SubgroupID() == 0) + { + vector_lv1_t lv1_val; + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) + scratchAccessor.template get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); + lv1_val = reduction1(lv1_val); + scratchAccessor.template set(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1]); + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // set as last element in scan (reduction) + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + scalar_t reduce_val; + scratchAccessor.template get(glsl::gl_SubgroupInvocationID(),reduce_val); + dataAccessor.template set(idx * Config::WorkgroupSize + virtualInvocationIndex, hlsl::promote(reduce_val)); + } + } +}; + +template +struct scan +{ + using scalar_t = typename BinOp::type_t; + using vector_lv0_t = vector; // data accessor needs to be this type + using vector_lv1_t = vector; + + template + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + using config_t = subgroup2::Configuration; + using params_lv0_t = subgroup2::ArithmeticParams; + using params_lv1_t = subgroup2::ArithmeticParams; + BinOp binop; + + vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; + const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); + subgroup2::inclusive_scan inclusiveScan0; + // level 0 scan + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + dataAccessor.template get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + scan_local[idx] = inclusiveScan0(scan_local[idx]); + if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); + scratchAccessor.template set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + } + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // level 1 scan + subgroup2::inclusive_scan inclusiveScan1; + if (glsl::gl_SubgroupID() == 0) + { + vector_lv1_t lv1_val; + const uint32_t prevIndex = invocationIndex-1; + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) + scratchAccessor.template get(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv1_val[i]); + vector_lv1_t shiftedInput = hlsl::mix(hlsl::promote(BinOp::identity), lv1_val, bool(invocationIndex)); + shiftedInput = inclusiveScan1(shiftedInput); + scratchAccessor.template set(invocationIndex, shiftedInput[Config::ItemsPerInvocation_1-1]); + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // combine with level 0 + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + scalar_t left; + scratchAccessor.template get(virtualSubgroupID,left); + if (Exclusive) + { + scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID())); + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) + scan_local[idx][Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(scan_local[idx][Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0))); + } + else + { + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) + scan_local[idx][i] = binop(left, scan_local[idx][i]); + } + dataAccessor.template set(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + } + } +}; + +// 3-level scans +template +struct reduce +{ + using scalar_t = typename BinOp::type_t; + using vector_lv0_t = vector; // data accessor needs to be this type + using vector_lv1_t = vector; + using vector_lv2_t = vector; + + template + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + using config_t = subgroup2::Configuration; + using params_lv0_t = subgroup2::ArithmeticParams; + using params_lv1_t = subgroup2::ArithmeticParams; + using params_lv2_t = subgroup2::ArithmeticParams; + BinOp binop; + + const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); + // level 0 scan + subgroup2::reduction reduction0; + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + vector_lv0_t scan_local; + dataAccessor.template get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local); + scan_local = reduction0(scan_local); + if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); + scratchAccessor.template set(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + } + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // level 1 scan + subgroup2::reduction reduction1; + if (glsl::gl_SubgroupID() < Config::SubgroupSizeLog2*Config::ItemsPerInvocation_1) + { + vector_lv1_t lv1_val; + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) + scratchAccessor.template get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); + lv1_val = reduction1(lv1_val); + if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) + { + const uint32_t bankedIndex = (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroup + (invocationIndex/Config::ItemsPerInvocation_2); + scratchAccessor.template set(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); + } + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // level 2 scan + subgroup2::reduction reduction2; + if (glsl::gl_SubgroupID() == 0) + { + vector_lv2_t lv2_val; + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++) + scratchAccessor.template get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv2_val[i]); + lv2_val = reduction2(lv2_val); + scratchAccessor.template set(invocationIndex, lv2_val[Config::ItemsPerInvocation_2-1]); + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // set as last element in scan (reduction) + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + scalar_t reduce_val; + scratchAccessor.template get(glsl::gl_SubgroupInvocationID(),reduce_val); + dataAccessor.template set(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val); + } + } +}; + +template +struct scan +{ + using scalar_t = typename BinOp::type_t; + using vector_lv0_t = vector; // data accessor needs to be this type + using vector_lv1_t = vector; + using vector_lv2_t = vector; + + template + void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor) + { + using config_t = subgroup2::Configuration; + using params_lv0_t = subgroup2::ArithmeticParams; + using params_lv1_t = subgroup2::ArithmeticParams; + using params_lv2_t = subgroup2::ArithmeticParams; + BinOp binop; + + vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize]; + const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex(); + subgroup2::inclusive_scan inclusiveScan0; + // level 0 scan + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + dataAccessor.template get(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + scan_local[idx] = inclusiveScan0(scan_local[idx]); + if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1); + scratchAccessor.template set(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan + } + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // level 1 scan + const uint32_t lv1_smem_size = Config::SubgroupsPerVirtualWorkgroup*Config::ItemsPerInvocation_1; + subgroup2::inclusive_scan inclusiveScan1; + if (glsl::gl_SubgroupID() < lv1_smem_size) + { + vector_lv1_t lv1_val; + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++) + scratchAccessor.template get(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]); + lv1_val = inclusiveScan1(lv1_val); + if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1) + { + const uint32_t bankedIndex = (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2); + scratchAccessor.template set(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]); + } + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // level 2 scan + subgroup2::inclusive_scan inclusiveScan2; + if (glsl::gl_SubgroupID() == 0) + { + vector_lv2_t lv2_val; + const uint32_t prevIndex = invocationIndex-1; + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++) + scratchAccessor.template get(lv1_smem_size+i*Config::SubgroupSize+prevIndex,lv2_val[i]); + vector_lv2_t shiftedInput = hlsl::mix(hlsl::promote(BinOp::identity), lv2_val, bool(invocationIndex)); + shiftedInput = inclusiveScan2(shiftedInput); + + // combine with level 1, only last element of each + [unroll] + for (uint32_t i = 0; i < Config::SubgroupsPerVirtualWorkgroup; i++) + { + scalar_t last_val; + scratchAccessor.template get((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i),last_val); + scalar_t val = hlsl::mix(hlsl::promote(BinOp::identity), lv2_val, bool(i)); + val = binop(last_val, shiftedInput[Config::ItemsPerInvocation_2-1]); + scratchAccessor.template set((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i), last_val); + } + } + scratchAccessor.workgroupExecutionAndMemoryBarrier(); + + // combine with level 0 + [unroll] + for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++) + { + const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID(); + const scalar_t left; + scratchAccessor.template get(virtualSubgroupID, left); + if (Exclusive) + { + scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID())); + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) + scan_local[idx][Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(scan_local[idx][Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0))); + } + else + { + [unroll] + for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++) + scan_local[idx][i] = binop(left, scan_local[idx][i]); + } + dataAccessor.template set(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]); + } + } +}; + +} + +} +} +} + +#endif diff --git a/src/nbl/builtin/CMakeLists.txt b/src/nbl/builtin/CMakeLists.txt index 9333a0d3b4..a6405a3c99 100644 --- a/src/nbl/builtin/CMakeLists.txt +++ b/src/nbl/builtin/CMakeLists.txt @@ -330,6 +330,10 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/basic.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/arithmetic_portability.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/arithmetic_portability_impl.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/fft.hlsl") +#subgroup2 +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup2/ballot.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup2/arithmetic_portability.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup2/arithmetic_portability_impl.hlsl") #shared header between C++ and HLSL LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/surface_transform.h") #workgroup @@ -341,6 +345,10 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/fft.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/scratch_size.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/shared_scan.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/shuffle.hlsl") +#workgroup2 +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup2/arithmetic_config.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup2/arithmetic.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup2/shared_scan.hlsl") #Extensions LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/ext/FullScreenTriangle/SVertexAttributes.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/ext/FullScreenTriangle/default.vert.hlsl") @@ -362,6 +370,7 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/loadable_i LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/mip_mapped.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/storable_image.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/fft.hlsl") +LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/workgroup_arithmetic.hlsl") #tgmath LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/tgmath.hlsl") LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/tgmath/impl.hlsl")