Skip to content

Commit

Permalink
Finished adding updates
Browse files Browse the repository at this point in the history
  • Loading branch information
cmhhelgeson committed Oct 30, 2023
1 parent fc40521 commit 50a1bae
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
22 changes: 14 additions & 8 deletions src/sample/bitonicSort/computeShader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,30 +80,36 @@ fn computeMain(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let offset = ${threadsPerWorkgroup} * 2 * workgroup_id.x;
// If we will perform a local swap, then populate the local data
if (uniforms.algo <= 2) {
//Each thread will populate the workgroup data... (1 thread for every 2 elements)
local_data[global_id.x * 2] = input_data[global_id.x * 2];
local_data[global_id.x * 2 + 1] = input_data[global_id.x * 2 + 1];
// Assign range of input_data to local_data.
// Range cannot exceed maxWorkgroupsX * 2
// Each thread will populate the workgroup data... (1 thread for every 2 elements)
local_data[local_id.x * 2] = input_data[offset + local_id.x * 2];
local_data[local_id.x * 2 + 1] = input_data[offset + local_id.x * 2 + 1];
}
//...and wait for each other to finish their own bit of data population.
workgroupBarrier();
switch uniforms.algo {
case 1: { // Local Flip
let idx = get_flip_indices(global_id.x, uniforms.blockHeight);
let idx = get_flip_indices(local_id.x, uniforms.blockHeight);
local_compare_and_swap(idx.x, idx.y);
}
case 2: { // Local Disperse
let idx = get_disperse_indices(global_id.x, uniforms.blockHeight);
let idx = get_disperse_indices(local_id.x, uniforms.blockHeight);
local_compare_and_swap(idx.x, idx.y);
}
case 3: { // Global Flip
let idx = get_flip_indices(global_id.x, uniforms.blockHeight);
global_compare_and_swap(idx.x, idx.y);
}
// case 4: { //Global Disperse
case 4: {
let idx = get_disperse_indices(global_id.x, uniforms.blockHeight);
global_compare_and_swap(idx.x, idx.y);
}
default: {
}
Expand All @@ -114,8 +120,8 @@ fn computeMain(
if (uniforms.algo <= ALGO_LOCAL_DISPERSE) {
//Repopulate global data with local data
output_data[local_id.x * 2] = local_data[local_id.x * 2];
output_data[local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
output_data[offset + local_id.x * 2] = local_data[local_id.x * 2];
output_data[offset + local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
}
}`;
Expand Down
24 changes: 16 additions & 8 deletions src/sample/bitonicSort/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@ enum StepEnum {
FLIP_LOCAL,
DISPERSE_LOCAL,
FLIP_GLOBAL,
DISPERSE_GLOBAL,
}

// String access to StepEnum
type StepType = 'NONE' | 'FLIP_LOCAL' | 'DISPERSE_LOCAL' | 'FLIP_GLOBAL';
type StepType =
| 'NONE'
| 'FLIP_LOCAL'
| 'DISPERSE_LOCAL'
| 'FLIP_GLOBAL'
| 'DISPERSE_GLOBAL';

type DisplayType = 'Elements' | 'Swap Highlight';

Expand Down Expand Up @@ -50,10 +56,10 @@ const getNumSteps = (numElements: number) => {
let init: SampleInit;
SampleInitFactoryWebGPU(
async ({ pageState, device, gui, presentationFormat, context, canvas }) => {
const maxWorkgroupsX = device.limits.maxComputeWorkgroupSizeX;
const maxThreadsX = device.limits.maxComputeWorkgroupSizeX;

const totalElementLengths = [];
const maxElements = maxWorkgroupsX * 2;
const maxElements = maxThreadsX * 32;
for (let i = maxElements; i >= 4; i /= 2) {
totalElementLengths.push(i);
}
Expand All @@ -73,7 +79,7 @@ SampleInitFactoryWebGPU(
// height of screen in cells
'Grid Height': defaultGridHeight,
// number of threads to execute in a workgroup ('Total Threads', 1, 1)
'Total Threads': maxWorkgroupsX,
'Total Threads': maxThreadsX,
// Cell in element grid mouse element is hovering over
'Hovered Cell': 0,
// element the hovered cell just swapped with,
Expand All @@ -91,7 +97,7 @@ SampleInitFactoryWebGPU(
// Max thread span of next block
'Next Swap Span': 2,
// Workgroups to dispatch per frame,
'Total Workgroups': maxElements / (maxWorkgroupsX * 2),
'Total Workgroups': maxElements / (maxThreadsX * 2),
// Whether we will dispatch a workload this frame
executeStep: false,
'Display Mode': 'Elements',
Expand Down Expand Up @@ -195,12 +201,12 @@ SampleInitFactoryWebGPU(
const resetExecutionInformation = () => {
// Total threads are either elements / 2 or maxWorkgroupsSizeX
totalThreadsController.setValue(
Math.min(settings['Total Elements'] / 2, maxWorkgroupsX)
Math.min(settings['Total Elements'] / 2, maxThreadsX)
);

// Dispatch a workgroup for every (Max threads * 2) elements
const workgroupsPerStep =
(settings['Total Elements'] - 1) / (maxWorkgroupsX * 2);
(settings['Total Elements'] - 1) / (maxThreadsX * 2);

totalWorkgroupsController.setValue(Math.ceil(workgroupsPerStep));

Expand Down Expand Up @@ -506,7 +512,9 @@ SampleInitFactoryWebGPU(
nextBlockHeightController.setValue(highestBlockHeight);
}
} else {
nextStepController.setValue('DISPERSE_LOCAL');
settings['Next Swap Span'] > settings['Total Threads'] * 2
? nextStepController.setValue('DISPERSE_GLOBAL')
: nextStepController.setValue('DISPERSE_LOCAL');
}
commandEncoder.copyBufferToBuffer(
elementsOutputBuffer,
Expand Down

0 comments on commit 50a1bae

Please sign in to comment.