Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bitonic Sort Update #309

Merged
merged 11 commits into from
Oct 31, 2023
16 changes: 8 additions & 8 deletions src/sample/bitonicSort/bitonicDisplay.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import {
BindGroupsObjectsAndLayout,
createBindGroupDescriptor,
BindGroupCluster,
createBindGroupCluster,
Base2DRendererClass,
} from './utils';

Expand All @@ -19,14 +19,14 @@ export default class BitonicDisplayRenderer extends Base2DRendererClass {

switchBindGroup: (name: string) => void;
setArguments: (args: BitonicDisplayRenderArgs) => void;
computeBGDescript: BindGroupsObjectsAndLayout;
computeBGDescript: BindGroupCluster;

constructor(
device: GPUDevice,
presentationFormat: GPUTextureFormat,
renderPassDescriptor: GPURenderPassDescriptor,
bindGroupNames: string[],
computeBGDescript: BindGroupsObjectsAndLayout,
computeBGDescript: BindGroupCluster,
label: string
) {
super();
Expand All @@ -38,7 +38,7 @@ export default class BitonicDisplayRenderer extends Base2DRendererClass {
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});

const bgDescript = createBindGroupDescriptor(
const bgCluster = createBindGroupCluster(
[0],
[GPUShaderStage.FRAGMENT],
['buffer'],
Expand All @@ -48,19 +48,19 @@ export default class BitonicDisplayRenderer extends Base2DRendererClass {
device
);

this.currentBindGroup = bgDescript.bindGroups[0];
this.currentBindGroup = bgCluster.bindGroups[0];
this.currentBindGroupName = bindGroupNames[0];

this.bindGroupMap = {};

bgDescript.bindGroups.forEach((bg, idx) => {
bgCluster.bindGroups.forEach((bg, idx) => {
this.bindGroupMap[bindGroupNames[idx]] = bg;
});

this.pipeline = super.create2DRenderPipeline(
device,
label,
[bgDescript.bindGroupLayout, this.computeBGDescript.bindGroupLayout],
[bgCluster.bindGroupLayout, this.computeBGDescript.bindGroupLayout],
bitonicDisplay,
presentationFormat
);
Expand Down
36 changes: 25 additions & 11 deletions src/sample/bitonicSort/computeShader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ struct Uniforms {

var<workgroup> local_data: array<u32, ${threadsPerWorkgroup * 2}>;

//Compare and swap values in local_data
fn compare_and_swap(idx_before: u32, idx_after: u32) {
// Compare and swap values in local_data
fn local_compare_and_swap(idx_before: u32, idx_after: u32) {
//idx_before should always be < idx_after
if (local_data[idx_after] < local_data[idx_before]) {
var temp: u32 = local_data[idx_before];
Expand All @@ -30,37 +30,45 @@ fn compare_and_swap(idx_before: u32, idx_after: u32) {
}

// thread_id goes from 0 to threadsPerWorkgroup
fn prepare_flip(thread_id: u32, block_height: u32) {
fn get_flip_indices(thread_id: u32, block_height: u32) -> vec2<u32> {
let q: u32 = ((2 * thread_id) / block_height) * block_height;
let half_height = block_height / 2;
var idx: vec2<u32> = vec2<u32>(
thread_id % half_height, block_height - (thread_id % half_height) - 1,
);
idx.x += q;
idx.y += q;
compare_and_swap(idx.x, idx.y);
return idx;
}

fn prepare_disperse(thread_id: u32, block_height: u32) {
fn get_disperse_indices(thread_id: u32, block_height: u32) -> vec2<u32> {
var q: u32 = ((2 * thread_id) / block_height) * block_height;
let half_height = block_height / 2;
var idx: vec2<u32> = vec2<u32>(
thread_id % half_height, (thread_id % half_height) + half_height
);
idx.x += q;
idx.y += q;
compare_and_swap(idx.x, idx.y);
return idx;
}

@group(0) @binding(0) var<storage, read> input_data: array<u32>;
@group(0) @binding(1) var<storage, read_write> output_data: array<u32>;
@group(0) @binding(2) var<uniform> uniforms: Uniforms;

fn global_compare_and_swap(idx_before: u32, idx_after: u32) {
if (input_data[idx_after] < input_data[idx_before]) {
output_data[idx_before] = input_data[idx_after];
output_data[idx_after] = input_data[idx_before];
}
}

// Our compute shader will execute specified # of threads or elements / 2 threads
@compute @workgroup_size(${threadsPerWorkgroup}, 1, 1)
fn computeMain(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup: vec3<u32>,
) {
//Each thread will populate the workgroup data... (1 thread for every 2 elements)
local_data[local_id.x * 2] = input_data[local_id.x * 2];
Expand All @@ -72,18 +80,24 @@ fn computeMain(
var num_elements = uniforms.width * uniforms.height;

switch uniforms.algo {
case 1: { //Local Flip
prepare_flip(local_id.x, uniforms.blockHeight);
case 1: { // Local Flip
let idx = get_flip_indices(local_id.x, uniforms.blockHeight);
local_compare_and_swap(idx.x, idx.y);
}
case 2: { // Local Disperse
let idx = get_disperse_indices(local_id.x, uniforms.blockHeight);
local_compare_and_swap(idx.x, idx.y);
}
case 2: { //Local Disperse
prepare_disperse(local_id.x, uniforms.blockHeight);
case 4: { // Global Flip
let idx = get_flip_indices(local_id.x, uniforms.blockHeight);
global_compare_and_swap(idx.x, idx.y);
}
default: {

}
}

//Ensure that all threads have swapped their own regions of data
// Ensure that all threads have swapped their own regions of data
workgroupBarrier();

//Repopulate global data with local data
Expand Down
Loading
Loading