diff --git a/src/sample/bitonicSort/bitonicDisplay.frag.wgsl b/src/sample/bitonicSort/bitonicDisplay.frag.wgsl index 3f4a17ea..a963e97c 100644 --- a/src/sample/bitonicSort/bitonicDisplay.frag.wgsl +++ b/src/sample/bitonicSort/bitonicDisplay.frag.wgsl @@ -1,6 +1,8 @@ struct Uniforms { width: f32, height: f32, + algo: u32, + blockHeight: u32, } struct VertexOutput { @@ -9,7 +11,6 @@ struct VertexOutput { } @group(0) @binding(0) var uniforms: Uniforms; -@group(1) @binding(0) var data: array; @fragment fn frag_main(input: VertexOutput) -> @location(0) vec4 { diff --git a/src/sample/bitonicSort/computeShader.ts b/src/sample/bitonicSort/computeShader.ts index 6b9e731d..f0f2e781 100644 --- a/src/sample/bitonicSort/computeShader.ts +++ b/src/sample/bitonicSort/computeShader.ts @@ -31,24 +31,26 @@ fn local_compare_and_swap(idx_before: u32, idx_after: u32) { // thread_id goes from 0 to threadsPerWorkgroup fn get_flip_indices(thread_id: u32, block_height: u32) -> vec2 { - let q: u32 = ((2 * thread_id) / block_height) * block_height; + // Caculate index offset (i.e move indices into correct block) + let block_offset: u32 = ((2 * thread_id) / block_height) * block_height; let half_height = block_height / 2; + // Calculate index spacing var idx: vec2 = vec2( thread_id % half_height, block_height - (thread_id % half_height) - 1, ); - idx.x += q; - idx.y += q; + idx.x += block_offset; + idx.y += block_offset; return idx; } fn get_disperse_indices(thread_id: u32, block_height: u32) -> vec2 { - var q: u32 = ((2 * thread_id) / block_height) * block_height; + var block_offset: u32 = ((2 * thread_id) / block_height) * block_height; let half_height = block_height / 2; var idx: vec2 = vec2( thread_id % half_height, (thread_id % half_height) + half_height ); - idx.x += q; - idx.y += q; + idx.x += block_offset; + idx.y += block_offset; return idx; } @@ -56,6 +58,7 @@ fn get_disperse_indices(thread_id: u32, block_height: u32) -> vec2 { @group(0) @binding(1) var output_data: array; @group(0) @binding(2) var uniforms: Uniforms; + fn global_compare_and_swap(idx_before: u32, idx_after: u32) { if (input_data[idx_after] < input_data[idx_before]) { output_data[idx_before] = input_data[idx_after]; @@ -63,35 +66,44 @@ fn global_compare_and_swap(idx_before: u32, idx_after: u32) { } } +// Constants/enum +const ALGO_NONE = 0; +const ALGO_LOCAL_FLIP = 1; +const ALGO_LOCAL_DISPERSE = 2; +const ALGO_GLOBAL_FLIP = 3; + // Our compute shader will execute specified # of threads or elements / 2 threads @compute @workgroup_size(${threadsPerWorkgroup}, 1, 1) fn computeMain( @builtin(global_invocation_id) global_id: vec3, @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) workgroup: vec3, + @builtin(workgroup_id) workgroup_id: vec3, ) { - //Each thread will populate the workgroup data... (1 thread for every 2 elements) - local_data[local_id.x * 2] = input_data[local_id.x * 2]; - local_data[local_id.x * 2 + 1] = input_data[local_id.x * 2 + 1]; + + // If we will perform a local swap, then populate the local data + if (uniforms.algo <= 2) { + //Each thread will populate the workgroup data... (1 thread for every 2 elements) + local_data[global_id.x * 2] = input_data[global_id.x * 2]; + local_data[global_id.x * 2 + 1] = input_data[global_id.x * 2 + 1]; + } //...and wait for each other to finish their own bit of data population. workgroupBarrier(); - var num_elements = uniforms.width * uniforms.height; - switch uniforms.algo { case 1: { // Local Flip - let idx = get_flip_indices(local_id.x, uniforms.blockHeight); + let idx = get_flip_indices(global_id.x, uniforms.blockHeight); local_compare_and_swap(idx.x, idx.y); - } + } case 2: { // Local Disperse - let idx = get_disperse_indices(local_id.x, uniforms.blockHeight); + let idx = get_disperse_indices(global_id.x, uniforms.blockHeight); local_compare_and_swap(idx.x, idx.y); - } - case 4: { // Global Flip - let idx = get_flip_indices(local_id.x, uniforms.blockHeight); + } + case 3: { // Global Flip + let idx = get_flip_indices(global_id.x, uniforms.blockHeight); global_compare_and_swap(idx.x, idx.y); } + // case 4: { //Global Disperse default: { } @@ -100,9 +112,11 @@ fn computeMain( // Ensure that all threads have swapped their own regions of data workgroupBarrier(); - //Repopulate global data with local data - output_data[local_id.x * 2] = local_data[local_id.x * 2]; - output_data[local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1]; + if (uniforms.algo <= ALGO_LOCAL_DISPERSE) { + //Repopulate global data with local data + output_data[local_id.x * 2] = local_data[local_id.x * 2]; + output_data[local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1]; + } }`; }; diff --git a/src/sample/bitonicSort/main.ts b/src/sample/bitonicSort/main.ts index da130e0b..ed2d7da4 100644 --- a/src/sample/bitonicSort/main.ts +++ b/src/sample/bitonicSort/main.ts @@ -10,17 +10,11 @@ enum StepEnum { NONE, FLIP_LOCAL, DISPERSE_LOCAL, - FLIP_DISPERSE_LOCAL, FLIP_GLOBAL, } // String access to StepEnum -type StepType = - | 'NONE' - | 'FLIP_LOCAL' - | 'DISPERSE_LOCAL' - | 'FLIP_DISPERSE_LOCAL' - | 'FLIP_GLOBAL'; +type StepType = 'NONE' | 'FLIP_LOCAL' | 'DISPERSE_LOCAL' | 'FLIP_GLOBAL'; // Gui settings object interface SettingsInterface { @@ -94,7 +88,7 @@ SampleInitFactoryWebGPU( // Max thread span of next block 'Next Swap Span': 2, // Workgroups to dispatch per frame, - 'Total Workgroups': 1, + 'Total Workgroups': maxElements / (maxWorkgroupsX * 2), // Whether we will dispatch a workload this frame executeStep: false, 'Randomize Values': () => { @@ -275,6 +269,7 @@ SampleInitFactoryWebGPU( let swappedIndex: number; switch (settings['Next Step']) { case 'FLIP_LOCAL': + case 'FLIP_GLOBAL': { const blockHeight = settings['Next Swap Span']; const p2 = Math.floor(settings['Hovered Cell'] / blockHeight) + 1; @@ -497,16 +492,16 @@ SampleInitFactoryWebGPU( nextBlockHeightController.setValue(settings['Next Swap Span'] / 2); if (settings['Next Swap Span'] === 1) { highestBlockHeight *= 2; - nextStepController.setValue( - highestBlockHeight === settings['Total Elements'] * 2 - ? 'NONE' - : 'FLIP_LOCAL' - ); - nextBlockHeightController.setValue( - highestBlockHeight === settings['Total Elements'] * 2 - ? 0 - : highestBlockHeight - ); + if (highestBlockHeight === settings['Total Elements'] * 2) { + nextStepController.setValue('NONE'); + nextBlockHeightController.setValue(0); + } else if (highestBlockHeight > settings['Total Threads'] * 2) { + nextStepController.setValue('FLIP_GLOBAL'); + nextBlockHeightController.setValue(highestBlockHeight); + } else { + nextStepController.setValue('FLIP_LOCAL'); + nextBlockHeightController.setValue(highestBlockHeight); + } } else { nextStepController.setValue('DISPERSE_LOCAL'); }