Skip to content
Closed
19 changes: 19 additions & 0 deletions include/nbl/builtin/hlsl/bitonic_sort/common.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_
#define _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_

#include <nbl/builtin/hlsl/cpp_compat.hlsl>
#include <nbl/builtin/hlsl/concepts.hlsl>
#include <nbl/builtin/hlsl/math/intutil.hlsl>

namespace nbl
{
namespace hlsl
{
namespace bitonic_sort
{

}
}
}

#endif
136 changes: 136 additions & 0 deletions include/nbl/builtin/hlsl/subgroup/bitonic_sort.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#ifndef NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
#define NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
#include "nbl/builtin/hlsl/functional.hlsl"
namespace nbl
{
namespace hlsl
{
namespace subgroup
{

template<typename KeyType, typename ValueType, typename Comparator = less<KeyType> >
struct bitonic_sort_config
{
using key_t = KeyType;
using value_t = ValueType;
using comparator_t = Comparator;
};
template<bool Ascending, typename Config, class device_capabilities = void>
struct bitonic_sort;
template<bool Ascending, typename KeyType, typename ValueType, typename Comparator, class device_capabilities>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that Ascending is used because when moving onto workgroup you're going to need to call alternating subgroup sorts. However, as a front-facing API if I wanted a single subgroup shuffle I'd usually want it in the order specified by the Comparator. Maybe push it after the Config and give it a default value of true. Or better yet, since Ascending can be confusing, consider calling it ReverseOrder or something simpler that conveys the intent better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ascending and later names like takeLarger implicitly assume the comparator is less (lo and hi don't, those are related to the "lane" order in the bitonic sort diagram). That's fine on its own, it makes the code more readable vs naming them with a more generic option. However, there should be comments mentioning that names assume this implicitly so there's no confusion.

struct bitonic_sort<Ascending, bitonic_sort_config<KeyType, ValueType, Comparator>, device_capabilities>
{
using config_t = bitonic_sort_config<KeyType, ValueType, Comparator>;
using key_t = typename config_t::key_t;
using value_t = typename config_t::value_t;
using comparator_t = typename config_t::comparator_t;
// Thread-level compare and swap (operates on lo/hi in registers)
static void compareAndSwap(bool ascending, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
comparator_t comp;
const bool shouldSwap = ascending ? comp(hiKey, loKey) : comp(loKey, hiKey);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler is probably dumb and might not realize the right term is the negation of the left term. Ternaries in SPIR-V usually get compiled to an OpSelect which treats both terms after the ? not as branches to conditionally execute, but as operands whose result must be evaluated before the select operation runs. That is to say, if the compiler is stupid you're going to run two comparisons. If you make the right term the negation of the left one, CSE is likely to kick in and evaluate the comparison only once.

if (shouldSwap)
{
// Swap keys
key_t tempKey = loKey;
loKey = hiKey;
hiKey = tempKey;
// Swap values
value_t tempVal = loVal;
loVal = hiVal;
hiVal = tempVal;
}
Comment on lines +39 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this branchless like you did the swaps in the subgroup branch

}


static void lastMergeStage(uint32_t stage, uint32_t invocationID, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the end this is just mergeStage with bitonicAscending = true, right? I think you can just have mergeStage and avoid having this function duplicated

NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
[unroll]
for (uint32_t pass = 0; pass <= stage; pass++)
{
const uint32_t stride = 1u << (stage - pass); // Element stride
const uint32_t threadStride = stride >> 1;
if (threadStride == 0)
{
// Local compare and swap for stage 0
compareAndSwap(Ascending, loKey, hiKey, loVal, hiVal);
}
else
{
// Shuffle from partner using XOR
const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loKey, threadStride);
const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiKey, threadStride);
const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loVal, threadStride);
const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiVal, threadStride);
comparator_t comp;
if (comp(loKey, pLoKey)) { loKey = pLoKey; loVal = pLoVal; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unlike the other method, both threads keep the min elements? Like upperHalf is not being considred here, so I'm inclined to believe this function is going to fail. I'd delete this method and just use mergeStage, since this is just that method but with a forced bitonicAscending = true.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test this code to make sure it's right, but it feels wrong. Either way, just use mergeStage and avoid having this duped.

if (comp(hiKey, pHiKey)) { hiKey = pHiKey; hiVal = pHiVal; }

}

}
}

static void mergeStage(uint32_t stage, bool bitonicAscending, uint32_t invocationID, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
[unroll]
for (uint32_t pass = 0; pass <= stage; pass++)
{
const uint32_t stride = 1u << (stage - pass); // Element stride
const uint32_t threadStride = stride >> 1;
if (threadStride == 0)
{
// Local compare and swap for stage 0
compareAndSwap(bitonicAscending, loKey, hiKey, loVal, hiVal);
}
else
{
// Shuffle from partner using XOR
const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loKey, threadStride);
const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiKey, threadStride);
const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loVal, threadStride);
const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiVal, threadStride);
// Determine if we're upper or lower half
const bool upperHalf = bool(invocationID & threadStride);
const bool takeLarger = upperHalf == bitonicAscending;
comparator_t comp;
if (takeLarger)
{
if (comp(loKey, pLoKey)) { loKey = pLoKey; loVal = pLoVal; }
if (comp(hiKey, pHiKey)) { hiKey = pHiKey; hiVal = pHiVal; }
}
else
{
if (comp(pLoKey, loKey)) { loKey = pLoKey; loVal = pLoVal; }
if (comp(pHiKey, hiKey)) { hiKey = pHiKey; hiVal = pHiVal; }
}
}
}
}

static void __call(NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
const uint32_t invocationID = glsl::gl_SubgroupInvocationID();
const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2();
[unroll]
for (uint32_t stage = 0; stage < subgroupSizeLog2; stage++)
{
const bool bitonicAscending = (stage == subgroupSizeLog2) ? Ascending : !bool(invocationID & (1u << stage));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stage == subgroupSizeLog2is never true in this loop, so just assign the term for the false clause.

mergeStage(stage, bitonicAscending, invocationID, loKey, hiKey, loVal, hiVal);
}
lastMergeStage(subgroupSizeLog2, invocationID, loKey, hiKey, loVal, hiVal);

}
};

}
}
}
#endif
155 changes: 155 additions & 0 deletions include/nbl/builtin/hlsl/workgroup/bitonic_sort.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#ifndef NBL_BUILTIN_HLSL_WORKGROUP_BITONIC_SORT_INCLUDED
#define NBL_BUILTIN_HLSL_WORKGROUP_BITONIC_SORT_INCLUDED
#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl"
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
#include "nbl/builtin/hlsl/functional.hlsl"
#include "nbl/builtin/hlsl/subgroup/bitonic_sort.hlsl"
#include "nbl/builtin/hlsl/bit.hlsl"
#include "nbl/builtin/hlsl/workgroup/shuffle.hlsl"
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"

namespace nbl
{
namespace hlsl
{
namespace workgroup
{
namespace bitonic_sort
{
// Reorder: non-type parameters FIRST, then typename parameters with defaults
// This matches FFT's pattern and avoids DXC bugs
template<uint16_t _ElementsPerInvocationLog2, uint16_t _WorkgroupSizeLog2, typename KeyType, typename ValueType, typename Comparator = less<KeyType> >
struct bitonic_sort_config
{
using key_t = KeyType;
using value_t = ValueType;
using comparator_t = Comparator;

NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = _ElementsPerInvocationLog2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NBL_CONSTEXPR_STATIC_INLINE resolves to const static when preprocessed. DXC is stupid and when it sees a static it WILL initialize a variable. But if it sees const on its own it does compile it down to a constant, which is the behaviour you would expect (and what would happen in C++). So for now just replace these usages with just const.

@devshgraphicsprogramming do we just change this macro to resolve to const in HLSL in master?

NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;

NBL_CONSTEXPR_STATIC_INLINE uint32_t ElementsPerInvocation = 1u << ElementsPerInvocationLog2;
NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkgroupSize = 1u << WorkgroupSizeLog2;
};
}

template<typename Config, class device_capabilities = void>
struct BitonicSort;


template<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2, typename KeyType, typename ValueType, typename Comparator, class device_capabilities>
struct BitonicSort<bitonic_sort::bitonic_sort_config<ElementsPerInvocationLog2, WorkgroupSizeLog2, KeyType, ValueType, Comparator>, device_capabilities>
{
using config_t = bitonic_sort::bitonic_sort_config<ElementsPerInvocationLog2, WorkgroupSizeLog2, KeyType, ValueType, Comparator>;
using key_t = KeyType;
using value_t = ValueType;
using comparator_t = Comparator;

using SortConfig = subgroup::bitonic_sort_config<key_t, value_t, comparator_t>;

template<typename SharedMemoryAccessor>
static void mergeStage(NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor, uint32_t stage, bool bitonicAscending, uint32_t invocationID, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkgroupSize = config_t::WorkgroupSize;
using adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor, key_t, value_t, 1, WorkgroupSize>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong. You're making the index type for the array be value_t. If your values were floats, for example, this would blow up. You are getting lucky here because (I guess) the types you're testing with are all integers.

You'll want the shared memory accessor to satisfy this concept:

#define NBL_CONCEPT_NAME GenericSharedMemoryAccessor

That concept basically states that the accessor can read and write uint32_ts. This is because shared memory (in most architectures) works with certain restrictions due to memory banking and size per transaction for each bank. It is the adaptor that is later in charge of reading/writing from/to shared memory with your actual type.

What you want here is to have a SharedMemoryAccessor accessing shared memory that is at least max(sizeof(key_t), sizeof(value_t)) * ArraySize bytes (this is unenforceable via concepts but you can make a utility for the config that returns this value so the user can allocate such an array).

Then you want TWO different adaptors: one is going to be
key_adaptor = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor, key_t, uint32_t, 1, WorkgroupSize>
and the other is going to be
value_adaptor = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor, value_t, uint32_t, 1, WorkgroupSize>

You would then shuffle the keys using the key adaptor, barrier, then shuffle the values using the value adaptor

adaptor_t sharedmemAdaptor;
sharedmemAdaptor.accessor = sharedmemAccessor;

const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2();

[unroll]
for (uint32_t pass = 0; pass <= stage; pass++)
{
// Stride calculation: stage S merges 2^(S+1) subgroups
const uint32_t stridePower = (stage - pass + 1) + subgroupSizeLog2;
const uint32_t stride = 1u << stridePower;
const uint32_t threadStride = stride >> 1;

// Separate shuffles for lo/hi streams (two-round shuffle as per PR review)
// TODO: Consider single-round shuffle of key-value pairs for better performance
key_t pLoKey = loKey;
shuffleXor(pLoKey, threadStride, sharedmemAdaptor);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be a shuffle using the key_adaptor

value_t pLoVal = loVal;
shuffleXor(pLoVal, threadStride, sharedmemAdaptor);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this one would be a shuffle using the value_adaptor


key_t pHiKey = hiKey;
shuffleXor(pHiKey, threadStride, sharedmemAdaptor);
value_t pHiVal = hiVal;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inbetween shuffles, your array is aliased. The first shuffle has the following behaviour: all threads write values, then there's a barrier (so all threads are done writing before they start reading) then they start reading. On the next shuffle, they immediately start writing again. If you don't barrier inbetween these shuffles, you risk writing before some other thread was done reading, overwriting what needed to be read. Between shuffles, therefore, you need to barrier to unalias the memory

shuffleXor(pHiVal, threadStride, sharedmemAdaptor);

const bool isUpper = (invocationID & threadStride) != 0;
const bool takeLarger = isUpper == bitonicAscending;

comparator_t comp;

// lo update
const bool loSelfSmaller = comp(loKey, pLoKey);
const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller;
loKey = takePartnerLo ? pLoKey : loKey;
loVal = takePartnerLo ? pLoVal : loVal;

// hi update
const bool hiSelfSmaller = comp(hiKey, pHiKey);
const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller;
hiKey = takePartnerHi ? pHiKey : hiKey;
hiVal = takePartnerHi ? pHiVal : hiVal;

sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
}
}

template<typename Accessor, typename SharedMemoryAccessor>
static void __call(
NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor,
NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkgroupSize = config_t::WorkgroupSize;

const uint32_t invocationID = glsl::gl_LocalInvocationID().x;
const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2();
const uint32_t subgroupSize = 1u << subgroupSizeLog2;
const uint32_t subgroupID = glsl::gl_SubgroupID();
const uint32_t numSubgroups = WorkgroupSize / subgroupSize;
const uint32_t numSubgroupsLog2 = findMSB(numSubgroups);


const bool subgroupAscending = (subgroupID & 1) == 0;
subgroup::bitonic_sort<SortConfig>::__call(subgroupAscending, loKey, hiKey, loVal, hiVal);


[unroll]
for (uint32_t stage = 0; stage < numSubgroupsLog2; ++stage)
{
const bool isLastStage = (stage == numSubgroupsLog2 - 1);
const bool bitonicAscending = isLastStage ? true : !bool(invocationID & (subgroupSize << (stage + 1)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be wrong here but it feels like the formula on the right yields true even for the last stage (at that point the single 1 in the bitmask is too far to the left, so the result of the & is a 0)


mergeStage(sharedmemAccessor, stage, bitonicAscending, invocationID, loKey, hiKey, loVal, hiVal);

const uint32_t subgroupInvocationID = glsl::gl_SubgroupInvocationID();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pull this one out of the loop

subgroup::bitonic_sort<SortConfig>::mergeStage(subgroupSizeLog2, bitonicAscending, subgroupInvocationID, loKey, hiKey, loVal, hiVal);
}


// Final: ensure lo <= hi within each thread (for ascending sort)
comparator_t comp;
if (comp(hiKey, loKey))
{
// Swap keys
key_t tempKey = loKey;
loKey = hiKey;
hiKey = tempKey;
// Swap values
value_t tempVal = loVal;
loVal = hiVal;
hiVal = tempVal;
}
}
};

}
}
}

#endif