Skip to content

Commit 20a26d2

Browse files
authored
Added bitonic sort example using compute shaders. (#301)
* Added bitonic sort example * Changed shaped of hover cursor in shader to be more readable * Changed names of values and added auto-complete sort functionality * Removed unused argKeys value * Implemented suggested changes * Implemented lolokung suggested changes * Removed references to reticle * Implemented non-extant austinEng suggestions * Removed unused enums (may add back later), changed type of completeSortIntervalID, opened Sort Controls folder on init for sake of clarity * Removed createWGSLUniforms * Minor shader fix
1 parent 6ef6ed2 commit 20a26d2

File tree

6 files changed

+949
-0
lines changed

6 files changed

+949
-0
lines changed

src/pages/samples/[slug].tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ export const pages: PageComponentType = {
4747
renderBundles: dynamic(() => import('../../sample/renderBundles/main')),
4848
worker: dynamic(() => import('../../sample/worker/main')),
4949
'A-buffer': dynamic(() => import('../../sample/a-buffer/main')),
50+
bitonicSort: dynamic(() => import('../../sample/bitonicSort/main')),
5051
};
5152

5253
function Page({ slug }: Props): JSX.Element {
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
struct Uniforms {
2+
width: f32,
3+
height: f32,
4+
}
5+
6+
struct VertexOutput {
7+
@builtin(position) Position: vec4<f32>,
8+
@location(0) fragUV: vec2<f32>
9+
}
10+
11+
@group(0) @binding(0) var<uniform> uniforms: Uniforms;
12+
@group(1) @binding(0) var<storage, read> data: array<u32>;
13+
14+
@fragment
15+
fn frag_main(input: VertexOutput) -> @location(0) vec4<f32> {
16+
var uv: vec2<f32> = vec2<f32>(
17+
input.fragUV.x * uniforms.width,
18+
input.fragUV.y * uniforms.height
19+
);
20+
21+
var pixel: vec2<u32> = vec2<u32>(
22+
u32(floor(uv.x)),
23+
u32(floor(uv.y)),
24+
);
25+
26+
var elementIndex = u32(uniforms.width) * pixel.y + pixel.x;
27+
var colorChanger = data[elementIndex];
28+
29+
var subtracter = f32(colorChanger) / (uniforms.width * uniforms.height);
30+
31+
var color: vec3<f32> = vec3f(
32+
1.0 - subtracter
33+
);
34+
35+
return vec4<f32>(color.rgb, 1.0);
36+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import {
2+
BindGroupsObjectsAndLayout,
3+
createBindGroupDescriptor,
4+
Base2DRendererClass,
5+
} from './utils';
6+
7+
import bitonicDisplay from './bitonicDisplay.frag.wgsl';
8+
9+
interface BitonicDisplayRenderArgs {
10+
width: number;
11+
height: number;
12+
}
13+
14+
export default class BitonicDisplayRenderer extends Base2DRendererClass {
15+
static sourceInfo = {
16+
name: __filename.substring(__dirname.length + 1),
17+
contents: __SOURCE__,
18+
};
19+
20+
switchBindGroup: (name: string) => void;
21+
setArguments: (args: BitonicDisplayRenderArgs) => void;
22+
computeBGDescript: BindGroupsObjectsAndLayout;
23+
24+
constructor(
25+
device: GPUDevice,
26+
presentationFormat: GPUTextureFormat,
27+
renderPassDescriptor: GPURenderPassDescriptor,
28+
bindGroupNames: string[],
29+
computeBGDescript: BindGroupsObjectsAndLayout,
30+
label: string
31+
) {
32+
super();
33+
this.renderPassDescriptor = renderPassDescriptor;
34+
this.computeBGDescript = computeBGDescript;
35+
36+
const uniformBuffer = device.createBuffer({
37+
size: Float32Array.BYTES_PER_ELEMENT * 2,
38+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
39+
});
40+
41+
const bgDescript = createBindGroupDescriptor(
42+
[0],
43+
[GPUShaderStage.FRAGMENT],
44+
['buffer'],
45+
[{ type: 'uniform' }],
46+
[[{ buffer: uniformBuffer }]],
47+
label,
48+
device
49+
);
50+
51+
this.currentBindGroup = bgDescript.bindGroups[0];
52+
this.currentBindGroupName = bindGroupNames[0];
53+
54+
this.bindGroupMap = {};
55+
56+
bgDescript.bindGroups.forEach((bg, idx) => {
57+
this.bindGroupMap[bindGroupNames[idx]] = bg;
58+
});
59+
60+
this.pipeline = super.create2DRenderPipeline(
61+
device,
62+
label,
63+
[bgDescript.bindGroupLayout, this.computeBGDescript.bindGroupLayout],
64+
bitonicDisplay,
65+
presentationFormat
66+
);
67+
68+
this.switchBindGroup = (name: string) => {
69+
this.currentBindGroup = this.bindGroupMap[name];
70+
this.currentBindGroupName = name;
71+
};
72+
73+
this.setArguments = (args: BitonicDisplayRenderArgs) => {
74+
super.setUniformArguments(device, uniformBuffer, args, [
75+
'width',
76+
'height',
77+
]);
78+
};
79+
}
80+
81+
startRun(commandEncoder: GPUCommandEncoder, args: BitonicDisplayRenderArgs) {
82+
this.setArguments(args);
83+
super.executeRun(commandEncoder, this.renderPassDescriptor, this.pipeline, [
84+
this.currentBindGroup,
85+
this.computeBGDescript.bindGroups[0],
86+
]);
87+
}
88+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
export const computeArgKeys = ['width', 'height', 'algo', 'blockHeight'];
2+
3+
export const NaiveBitonicCompute = (threadsPerWorkgroup: number) => {
4+
if (threadsPerWorkgroup % 2 !== 0 || threadsPerWorkgroup > 256) {
5+
threadsPerWorkgroup = 256;
6+
}
7+
// Ensure that workgroupSize is half the number of elements
8+
return `
9+
10+
struct Uniforms {
11+
width: f32,
12+
height: f32,
13+
algo: u32,
14+
blockHeight: u32,
15+
}
16+
17+
// Create local workgroup data that can contain all elements
18+
19+
var<workgroup> local_data: array<u32, ${threadsPerWorkgroup * 2}>;
20+
21+
//Compare and swap values in local_data
22+
fn compare_and_swap(idx_before: u32, idx_after: u32) {
23+
//idx_before should always be < idx_after
24+
if (local_data[idx_after] < local_data[idx_before]) {
25+
var temp: u32 = local_data[idx_before];
26+
local_data[idx_before] = local_data[idx_after];
27+
local_data[idx_after] = temp;
28+
}
29+
return;
30+
}
31+
32+
// thread_id goes from 0 to threadsPerWorkgroup
33+
fn prepare_flip(thread_id: u32, block_height: u32) {
34+
let q: u32 = ((2 * thread_id) / block_height) * block_height;
35+
let half_height = block_height / 2;
36+
var idx: vec2<u32> = vec2<u32>(
37+
thread_id % half_height, block_height - (thread_id % half_height) - 1,
38+
);
39+
idx.x += q;
40+
idx.y += q;
41+
compare_and_swap(idx.x, idx.y);
42+
}
43+
44+
fn prepare_disperse(thread_id: u32, block_height: u32) {
45+
var q: u32 = ((2 * thread_id) / block_height) * block_height;
46+
let half_height = block_height / 2;
47+
var idx: vec2<u32> = vec2<u32>(
48+
thread_id % half_height, (thread_id % half_height) + half_height
49+
);
50+
idx.x += q;
51+
idx.y += q;
52+
compare_and_swap(idx.x, idx.y);
53+
}
54+
55+
@group(0) @binding(0) var<storage, read> input_data: array<u32>;
56+
@group(0) @binding(1) var<storage, read_write> output_data: array<u32>;
57+
@group(0) @binding(2) var<uniform> uniforms: Uniforms;
58+
59+
// Our compute shader will execute specified # of threads or elements / 2 threads
60+
@compute @workgroup_size(${threadsPerWorkgroup}, 1, 1)
61+
fn computeMain(
62+
@builtin(global_invocation_id) global_id: vec3<u32>,
63+
@builtin(local_invocation_id) local_id: vec3<u32>,
64+
) {
65+
//Each thread will populate the workgroup data... (1 thread for every 2 elements)
66+
local_data[local_id.x * 2] = input_data[local_id.x * 2];
67+
local_data[local_id.x * 2 + 1] = input_data[local_id.x * 2 + 1];
68+
69+
//...and wait for each other to finish their own bit of data population.
70+
workgroupBarrier();
71+
72+
var num_elements = uniforms.width * uniforms.height;
73+
74+
switch uniforms.algo {
75+
case 1: { //Local Flip
76+
prepare_flip(local_id.x, uniforms.blockHeight);
77+
}
78+
case 2: { //Local Disperse
79+
prepare_disperse(local_id.x, uniforms.blockHeight);
80+
}
81+
default: {
82+
83+
}
84+
}
85+
86+
//Ensure that all threads have swapped their own regions of data
87+
workgroupBarrier();
88+
89+
//Repopulate global data with local data
90+
output_data[local_id.x * 2] = local_data[local_id.x * 2];
91+
output_data[local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
92+
93+
}`;
94+
};

0 commit comments

Comments
 (0)