Skip to content

Commit cb95256

Browse files
authored
Collatz Computation (#623)
Build the compute shader for vulkan1.1 as required
1 parent a5e9fe7 commit cb95256

File tree

8 files changed

+211
-58
lines changed

8 files changed

+211
-58
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/spirv-builder/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ fn invoke_rustc(builder: &SpirvBuilder) -> Result<PathBuf, SpirvBuilderError> {
367367
let mut cargo = Command::new("cargo");
368368
cargo.args(&[
369369
"build",
370+
"--lib",
370371
"--message-format=json-render-diagnostics",
371372
"-Zbuild-std=core",
372373
"-Zbuild-std-features=compiler-builtins-mem",

examples/runners/wgpu/builder/src/main.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ fn build_shader(
1111
) -> Result<(), Box<dyn Error>> {
1212
let builder_dir = &Path::new(env!("CARGO_MANIFEST_DIR"));
1313
let path_to_crate = builder_dir.join(path_to_crate);
14-
let mut builder = SpirvBuilder::new(path_to_crate, "spirv-unknown-vulkan1.0");
14+
let mut builder = SpirvBuilder::new(path_to_crate, "spirv-unknown-vulkan1.1");
1515
for &cap in caps {
1616
builder = builder.capability(cap);
1717
}
@@ -28,7 +28,12 @@ fn build_shader(
2828
fn main() -> Result<(), Box<dyn Error>> {
2929
build_shader("../../../shaders/sky-shader", true, &[])?;
3030
build_shader("../../../shaders/simplest-shader", false, &[])?;
31-
build_shader("../../../shaders/compute-shader", false, &[])?;
31+
// We need the int8 capability for using `Option`
32+
build_shader(
33+
"../../../shaders/compute-shader",
34+
false,
35+
&[Capability::Int8],
36+
)?;
3237
build_shader("../../../shaders/mouse-shader", false, &[])?;
3338
Ok(())
3439
}

examples/runners/wgpu/src/compute.rs

+123-37
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,65 @@
1+
use wgpu::util::DeviceExt;
2+
13
use super::{shader_module, Options};
2-
use core::num::NonZeroU64;
3-
4-
fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
5-
async fn create_device_queue_async() -> (wgpu::Device, wgpu::Queue) {
6-
let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
7-
let adapter = instance
8-
.request_adapter(&wgpu::RequestAdapterOptions {
9-
power_preference: wgpu::PowerPreference::default(),
10-
compatible_surface: None,
11-
})
12-
.await
13-
.expect("Failed to find an appropriate adapter");
14-
15-
adapter
16-
.request_device(
17-
&wgpu::DeviceDescriptor {
18-
label: None,
19-
features: wgpu::Features::empty(),
20-
limits: wgpu::Limits::default(),
21-
},
22-
None,
23-
)
24-
.await
25-
.expect("Failed to create device")
26-
}
4+
use futures::future::join;
5+
use std::{convert::TryInto, future::Future, num::NonZeroU64, time::Duration};
6+
7+
fn block_on<T>(future: impl Future<Output = T>) -> T {
278
cfg_if::cfg_if! {
289
if #[cfg(target_arch = "wasm32")] {
29-
wasm_bindgen_futures::spawn_local(create_device_queue_async())
10+
wasm_bindgen_futures::spawn_local(future)
3011
} else {
31-
futures::executor::block_on(create_device_queue_async())
12+
futures::executor::block_on(future)
3213
}
3314
}
3415
}
3516

3617
pub fn start(options: &Options) {
3718
let shader_binary = shader_module(options.shader);
3819

39-
let (device, queue) = create_device_queue();
20+
block_on(start_internal(options, shader_binary))
21+
}
22+
23+
pub async fn start_internal(
24+
_options: &Options,
25+
shader_binary: wgpu::ShaderModuleDescriptor<'static>,
26+
) {
27+
let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
28+
let adapter = instance
29+
.request_adapter(&wgpu::RequestAdapterOptions {
30+
power_preference: wgpu::PowerPreference::default(),
31+
compatible_surface: None,
32+
})
33+
.await
34+
.expect("Failed to find an appropriate adapter");
4035

36+
let timestamp_period = adapter.get_timestamp_period();
37+
let (device, queue) = adapter
38+
.request_device(
39+
&wgpu::DeviceDescriptor {
40+
label: None,
41+
features: wgpu::Features::TIMESTAMP_QUERY,
42+
limits: wgpu::Limits::default(),
43+
},
44+
None,
45+
)
46+
.await
47+
.expect("Failed to create device");
48+
drop(instance);
49+
drop(adapter);
4150
// Load the shaders from disk
4251
let module = device.create_shader_module(&shader_binary);
4352

53+
let top = 2u32.pow(20);
54+
let src_range = 1..top;
55+
56+
let src = src_range
57+
.clone()
58+
// Not sure which endianness is correct to use here
59+
.map(u32::to_ne_bytes)
60+
.flat_map(core::array::IntoIter::new)
61+
.collect::<Vec<_>>();
62+
4463
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
4564
label: None,
4665
entries: &[
@@ -72,10 +91,26 @@ pub fn start(options: &Options) {
7291
entry_point: "main_cs",
7392
});
7493

75-
let buf = device.create_buffer(&wgpu::BufferDescriptor {
94+
let readback_buffer = device.create_buffer(&wgpu::BufferDescriptor {
7695
label: None,
77-
size: 1,
78-
usage: wgpu::BufferUsage::STORAGE,
96+
size: src.len() as wgpu::BufferAddress,
97+
// Can be read to the CPU, and can be copied from the shader's storage buffer
98+
usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
99+
mapped_at_creation: false,
100+
});
101+
102+
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
103+
label: Some("Collatz Conjecture Input"),
104+
contents: &src,
105+
usage: wgpu::BufferUsage::STORAGE
106+
| wgpu::BufferUsage::COPY_DST
107+
| wgpu::BufferUsage::COPY_SRC,
108+
});
109+
110+
let timestamp_buffer = device.create_buffer(&wgpu::BufferDescriptor {
111+
label: Some("Timestamps buffer"),
112+
size: 16,
113+
usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
79114
mapped_at_creation: false,
80115
});
81116

@@ -84,23 +119,74 @@ pub fn start(options: &Options) {
84119
layout: &bind_group_layout,
85120
entries: &[wgpu::BindGroupEntry {
86121
binding: 0,
87-
resource: wgpu::BindingResource::Buffer {
88-
buffer: &buf,
89-
offset: 0,
90-
size: None,
91-
},
122+
resource: storage_buffer.as_entire_binding(),
92123
}],
93124
});
94125

126+
let queries = device.create_query_set(&wgpu::QuerySetDescriptor {
127+
count: 2,
128+
ty: wgpu::QueryType::Timestamp,
129+
});
130+
95131
let mut encoder =
96132
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
97133

98134
{
99135
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
100136
cpass.set_bind_group(0, &bind_group, &[]);
101137
cpass.set_pipeline(&compute_pipeline);
102-
cpass.dispatch(1, 1, 1);
138+
cpass.write_timestamp(&queries, 0);
139+
cpass.dispatch(src_range.len() as u32 / 64, 1, 1);
140+
cpass.write_timestamp(&queries, 1);
103141
}
104142

143+
encoder.copy_buffer_to_buffer(
144+
&storage_buffer,
145+
0,
146+
&readback_buffer,
147+
0,
148+
src.len() as wgpu::BufferAddress,
149+
);
150+
encoder.resolve_query_set(&queries, 0..2, &timestamp_buffer, 0);
151+
105152
queue.submit(Some(encoder.finish()));
153+
let buffer_slice = readback_buffer.slice(..);
154+
let timestamp_slice = timestamp_buffer.slice(..);
155+
let timestamp_future = timestamp_slice.map_async(wgpu::MapMode::Read);
156+
let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read);
157+
device.poll(wgpu::Maintain::Wait);
158+
159+
if let (Ok(()), Ok(())) = join(buffer_future, timestamp_future).await {
160+
let data = buffer_slice.get_mapped_range();
161+
let timing_data = timestamp_slice.get_mapped_range();
162+
let result = data
163+
.chunks_exact(4)
164+
.map(|b| u32::from_ne_bytes(b.try_into().unwrap()))
165+
.collect::<Vec<_>>();
166+
let timings = timing_data
167+
.chunks_exact(8)
168+
.map(|b| u64::from_ne_bytes(b.try_into().unwrap()))
169+
.collect::<Vec<_>>();
170+
drop(data);
171+
readback_buffer.unmap();
172+
drop(timing_data);
173+
timestamp_buffer.unmap();
174+
let mut max = 0;
175+
for (src, out) in src_range.zip(result.iter().copied()) {
176+
if out == u32::MAX {
177+
println!("{}: overflowed", src);
178+
break;
179+
} else if out > max {
180+
max = out;
181+
// Should produce <https://oeis.org/A006877>
182+
println!("{}: {}", src, out);
183+
}
184+
}
185+
println!(
186+
"Took: {:?}",
187+
Duration::from_nanos(
188+
((timings[1] - timings[0]) as f64 * f64::from(timestamp_period)) as u64
189+
)
190+
);
191+
}
106192
}

examples/runners/wgpu/src/lib.rs

+7-13
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ fn shader_module(shader: RustGPUShader) -> wgpu::ShaderModuleDescriptor<'static>
6161
{
6262
use spirv_builder::{Capability, SpirvBuilder};
6363
use std::borrow::Cow;
64-
use std::path::{Path, PathBuf};
64+
use std::path::PathBuf;
6565
// Hack: spirv_builder builds into a custom directory if running under cargo, to not
6666
// deadlock, and the default target directory if not. However, packages like `proc-macro2`
6767
// have different configurations when being built here vs. when building
@@ -73,22 +73,16 @@ fn shader_module(shader: RustGPUShader) -> wgpu::ShaderModuleDescriptor<'static>
7373
let (crate_name, capabilities): (_, &[Capability]) = match shader {
7474
RustGPUShader::Simplest => ("simplest-shader", &[]),
7575
RustGPUShader::Sky => ("sky-shader", &[]),
76-
RustGPUShader::Compute => ("compute-shader", &[]),
76+
RustGPUShader::Compute => ("compute-shader", &[Capability::Int8]),
7777
RustGPUShader::Mouse => ("mouse-shader", &[]),
7878
};
7979
let manifest_dir = env!("CARGO_MANIFEST_DIR");
80-
let crate_path = [
81-
Path::new(manifest_dir),
82-
Path::new(".."),
83-
Path::new(".."),
84-
Path::new("shaders"),
85-
Path::new(crate_name),
86-
]
87-
.iter()
88-
.copied()
89-
.collect::<PathBuf>();
80+
let crate_path = [manifest_dir, "..", "..", "shaders", crate_name]
81+
.iter()
82+
.copied()
83+
.collect::<PathBuf>();
9084
let mut builder =
91-
SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.0").print_metadata(false);
85+
SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.1").print_metadata(false);
9286
for &cap in capabilities {
9387
builder = builder.capability(cap);
9488
}

examples/shaders/compute-shader/Cargo.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ license = "MIT OR Apache-2.0"
77
publish = false
88

99
[lib]
10-
crate-type = ["dylib"]
10+
crate-type = ["dylib", "lib"]
1111

1212
[dependencies]
1313
spirv-std = { path = "../../../crates/spirv-std", features = ["glam"] }
14+
15+
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
16+
rayon = "1.5"
+36-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,48 @@
11
#![cfg_attr(
22
target_arch = "spirv",
3-
no_std,
43
feature(register_attr),
5-
register_attr(spirv)
4+
register_attr(spirv),
5+
no_std
66
)]
77
// HACK(eddyb) can't easily see warnings otherwise from `spirv-builder` builds.
88
#![deny(warnings)]
99

1010
extern crate spirv_std;
1111

12+
use glam::UVec3;
13+
use spirv_std::glam;
1214
#[cfg(not(target_arch = "spirv"))]
1315
use spirv_std::macros::spirv;
1416

15-
// LocalSize/numthreads of (x = 32, y = 1, z = 1)
16-
#[spirv(compute(threads(32)))]
17-
pub fn main_cs() {}
17+
// Adapted from the wgpu hello-compute example
18+
19+
pub fn collatz(mut n: u32) -> Option<u32> {
20+
let mut i = 0;
21+
if n == 0 {
22+
return None;
23+
}
24+
while n != 1 {
25+
n = if n % 2 == 0 {
26+
n / 2
27+
} else {
28+
// Overflow? (i.e. 3*n + 1 > 0xffff_ffff)
29+
if n >= 0x5555_5555 {
30+
return None;
31+
}
32+
// TODO: Use this instead when/if checked add/mul can work: n.checked_mul(3)?.checked_add(1)?
33+
3 * n + 1
34+
};
35+
i += 1;
36+
}
37+
Some(i)
38+
}
39+
40+
// LocalSize/numthreads of (x = 64, y = 1, z = 1)
41+
#[spirv(compute(threads(64)))]
42+
pub fn main_cs(
43+
#[spirv(global_invocation_id)] id: UVec3,
44+
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] prime_indices: &mut [u32],
45+
) {
46+
let index = id.x as usize;
47+
prime_indices[index] = collatz(prime_indices[index]).unwrap_or(u32::MAX);
48+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
use std::time::Instant;
2+
3+
use compute_shader::collatz;
4+
use rayon::prelude::*;
5+
6+
fn main() {
7+
let top = 2u32.pow(20);
8+
let src_range = 1..top;
9+
let start = Instant::now();
10+
let result = src_range
11+
.clone()
12+
.into_par_iter()
13+
.map(collatz)
14+
.collect::<Vec<_>>();
15+
let took = start.elapsed();
16+
let mut max = 0;
17+
for (src, out) in src_range.zip(result.iter().copied()) {
18+
match out {
19+
Some(out) if out > max => {
20+
max = out;
21+
// Should produce <https://oeis.org/A006877>
22+
println!("{}: {}", src, out);
23+
}
24+
Some(_) => (),
25+
None => {
26+
println!("{}: overflowed", src);
27+
break;
28+
}
29+
}
30+
}
31+
println!("Took: {:?}", took);
32+
}

0 commit comments

Comments
 (0)