Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bitonic Sort Update #309

Merged
merged 11 commits into from
Oct 31, 2023
7 changes: 7 additions & 0 deletions src/sample/bitonicSort/atomicToZero.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
@group(0) @binding(3) var<storage, read_write> counter: atomic<u32>;

@compute @workgroup_size(1, 1, 1)
fn atomicToZero() {
let counterValue = atomicLoad(&counter);
atomicSub(&counter, counterValue);
}
26 changes: 23 additions & 3 deletions src/sample/bitonicSort/bitonicDisplay.frag.wgsl
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
struct Uniforms {
struct ComputeUniforms {
width: f32,
height: f32,
algo: u32,
blockHeight: u32,
}

struct FragmentUniforms {
// boolean, either 0 or 1
highlight: u32,
}

struct VertexOutput {
@builtin(position) Position: vec4<f32>,
@location(0) fragUV: vec2<f32>
}

@group(0) @binding(0) var<uniform> uniforms: Uniforms;
@group(1) @binding(0) var<storage, read> data: array<u32>;
// Uniforms from compute shader
@group(0) @binding(0) var<storage, read> data: array<u32>;
@group(0) @binding(2) var<uniform> uniforms: ComputeUniforms;
// Fragment shader uniforms
@group(1) @binding(0) var<uniform> fragment_uniforms: FragmentUniforms;

@fragment
fn frag_main(input: VertexOutput) -> @location(0) vec4<f32> {
Expand All @@ -28,6 +38,16 @@ fn frag_main(input: VertexOutput) -> @location(0) vec4<f32> {

var subtracter = f32(colorChanger) / (uniforms.width * uniforms.height);

if (fragment_uniforms.highlight == 1) {
return select(
//If element is above halfHeight, highlight green
vec4<f32>(vec3<f32>(0.0, 1.0 - subtracter, 0.0).rgb, 1.0),
//If element is below halfheight, highlight red
vec4<f32>(vec3<f32>(1.0 - subtracter, 0.0, 0.0).rgb, 1.0),
elementIndex % uniforms.blockHeight < uniforms.blockHeight / 2
);
}

var color: vec3<f32> = vec3f(
1.0 - subtracter
);
Expand Down
43 changes: 15 additions & 28 deletions src/sample/bitonicSort/bitonicDisplay.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import {
BindGroupsObjectsAndLayout,
createBindGroupDescriptor,
BindGroupCluster,
Base2DRendererClass,
createBindGroupCluster,
} from './utils';

import bitonicDisplay from './bitonicDisplay.frag.wgsl';

interface BitonicDisplayRenderArgs {
width: number;
height: number;
highlight: number;
}

export default class BitonicDisplayRenderer extends Base2DRendererClass {
Expand All @@ -19,26 +18,25 @@ export default class BitonicDisplayRenderer extends Base2DRendererClass {

switchBindGroup: (name: string) => void;
setArguments: (args: BitonicDisplayRenderArgs) => void;
computeBGDescript: BindGroupsObjectsAndLayout;
computeBGDescript: BindGroupCluster;

constructor(
device: GPUDevice,
presentationFormat: GPUTextureFormat,
renderPassDescriptor: GPURenderPassDescriptor,
bindGroupNames: string[],
computeBGDescript: BindGroupsObjectsAndLayout,
computeBGDescript: BindGroupCluster,
label: string
) {
super();
this.renderPassDescriptor = renderPassDescriptor;
this.computeBGDescript = computeBGDescript;

const uniformBuffer = device.createBuffer({
size: Float32Array.BYTES_PER_ELEMENT * 2,
size: Uint32Array.BYTES_PER_ELEMENT,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});

const bgDescript = createBindGroupDescriptor(
const bgCluster = createBindGroupCluster(
[0],
[GPUShaderStage.FRAGMENT],
['buffer'],
Expand All @@ -48,41 +46,30 @@ export default class BitonicDisplayRenderer extends Base2DRendererClass {
device
);

this.currentBindGroup = bgDescript.bindGroups[0];
this.currentBindGroupName = bindGroupNames[0];

this.bindGroupMap = {};

bgDescript.bindGroups.forEach((bg, idx) => {
this.bindGroupMap[bindGroupNames[idx]] = bg;
});
this.currentBindGroup = bgCluster.bindGroups[0];

this.pipeline = super.create2DRenderPipeline(
device,
label,
[bgDescript.bindGroupLayout, this.computeBGDescript.bindGroupLayout],
[this.computeBGDescript.bindGroupLayout, bgCluster.bindGroupLayout],
bitonicDisplay,
presentationFormat
);

this.switchBindGroup = (name: string) => {
this.currentBindGroup = this.bindGroupMap[name];
this.currentBindGroupName = name;
};

this.setArguments = (args: BitonicDisplayRenderArgs) => {
super.setUniformArguments(device, uniformBuffer, args, [
'width',
'height',
]);
device.queue.writeBuffer(
uniformBuffer,
0,
new Uint32Array([args.highlight])
);
};
}

startRun(commandEncoder: GPUCommandEncoder, args: BitonicDisplayRenderArgs) {
this.setArguments(args);
super.executeRun(commandEncoder, this.renderPassDescriptor, this.pipeline, [
this.currentBindGroup,
this.computeBGDescript.bindGroups[0],
this.currentBindGroup,
]);
}
}
93 changes: 64 additions & 29 deletions src/sample/bitonicSort/computeShader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@ struct Uniforms {
}

// Create local workgroup data that can contain all elements

var<workgroup> local_data: array<u32, ${threadsPerWorkgroup * 2}>;

//Compare and swap values in local_data
fn compare_and_swap(idx_before: u32, idx_after: u32) {
// Define groups (functions refer to this data)
@group(0) @binding(0) var<storage, read> input_data: array<u32>;
@group(0) @binding(1) var<storage, read_write> output_data: array<u32>;
@group(0) @binding(2) var<uniform> uniforms: Uniforms;
@group(0) @binding(3) var<storage, read_write> counter: atomic<u32>;

// Compare and swap values in local_data
fn local_compare_and_swap(idx_before: u32, idx_after: u32) {
//idx_before should always be < idx_after
if (local_data[idx_after] < local_data[idx_before]) {
atomicAdd(&counter, 1);
var temp: u32 = local_data[idx_before];
local_data[idx_before] = local_data[idx_after];
local_data[idx_after] = temp;
Expand All @@ -30,65 +36,94 @@ fn compare_and_swap(idx_before: u32, idx_after: u32) {
}

// thread_id goes from 0 to threadsPerWorkgroup
fn prepare_flip(thread_id: u32, block_height: u32) {
let q: u32 = ((2 * thread_id) / block_height) * block_height;
fn get_flip_indices(thread_id: u32, block_height: u32) -> vec2<u32> {
// Caculate index offset (i.e move indices into correct block)
let block_offset: u32 = ((2 * thread_id) / block_height) * block_height;
let half_height = block_height / 2;
// Calculate index spacing
var idx: vec2<u32> = vec2<u32>(
thread_id % half_height, block_height - (thread_id % half_height) - 1,
);
idx.x += q;
idx.y += q;
compare_and_swap(idx.x, idx.y);
idx.x += block_offset;
idx.y += block_offset;
return idx;
}

fn prepare_disperse(thread_id: u32, block_height: u32) {
var q: u32 = ((2 * thread_id) / block_height) * block_height;
fn get_disperse_indices(thread_id: u32, block_height: u32) -> vec2<u32> {
var block_offset: u32 = ((2 * thread_id) / block_height) * block_height;
let half_height = block_height / 2;
var idx: vec2<u32> = vec2<u32>(
thread_id % half_height, (thread_id % half_height) + half_height
);
idx.x += q;
idx.y += q;
compare_and_swap(idx.x, idx.y);
idx.x += block_offset;
idx.y += block_offset;
return idx;
}

@group(0) @binding(0) var<storage, read> input_data: array<u32>;
@group(0) @binding(1) var<storage, read_write> output_data: array<u32>;
@group(0) @binding(2) var<uniform> uniforms: Uniforms;
fn global_compare_and_swap(idx_before: u32, idx_after: u32) {
if (input_data[idx_after] < input_data[idx_before]) {
output_data[idx_before] = input_data[idx_after];
output_data[idx_after] = input_data[idx_before];
}
}

// Constants/enum
const ALGO_NONE = 0;
const ALGO_LOCAL_FLIP = 1;
const ALGO_LOCAL_DISPERSE = 2;
const ALGO_GLOBAL_FLIP = 3;

// Our compute shader will execute specified # of threads or elements / 2 threads
@compute @workgroup_size(${threadsPerWorkgroup}, 1, 1)
fn computeMain(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
//Each thread will populate the workgroup data... (1 thread for every 2 elements)
local_data[local_id.x * 2] = input_data[local_id.x * 2];
local_data[local_id.x * 2 + 1] = input_data[local_id.x * 2 + 1];

let offset = ${threadsPerWorkgroup} * 2 * workgroup_id.x;
// If we will perform a local swap, then populate the local data
if (uniforms.algo <= 2) {
// Assign range of input_data to local_data.
// Range cannot exceed maxWorkgroupsX * 2
// Each thread will populate the workgroup data... (1 thread for every 2 elements)
local_data[local_id.x * 2] = input_data[offset + local_id.x * 2];
local_data[local_id.x * 2 + 1] = input_data[offset + local_id.x * 2 + 1];
}

//...and wait for each other to finish their own bit of data population.
workgroupBarrier();

var num_elements = uniforms.width * uniforms.height;

switch uniforms.algo {
case 1: { //Local Flip
prepare_flip(local_id.x, uniforms.blockHeight);
case 1: { // Local Flip
let idx = get_flip_indices(local_id.x, uniforms.blockHeight);
local_compare_and_swap(idx.x, idx.y);
}
case 2: { // Local Disperse
let idx = get_disperse_indices(local_id.x, uniforms.blockHeight);
local_compare_and_swap(idx.x, idx.y);
}
case 3: { // Global Flip
let idx = get_flip_indices(global_id.x, uniforms.blockHeight);
global_compare_and_swap(idx.x, idx.y);
}
case 2: { //Local Disperse
prepare_disperse(local_id.x, uniforms.blockHeight);
case 4: {
let idx = get_disperse_indices(global_id.x, uniforms.blockHeight);
global_compare_and_swap(idx.x, idx.y);
}
default: {

}
}

//Ensure that all threads have swapped their own regions of data
// Ensure that all threads have swapped their own regions of data
workgroupBarrier();

//Repopulate global data with local data
output_data[local_id.x * 2] = local_data[local_id.x * 2];
output_data[local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
if (uniforms.algo <= ALGO_LOCAL_DISPERSE) {
//Repopulate global data with local data
output_data[offset + local_id.x * 2] = local_data[local_id.x * 2];
output_data[offset + local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
}

}`;
};
Loading
Loading