Skip to content

Commit

Permalink
Added bitonic sort example using compute shaders. (#301)
Browse files Browse the repository at this point in the history
* Added bitonic sort example

* Changed shaped of hover cursor in shader to be more readable

* Changed names of values and added auto-complete sort functionality

* Removed unused argKeys value

* Implemented suggested changes

* Implemented lolokung suggested changes

* Removed references to reticle

* Implemented non-extant austinEng suggestions

* Removed unused enums (may add back later), changed type of completeSortIntervalID, opened Sort Controls folder on init for sake of clarity

* Removed createWGSLUniforms

* Minor shader fix
  • Loading branch information
cmhhelgeson authored Oct 17, 2023
1 parent 6ef6ed2 commit 20a26d2
Show file tree
Hide file tree
Showing 6 changed files with 949 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/pages/samples/[slug].tsx
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export const pages: PageComponentType = {
renderBundles: dynamic(() => import('../../sample/renderBundles/main')),
worker: dynamic(() => import('../../sample/worker/main')),
'A-buffer': dynamic(() => import('../../sample/a-buffer/main')),
bitonicSort: dynamic(() => import('../../sample/bitonicSort/main')),
};

function Page({ slug }: Props): JSX.Element {
Expand Down
36 changes: 36 additions & 0 deletions src/sample/bitonicSort/bitonicDisplay.frag.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
struct Uniforms {
width: f32,
height: f32,
}

struct VertexOutput {
@builtin(position) Position: vec4<f32>,
@location(0) fragUV: vec2<f32>
}

@group(0) @binding(0) var<uniform> uniforms: Uniforms;
@group(1) @binding(0) var<storage, read> data: array<u32>;

@fragment
fn frag_main(input: VertexOutput) -> @location(0) vec4<f32> {
var uv: vec2<f32> = vec2<f32>(
input.fragUV.x * uniforms.width,
input.fragUV.y * uniforms.height
);

var pixel: vec2<u32> = vec2<u32>(
u32(floor(uv.x)),
u32(floor(uv.y)),
);

var elementIndex = u32(uniforms.width) * pixel.y + pixel.x;
var colorChanger = data[elementIndex];

var subtracter = f32(colorChanger) / (uniforms.width * uniforms.height);

var color: vec3<f32> = vec3f(
1.0 - subtracter
);

return vec4<f32>(color.rgb, 1.0);
}
88 changes: 88 additions & 0 deletions src/sample/bitonicSort/bitonicDisplay.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import {
BindGroupsObjectsAndLayout,
createBindGroupDescriptor,
Base2DRendererClass,
} from './utils';

import bitonicDisplay from './bitonicDisplay.frag.wgsl';

interface BitonicDisplayRenderArgs {
width: number;
height: number;
}

export default class BitonicDisplayRenderer extends Base2DRendererClass {
static sourceInfo = {
name: __filename.substring(__dirname.length + 1),
contents: __SOURCE__,
};

switchBindGroup: (name: string) => void;
setArguments: (args: BitonicDisplayRenderArgs) => void;
computeBGDescript: BindGroupsObjectsAndLayout;

constructor(
device: GPUDevice,
presentationFormat: GPUTextureFormat,
renderPassDescriptor: GPURenderPassDescriptor,
bindGroupNames: string[],
computeBGDescript: BindGroupsObjectsAndLayout,
label: string
) {
super();
this.renderPassDescriptor = renderPassDescriptor;
this.computeBGDescript = computeBGDescript;

const uniformBuffer = device.createBuffer({
size: Float32Array.BYTES_PER_ELEMENT * 2,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});

const bgDescript = createBindGroupDescriptor(
[0],
[GPUShaderStage.FRAGMENT],
['buffer'],
[{ type: 'uniform' }],
[[{ buffer: uniformBuffer }]],
label,
device
);

this.currentBindGroup = bgDescript.bindGroups[0];
this.currentBindGroupName = bindGroupNames[0];

this.bindGroupMap = {};

bgDescript.bindGroups.forEach((bg, idx) => {
this.bindGroupMap[bindGroupNames[idx]] = bg;
});

this.pipeline = super.create2DRenderPipeline(
device,
label,
[bgDescript.bindGroupLayout, this.computeBGDescript.bindGroupLayout],
bitonicDisplay,
presentationFormat
);

this.switchBindGroup = (name: string) => {
this.currentBindGroup = this.bindGroupMap[name];
this.currentBindGroupName = name;
};

this.setArguments = (args: BitonicDisplayRenderArgs) => {
super.setUniformArguments(device, uniformBuffer, args, [
'width',
'height',
]);
};
}

startRun(commandEncoder: GPUCommandEncoder, args: BitonicDisplayRenderArgs) {
this.setArguments(args);
super.executeRun(commandEncoder, this.renderPassDescriptor, this.pipeline, [
this.currentBindGroup,
this.computeBGDescript.bindGroups[0],
]);
}
}
94 changes: 94 additions & 0 deletions src/sample/bitonicSort/computeShader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
export const computeArgKeys = ['width', 'height', 'algo', 'blockHeight'];

export const NaiveBitonicCompute = (threadsPerWorkgroup: number) => {
if (threadsPerWorkgroup % 2 !== 0 || threadsPerWorkgroup > 256) {
threadsPerWorkgroup = 256;
}
// Ensure that workgroupSize is half the number of elements
return `
struct Uniforms {
width: f32,
height: f32,
algo: u32,
blockHeight: u32,
}
// Create local workgroup data that can contain all elements
var<workgroup> local_data: array<u32, ${threadsPerWorkgroup * 2}>;
//Compare and swap values in local_data
fn compare_and_swap(idx_before: u32, idx_after: u32) {
//idx_before should always be < idx_after
if (local_data[idx_after] < local_data[idx_before]) {
var temp: u32 = local_data[idx_before];
local_data[idx_before] = local_data[idx_after];
local_data[idx_after] = temp;
}
return;
}
// thread_id goes from 0 to threadsPerWorkgroup
fn prepare_flip(thread_id: u32, block_height: u32) {
let q: u32 = ((2 * thread_id) / block_height) * block_height;
let half_height = block_height / 2;
var idx: vec2<u32> = vec2<u32>(
thread_id % half_height, block_height - (thread_id % half_height) - 1,
);
idx.x += q;
idx.y += q;
compare_and_swap(idx.x, idx.y);
}
fn prepare_disperse(thread_id: u32, block_height: u32) {
var q: u32 = ((2 * thread_id) / block_height) * block_height;
let half_height = block_height / 2;
var idx: vec2<u32> = vec2<u32>(
thread_id % half_height, (thread_id % half_height) + half_height
);
idx.x += q;
idx.y += q;
compare_and_swap(idx.x, idx.y);
}
@group(0) @binding(0) var<storage, read> input_data: array<u32>;
@group(0) @binding(1) var<storage, read_write> output_data: array<u32>;
@group(0) @binding(2) var<uniform> uniforms: Uniforms;
// Our compute shader will execute specified # of threads or elements / 2 threads
@compute @workgroup_size(${threadsPerWorkgroup}, 1, 1)
fn computeMain(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
) {
//Each thread will populate the workgroup data... (1 thread for every 2 elements)
local_data[local_id.x * 2] = input_data[local_id.x * 2];
local_data[local_id.x * 2 + 1] = input_data[local_id.x * 2 + 1];
//...and wait for each other to finish their own bit of data population.
workgroupBarrier();
var num_elements = uniforms.width * uniforms.height;
switch uniforms.algo {
case 1: { //Local Flip
prepare_flip(local_id.x, uniforms.blockHeight);
}
case 2: { //Local Disperse
prepare_disperse(local_id.x, uniforms.blockHeight);
}
default: {
}
}
//Ensure that all threads have swapped their own regions of data
workgroupBarrier();
//Repopulate global data with local data
output_data[local_id.x * 2] = local_data[local_id.x * 2];
output_data[local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
}`;
};
Loading

0 comments on commit 20a26d2

Please sign in to comment.