Skip to content

Flesh out compute example #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ cargo_test examples/runners/wgpu
cargo_test_no_features examples/runners/cpu
cargo_test_no_features examples/shaders/sky-shader
cargo_test_no_features examples/shaders/simplest-shader
cargo_test_no_features examples/shaders/compute-shader
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions examples/runners/wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ wgpu = "0.6.0"
winit = { version = "0.24", features = ["web-sys"] }
clap = "3.0.0-beta.2"
strum = { version = "0.19", default_features = false, features = ["derive"] }
bytemuck = { version = "1.4", features = ["derive"] }

[build-dependencies]
spirv-builder = { path = "../../../crates/spirv-builder", default-features = false }
Expand Down
140 changes: 98 additions & 42 deletions examples/runners/wgpu/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use super::{shader_module, Options};
use core::num::NonZeroU64;
use std::convert::TryInto;
use wgpu::util::DeviceExt;

fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
async fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
async fn create_device_queue_async() -> (wgpu::Device, wgpu::Queue) {
let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::default(),
power_preference: wgpu::PowerPreference::Default,
compatible_surface: None,
})
.await
Expand All @@ -28,33 +29,57 @@ fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
if #[cfg(target_arch = "wasm32")] {
wasm_bindgen_futures::spawn_local(create_device_queue_async())
} else {
futures::executor::block_on(create_device_queue_async())
create_device_queue_async().await
}
}
}

pub fn start(options: &Options) {
let (device, queue) = create_device_queue();
pub async fn start(options: &Options, numbers: Vec<u32>) -> Vec<u32> {
wgpu_subscriber::initialize_default_subscriber(None);

// Load the shaders from disk
let module = device.create_shader_module(shader_module(options.shader));
let slice_size = numbers.len() * std::mem::size_of::<u32>();
let size = slice_size as wgpu::BufferAddress;

let (device, queue) = create_device_queue().await;

let cs_module = device.create_shader_module(shader_module(options.shader));

let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size,
usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
mapped_at_creation: false,
});

let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Storage Buffer"),
contents: bytemuck::cast_slice(&numbers),
usage: wgpu::BufferUsage::STORAGE
| wgpu::BufferUsage::COPY_DST
| wgpu::BufferUsage::COPY_SRC,
});

let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[
// XXX - some graphics cards do not support empty bind layout groups, so
// create a dummy entry.
wgpu::BindGroupLayoutEntry {
binding: 0,
count: None,
visibility: wgpu::ShaderStage::COMPUTE,
ty: wgpu::BindingType::StorageBuffer {
dynamic: false,
min_binding_size: Some(NonZeroU64::new(1).unwrap()),
readonly: false,
},
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStage::COMPUTE,
ty: wgpu::BindingType::StorageBuffer {
dynamic: false,
readonly: false,
min_binding_size: wgpu::BufferSize::new(4),
},
],
count: None,
}],
});

let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_group_layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::Buffer(storage_buffer.slice(..)),
}],
});

let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
Expand All @@ -67,36 +92,67 @@ pub fn start(options: &Options) {
label: None,
layout: Some(&pipeline_layout),
compute_stage: wgpu::ProgrammableStageDescriptor {
module: &module,
module: &cs_module,
entry_point: "main_cs",
},
});

let buf = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 1,
usage: wgpu::BufferUsage::STORAGE,
mapped_at_creation: false,
});

let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_group_layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::Buffer(buf.slice(..)),
}],
});

let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });

{
let mut cpass = encoder.begin_compute_pass();
cpass.set_bind_group(0, &bind_group, &[]);
cpass.set_pipeline(&compute_pipeline);
cpass.dispatch(1, 1, 1);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.insert_debug_marker("compute collatz iterations");
cpass.dispatch(numbers.len() as u32, 1, 1);
}

encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, size);
queue.submit(Some(encoder.finish()));
// Note that we're not calling `.await` here.

let buffer_slice = staging_buffer.slice(..);
let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read);

// Poll the device in a blocking manner so that our future resolves.
// In an actual application, `device.poll(...)` should
// be called in an event loop or on another thread.
device.poll(wgpu::Maintain::Wait);

if let Ok(()) = buffer_future.await {
let data = buffer_slice.get_mapped_range();
let result = data
.chunks_exact(4)
.map(|b| u32::from_ne_bytes(b.try_into().unwrap()))
.collect();

// With the current interface, we have to make sure all mapped views are
// dropped before we unmap the buffer.
drop(data);
staging_buffer.unmap();

println!("Times: {:?}", result);
#[cfg(target_arch = "wasm32")]
log::info!("Times: {:?}", result);
result
} else {
panic!("failed to run compute on gpu!")
}
}

// #[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_compute_1() {
let input = vec![1, 2, 3, 4];
futures::executor::block_on(assert_start(input, vec![0, 1, 7, 2]));
}

async fn assert_start(input: Vec<u32>, expected: Vec<u32>) {
let options: Options = Options {
shader: crate::RustGPUShader::Compute,
};
assert_eq!(start(&options, input).await, expected);
}
}
14 changes: 7 additions & 7 deletions examples/runners/wgpu/src/graphics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ async fn run(
});
}

pub fn start(options: &Options) {
pub async fn start(options: &Options) {
let event_loop = EventLoop::new();
let window = winit::window::WindowBuilder::new()
.with_title("Rust GPU - wgpu")
Expand All @@ -282,16 +282,16 @@ pub fn start(options: &Options) {
body.append_child(&web_sys::Element::from(window.canvas()))
.ok()
})
.expect("couldn't append canvas to document body");
.expect("couldn't append canvas to document body");
// Temporarily avoid srgb formats for the swapchain on the web
wasm_bindgen_futures::spawn_local(run(
event_loop,
window,
wgpu::TextureFormat::Bgra8Unorm,
event_loop,
window,
wgpu::TextureFormat::Bgra8Unorm,
));
} else {
wgpu_subscriber::initialize_default_subscriber(None);
futures::executor::block_on(run(
run(
options,
event_loop,
window,
Expand All @@ -300,7 +300,7 @@ pub fn start(options: &Options) {
} else {
wgpu::TextureFormat::Bgra8UnormSrgb
},
));
).await;
}
}
}
4 changes: 2 additions & 2 deletions examples/runners/wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ pub fn main() {
let options: Options = Options::parse();

if is_compute_shader(options.shader) {
compute::start(&options)
futures::executor::block_on(compute::start(&options, vec![1, 2, 3, 4]));
} else {
graphics::start(&options);
futures::executor::block_on(graphics::start(&options));
}
}
2 changes: 1 addition & 1 deletion examples/runners/wgpu/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
fn main() {
example_runner_wgpu::main()
example_runner_wgpu::main();
}
32 changes: 30 additions & 2 deletions examples/shaders/compute-shader/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,34 @@ extern crate spirv_std;
#[macro_use]
pub extern crate spirv_std_macros;

use spirv_std::storage_class::{Input, StorageBuffer};

// The Collatz Conjecture states that for any integer n:
// If n is even, n = n/2
// If n is odd, n = 3n+1
// And repeat this process for each new n, you will always eventually reach 1.
// Though the conjecture has not been proven, no counterexample has ever been found.
// This function returns how many times this recurrence needs to be applied to reach 1.
pub fn collatz_iterations(mut n: i32) -> i32 {
let mut i = 0;
while n != 1 {
if n.rem_euclid(2) == 0 {
n /= 2;
} else {
n = 3 * n + 1;
}
i += 1;
}
i
}

#[allow(unused_attributes)]
#[spirv(gl_compute)]
pub fn main_cs() {}
#[spirv(gl_compute(local_size_x = 1))]
pub fn main_cs(
#[spirv(global_invocation_id)] gid: Input<i32>,
#[spirv(storage_buffer)] mut storage: StorageBuffer<u32>,
) {
let gid = gid.load();
let result = collatz_iterations(gid);
storage.store(result as u32)
}