Skip to content

Commit

Permalink
Removed all references to threads in bitonicCompute.ts
Browse files Browse the repository at this point in the history
  • Loading branch information
cmhhelgeson committed Nov 30, 2023
1 parent 63d98ee commit 32db6c9
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions src/sample/bitonicSort/bitonicCompute.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
export const computeArgKeys = ['width', 'height', 'algo', 'blockHeight'];

export const NaiveBitonicCompute = (threadsPerWorkgroup: number) => {
if (threadsPerWorkgroup % 2 !== 0 || threadsPerWorkgroup > 256) {
threadsPerWorkgroup = 256;
export const NaiveBitonicCompute = (invocationsPerWorkgroup: number) => {
if (invocationsPerWorkgroup % 2 !== 0 || invocationsPerWorkgroup > 256) {
invocationsPerWorkgroup = 256;
}
// Ensure that workgroupSize is half the number of elements
return `
Expand All @@ -15,7 +15,7 @@ struct Uniforms {
}
// Create local workgroup data that can contain all elements
var<workgroup> local_data: array<u32, ${threadsPerWorkgroup * 2}>;
var<workgroup> local_data: array<u32, ${invocationsPerWorkgroup * 2}>;
// Define groups (functions refer to this data)
@group(0) @binding(0) var<storage, read> input_data: array<u32>;
Expand All @@ -35,25 +35,25 @@ fn local_compare_and_swap(idx_before: u32, idx_after: u32) {
return;
}
// thread_id goes from 0 to threadsPerWorkgroup
fn get_flip_indices(thread_id: u32, block_height: u32) -> vec2<u32> {
// invoke_id goes from 0 to invocationsPerWorkgroup
fn get_flip_indices(invoke_id: u32, block_height: u32) -> vec2<u32> {
// Caculate index offset (i.e move indices into correct block)
let block_offset: u32 = ((2 * thread_id) / block_height) * block_height;
let block_offset: u32 = ((2 * invoke_id) / block_height) * block_height;
let half_height = block_height / 2;
// Calculate index spacing
var idx: vec2<u32> = vec2<u32>(
thread_id % half_height, block_height - (thread_id % half_height) - 1,
invoke_id % half_height, block_height - (invoke_id % half_height) - 1,
);
idx.x += block_offset;
idx.y += block_offset;
return idx;
}
fn get_disperse_indices(thread_id: u32, block_height: u32) -> vec2<u32> {
var block_offset: u32 = ((2 * thread_id) / block_height) * block_height;
fn get_disperse_indices(invoke_id: u32, block_height: u32) -> vec2<u32> {
var block_offset: u32 = ((2 * invoke_id) / block_height) * block_height;
let half_height = block_height / 2;
var idx: vec2<u32> = vec2<u32>(
thread_id % half_height, (thread_id % half_height) + half_height
invoke_id % half_height, (invoke_id % half_height) + half_height
);
idx.x += block_offset;
idx.y += block_offset;
Expand All @@ -73,20 +73,20 @@ const ALGO_LOCAL_FLIP = 1;
const ALGO_LOCAL_DISPERSE = 2;
const ALGO_GLOBAL_FLIP = 3;
// Our compute shader will execute specified # of threads or elements / 2 threads
@compute @workgroup_size(${threadsPerWorkgroup}, 1, 1)
// Our compute shader will execute specified # of invocations or elements / 2 invocations
@compute @workgroup_size(${invocationsPerWorkgroup}, 1, 1)
fn computeMain(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let offset = ${threadsPerWorkgroup} * 2 * workgroup_id.x;
let offset = ${invocationsPerWorkgroup} * 2 * workgroup_id.x;
// If we will perform a local swap, then populate the local data
if (uniforms.algo <= 2) {
// Assign range of input_data to local_data.
// Range cannot exceed maxWorkgroupsX * 2
// Each thread will populate the workgroup data... (1 thread for every 2 elements)
// Each invocation will populate the workgroup data... (1 invocation for every 2 elements)
local_data[local_id.x * 2] = input_data[offset + local_id.x * 2];
local_data[local_id.x * 2 + 1] = input_data[offset + local_id.x * 2 + 1];
}
Expand Down Expand Up @@ -116,7 +116,7 @@ fn computeMain(
}
}
// Ensure that all threads have swapped their own regions of data
// Ensure that all invocations have swapped their own regions of data
workgroupBarrier();
if (uniforms.algo <= ALGO_LOCAL_DISPERSE) {
Expand Down

0 comments on commit 32db6c9

Please sign in to comment.