Skip to content

Commit

Permalink
Safety commit
Browse files Browse the repository at this point in the history
  • Loading branch information
cmhhelgeson committed Oct 30, 2023
1 parent 4b65446 commit ebf1684
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 40 deletions.
3 changes: 2 additions & 1 deletion src/sample/bitonicSort/bitonicDisplay.frag.wgsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
struct Uniforms {
width: f32,
height: f32,
algo: u32,
blockHeight: u32,
}

struct VertexOutput {
Expand All @@ -9,7 +11,6 @@ struct VertexOutput {
}

@group(0) @binding(0) var<uniform> uniforms: Uniforms;
@group(1) @binding(0) var<storage, read> data: array<u32>;

@fragment
fn frag_main(input: VertexOutput) -> @location(0) vec4<f32> {
Expand Down
56 changes: 35 additions & 21 deletions src/sample/bitonicSort/computeShader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,67 +31,79 @@ fn local_compare_and_swap(idx_before: u32, idx_after: u32) {
// thread_id goes from 0 to threadsPerWorkgroup
fn get_flip_indices(thread_id: u32, block_height: u32) -> vec2<u32> {
let q: u32 = ((2 * thread_id) / block_height) * block_height;
// Caculate index offset (i.e move indices into correct block)
let block_offset: u32 = ((2 * thread_id) / block_height) * block_height;
let half_height = block_height / 2;
// Calculate index spacing
var idx: vec2<u32> = vec2<u32>(
thread_id % half_height, block_height - (thread_id % half_height) - 1,
);
idx.x += q;
idx.y += q;
idx.x += block_offset;
idx.y += block_offset;
return idx;
}
fn get_disperse_indices(thread_id: u32, block_height: u32) -> vec2<u32> {
var q: u32 = ((2 * thread_id) / block_height) * block_height;
var block_offset: u32 = ((2 * thread_id) / block_height) * block_height;
let half_height = block_height / 2;
var idx: vec2<u32> = vec2<u32>(
thread_id % half_height, (thread_id % half_height) + half_height
);
idx.x += q;
idx.y += q;
idx.x += block_offset;
idx.y += block_offset;
return idx;
}
@group(0) @binding(0) var<storage, read> input_data: array<u32>;
@group(0) @binding(1) var<storage, read_write> output_data: array<u32>;
@group(0) @binding(2) var<uniform> uniforms: Uniforms;
fn global_compare_and_swap(idx_before: u32, idx_after: u32) {
if (input_data[idx_after] < input_data[idx_before]) {
output_data[idx_before] = input_data[idx_after];
output_data[idx_after] = input_data[idx_before];
}
}
// Constants/enum
const ALGO_NONE = 0;
const ALGO_LOCAL_FLIP = 1;
const ALGO_LOCAL_DISPERSE = 2;
const ALGO_GLOBAL_FLIP = 3;
// Our compute shader will execute specified # of threads or elements / 2 threads
@compute @workgroup_size(${threadsPerWorkgroup}, 1, 1)
fn computeMain(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
//Each thread will populate the workgroup data... (1 thread for every 2 elements)
local_data[local_id.x * 2] = input_data[local_id.x * 2];
local_data[local_id.x * 2 + 1] = input_data[local_id.x * 2 + 1];
// If we will perform a local swap, then populate the local data
if (uniforms.algo <= 2) {
//Each thread will populate the workgroup data... (1 thread for every 2 elements)
local_data[global_id.x * 2] = input_data[global_id.x * 2];
local_data[global_id.x * 2 + 1] = input_data[global_id.x * 2 + 1];
}
//...and wait for each other to finish their own bit of data population.
workgroupBarrier();
var num_elements = uniforms.width * uniforms.height;
switch uniforms.algo {
case 1: { // Local Flip
let idx = get_flip_indices(local_id.x, uniforms.blockHeight);
let idx = get_flip_indices(global_id.x, uniforms.blockHeight);
local_compare_and_swap(idx.x, idx.y);
}
}
case 2: { // Local Disperse
let idx = get_disperse_indices(local_id.x, uniforms.blockHeight);
let idx = get_disperse_indices(global_id.x, uniforms.blockHeight);
local_compare_and_swap(idx.x, idx.y);
}
case 4: { // Global Flip
let idx = get_flip_indices(local_id.x, uniforms.blockHeight);
}
case 3: { // Global Flip
let idx = get_flip_indices(global_id.x, uniforms.blockHeight);
global_compare_and_swap(idx.x, idx.y);
}
// case 4: { //Global Disperse
default: {
}
Expand All @@ -100,9 +112,11 @@ fn computeMain(
// Ensure that all threads have swapped their own regions of data
workgroupBarrier();
//Repopulate global data with local data
output_data[local_id.x * 2] = local_data[local_id.x * 2];
output_data[local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
if (uniforms.algo <= ALGO_LOCAL_DISPERSE) {
//Repopulate global data with local data
output_data[local_id.x * 2] = local_data[local_id.x * 2];
output_data[local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
}
}`;
};
31 changes: 13 additions & 18 deletions src/sample/bitonicSort/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,11 @@ enum StepEnum {
NONE,
FLIP_LOCAL,
DISPERSE_LOCAL,
FLIP_DISPERSE_LOCAL,
FLIP_GLOBAL,
}

// String access to StepEnum
type StepType =
| 'NONE'
| 'FLIP_LOCAL'
| 'DISPERSE_LOCAL'
| 'FLIP_DISPERSE_LOCAL'
| 'FLIP_GLOBAL';
type StepType = 'NONE' | 'FLIP_LOCAL' | 'DISPERSE_LOCAL' | 'FLIP_GLOBAL';

// Gui settings object
interface SettingsInterface {
Expand Down Expand Up @@ -94,7 +88,7 @@ SampleInitFactoryWebGPU(
// Max thread span of next block
'Next Swap Span': 2,
// Workgroups to dispatch per frame,
'Total Workgroups': 1,
'Total Workgroups': maxElements / (maxWorkgroupsX * 2),
// Whether we will dispatch a workload this frame
executeStep: false,
'Randomize Values': () => {
Expand Down Expand Up @@ -275,6 +269,7 @@ SampleInitFactoryWebGPU(
let swappedIndex: number;
switch (settings['Next Step']) {
case 'FLIP_LOCAL':
case 'FLIP_GLOBAL':
{
const blockHeight = settings['Next Swap Span'];
const p2 = Math.floor(settings['Hovered Cell'] / blockHeight) + 1;
Expand Down Expand Up @@ -497,16 +492,16 @@ SampleInitFactoryWebGPU(
nextBlockHeightController.setValue(settings['Next Swap Span'] / 2);
if (settings['Next Swap Span'] === 1) {
highestBlockHeight *= 2;
nextStepController.setValue(
highestBlockHeight === settings['Total Elements'] * 2
? 'NONE'
: 'FLIP_LOCAL'
);
nextBlockHeightController.setValue(
highestBlockHeight === settings['Total Elements'] * 2
? 0
: highestBlockHeight
);
if (highestBlockHeight === settings['Total Elements'] * 2) {
nextStepController.setValue('NONE');
nextBlockHeightController.setValue(0);
} else if (highestBlockHeight > settings['Total Threads'] * 2) {
nextStepController.setValue('FLIP_GLOBAL');
nextBlockHeightController.setValue(highestBlockHeight);
} else {
nextStepController.setValue('FLIP_LOCAL');
nextBlockHeightController.setValue(highestBlockHeight);
}
} else {
nextStepController.setValue('DISPERSE_LOCAL');
}
Expand Down

0 comments on commit ebf1684

Please sign in to comment.