diff --git a/Cargo.lock b/Cargo.lock index 356f02768c..752b4a8cbf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -401,6 +401,7 @@ dependencies = [ name = "compute-shader" version = "0.4.0-alpha.8" dependencies = [ + "rayon", "spirv-std", ] diff --git a/crates/spirv-builder/src/lib.rs b/crates/spirv-builder/src/lib.rs index 09f47b3386..47471c8812 100644 --- a/crates/spirv-builder/src/lib.rs +++ b/crates/spirv-builder/src/lib.rs @@ -367,6 +367,7 @@ fn invoke_rustc(builder: &SpirvBuilder) -> Result { let mut cargo = Command::new("cargo"); cargo.args(&[ "build", + "--lib", "--message-format=json-render-diagnostics", "-Zbuild-std=core", "-Zbuild-std-features=compiler-builtins-mem", diff --git a/examples/runners/wgpu/builder/src/main.rs b/examples/runners/wgpu/builder/src/main.rs index ce85b246e6..d85cfa5fd3 100644 --- a/examples/runners/wgpu/builder/src/main.rs +++ b/examples/runners/wgpu/builder/src/main.rs @@ -11,7 +11,7 @@ fn build_shader( ) -> Result<(), Box> { let builder_dir = &Path::new(env!("CARGO_MANIFEST_DIR")); let path_to_crate = builder_dir.join(path_to_crate); - let mut builder = SpirvBuilder::new(path_to_crate, "spirv-unknown-vulkan1.0"); + let mut builder = SpirvBuilder::new(path_to_crate, "spirv-unknown-vulkan1.1"); for &cap in caps { builder = builder.capability(cap); } @@ -28,7 +28,12 @@ fn build_shader( fn main() -> Result<(), Box> { build_shader("../../../shaders/sky-shader", true, &[])?; build_shader("../../../shaders/simplest-shader", false, &[])?; - build_shader("../../../shaders/compute-shader", false, &[])?; + // We need the int8 capability for using `Option` + build_shader( + "../../../shaders/compute-shader", + false, + &[Capability::Int8], + )?; build_shader("../../../shaders/mouse-shader", false, &[])?; Ok(()) } diff --git a/examples/runners/wgpu/src/compute.rs b/examples/runners/wgpu/src/compute.rs index 947d81ef25..64fd90ec72 100644 --- a/examples/runners/wgpu/src/compute.rs +++ b/examples/runners/wgpu/src/compute.rs @@ -1,34 +1,15 @@ +use wgpu::util::DeviceExt; + use super::{shader_module, Options}; -use core::num::NonZeroU64; - -fn create_device_queue() -> (wgpu::Device, wgpu::Queue) { - async fn create_device_queue_async() -> (wgpu::Device, wgpu::Queue) { - let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY); - let adapter = instance - .request_adapter(&wgpu::RequestAdapterOptions { - power_preference: wgpu::PowerPreference::default(), - compatible_surface: None, - }) - .await - .expect("Failed to find an appropriate adapter"); - - adapter - .request_device( - &wgpu::DeviceDescriptor { - label: None, - features: wgpu::Features::empty(), - limits: wgpu::Limits::default(), - }, - None, - ) - .await - .expect("Failed to create device") - } +use futures::future::join; +use std::{convert::TryInto, future::Future, num::NonZeroU64, time::Duration}; + +fn block_on(future: impl Future) -> T { cfg_if::cfg_if! { if #[cfg(target_arch = "wasm32")] { - wasm_bindgen_futures::spawn_local(create_device_queue_async()) + wasm_bindgen_futures::spawn_local(future) } else { - futures::executor::block_on(create_device_queue_async()) + futures::executor::block_on(future) } } } @@ -36,11 +17,49 @@ fn create_device_queue() -> (wgpu::Device, wgpu::Queue) { pub fn start(options: &Options) { let shader_binary = shader_module(options.shader); - let (device, queue) = create_device_queue(); + block_on(start_internal(options, shader_binary)) +} + +pub async fn start_internal( + _options: &Options, + shader_binary: wgpu::ShaderModuleDescriptor<'static>, +) { + let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::default(), + compatible_surface: None, + }) + .await + .expect("Failed to find an appropriate adapter"); + let timestamp_period = adapter.get_timestamp_period(); + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: None, + features: wgpu::Features::TIMESTAMP_QUERY, + limits: wgpu::Limits::default(), + }, + None, + ) + .await + .expect("Failed to create device"); + drop(instance); + drop(adapter); // Load the shaders from disk let module = device.create_shader_module(&shader_binary); + let top = 2u32.pow(20); + let src_range = 1..top; + + let src = src_range + .clone() + // Not sure which endianness is correct to use here + .map(u32::to_ne_bytes) + .flat_map(core::array::IntoIter::new) + .collect::>(); + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { label: None, entries: &[ @@ -72,10 +91,26 @@ pub fn start(options: &Options) { entry_point: "main_cs", }); - let buf = device.create_buffer(&wgpu::BufferDescriptor { + let readback_buffer = device.create_buffer(&wgpu::BufferDescriptor { label: None, - size: 1, - usage: wgpu::BufferUsage::STORAGE, + size: src.len() as wgpu::BufferAddress, + // Can be read to the CPU, and can be copied from the shader's storage buffer + usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST, + mapped_at_creation: false, + }); + + let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("Collatz Conjecture Input"), + contents: &src, + usage: wgpu::BufferUsage::STORAGE + | wgpu::BufferUsage::COPY_DST + | wgpu::BufferUsage::COPY_SRC, + }); + + let timestamp_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Timestamps buffer"), + size: 16, + usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST, mapped_at_creation: false, }); @@ -84,14 +119,15 @@ pub fn start(options: &Options) { layout: &bind_group_layout, entries: &[wgpu::BindGroupEntry { binding: 0, - resource: wgpu::BindingResource::Buffer { - buffer: &buf, - offset: 0, - size: None, - }, + resource: storage_buffer.as_entire_binding(), }], }); + let queries = device.create_query_set(&wgpu::QuerySetDescriptor { + count: 2, + ty: wgpu::QueryType::Timestamp, + }); + let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); @@ -99,8 +135,58 @@ pub fn start(options: &Options) { let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None }); cpass.set_bind_group(0, &bind_group, &[]); cpass.set_pipeline(&compute_pipeline); - cpass.dispatch(1, 1, 1); + cpass.write_timestamp(&queries, 0); + cpass.dispatch(src_range.len() as u32 / 64, 1, 1); + cpass.write_timestamp(&queries, 1); } + encoder.copy_buffer_to_buffer( + &storage_buffer, + 0, + &readback_buffer, + 0, + src.len() as wgpu::BufferAddress, + ); + encoder.resolve_query_set(&queries, 0..2, ×tamp_buffer, 0); + queue.submit(Some(encoder.finish())); + let buffer_slice = readback_buffer.slice(..); + let timestamp_slice = timestamp_buffer.slice(..); + let timestamp_future = timestamp_slice.map_async(wgpu::MapMode::Read); + let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read); + device.poll(wgpu::Maintain::Wait); + + if let (Ok(()), Ok(())) = join(buffer_future, timestamp_future).await { + let data = buffer_slice.get_mapped_range(); + let timing_data = timestamp_slice.get_mapped_range(); + let result = data + .chunks_exact(4) + .map(|b| u32::from_ne_bytes(b.try_into().unwrap())) + .collect::>(); + let timings = timing_data + .chunks_exact(8) + .map(|b| u64::from_ne_bytes(b.try_into().unwrap())) + .collect::>(); + drop(data); + readback_buffer.unmap(); + drop(timing_data); + timestamp_buffer.unmap(); + let mut max = 0; + for (src, out) in src_range.zip(result.iter().copied()) { + if out == u32::MAX { + println!("{}: overflowed", src); + break; + } else if out > max { + max = out; + // Should produce + println!("{}: {}", src, out); + } + } + println!( + "Took: {:?}", + Duration::from_nanos( + ((timings[1] - timings[0]) as f64 * f64::from(timestamp_period)) as u64 + ) + ); + } } diff --git a/examples/runners/wgpu/src/lib.rs b/examples/runners/wgpu/src/lib.rs index e1d3c390ad..846162c0d3 100644 --- a/examples/runners/wgpu/src/lib.rs +++ b/examples/runners/wgpu/src/lib.rs @@ -61,7 +61,7 @@ fn shader_module(shader: RustGPUShader) -> wgpu::ShaderModuleDescriptor<'static> { use spirv_builder::{Capability, SpirvBuilder}; use std::borrow::Cow; - use std::path::{Path, PathBuf}; + use std::path::PathBuf; // Hack: spirv_builder builds into a custom directory if running under cargo, to not // deadlock, and the default target directory if not. However, packages like `proc-macro2` // have different configurations when being built here vs. when building @@ -73,22 +73,16 @@ fn shader_module(shader: RustGPUShader) -> wgpu::ShaderModuleDescriptor<'static> let (crate_name, capabilities): (_, &[Capability]) = match shader { RustGPUShader::Simplest => ("simplest-shader", &[]), RustGPUShader::Sky => ("sky-shader", &[]), - RustGPUShader::Compute => ("compute-shader", &[]), + RustGPUShader::Compute => ("compute-shader", &[Capability::Int8]), RustGPUShader::Mouse => ("mouse-shader", &[]), }; let manifest_dir = env!("CARGO_MANIFEST_DIR"); - let crate_path = [ - Path::new(manifest_dir), - Path::new(".."), - Path::new(".."), - Path::new("shaders"), - Path::new(crate_name), - ] - .iter() - .copied() - .collect::(); + let crate_path = [manifest_dir, "..", "..", "shaders", crate_name] + .iter() + .copied() + .collect::(); let mut builder = - SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.0").print_metadata(false); + SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.1").print_metadata(false); for &cap in capabilities { builder = builder.capability(cap); } diff --git a/examples/shaders/compute-shader/Cargo.toml b/examples/shaders/compute-shader/Cargo.toml index a798d10c3c..7546cd0b9c 100644 --- a/examples/shaders/compute-shader/Cargo.toml +++ b/examples/shaders/compute-shader/Cargo.toml @@ -7,7 +7,10 @@ license = "MIT OR Apache-2.0" publish = false [lib] -crate-type = ["dylib"] +crate-type = ["dylib", "lib"] [dependencies] spirv-std = { path = "../../../crates/spirv-std", features = ["glam"] } + +[target.'cfg(not(target_arch = "spirv"))'.dependencies] +rayon = "1.5" diff --git a/examples/shaders/compute-shader/src/lib.rs b/examples/shaders/compute-shader/src/lib.rs index 84cfb2366e..602c3bda8c 100644 --- a/examples/shaders/compute-shader/src/lib.rs +++ b/examples/shaders/compute-shader/src/lib.rs @@ -1,17 +1,48 @@ #![cfg_attr( target_arch = "spirv", - no_std, feature(register_attr), - register_attr(spirv) + register_attr(spirv), + no_std )] // HACK(eddyb) can't easily see warnings otherwise from `spirv-builder` builds. #![deny(warnings)] extern crate spirv_std; +use glam::UVec3; +use spirv_std::glam; #[cfg(not(target_arch = "spirv"))] use spirv_std::macros::spirv; -// LocalSize/numthreads of (x = 32, y = 1, z = 1) -#[spirv(compute(threads(32)))] -pub fn main_cs() {} +// Adapted from the wgpu hello-compute example + +pub fn collatz(mut n: u32) -> Option { + let mut i = 0; + if n == 0 { + return None; + } + while n != 1 { + n = if n % 2 == 0 { + n / 2 + } else { + // Overflow? (i.e. 3*n + 1 > 0xffff_ffff) + if n >= 0x5555_5555 { + return None; + } + // TODO: Use this instead when/if checked add/mul can work: n.checked_mul(3)?.checked_add(1)? + 3 * n + 1 + }; + i += 1; + } + Some(i) +} + +// LocalSize/numthreads of (x = 64, y = 1, z = 1) +#[spirv(compute(threads(64)))] +pub fn main_cs( + #[spirv(global_invocation_id)] id: UVec3, + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] prime_indices: &mut [u32], +) { + let index = id.x as usize; + prime_indices[index] = collatz(prime_indices[index]).unwrap_or(u32::MAX); +} diff --git a/examples/shaders/compute-shader/src/main.rs b/examples/shaders/compute-shader/src/main.rs new file mode 100644 index 0000000000..ddabff5345 --- /dev/null +++ b/examples/shaders/compute-shader/src/main.rs @@ -0,0 +1,32 @@ +use std::time::Instant; + +use compute_shader::collatz; +use rayon::prelude::*; + +fn main() { + let top = 2u32.pow(20); + let src_range = 1..top; + let start = Instant::now(); + let result = src_range + .clone() + .into_par_iter() + .map(collatz) + .collect::>(); + let took = start.elapsed(); + let mut max = 0; + for (src, out) in src_range.zip(result.iter().copied()) { + match out { + Some(out) if out > max => { + max = out; + // Should produce + println!("{}: {}", src, out); + } + Some(_) => (), + None => { + println!("{}: overflowed", src); + break; + } + } + } + println!("Took: {:?}", took); +}