Skip to content

Commit 912d35a

Browse files
committed
Flesh out compute example
Just a copy of wgpu-rs's 'hello-compute' example: https://github.com/gfx-rs/wgpu-rs/tree/v0.6/examples/hello-compute
1 parent b8cea94 commit 912d35a

File tree

7 files changed

+143
-55
lines changed

7 files changed

+143
-55
lines changed

Cargo.lock

+21
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/runners/wgpu/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ wgpu = "0.6.0"
2323
winit = { version = "0.24", features = ["web-sys"] }
2424
clap = "3.0.0-beta.2"
2525
strum = { version = "0.19", default_features = false, features = ["derive"] }
26+
bytemuck = { version = "1.4", features = ["derive"] }
2627

2728
[build-dependencies]
2829
spirv-builder = { path = "../../../crates/spirv-builder", default-features = false }

examples/runners/wgpu/src/compute.rs

+80-42
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use super::{shader_module, Options};
2-
use core::num::NonZeroU64;
2+
use std::convert::TryInto;
3+
use wgpu::util::DeviceExt;
34

4-
fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
5+
async fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
56
async fn create_device_queue_async() -> (wgpu::Device, wgpu::Queue) {
67
let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
78
let adapter = instance
89
.request_adapter(&wgpu::RequestAdapterOptions {
9-
power_preference: wgpu::PowerPreference::default(),
10+
power_preference: wgpu::PowerPreference::Default,
1011
compatible_surface: None,
1112
})
1213
.await
@@ -28,33 +29,57 @@ fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
2829
if #[cfg(target_arch = "wasm32")] {
2930
wasm_bindgen_futures::spawn_local(create_device_queue_async())
3031
} else {
31-
futures::executor::block_on(create_device_queue_async())
32+
create_device_queue_async().await
3233
}
3334
}
3435
}
3536

36-
pub fn start(options: &Options) {
37-
let (device, queue) = create_device_queue();
37+
pub async fn start(options: &Options) -> Vec<u32> {
38+
wgpu_subscriber::initialize_default_subscriber(None);
39+
let numbers: Vec<u32> = vec![1, 2, 3, 4];
40+
let slice_size = numbers.len() * std::mem::size_of::<u32>();
41+
let size = slice_size as wgpu::BufferAddress;
3842

39-
// Load the shaders from disk
40-
let module = device.create_shader_module(shader_module(options.shader));
43+
let (device, queue) = create_device_queue().await;
44+
45+
let cs_module = device.create_shader_module(shader_module(options.shader));
46+
47+
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
48+
label: None,
49+
size,
50+
usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
51+
mapped_at_creation: false,
52+
});
53+
54+
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
55+
label: Some("Storage Buffer"),
56+
contents: bytemuck::cast_slice(&numbers),
57+
usage: wgpu::BufferUsage::STORAGE
58+
| wgpu::BufferUsage::COPY_DST
59+
| wgpu::BufferUsage::COPY_SRC,
60+
});
4161

4262
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
4363
label: None,
44-
entries: &[
45-
// XXX - some graphics cards do not support empty bind layout groups, so
46-
// create a dummy entry.
47-
wgpu::BindGroupLayoutEntry {
48-
binding: 0,
49-
count: None,
50-
visibility: wgpu::ShaderStage::COMPUTE,
51-
ty: wgpu::BindingType::StorageBuffer {
52-
dynamic: false,
53-
min_binding_size: Some(NonZeroU64::new(1).unwrap()),
54-
readonly: false,
55-
},
64+
entries: &[wgpu::BindGroupLayoutEntry {
65+
binding: 0,
66+
visibility: wgpu::ShaderStage::COMPUTE,
67+
ty: wgpu::BindingType::StorageBuffer {
68+
dynamic: false,
69+
readonly: false,
70+
min_binding_size: wgpu::BufferSize::new(4),
5671
},
57-
],
72+
count: None,
73+
}],
74+
});
75+
76+
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
77+
label: None,
78+
layout: &bind_group_layout,
79+
entries: &[wgpu::BindGroupEntry {
80+
binding: 0,
81+
resource: wgpu::BindingResource::Buffer(storage_buffer.slice(..)),
82+
}],
5883
});
5984

6085
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
@@ -67,36 +92,49 @@ pub fn start(options: &Options) {
6792
label: None,
6893
layout: Some(&pipeline_layout),
6994
compute_stage: wgpu::ProgrammableStageDescriptor {
70-
module: &module,
95+
module: &cs_module,
7196
entry_point: "main_cs",
7297
},
7398
});
7499

75-
let buf = device.create_buffer(&wgpu::BufferDescriptor {
76-
label: None,
77-
size: 1,
78-
usage: wgpu::BufferUsage::STORAGE,
79-
mapped_at_creation: false,
80-
});
81-
82-
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
83-
label: None,
84-
layout: &bind_group_layout,
85-
entries: &[wgpu::BindGroupEntry {
86-
binding: 0,
87-
resource: wgpu::BindingResource::Buffer(buf.slice(..)),
88-
}],
89-
});
90-
91100
let mut encoder =
92101
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
93-
94102
{
95103
let mut cpass = encoder.begin_compute_pass();
96-
cpass.set_bind_group(0, &bind_group, &[]);
97104
cpass.set_pipeline(&compute_pipeline);
98-
cpass.dispatch(1, 1, 1);
105+
cpass.set_bind_group(0, &bind_group, &[]);
106+
cpass.insert_debug_marker("compute collatz iterations");
107+
cpass.dispatch(numbers.len() as u32, 1, 1);
99108
}
100-
109+
encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, size);
101110
queue.submit(Some(encoder.finish()));
111+
// Note that we're not calling `.await` here.
112+
113+
let buffer_slice = staging_buffer.slice(..);
114+
let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read);
115+
116+
// Poll the device in a blocking manner so that our future resolves.
117+
// In an actual application, `device.poll(...)` should
118+
// be called in an event loop or on another thread.
119+
device.poll(wgpu::Maintain::Wait);
120+
121+
if let Ok(()) = buffer_future.await {
122+
let data = buffer_slice.get_mapped_range();
123+
let result = data
124+
.chunks_exact(4)
125+
.map(|b| u32::from_ne_bytes(b.try_into().unwrap()))
126+
.collect();
127+
128+
// With the current interface, we have to make sure all mapped views are
129+
// dropped before we unmap the buffer.
130+
drop(data);
131+
staging_buffer.unmap();
132+
133+
println!("Times: {:?}", result);
134+
#[cfg(target_arch = "wasm32")]
135+
log::info!("Times: {:?}", result);
136+
result
137+
} else {
138+
panic!("failed to run compute on gpu!")
139+
}
102140
}

examples/runners/wgpu/src/graphics.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ async fn run(
261261
});
262262
}
263263

264-
pub fn start(options: &Options) {
264+
pub async fn start(options: &Options) {
265265
let event_loop = EventLoop::new();
266266
let window = winit::window::WindowBuilder::new()
267267
.with_title("Rust GPU - wgpu")
@@ -282,16 +282,16 @@ pub fn start(options: &Options) {
282282
body.append_child(&web_sys::Element::from(window.canvas()))
283283
.ok()
284284
})
285-
.expect("couldn't append canvas to document body");
285+
.expect("couldn't append canvas to document body");
286286
// Temporarily avoid srgb formats for the swapchain on the web
287287
wasm_bindgen_futures::spawn_local(run(
288-
event_loop,
289-
window,
290-
wgpu::TextureFormat::Bgra8Unorm,
288+
event_loop,
289+
window,
290+
wgpu::TextureFormat::Bgra8Unorm,
291291
));
292292
} else {
293293
wgpu_subscriber::initialize_default_subscriber(None);
294-
futures::executor::block_on(run(
294+
run(
295295
options,
296296
event_loop,
297297
window,
@@ -300,7 +300,7 @@ pub fn start(options: &Options) {
300300
} else {
301301
wgpu::TextureFormat::Bgra8UnormSrgb
302302
},
303-
));
303+
).await;
304304
}
305305
}
306306
}

examples/runners/wgpu/src/lib.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ pub struct Options {
3232
}
3333

3434
#[cfg_attr(target_os = "android", ndk_glue::main(backtrace = "on"))]
35-
pub fn main() {
35+
pub async fn main() {
3636
let options: Options = Options::parse();
3737

3838
if is_compute_shader(options.shader) {
39-
compute::start(&options)
39+
compute::start(&options).await;
4040
} else {
41-
graphics::start(&options);
41+
graphics::start(&options).await;
4242
}
4343
}

examples/runners/wgpu/src/main.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
fn main() {
2-
example_runner_wgpu::main()
2+
futures::executor::block_on(example_runner_wgpu::main());
33
}

examples/shaders/compute-shader/src/lib.rs

+30-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,34 @@ extern crate spirv_std;
1111
#[macro_use]
1212
pub extern crate spirv_std_macros;
1313

14+
use spirv_std::storage_class::{Input, StorageBuffer};
15+
16+
// The Collatz Conjecture states that for any integer n:
17+
// If n is even, n = n/2
18+
// If n is odd, n = 3n+1
19+
// And repeat this process for each new n, you will always eventually reach 1.
20+
// Though the conjecture has not been proven, no counterexample has ever been found.
21+
// This function returns how many times this recurrence needs to be applied to reach 1.
22+
pub fn collatz_iterations(mut n: i32) -> i32 {
23+
let mut i = 0;
24+
while n != 1 {
25+
if n.rem_euclid(2) == 0 {
26+
n = n / 2;
27+
} else {
28+
n = 3 * n + 1;
29+
}
30+
i += 1;
31+
}
32+
i
33+
}
34+
1435
#[allow(unused_attributes)]
15-
#[spirv(gl_compute)]
16-
pub fn main_cs() {}
36+
#[spirv(gl_compute(local_size_x = 1))]
37+
pub fn main_cs(
38+
#[spirv(global_invocation_id)] gid: Input<i32>,
39+
#[spirv(storage_buffer)] mut storage: StorageBuffer<u32>,
40+
) {
41+
let gid = gid.load();
42+
let result = collatz_iterations(gid);
43+
storage.store(result as u32)
44+
}

0 commit comments

Comments
 (0)