Skip to content
Closed
19 changes: 19 additions & 0 deletions include/nbl/builtin/hlsl/bitonic_sort/common.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_
#define _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_

#include <nbl/builtin/hlsl/cpp_compat.hlsl>
#include <nbl/builtin/hlsl/concepts.hlsl>
#include <nbl/builtin/hlsl/math/intutil.hlsl>

namespace nbl
{
namespace hlsl
{
namespace bitonic_sort
{

}
}
}

#endif
110 changes: 110 additions & 0 deletions include/nbl/builtin/hlsl/subgroup/bitonic_sort.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#ifndef NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
#define NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
#include "nbl/builtin/hlsl/functional.hlsl"
namespace nbl
{
namespace hlsl
{
namespace subgroup
{

template<typename KeyType, typename ValueType, typename Comparator = less<KeyType> >
struct bitonic_sort_config
{
using key_t = KeyType;
using value_t = ValueType;
using comparator_t = Comparator;
};

template<typename Config, class device_capabilities = void>
struct bitonic_sort;

template<typename KeyType, typename ValueType, typename Comparator, class device_capabilities>
struct bitonic_sort<bitonic_sort_config<KeyType, ValueType, Comparator>, device_capabilities>
{
using config_t = bitonic_sort_config<KeyType, ValueType, Comparator>;
using key_t = typename config_t::key_t;
using value_t = typename config_t::value_t;
using comparator_t = typename config_t::comparator_t;

// Thread-level compare and swap (operates on lo/hi in registers)
static void compareAndSwap(bool ascending, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
comparator_t comp;
const bool shouldSwap = ascending ? comp(hiKey, loKey) : comp(loKey, hiKey);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler is probably dumb and might not realize the right term is the negation of the left term. Ternaries in SPIR-V usually get compiled to an OpSelect which treats both terms after the ? not as branches to conditionally execute, but as operands whose result must be evaluated before the select operation runs. That is to say, if the compiler is stupid you're going to run two comparisons. If you make the right term the negation of the left one, CSE is likely to kick in and evaluate the comparison only once.

if (shouldSwap)
{
// Swap keys
key_t tempKey = loKey;
loKey = hiKey;
hiKey = tempKey;
// Swap values
value_t tempVal = loVal;
loVal = hiVal;
hiVal = tempVal;
}
Comment on lines +39 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this branchless like you did the swaps in the subgroup branch

}

static void mergeStage(uint32_t stage, bool bitonicAscending, uint32_t invocationID, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
[unroll]
for (uint32_t pass = 0; pass <= stage; pass++)
{
const uint32_t stride = 1u << (stage - pass); // Element stride
const uint32_t threadStride = stride >> 1;
if (threadStride == 0)
{
// Local compare and swap for stage 0
compareAndSwap(bitonicAscending, loKey, hiKey, loVal, hiVal);
}
else
{
// Shuffle from partner using XOR
const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loKey, threadStride);
const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiKey, threadStride);
const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loVal, threadStride);
const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiVal, threadStride);

// Branchless compare-and-swap
const bool isUpper = bool(invocationID & threadStride);
const bool takeLarger = isUpper == bitonicAscending;
comparator_t comp;

// lo update
const bool loSelfSmaller = comp(loKey, pLoKey);
const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller;
loKey = takePartnerLo ? pLoKey : loKey;
loVal = takePartnerLo ? pLoVal : loVal;

// hi update
const bool hiSelfSmaller = comp(hiKey, pHiKey);
const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller;
hiKey = takePartnerHi ? pHiKey : hiKey;
hiVal = takePartnerHi ? pHiVal : hiVal;
Comment on lines +78 to +88
Copy link
Contributor

@Fletterio Fletterio Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like both the lo and hi update can be expressed using the compareAndSwap method above, using !takeLarger (or maybe takeLarger but I feel it's negated) instead of ascending

}
}
}

static void __call(bool ascending, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
const uint32_t invocationID = glsl::gl_SubgroupInvocationID();
const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2();
[unroll]
for (uint32_t stage = 0; stage <= subgroupSizeLog2; stage++)
{
const bool bitonicAscending = (stage == subgroupSizeLog2) ? ascending : !bool(invocationID & (1u << stage));
mergeStage(stage, bitonicAscending, invocationID, loKey, hiKey, loVal, hiVal);
}
}
};

}
}
}
#endif
Loading