Skip to content

Compute a collatz sequence in the compute example #623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/spirv-builder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ fn invoke_rustc(builder: &SpirvBuilder) -> Result<PathBuf, SpirvBuilderError> {
let mut cargo = Command::new("cargo");
cargo.args(&[
"build",
"--lib",
"--message-format=json-render-diagnostics",
"-Zbuild-std=core",
"-Zbuild-std-features=compiler-builtins-mem",
Expand Down
9 changes: 7 additions & 2 deletions examples/runners/wgpu/builder/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn build_shader(
) -> Result<(), Box<dyn Error>> {
let builder_dir = &Path::new(env!("CARGO_MANIFEST_DIR"));
let path_to_crate = builder_dir.join(path_to_crate);
let mut builder = SpirvBuilder::new(path_to_crate, "spirv-unknown-vulkan1.0");
let mut builder = SpirvBuilder::new(path_to_crate, "spirv-unknown-vulkan1.1");
for &cap in caps {
builder = builder.capability(cap);
}
Expand All @@ -28,7 +28,12 @@ fn build_shader(
fn main() -> Result<(), Box<dyn Error>> {
build_shader("../../../shaders/sky-shader", true, &[])?;
build_shader("../../../shaders/simplest-shader", false, &[])?;
build_shader("../../../shaders/compute-shader", false, &[])?;
// We need the int8 capability for using `Option`
build_shader(
"../../../shaders/compute-shader",
false,
&[Capability::Int8],
)?;
build_shader("../../../shaders/mouse-shader", false, &[])?;
Ok(())
}
160 changes: 123 additions & 37 deletions examples/runners/wgpu/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,65 @@
use wgpu::util::DeviceExt;

use super::{shader_module, Options};
use core::num::NonZeroU64;

fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
async fn create_device_queue_async() -> (wgpu::Device, wgpu::Queue) {
let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::default(),
compatible_surface: None,
})
.await
.expect("Failed to find an appropriate adapter");

adapter
.request_device(
&wgpu::DeviceDescriptor {
label: None,
features: wgpu::Features::empty(),
limits: wgpu::Limits::default(),
},
None,
)
.await
.expect("Failed to create device")
}
use futures::future::join;
use std::{convert::TryInto, future::Future, num::NonZeroU64, time::Duration};

fn block_on<T>(future: impl Future<Output = T>) -> T {
cfg_if::cfg_if! {
if #[cfg(target_arch = "wasm32")] {
wasm_bindgen_futures::spawn_local(create_device_queue_async())
wasm_bindgen_futures::spawn_local(future)
} else {
futures::executor::block_on(create_device_queue_async())
futures::executor::block_on(future)
}
}
}

pub fn start(options: &Options) {
let shader_binary = shader_module(options.shader);

let (device, queue) = create_device_queue();
block_on(start_internal(options, shader_binary))
}

pub async fn start_internal(
_options: &Options,
shader_binary: wgpu::ShaderModuleDescriptor<'static>,
) {
let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::default(),
compatible_surface: None,
})
.await
.expect("Failed to find an appropriate adapter");

let timestamp_period = adapter.get_timestamp_period();
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: None,
features: wgpu::Features::TIMESTAMP_QUERY,
limits: wgpu::Limits::default(),
},
None,
)
.await
.expect("Failed to create device");
drop(instance);
drop(adapter);
// Load the shaders from disk
let module = device.create_shader_module(&shader_binary);

let top = 2u32.pow(20);
let src_range = 1..top;

let src = src_range
.clone()
// Not sure which endianness is correct to use here
.map(u32::to_ne_bytes)
.flat_map(core::array::IntoIter::new)
.collect::<Vec<_>>();

let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[
Expand Down Expand Up @@ -72,10 +91,26 @@ pub fn start(options: &Options) {
entry_point: "main_cs",
});

let buf = device.create_buffer(&wgpu::BufferDescriptor {
let readback_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 1,
usage: wgpu::BufferUsage::STORAGE,
size: src.len() as wgpu::BufferAddress,
// Can be read to the CPU, and can be copied from the shader's storage buffer
usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
mapped_at_creation: false,
});

let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Collatz Conjecture Input"),
contents: &src,
usage: wgpu::BufferUsage::STORAGE
| wgpu::BufferUsage::COPY_DST
| wgpu::BufferUsage::COPY_SRC,
});

let timestamp_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Timestamps buffer"),
size: 16,
usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
mapped_at_creation: false,
});

Expand All @@ -84,23 +119,74 @@ pub fn start(options: &Options) {
layout: &bind_group_layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::Buffer {
buffer: &buf,
offset: 0,
size: None,
},
resource: storage_buffer.as_entire_binding(),
}],
});

let queries = device.create_query_set(&wgpu::QuerySetDescriptor {
count: 2,
ty: wgpu::QueryType::Timestamp,
});

let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });

{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
cpass.set_bind_group(0, &bind_group, &[]);
cpass.set_pipeline(&compute_pipeline);
cpass.dispatch(1, 1, 1);
cpass.write_timestamp(&queries, 0);
cpass.dispatch(src_range.len() as u32 / 64, 1, 1);
cpass.write_timestamp(&queries, 1);
}

encoder.copy_buffer_to_buffer(
&storage_buffer,
0,
&readback_buffer,
0,
src.len() as wgpu::BufferAddress,
);
encoder.resolve_query_set(&queries, 0..2, &timestamp_buffer, 0);

queue.submit(Some(encoder.finish()));
let buffer_slice = readback_buffer.slice(..);
let timestamp_slice = timestamp_buffer.slice(..);
let timestamp_future = timestamp_slice.map_async(wgpu::MapMode::Read);
let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read);
device.poll(wgpu::Maintain::Wait);

if let (Ok(()), Ok(())) = join(buffer_future, timestamp_future).await {
let data = buffer_slice.get_mapped_range();
let timing_data = timestamp_slice.get_mapped_range();
let result = data
.chunks_exact(4)
.map(|b| u32::from_ne_bytes(b.try_into().unwrap()))
.collect::<Vec<_>>();
let timings = timing_data
.chunks_exact(8)
.map(|b| u64::from_ne_bytes(b.try_into().unwrap()))
.collect::<Vec<_>>();
drop(data);
readback_buffer.unmap();
drop(timing_data);
timestamp_buffer.unmap();
let mut max = 0;
for (src, out) in src_range.zip(result.iter().copied()) {
if out == u32::MAX {
println!("{}: overflowed", src);
break;
} else if out > max {
max = out;
// Should produce <https://oeis.org/A006877>
println!("{}: {}", src, out);
}
}
println!(
"Took: {:?}",
Duration::from_nanos(
((timings[1] - timings[0]) as f64 * f64::from(timestamp_period)) as u64
)
);
}
}
20 changes: 7 additions & 13 deletions examples/runners/wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn shader_module(shader: RustGPUShader) -> wgpu::ShaderModuleDescriptor<'static>
{
use spirv_builder::{Capability, SpirvBuilder};
use std::borrow::Cow;
use std::path::{Path, PathBuf};
use std::path::PathBuf;
// Hack: spirv_builder builds into a custom directory if running under cargo, to not
// deadlock, and the default target directory if not. However, packages like `proc-macro2`
// have different configurations when being built here vs. when building
Expand All @@ -73,22 +73,16 @@ fn shader_module(shader: RustGPUShader) -> wgpu::ShaderModuleDescriptor<'static>
let (crate_name, capabilities): (_, &[Capability]) = match shader {
RustGPUShader::Simplest => ("simplest-shader", &[]),
RustGPUShader::Sky => ("sky-shader", &[]),
RustGPUShader::Compute => ("compute-shader", &[]),
RustGPUShader::Compute => ("compute-shader", &[Capability::Int8]),
RustGPUShader::Mouse => ("mouse-shader", &[]),
};
let manifest_dir = env!("CARGO_MANIFEST_DIR");
let crate_path = [
Path::new(manifest_dir),
Path::new(".."),
Path::new(".."),
Path::new("shaders"),
Path::new(crate_name),
]
.iter()
.copied()
.collect::<PathBuf>();
let crate_path = [manifest_dir, "..", "..", "shaders", crate_name]
.iter()
.copied()
.collect::<PathBuf>();
let mut builder =
SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.0").print_metadata(false);
SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.1").print_metadata(false);
for &cap in capabilities {
builder = builder.capability(cap);
}
Expand Down
5 changes: 4 additions & 1 deletion examples/shaders/compute-shader/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ license = "MIT OR Apache-2.0"
publish = false

[lib]
crate-type = ["dylib"]
crate-type = ["dylib", "lib"]

[dependencies]
spirv-std = { path = "../../../crates/spirv-std", features = ["glam"] }

[target.'cfg(not(target_arch = "spirv"))'.dependencies]
rayon = "1.5"
41 changes: 36 additions & 5 deletions examples/shaders/compute-shader/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,48 @@
#![cfg_attr(
target_arch = "spirv",
no_std,
feature(register_attr),
register_attr(spirv)
register_attr(spirv),
no_std
)]
// HACK(eddyb) can't easily see warnings otherwise from `spirv-builder` builds.
#![deny(warnings)]

extern crate spirv_std;

use glam::UVec3;
use spirv_std::glam;
#[cfg(not(target_arch = "spirv"))]
use spirv_std::macros::spirv;

// LocalSize/numthreads of (x = 32, y = 1, z = 1)
#[spirv(compute(threads(32)))]
pub fn main_cs() {}
// Adapted from the wgpu hello-compute example

pub fn collatz(mut n: u32) -> Option<u32> {
let mut i = 0;
if n == 0 {
return None;
}
while n != 1 {
n = if n % 2 == 0 {
n / 2
} else {
// Overflow? (i.e. 3*n + 1 > 0xffff_ffff)
if n >= 0x5555_5555 {
return None;
}
// TODO: Use this instead when/if checked add/mul can work: n.checked_mul(3)?.checked_add(1)?
3 * n + 1
};
i += 1;
}
Some(i)
}

// LocalSize/numthreads of (x = 64, y = 1, z = 1)
#[spirv(compute(threads(64)))]
pub fn main_cs(
#[spirv(global_invocation_id)] id: UVec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] prime_indices: &mut [u32],
) {
let index = id.x as usize;
prime_indices[index] = collatz(prime_indices[index]).unwrap_or(u32::MAX);
}
32 changes: 32 additions & 0 deletions examples/shaders/compute-shader/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use std::time::Instant;

use compute_shader::collatz;
use rayon::prelude::*;

fn main() {
let top = 2u32.pow(20);
let src_range = 1..top;
let start = Instant::now();
let result = src_range
.clone()
.into_par_iter()
.map(collatz)
.collect::<Vec<_>>();
let took = start.elapsed();
let mut max = 0;
for (src, out) in src_range.zip(result.iter().copied()) {
match out {
Some(out) if out > max => {
max = out;
// Should produce <https://oeis.org/A006877>
println!("{}: {}", src, out);
}
Some(_) => (),
None => {
println!("{}: overflowed", src);
break;
}
}
}
println!("Took: {:?}", took);
}