Skip to content

Commit

Permalink
Bitonic Sort Update (#309)
Browse files Browse the repository at this point in the history
  • Loading branch information
cmhhelgeson authored Oct 31, 2023
1 parent 3ff5cfa commit b77d32c
Show file tree
Hide file tree
Showing 6 changed files with 378 additions and 171 deletions.
7 changes: 7 additions & 0 deletions src/sample/bitonicSort/atomicToZero.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
@group(0) @binding(3) var<storage, read_write> counter: atomic<u32>;

@compute @workgroup_size(1, 1, 1)
fn atomicToZero() {
let counterValue = atomicLoad(&counter);
atomicSub(&counter, counterValue);
}
26 changes: 23 additions & 3 deletions src/sample/bitonicSort/bitonicDisplay.frag.wgsl
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
struct Uniforms {
struct ComputeUniforms {
width: f32,
height: f32,
algo: u32,
blockHeight: u32,
}

struct FragmentUniforms {
// boolean, either 0 or 1
highlight: u32,
}

struct VertexOutput {
@builtin(position) Position: vec4<f32>,
@location(0) fragUV: vec2<f32>
}

@group(0) @binding(0) var<uniform> uniforms: Uniforms;
@group(1) @binding(0) var<storage, read> data: array<u32>;
// Uniforms from compute shader
@group(0) @binding(0) var<storage, read> data: array<u32>;
@group(0) @binding(2) var<uniform> uniforms: ComputeUniforms;
// Fragment shader uniforms
@group(1) @binding(0) var<uniform> fragment_uniforms: FragmentUniforms;

@fragment
fn frag_main(input: VertexOutput) -> @location(0) vec4<f32> {
Expand All @@ -28,6 +38,16 @@ fn frag_main(input: VertexOutput) -> @location(0) vec4<f32> {

var subtracter = f32(colorChanger) / (uniforms.width * uniforms.height);

if (fragment_uniforms.highlight == 1) {
return select(
//If element is above halfHeight, highlight green
vec4<f32>(vec3<f32>(0.0, 1.0 - subtracter, 0.0).rgb, 1.0),
//If element is below halfheight, highlight red
vec4<f32>(vec3<f32>(1.0 - subtracter, 0.0, 0.0).rgb, 1.0),
elementIndex % uniforms.blockHeight < uniforms.blockHeight / 2
);
}

var color: vec3<f32> = vec3f(
1.0 - subtracter
);
Expand Down
43 changes: 15 additions & 28 deletions src/sample/bitonicSort/bitonicDisplay.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import {
BindGroupsObjectsAndLayout,
createBindGroupDescriptor,
BindGroupCluster,
Base2DRendererClass,
createBindGroupCluster,
} from './utils';

import bitonicDisplay from './bitonicDisplay.frag.wgsl';

interface BitonicDisplayRenderArgs {
width: number;
height: number;
highlight: number;
}

export default class BitonicDisplayRenderer extends Base2DRendererClass {
Expand All @@ -19,26 +18,25 @@ export default class BitonicDisplayRenderer extends Base2DRendererClass {

switchBindGroup: (name: string) => void;
setArguments: (args: BitonicDisplayRenderArgs) => void;
computeBGDescript: BindGroupsObjectsAndLayout;
computeBGDescript: BindGroupCluster;

constructor(
device: GPUDevice,
presentationFormat: GPUTextureFormat,
renderPassDescriptor: GPURenderPassDescriptor,
bindGroupNames: string[],
computeBGDescript: BindGroupsObjectsAndLayout,
computeBGDescript: BindGroupCluster,
label: string
) {
super();
this.renderPassDescriptor = renderPassDescriptor;
this.computeBGDescript = computeBGDescript;

const uniformBuffer = device.createBuffer({
size: Float32Array.BYTES_PER_ELEMENT * 2,
size: Uint32Array.BYTES_PER_ELEMENT,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});

const bgDescript = createBindGroupDescriptor(
const bgCluster = createBindGroupCluster(
[0],
[GPUShaderStage.FRAGMENT],
['buffer'],
Expand All @@ -48,41 +46,30 @@ export default class BitonicDisplayRenderer extends Base2DRendererClass {
device
);

this.currentBindGroup = bgDescript.bindGroups[0];
this.currentBindGroupName = bindGroupNames[0];

this.bindGroupMap = {};

bgDescript.bindGroups.forEach((bg, idx) => {
this.bindGroupMap[bindGroupNames[idx]] = bg;
});
this.currentBindGroup = bgCluster.bindGroups[0];

this.pipeline = super.create2DRenderPipeline(
device,
label,
[bgDescript.bindGroupLayout, this.computeBGDescript.bindGroupLayout],
[this.computeBGDescript.bindGroupLayout, bgCluster.bindGroupLayout],
bitonicDisplay,
presentationFormat
);

this.switchBindGroup = (name: string) => {
this.currentBindGroup = this.bindGroupMap[name];
this.currentBindGroupName = name;
};

this.setArguments = (args: BitonicDisplayRenderArgs) => {
super.setUniformArguments(device, uniformBuffer, args, [
'width',
'height',
]);
device.queue.writeBuffer(
uniformBuffer,
0,
new Uint32Array([args.highlight])
);
};
}

startRun(commandEncoder: GPUCommandEncoder, args: BitonicDisplayRenderArgs) {
this.setArguments(args);
super.executeRun(commandEncoder, this.renderPassDescriptor, this.pipeline, [
this.currentBindGroup,
this.computeBGDescript.bindGroups[0],
this.currentBindGroup,
]);
}
}
93 changes: 64 additions & 29 deletions src/sample/bitonicSort/computeShader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@ struct Uniforms {
}
// Create local workgroup data that can contain all elements
var<workgroup> local_data: array<u32, ${threadsPerWorkgroup * 2}>;
//Compare and swap values in local_data
fn compare_and_swap(idx_before: u32, idx_after: u32) {
// Define groups (functions refer to this data)
@group(0) @binding(0) var<storage, read> input_data: array<u32>;
@group(0) @binding(1) var<storage, read_write> output_data: array<u32>;
@group(0) @binding(2) var<uniform> uniforms: Uniforms;
@group(0) @binding(3) var<storage, read_write> counter: atomic<u32>;
// Compare and swap values in local_data
fn local_compare_and_swap(idx_before: u32, idx_after: u32) {
//idx_before should always be < idx_after
if (local_data[idx_after] < local_data[idx_before]) {
atomicAdd(&counter, 1);
var temp: u32 = local_data[idx_before];
local_data[idx_before] = local_data[idx_after];
local_data[idx_after] = temp;
Expand All @@ -30,65 +36,94 @@ fn compare_and_swap(idx_before: u32, idx_after: u32) {
}
// thread_id goes from 0 to threadsPerWorkgroup
fn prepare_flip(thread_id: u32, block_height: u32) {
let q: u32 = ((2 * thread_id) / block_height) * block_height;
fn get_flip_indices(thread_id: u32, block_height: u32) -> vec2<u32> {
// Caculate index offset (i.e move indices into correct block)
let block_offset: u32 = ((2 * thread_id) / block_height) * block_height;
let half_height = block_height / 2;
// Calculate index spacing
var idx: vec2<u32> = vec2<u32>(
thread_id % half_height, block_height - (thread_id % half_height) - 1,
);
idx.x += q;
idx.y += q;
compare_and_swap(idx.x, idx.y);
idx.x += block_offset;
idx.y += block_offset;
return idx;
}
fn prepare_disperse(thread_id: u32, block_height: u32) {
var q: u32 = ((2 * thread_id) / block_height) * block_height;
fn get_disperse_indices(thread_id: u32, block_height: u32) -> vec2<u32> {
var block_offset: u32 = ((2 * thread_id) / block_height) * block_height;
let half_height = block_height / 2;
var idx: vec2<u32> = vec2<u32>(
thread_id % half_height, (thread_id % half_height) + half_height
);
idx.x += q;
idx.y += q;
compare_and_swap(idx.x, idx.y);
idx.x += block_offset;
idx.y += block_offset;
return idx;
}
@group(0) @binding(0) var<storage, read> input_data: array<u32>;
@group(0) @binding(1) var<storage, read_write> output_data: array<u32>;
@group(0) @binding(2) var<uniform> uniforms: Uniforms;
fn global_compare_and_swap(idx_before: u32, idx_after: u32) {
if (input_data[idx_after] < input_data[idx_before]) {
output_data[idx_before] = input_data[idx_after];
output_data[idx_after] = input_data[idx_before];
}
}
// Constants/enum
const ALGO_NONE = 0;
const ALGO_LOCAL_FLIP = 1;
const ALGO_LOCAL_DISPERSE = 2;
const ALGO_GLOBAL_FLIP = 3;
// Our compute shader will execute specified # of threads or elements / 2 threads
@compute @workgroup_size(${threadsPerWorkgroup}, 1, 1)
fn computeMain(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
//Each thread will populate the workgroup data... (1 thread for every 2 elements)
local_data[local_id.x * 2] = input_data[local_id.x * 2];
local_data[local_id.x * 2 + 1] = input_data[local_id.x * 2 + 1];
let offset = ${threadsPerWorkgroup} * 2 * workgroup_id.x;
// If we will perform a local swap, then populate the local data
if (uniforms.algo <= 2) {
// Assign range of input_data to local_data.
// Range cannot exceed maxWorkgroupsX * 2
// Each thread will populate the workgroup data... (1 thread for every 2 elements)
local_data[local_id.x * 2] = input_data[offset + local_id.x * 2];
local_data[local_id.x * 2 + 1] = input_data[offset + local_id.x * 2 + 1];
}
//...and wait for each other to finish their own bit of data population.
workgroupBarrier();
var num_elements = uniforms.width * uniforms.height;
switch uniforms.algo {
case 1: { //Local Flip
prepare_flip(local_id.x, uniforms.blockHeight);
case 1: { // Local Flip
let idx = get_flip_indices(local_id.x, uniforms.blockHeight);
local_compare_and_swap(idx.x, idx.y);
}
case 2: { // Local Disperse
let idx = get_disperse_indices(local_id.x, uniforms.blockHeight);
local_compare_and_swap(idx.x, idx.y);
}
case 3: { // Global Flip
let idx = get_flip_indices(global_id.x, uniforms.blockHeight);
global_compare_and_swap(idx.x, idx.y);
}
case 2: { //Local Disperse
prepare_disperse(local_id.x, uniforms.blockHeight);
case 4: {
let idx = get_disperse_indices(global_id.x, uniforms.blockHeight);
global_compare_and_swap(idx.x, idx.y);
}
default: {
}
}
//Ensure that all threads have swapped their own regions of data
// Ensure that all threads have swapped their own regions of data
workgroupBarrier();
//Repopulate global data with local data
output_data[local_id.x * 2] = local_data[local_id.x * 2];
output_data[local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
if (uniforms.algo <= ALGO_LOCAL_DISPERSE) {
//Repopulate global data with local data
output_data[offset + local_id.x * 2] = local_data[local_id.x * 2];
output_data[offset + local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
}
}`;
};
Loading

0 comments on commit b77d32c

Please sign in to comment.