Skip to content

Commit b77d32c

Browse files
authored
Bitonic Sort Update (#309)
1 parent 3ff5cfa commit b77d32c

File tree

6 files changed

+378
-171
lines changed

6 files changed

+378
-171
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
@group(0) @binding(3) var<storage, read_write> counter: atomic<u32>;
2+
3+
@compute @workgroup_size(1, 1, 1)
4+
fn atomicToZero() {
5+
let counterValue = atomicLoad(&counter);
6+
atomicSub(&counter, counterValue);
7+
}

src/sample/bitonicSort/bitonicDisplay.frag.wgsl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
1-
struct Uniforms {
1+
struct ComputeUniforms {
22
width: f32,
33
height: f32,
4+
algo: u32,
5+
blockHeight: u32,
6+
}
7+
8+
struct FragmentUniforms {
9+
// boolean, either 0 or 1
10+
highlight: u32,
411
}
512

613
struct VertexOutput {
714
@builtin(position) Position: vec4<f32>,
815
@location(0) fragUV: vec2<f32>
916
}
1017

11-
@group(0) @binding(0) var<uniform> uniforms: Uniforms;
12-
@group(1) @binding(0) var<storage, read> data: array<u32>;
18+
// Uniforms from compute shader
19+
@group(0) @binding(0) var<storage, read> data: array<u32>;
20+
@group(0) @binding(2) var<uniform> uniforms: ComputeUniforms;
21+
// Fragment shader uniforms
22+
@group(1) @binding(0) var<uniform> fragment_uniforms: FragmentUniforms;
1323

1424
@fragment
1525
fn frag_main(input: VertexOutput) -> @location(0) vec4<f32> {
@@ -28,6 +38,16 @@ fn frag_main(input: VertexOutput) -> @location(0) vec4<f32> {
2838

2939
var subtracter = f32(colorChanger) / (uniforms.width * uniforms.height);
3040

41+
if (fragment_uniforms.highlight == 1) {
42+
return select(
43+
//If element is above halfHeight, highlight green
44+
vec4<f32>(vec3<f32>(0.0, 1.0 - subtracter, 0.0).rgb, 1.0),
45+
//If element is below halfheight, highlight red
46+
vec4<f32>(vec3<f32>(1.0 - subtracter, 0.0, 0.0).rgb, 1.0),
47+
elementIndex % uniforms.blockHeight < uniforms.blockHeight / 2
48+
);
49+
}
50+
3151
var color: vec3<f32> = vec3f(
3252
1.0 - subtracter
3353
);
Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import {
2-
BindGroupsObjectsAndLayout,
3-
createBindGroupDescriptor,
2+
BindGroupCluster,
43
Base2DRendererClass,
4+
createBindGroupCluster,
55
} from './utils';
66

77
import bitonicDisplay from './bitonicDisplay.frag.wgsl';
88

99
interface BitonicDisplayRenderArgs {
10-
width: number;
11-
height: number;
10+
highlight: number;
1211
}
1312

1413
export default class BitonicDisplayRenderer extends Base2DRendererClass {
@@ -19,26 +18,25 @@ export default class BitonicDisplayRenderer extends Base2DRendererClass {
1918

2019
switchBindGroup: (name: string) => void;
2120
setArguments: (args: BitonicDisplayRenderArgs) => void;
22-
computeBGDescript: BindGroupsObjectsAndLayout;
21+
computeBGDescript: BindGroupCluster;
2322

2423
constructor(
2524
device: GPUDevice,
2625
presentationFormat: GPUTextureFormat,
2726
renderPassDescriptor: GPURenderPassDescriptor,
28-
bindGroupNames: string[],
29-
computeBGDescript: BindGroupsObjectsAndLayout,
27+
computeBGDescript: BindGroupCluster,
3028
label: string
3129
) {
3230
super();
3331
this.renderPassDescriptor = renderPassDescriptor;
3432
this.computeBGDescript = computeBGDescript;
3533

3634
const uniformBuffer = device.createBuffer({
37-
size: Float32Array.BYTES_PER_ELEMENT * 2,
35+
size: Uint32Array.BYTES_PER_ELEMENT,
3836
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
3937
});
4038

41-
const bgDescript = createBindGroupDescriptor(
39+
const bgCluster = createBindGroupCluster(
4240
[0],
4341
[GPUShaderStage.FRAGMENT],
4442
['buffer'],
@@ -48,41 +46,30 @@ export default class BitonicDisplayRenderer extends Base2DRendererClass {
4846
device
4947
);
5048

51-
this.currentBindGroup = bgDescript.bindGroups[0];
52-
this.currentBindGroupName = bindGroupNames[0];
53-
54-
this.bindGroupMap = {};
55-
56-
bgDescript.bindGroups.forEach((bg, idx) => {
57-
this.bindGroupMap[bindGroupNames[idx]] = bg;
58-
});
49+
this.currentBindGroup = bgCluster.bindGroups[0];
5950

6051
this.pipeline = super.create2DRenderPipeline(
6152
device,
6253
label,
63-
[bgDescript.bindGroupLayout, this.computeBGDescript.bindGroupLayout],
54+
[this.computeBGDescript.bindGroupLayout, bgCluster.bindGroupLayout],
6455
bitonicDisplay,
6556
presentationFormat
6657
);
6758

68-
this.switchBindGroup = (name: string) => {
69-
this.currentBindGroup = this.bindGroupMap[name];
70-
this.currentBindGroupName = name;
71-
};
72-
7359
this.setArguments = (args: BitonicDisplayRenderArgs) => {
74-
super.setUniformArguments(device, uniformBuffer, args, [
75-
'width',
76-
'height',
77-
]);
60+
device.queue.writeBuffer(
61+
uniformBuffer,
62+
0,
63+
new Uint32Array([args.highlight])
64+
);
7865
};
7966
}
8067

8168
startRun(commandEncoder: GPUCommandEncoder, args: BitonicDisplayRenderArgs) {
8269
this.setArguments(args);
8370
super.executeRun(commandEncoder, this.renderPassDescriptor, this.pipeline, [
84-
this.currentBindGroup,
8571
this.computeBGDescript.bindGroups[0],
72+
this.currentBindGroup,
8673
]);
8774
}
8875
}

src/sample/bitonicSort/computeShader.ts

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,19 @@ struct Uniforms {
1515
}
1616
1717
// Create local workgroup data that can contain all elements
18-
1918
var<workgroup> local_data: array<u32, ${threadsPerWorkgroup * 2}>;
2019
21-
//Compare and swap values in local_data
22-
fn compare_and_swap(idx_before: u32, idx_after: u32) {
20+
// Define groups (functions refer to this data)
21+
@group(0) @binding(0) var<storage, read> input_data: array<u32>;
22+
@group(0) @binding(1) var<storage, read_write> output_data: array<u32>;
23+
@group(0) @binding(2) var<uniform> uniforms: Uniforms;
24+
@group(0) @binding(3) var<storage, read_write> counter: atomic<u32>;
25+
26+
// Compare and swap values in local_data
27+
fn local_compare_and_swap(idx_before: u32, idx_after: u32) {
2328
//idx_before should always be < idx_after
2429
if (local_data[idx_after] < local_data[idx_before]) {
30+
atomicAdd(&counter, 1);
2531
var temp: u32 = local_data[idx_before];
2632
local_data[idx_before] = local_data[idx_after];
2733
local_data[idx_after] = temp;
@@ -30,65 +36,94 @@ fn compare_and_swap(idx_before: u32, idx_after: u32) {
3036
}
3137
3238
// thread_id goes from 0 to threadsPerWorkgroup
33-
fn prepare_flip(thread_id: u32, block_height: u32) {
34-
let q: u32 = ((2 * thread_id) / block_height) * block_height;
39+
fn get_flip_indices(thread_id: u32, block_height: u32) -> vec2<u32> {
40+
// Caculate index offset (i.e move indices into correct block)
41+
let block_offset: u32 = ((2 * thread_id) / block_height) * block_height;
3542
let half_height = block_height / 2;
43+
// Calculate index spacing
3644
var idx: vec2<u32> = vec2<u32>(
3745
thread_id % half_height, block_height - (thread_id % half_height) - 1,
3846
);
39-
idx.x += q;
40-
idx.y += q;
41-
compare_and_swap(idx.x, idx.y);
47+
idx.x += block_offset;
48+
idx.y += block_offset;
49+
return idx;
4250
}
4351
44-
fn prepare_disperse(thread_id: u32, block_height: u32) {
45-
var q: u32 = ((2 * thread_id) / block_height) * block_height;
52+
fn get_disperse_indices(thread_id: u32, block_height: u32) -> vec2<u32> {
53+
var block_offset: u32 = ((2 * thread_id) / block_height) * block_height;
4654
let half_height = block_height / 2;
4755
var idx: vec2<u32> = vec2<u32>(
4856
thread_id % half_height, (thread_id % half_height) + half_height
4957
);
50-
idx.x += q;
51-
idx.y += q;
52-
compare_and_swap(idx.x, idx.y);
58+
idx.x += block_offset;
59+
idx.y += block_offset;
60+
return idx;
5361
}
5462
55-
@group(0) @binding(0) var<storage, read> input_data: array<u32>;
56-
@group(0) @binding(1) var<storage, read_write> output_data: array<u32>;
57-
@group(0) @binding(2) var<uniform> uniforms: Uniforms;
63+
fn global_compare_and_swap(idx_before: u32, idx_after: u32) {
64+
if (input_data[idx_after] < input_data[idx_before]) {
65+
output_data[idx_before] = input_data[idx_after];
66+
output_data[idx_after] = input_data[idx_before];
67+
}
68+
}
69+
70+
// Constants/enum
71+
const ALGO_NONE = 0;
72+
const ALGO_LOCAL_FLIP = 1;
73+
const ALGO_LOCAL_DISPERSE = 2;
74+
const ALGO_GLOBAL_FLIP = 3;
5875
5976
// Our compute shader will execute specified # of threads or elements / 2 threads
6077
@compute @workgroup_size(${threadsPerWorkgroup}, 1, 1)
6178
fn computeMain(
6279
@builtin(global_invocation_id) global_id: vec3<u32>,
6380
@builtin(local_invocation_id) local_id: vec3<u32>,
81+
@builtin(workgroup_id) workgroup_id: vec3<u32>,
6482
) {
65-
//Each thread will populate the workgroup data... (1 thread for every 2 elements)
66-
local_data[local_id.x * 2] = input_data[local_id.x * 2];
67-
local_data[local_id.x * 2 + 1] = input_data[local_id.x * 2 + 1];
83+
84+
let offset = ${threadsPerWorkgroup} * 2 * workgroup_id.x;
85+
// If we will perform a local swap, then populate the local data
86+
if (uniforms.algo <= 2) {
87+
// Assign range of input_data to local_data.
88+
// Range cannot exceed maxWorkgroupsX * 2
89+
// Each thread will populate the workgroup data... (1 thread for every 2 elements)
90+
local_data[local_id.x * 2] = input_data[offset + local_id.x * 2];
91+
local_data[local_id.x * 2 + 1] = input_data[offset + local_id.x * 2 + 1];
92+
}
6893
6994
//...and wait for each other to finish their own bit of data population.
7095
workgroupBarrier();
7196
72-
var num_elements = uniforms.width * uniforms.height;
73-
7497
switch uniforms.algo {
75-
case 1: { //Local Flip
76-
prepare_flip(local_id.x, uniforms.blockHeight);
98+
case 1: { // Local Flip
99+
let idx = get_flip_indices(local_id.x, uniforms.blockHeight);
100+
local_compare_and_swap(idx.x, idx.y);
101+
}
102+
case 2: { // Local Disperse
103+
let idx = get_disperse_indices(local_id.x, uniforms.blockHeight);
104+
local_compare_and_swap(idx.x, idx.y);
105+
}
106+
case 3: { // Global Flip
107+
let idx = get_flip_indices(global_id.x, uniforms.blockHeight);
108+
global_compare_and_swap(idx.x, idx.y);
77109
}
78-
case 2: { //Local Disperse
79-
prepare_disperse(local_id.x, uniforms.blockHeight);
110+
case 4: {
111+
let idx = get_disperse_indices(global_id.x, uniforms.blockHeight);
112+
global_compare_and_swap(idx.x, idx.y);
80113
}
81114
default: {
82115
83116
}
84117
}
85118
86-
//Ensure that all threads have swapped their own regions of data
119+
// Ensure that all threads have swapped their own regions of data
87120
workgroupBarrier();
88121
89-
//Repopulate global data with local data
90-
output_data[local_id.x * 2] = local_data[local_id.x * 2];
91-
output_data[local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
122+
if (uniforms.algo <= ALGO_LOCAL_DISPERSE) {
123+
//Repopulate global data with local data
124+
output_data[offset + local_id.x * 2] = local_data[local_id.x * 2];
125+
output_data[offset + local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
126+
}
92127
93128
}`;
94129
};

0 commit comments

Comments
 (0)