Skip to content

Commit 0431f10

Browse files
committed
Flesh out compute example
Just a copy of wgpu-rs's 'hello-compute' example: https://github.com/gfx-rs/wgpu-rs/tree/v0.6/examples/hello-compute
1 parent d431dfd commit 0431f10

File tree

6 files changed

+137
-20
lines changed

6 files changed

+137
-20
lines changed

Cargo.lock

+21
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/runners/wgpu/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ wgpu = "0.6.0"
2323
winit = { version = "0.24", features = ["web-sys"] }
2424
clap = "3.0.0-beta.2"
2525
strum = { version = "0.19", default_features = false, features = ["derive"] }
26+
bytemuck = { version = "1.4", features = ["derive"] }
2627

2728
[build-dependencies]
2829
spirv-builder = { path = "../../../crates/spirv-builder", default-features = false }

examples/runners/wgpu/src/compute.rs

+74-15
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
use super::{shader_module, Options};
2+
use std::convert::TryInto;
3+
use wgpu::util::DeviceExt;
24

35
fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
46
async fn create_device_queue_async() -> (wgpu::Device, wgpu::Queue) {
57
let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
68
let adapter = instance
79
.request_adapter(&wgpu::RequestAdapterOptions {
8-
power_preference: wgpu::PowerPreference::default(),
10+
power_preference: wgpu::PowerPreference::Default,
911
compatible_surface: None,
1012
})
1113
.await
@@ -32,15 +34,51 @@ fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
3234
}
3335
}
3436

35-
pub fn start(options: &Options) {
37+
pub async fn start(options: &Options) -> Vec<u32> {
38+
let numbers: Vec<u32> = vec![1, 2, 3, 4];
39+
let slice_size = numbers.len() * std::mem::size_of::<u32>();
40+
let size = slice_size as wgpu::BufferAddress;
41+
3642
let (device, queue) = create_device_queue();
3743

38-
// Load the shaders from disk
39-
let module = device.create_shader_module(shader_module(options.shader));
44+
let cs_module = device.create_shader_module(shader_module(options.shader));
45+
46+
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
47+
label: None,
48+
size,
49+
usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
50+
mapped_at_creation: false,
51+
});
52+
53+
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
54+
label: Some("Storage Buffer"),
55+
contents: bytemuck::cast_slice(&numbers),
56+
usage: wgpu::BufferUsage::STORAGE
57+
| wgpu::BufferUsage::COPY_DST
58+
| wgpu::BufferUsage::COPY_SRC,
59+
});
4060

4161
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
4262
label: None,
43-
entries: &[],
63+
entries: &[wgpu::BindGroupLayoutEntry {
64+
binding: 0,
65+
visibility: wgpu::ShaderStage::COMPUTE,
66+
ty: wgpu::BindingType::StorageBuffer {
67+
dynamic: false,
68+
readonly: false,
69+
min_binding_size: wgpu::BufferSize::new(4),
70+
},
71+
count: None,
72+
}],
73+
});
74+
75+
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
76+
label: None,
77+
layout: &bind_group_layout,
78+
entries: &[wgpu::BindGroupEntry {
79+
binding: 0,
80+
resource: wgpu::BindingResource::Buffer(storage_buffer.slice(..)),
81+
}],
4482
});
4583

4684
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
@@ -53,26 +91,47 @@ pub fn start(options: &Options) {
5391
label: None,
5492
layout: Some(&pipeline_layout),
5593
compute_stage: wgpu::ProgrammableStageDescriptor {
56-
module: &module,
94+
module: &cs_module,
5795
entry_point: "main_cs",
5896
},
5997
});
6098

61-
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
62-
label: None,
63-
layout: &bind_group_layout,
64-
entries: &[],
65-
});
66-
6799
let mut encoder =
68100
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
69-
70101
{
71102
let mut cpass = encoder.begin_compute_pass();
72-
cpass.set_bind_group(0, &bind_group, &[]);
73103
cpass.set_pipeline(&compute_pipeline);
74-
cpass.dispatch(1, 1, 1);
104+
cpass.set_bind_group(0, &bind_group, &[]);
105+
cpass.insert_debug_marker("compute collatz iterations");
106+
cpass.dispatch(numbers.len() as u32, 1, 1);
75107
}
108+
encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, size);
76109

77110
queue.submit(Some(encoder.finish()));
111+
112+
// Note that we're not calling `.await` here.
113+
let buffer_slice = staging_buffer.slice(..);
114+
let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read);
115+
116+
// Poll the device in a blocking manner so that our future resolves.
117+
// In an actual application, `device.poll(...)` should
118+
// be called in an event loop or on another thread.
119+
device.poll(wgpu::Maintain::Wait);
120+
121+
if let Ok(()) = buffer_future.await {
122+
let data = buffer_slice.get_mapped_range();
123+
let result = data
124+
.chunks_exact(4)
125+
.map(|b| u32::from_ne_bytes(b.try_into().unwrap()))
126+
.collect();
127+
128+
// With the current interface, we have to make sure all mapped views are
129+
// dropped before we unmap the buffer.
130+
drop(data);
131+
staging_buffer.unmap();
132+
133+
result
134+
} else {
135+
panic!("failed to run compute on gpu!")
136+
}
78137
}

examples/runners/wgpu/src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ pub struct Options {
3232
}
3333

3434
#[cfg_attr(target_os = "android", ndk_glue::main(backtrace = "on"))]
35-
pub fn main() {
35+
pub async fn main() {
3636
let options: Options = Options::parse();
3737

3838
if is_compute_shader(options.shader) {
39-
compute::start(&options)
39+
compute::start(&options).await;
4040
} else {
4141
graphics::start(&options);
4242
}

examples/runners/wgpu/src/main.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
11
fn main() {
2-
example_runner_wgpu::main()
2+
subscriber::initialize_default_subscriber(None);
3+
futures::executor::block_on(run());
4+
}
5+
6+
async fn run() {
7+
let times = example_runner_wgpu::main().await;
8+
println!("Times: {:?}", times);
9+
#[cfg(target_arch = "wasm32")]
10+
log::info!("Times: {:?}", times);
311
}

examples/shaders/compute-shader/src/lib.rs

+30-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,34 @@ extern crate spirv_std;
1111
#[macro_use]
1212
pub extern crate spirv_std_macros;
1313

14+
use spirv_std::storage_class::{Input, StorageBuffer};
15+
16+
// The Collatz Conjecture states that for any integer n:
17+
// If n is even, n = n/2
18+
// If n is odd, n = 3n+1
19+
// And repeat this process for each new n, you will always eventually reach 1.
20+
// Though the conjecture has not been proven, no counterexample has ever been found.
21+
// This function returns how many times this recurrence needs to be applied to reach 1.
22+
pub fn collatz_iterations(mut n: i32) -> i32 {
23+
let mut i = 0;
24+
while n != 1 {
25+
if n.rem_euclid(2) == 0 {
26+
n = n / 2;
27+
} else {
28+
n = 3 * n + 1;
29+
}
30+
i += 1;
31+
}
32+
i
33+
}
34+
1435
#[allow(unused_attributes)]
15-
#[spirv(gl_compute)]
16-
pub fn main_cs() {}
36+
#[spirv(gl_compute(local_size_x = 1))]
37+
pub fn main_cs(
38+
#[spirv(global_invocation_id)] gid: Input<i32>,
39+
#[spirv(storage_buffer)] mut storage: StorageBuffer<u32>,
40+
) {
41+
let gid = gid.load();
42+
let result = collatz_iterations(gid);
43+
storage.store(result as u32)
44+
}

0 commit comments

Comments
 (0)