@@ -15,13 +15,19 @@ struct Uniforms {
15
15
}
16
16
17
17
// Create local workgroup data that can contain all elements
18
-
19
18
var<workgroup> local_data: array<u32, ${ threadsPerWorkgroup * 2 } >;
20
19
21
- //Compare and swap values in local_data
22
- fn compare_and_swap(idx_before: u32, idx_after: u32) {
20
+ // Define groups (functions refer to this data)
21
+ @group(0) @binding(0) var<storage, read> input_data: array<u32>;
22
+ @group(0) @binding(1) var<storage, read_write> output_data: array<u32>;
23
+ @group(0) @binding(2) var<uniform> uniforms: Uniforms;
24
+ @group(0) @binding(3) var<storage, read_write> counter: atomic<u32>;
25
+
26
+ // Compare and swap values in local_data
27
+ fn local_compare_and_swap(idx_before: u32, idx_after: u32) {
23
28
//idx_before should always be < idx_after
24
29
if (local_data[idx_after] < local_data[idx_before]) {
30
+ atomicAdd(&counter, 1);
25
31
var temp: u32 = local_data[idx_before];
26
32
local_data[idx_before] = local_data[idx_after];
27
33
local_data[idx_after] = temp;
@@ -30,65 +36,94 @@ fn compare_and_swap(idx_before: u32, idx_after: u32) {
30
36
}
31
37
32
38
// thread_id goes from 0 to threadsPerWorkgroup
33
- fn prepare_flip(thread_id: u32, block_height: u32) {
34
- let q: u32 = ((2 * thread_id) / block_height) * block_height;
39
+ fn get_flip_indices(thread_id: u32, block_height: u32) -> vec2<u32> {
40
+ // Caculate index offset (i.e move indices into correct block)
41
+ let block_offset: u32 = ((2 * thread_id) / block_height) * block_height;
35
42
let half_height = block_height / 2;
43
+ // Calculate index spacing
36
44
var idx: vec2<u32> = vec2<u32>(
37
45
thread_id % half_height, block_height - (thread_id % half_height) - 1,
38
46
);
39
- idx.x += q ;
40
- idx.y += q ;
41
- compare_and_swap(idx.x, idx.y) ;
47
+ idx.x += block_offset ;
48
+ idx.y += block_offset ;
49
+ return idx;
42
50
}
43
51
44
- fn prepare_disperse (thread_id: u32, block_height: u32) {
45
- var q : u32 = ((2 * thread_id) / block_height) * block_height;
52
+ fn get_disperse_indices (thread_id: u32, block_height: u32) -> vec2<u32> {
53
+ var block_offset : u32 = ((2 * thread_id) / block_height) * block_height;
46
54
let half_height = block_height / 2;
47
55
var idx: vec2<u32> = vec2<u32>(
48
56
thread_id % half_height, (thread_id % half_height) + half_height
49
57
);
50
- idx.x += q ;
51
- idx.y += q ;
52
- compare_and_swap(idx.x, idx.y) ;
58
+ idx.x += block_offset ;
59
+ idx.y += block_offset ;
60
+ return idx;
53
61
}
54
62
55
- @group(0) @binding(0) var<storage, read> input_data: array<u32>;
56
- @group(0) @binding(1) var<storage, read_write> output_data: array<u32>;
57
- @group(0) @binding(2) var<uniform> uniforms: Uniforms;
63
+ fn global_compare_and_swap(idx_before: u32, idx_after: u32) {
64
+ if (input_data[idx_after] < input_data[idx_before]) {
65
+ output_data[idx_before] = input_data[idx_after];
66
+ output_data[idx_after] = input_data[idx_before];
67
+ }
68
+ }
69
+
70
+ // Constants/enum
71
+ const ALGO_NONE = 0;
72
+ const ALGO_LOCAL_FLIP = 1;
73
+ const ALGO_LOCAL_DISPERSE = 2;
74
+ const ALGO_GLOBAL_FLIP = 3;
58
75
59
76
// Our compute shader will execute specified # of threads or elements / 2 threads
60
77
@compute @workgroup_size(${ threadsPerWorkgroup } , 1, 1)
61
78
fn computeMain(
62
79
@builtin(global_invocation_id) global_id: vec3<u32>,
63
80
@builtin(local_invocation_id) local_id: vec3<u32>,
81
+ @builtin(workgroup_id) workgroup_id: vec3<u32>,
64
82
) {
65
- //Each thread will populate the workgroup data... (1 thread for every 2 elements)
66
- local_data[local_id.x * 2] = input_data[local_id.x * 2];
67
- local_data[local_id.x * 2 + 1] = input_data[local_id.x * 2 + 1];
83
+
84
+ let offset = ${ threadsPerWorkgroup } * 2 * workgroup_id.x;
85
+ // If we will perform a local swap, then populate the local data
86
+ if (uniforms.algo <= 2) {
87
+ // Assign range of input_data to local_data.
88
+ // Range cannot exceed maxWorkgroupsX * 2
89
+ // Each thread will populate the workgroup data... (1 thread for every 2 elements)
90
+ local_data[local_id.x * 2] = input_data[offset + local_id.x * 2];
91
+ local_data[local_id.x * 2 + 1] = input_data[offset + local_id.x * 2 + 1];
92
+ }
68
93
69
94
//...and wait for each other to finish their own bit of data population.
70
95
workgroupBarrier();
71
96
72
- var num_elements = uniforms.width * uniforms.height;
73
-
74
97
switch uniforms.algo {
75
- case 1: { //Local Flip
76
- prepare_flip(local_id.x, uniforms.blockHeight);
98
+ case 1: { // Local Flip
99
+ let idx = get_flip_indices(local_id.x, uniforms.blockHeight);
100
+ local_compare_and_swap(idx.x, idx.y);
101
+ }
102
+ case 2: { // Local Disperse
103
+ let idx = get_disperse_indices(local_id.x, uniforms.blockHeight);
104
+ local_compare_and_swap(idx.x, idx.y);
105
+ }
106
+ case 3: { // Global Flip
107
+ let idx = get_flip_indices(global_id.x, uniforms.blockHeight);
108
+ global_compare_and_swap(idx.x, idx.y);
77
109
}
78
- case 2: { //Local Disperse
79
- prepare_disperse(local_id.x, uniforms.blockHeight);
110
+ case 4: {
111
+ let idx = get_disperse_indices(global_id.x, uniforms.blockHeight);
112
+ global_compare_and_swap(idx.x, idx.y);
80
113
}
81
114
default: {
82
115
83
116
}
84
117
}
85
118
86
- //Ensure that all threads have swapped their own regions of data
119
+ // Ensure that all threads have swapped their own regions of data
87
120
workgroupBarrier();
88
121
89
- //Repopulate global data with local data
90
- output_data[local_id.x * 2] = local_data[local_id.x * 2];
91
- output_data[local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
122
+ if (uniforms.algo <= ALGO_LOCAL_DISPERSE) {
123
+ //Repopulate global data with local data
124
+ output_data[offset + local_id.x * 2] = local_data[local_id.x * 2];
125
+ output_data[offset + local_id.x * 2 + 1] = local_data[local_id.x * 2 + 1];
126
+ }
92
127
93
128
}` ;
94
129
} ;
0 commit comments