diff --git a/CHANGELOG.md b/CHANGELOG.md index 18fcf8e4d2..ffe2021988 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,10 +29,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -### 🚨BREAKING🚨 -- Signed for loops like `for _ in 0..4i32 {}` no longer compile. We recommend switching to unsigned for loops and casting back to signed integers in the meanwhile. - ### Changed 🛠 +- [PR#17](https://github.com/Rust-GPU/rust-gpu/pull/17) refactor ByteAddressableBuffer to allow reading from read-only buffers - [PR#14](https://github.com/Rust-GPU/rust-gpu/pull/14) add subgroup intrinsics matching glsl's [`GL_KHR_shader_subgroup`](https://github.com/KhronosGroup/GLSL/blob/main/extensions/khr/GL_KHR_shader_subgroup.txt) - [PR#13](https://github.com/Rust-GPU/rust-gpu/pull/13) allow cargo features to be passed to the shader crate - [PR#12](https://github.com/rust-gpu/rust-gpu/pull/12) updated toolchain to `nightly-2024-04-24` diff --git a/crates/spirv-std/src/byte_addressable_buffer.rs b/crates/spirv-std/src/byte_addressable_buffer.rs index f22cfe7ff9..0640e54bda 100644 --- a/crates/spirv-std/src/byte_addressable_buffer.rs +++ b/crates/spirv-std/src/byte_addressable_buffer.rs @@ -43,74 +43,120 @@ unsafe fn buffer_store_intrinsic( .write(value); } -/// `ByteAddressableBuffer` is an untyped blob of data, allowing loads and stores of arbitrary -/// basic data types at arbitrary indices. However, all data must be aligned to size 4, each -/// element within the data (e.g. struct fields) must have a size and alignment of a multiple of 4, -/// and the `byte_index` passed to load and store must be a multiple of 4 (`byte_index` will be -/// rounded down to the nearest multiple of 4). So, it's not technically a *byte* addressable -/// buffer, but rather a *word* buffer, but this naming and behavior was inherited from HLSL (where -/// it's UB to pass in an index not a multiple of 4). +/// `ByteAddressableBuffer` is a view to an untyped blob of data, allowing +/// loads and stores of arbitrary basic data types at arbitrary indices. +/// +/// # Alignment +/// All data must be aligned to size 4, each element within the data (e.g. +/// struct fields) must have a size and alignment of a multiple of 4, and the +/// `byte_index` passed to load and store must be a multiple of 4. Technically +/// it is not a *byte* addressable buffer, but rather a *word* buffer, but this +/// naming and behavior was inherited from HLSL (where it's UB to pass in an +/// index not a multiple of 4). +/// +/// # Safety +/// Using these functions allows reading a different type from the buffer than +/// was originally written (by a previous `store()` or the host API), allowing +/// all sorts of safety guarantees to be bypassed, making it effectively a +/// transmute. #[repr(transparent)] -pub struct ByteAddressableBuffer<'a> { +pub struct ByteAddressableBuffer { /// The underlying array of bytes, able to be directly accessed. - pub data: &'a mut [u32], + pub data: T, } -impl<'a> ByteAddressableBuffer<'a> { +fn bounds_check(data: &[u32], byte_index: u32) { + let sizeof = mem::size_of::() as u32; + if byte_index % 4 != 0 { + panic!("`byte_index` should be a multiple of 4"); + } + if byte_index + sizeof > data.len() as u32 { + let last_byte = byte_index + sizeof; + panic!( + "index out of bounds: the len is {} but loading {} bytes at `byte_index` {} reads until {} (exclusive)", + data.len(), + sizeof, + byte_index, + last_byte, + ); + } +} + +impl<'a> ByteAddressableBuffer<&'a [u32]> { /// Creates a `ByteAddressableBuffer` from the untyped blob of data. #[inline] - pub fn new(data: &'a mut [u32]) -> Self { + pub fn from_slice(data: &'a [u32]) -> Self { Self { data } } - /// Loads an arbitrary type from the buffer. `byte_index` must be a multiple of 4, otherwise, - /// it will get silently rounded down to the nearest multiple of 4. + /// Loads an arbitrary type from the buffer. `byte_index` must be a + /// multiple of 4. /// /// # Safety - /// This function allows writing a type to an untyped buffer, then reading a different type - /// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a - /// transmute) + /// See [`Self`]. pub unsafe fn load(&self, byte_index: u32) -> T { - if byte_index + mem::size_of::() as u32 > self.data.len() as u32 { - panic!("Index out of range"); - } + bounds_check::(self.data, byte_index); buffer_load_intrinsic(self.data, byte_index) } - /// Loads an arbitrary type from the buffer. `byte_index` must be a multiple of 4, otherwise, - /// it will get silently rounded down to the nearest multiple of 4. Bounds checking is not - /// performed. + /// Loads an arbitrary type from the buffer. `byte_index` must be a + /// multiple of 4. /// /// # Safety - /// This function allows writing a type to an untyped buffer, then reading a different type - /// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a - /// transmute). Additionally, bounds checking is not performed. + /// See [`Self`]. Additionally, bounds or alignment checking is not performed. pub unsafe fn load_unchecked(&self, byte_index: u32) -> T { buffer_load_intrinsic(self.data, byte_index) } +} + +impl<'a> ByteAddressableBuffer<&'a mut [u32]> { + /// Creates a `ByteAddressableBuffer` from the untyped blob of data. + #[inline] + pub fn from_mut_slice(data: &'a mut [u32]) -> Self { + Self { data } + } + + /// Create a non-mutable `ByteAddressableBuffer` from this mutable one. + #[inline] + pub fn as_ref(&self) -> ByteAddressableBuffer<&[u32]> { + ByteAddressableBuffer { data: self.data } + } + + /// Loads an arbitrary type from the buffer. `byte_index` must be a + /// multiple of 4. + /// + /// # Safety + /// See [`Self`]. + #[inline] + pub unsafe fn load(&self, byte_index: u32) -> T { + self.as_ref().load(byte_index) + } + + /// Loads an arbitrary type from the buffer. `byte_index` must be a + /// multiple of 4. + /// + /// # Safety + /// See [`Self`]. Additionally, bounds or alignment checking is not performed. + #[inline] + pub unsafe fn load_unchecked(&self, byte_index: u32) -> T { + self.as_ref().load_unchecked(byte_index) + } - /// Stores an arbitrary type int the buffer. `byte_index` must be a multiple of 4, otherwise, - /// it will get silently rounded down to the nearest multiple of 4. + /// Stores an arbitrary type into the buffer. `byte_index` must be a + /// multiple of 4. /// /// # Safety - /// This function allows writing a type to an untyped buffer, then reading a different type - /// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a - /// transmute) + /// See [`Self`]. pub unsafe fn store(&mut self, byte_index: u32, value: T) { - if byte_index + mem::size_of::() as u32 > self.data.len() as u32 { - panic!("Index out of range"); - } + bounds_check::(self.data, byte_index); buffer_store_intrinsic(self.data, byte_index, value); } - /// Stores an arbitrary type int the buffer. `byte_index` must be a multiple of 4, otherwise, - /// it will get silently rounded down to the nearest multiple of 4. Bounds checking is not - /// performed. + /// Stores an arbitrary type into the buffer. `byte_index` must be a + /// multiple of 4. /// /// # Safety - /// This function allows writing a type to an untyped buffer, then reading a different type - /// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a - /// transmute). Additionally, bounds checking is not performed. + /// See [`Self`]. Additionally, bounds or alignment checking is not performed. pub unsafe fn store_unchecked(&mut self, byte_index: u32, value: T) { buffer_store_intrinsic(self.data, byte_index, value); } diff --git a/tests/ui/byte_addressable_buffer/arr.rs b/tests/ui/byte_addressable_buffer/arr.rs index 4cacbcf906..798c3cd515 100644 --- a/tests/ui/byte_addressable_buffer/arr.rs +++ b/tests/ui/byte_addressable_buffer/arr.rs @@ -5,11 +5,22 @@ use spirv_std::{glam::Vec4, ByteAddressableBuffer}; #[spirv(fragment)] pub fn load( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32], + out: &mut [i32; 4], +) { + unsafe { + let buf = ByteAddressableBuffer::from_slice(buf); + *out = buf.load(5); + } +} + +#[spirv(fragment)] +pub fn load_mut( #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], out: &mut [i32; 4], ) { unsafe { - let buf = ByteAddressableBuffer::new(buf); + let buf = ByteAddressableBuffer::from_mut_slice(buf); *out = buf.load(5); } } @@ -20,7 +31,7 @@ pub fn store( #[spirv(flat)] val: [i32; 4], ) { unsafe { - let mut buf = ByteAddressableBuffer::new(buf); + let mut buf = ByteAddressableBuffer::from_mut_slice(buf); buf.store(5, val); } } diff --git a/tests/ui/byte_addressable_buffer/big_struct.rs b/tests/ui/byte_addressable_buffer/big_struct.rs index 27907afe3c..233cb87fae 100644 --- a/tests/ui/byte_addressable_buffer/big_struct.rs +++ b/tests/ui/byte_addressable_buffer/big_struct.rs @@ -14,11 +14,22 @@ pub struct BigStruct { #[spirv(fragment)] pub fn load( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32], + out: &mut BigStruct, +) { + unsafe { + let buf = ByteAddressableBuffer::from_slice(buf); + *out = buf.load(5); + } +} + +#[spirv(fragment)] +pub fn load_mut( #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], out: &mut BigStruct, ) { unsafe { - let buf = ByteAddressableBuffer::new(buf); + let buf = ByteAddressableBuffer::from_mut_slice(buf); *out = buf.load(5); } } @@ -29,7 +40,7 @@ pub fn store( #[spirv(flat)] val: BigStruct, ) { unsafe { - let mut buf = ByteAddressableBuffer::new(buf); + let mut buf = ByteAddressableBuffer::from_mut_slice(buf); buf.store(5, val); } } diff --git a/tests/ui/byte_addressable_buffer/complex.rs b/tests/ui/byte_addressable_buffer/complex.rs index 1ec3267a2e..b9e3edf128 100644 --- a/tests/ui/byte_addressable_buffer/complex.rs +++ b/tests/ui/byte_addressable_buffer/complex.rs @@ -20,11 +20,22 @@ pub struct Nesty { #[spirv(fragment)] pub fn load( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32], + out: &mut Nesty, +) { + unsafe { + let buf = ByteAddressableBuffer::from_slice(buf); + *out = buf.load(5); + } +} + +#[spirv(fragment)] +pub fn load_mut( #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], out: &mut Nesty, ) { unsafe { - let buf = ByteAddressableBuffer::new(buf); + let buf = ByteAddressableBuffer::from_mut_slice(buf); *out = buf.load(5); } } @@ -35,7 +46,7 @@ pub fn store( val: Nesty, ) { unsafe { - let mut buf = ByteAddressableBuffer::new(buf); + let mut buf = ByteAddressableBuffer::from_mut_slice(buf); buf.store(5, val); } } diff --git a/tests/ui/byte_addressable_buffer/empty_struct.rs b/tests/ui/byte_addressable_buffer/empty_struct.rs index b586139d3a..1425526112 100644 --- a/tests/ui/byte_addressable_buffer/empty_struct.rs +++ b/tests/ui/byte_addressable_buffer/empty_struct.rs @@ -7,11 +7,22 @@ pub struct EmptyStruct {} #[spirv(fragment)] pub fn load( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32], + out: &mut EmptyStruct, +) { + unsafe { + let buf = ByteAddressableBuffer::from_slice(buf); + *out = buf.load(5); + } +} + +#[spirv(fragment)] +pub fn load_mut( #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], out: &mut EmptyStruct, ) { unsafe { - let buf = ByteAddressableBuffer::new(buf); + let buf = ByteAddressableBuffer::from_mut_slice(buf); *out = buf.load(5); } } @@ -20,7 +31,7 @@ pub fn load( pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32]) { let val = EmptyStruct {}; unsafe { - let mut buf = ByteAddressableBuffer::new(buf); + let mut buf = ByteAddressableBuffer::from_mut_slice(buf); buf.store(5, val); } } diff --git a/tests/ui/byte_addressable_buffer/f32.rs b/tests/ui/byte_addressable_buffer/f32.rs index 3dc5d3ffe3..2b82f89d16 100644 --- a/tests/ui/byte_addressable_buffer/f32.rs +++ b/tests/ui/byte_addressable_buffer/f32.rs @@ -4,12 +4,20 @@ use spirv_std::spirv; use spirv_std::ByteAddressableBuffer; #[spirv(fragment)] -pub fn load( +pub fn load(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32], out: &mut f32) { + unsafe { + let buf = ByteAddressableBuffer::from_slice(buf); + *out = buf.load(5); + } +} + +#[spirv(fragment)] +pub fn load_mut( #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], out: &mut f32, ) { unsafe { - let buf = ByteAddressableBuffer::new(buf); + let buf = ByteAddressableBuffer::from_mut_slice(buf); *out = buf.load(5); } } @@ -17,7 +25,7 @@ pub fn load( #[spirv(fragment)] pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], val: f32) { unsafe { - let mut buf = ByteAddressableBuffer::new(buf); + let mut buf = ByteAddressableBuffer::from_mut_slice(buf); buf.store(5, val); } } diff --git a/tests/ui/byte_addressable_buffer/small_struct.rs b/tests/ui/byte_addressable_buffer/small_struct.rs index d6959db768..948af2dd92 100644 --- a/tests/ui/byte_addressable_buffer/small_struct.rs +++ b/tests/ui/byte_addressable_buffer/small_struct.rs @@ -10,11 +10,22 @@ pub struct SmallStruct { #[spirv(fragment)] pub fn load( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32], + out: &mut SmallStruct, +) { + unsafe { + let buf = ByteAddressableBuffer::from_slice(buf); + *out = buf.load(5); + } +} + +#[spirv(fragment)] +pub fn load_mut( #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], out: &mut SmallStruct, ) { unsafe { - let buf = ByteAddressableBuffer::new(buf); + let buf = ByteAddressableBuffer::from_mut_slice(buf); *out = buf.load(5); } } @@ -27,7 +38,7 @@ pub fn store( ) { let val = SmallStruct { a, b }; unsafe { - let mut buf = ByteAddressableBuffer::new(buf); + let mut buf = ByteAddressableBuffer::from_mut_slice(buf); buf.store(5, val); } } diff --git a/tests/ui/byte_addressable_buffer/u32.rs b/tests/ui/byte_addressable_buffer/u32.rs index 8d5d03ad65..d0e1e44624 100644 --- a/tests/ui/byte_addressable_buffer/u32.rs +++ b/tests/ui/byte_addressable_buffer/u32.rs @@ -4,12 +4,20 @@ use spirv_std::spirv; use spirv_std::ByteAddressableBuffer; #[spirv(fragment)] -pub fn load( +pub fn load(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32], out: &mut u32) { + unsafe { + let buf = ByteAddressableBuffer::from_slice(buf); + *out = buf.load(5); + } +} + +#[spirv(fragment)] +pub fn load_mut( #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], out: &mut u32, ) { unsafe { - let buf = ByteAddressableBuffer::new(buf); + let buf = ByteAddressableBuffer::from_mut_slice(buf); *out = buf.load(5); } } @@ -20,7 +28,7 @@ pub fn store( #[spirv(flat)] val: u32, ) { unsafe { - let mut buf = ByteAddressableBuffer::new(buf); + let mut buf = ByteAddressableBuffer::from_mut_slice(buf); buf.store(5, val); } } diff --git a/tests/ui/byte_addressable_buffer/vec.rs b/tests/ui/byte_addressable_buffer/vec.rs index 49ecbc15c2..e934071b12 100644 --- a/tests/ui/byte_addressable_buffer/vec.rs +++ b/tests/ui/byte_addressable_buffer/vec.rs @@ -13,12 +13,25 @@ pub struct Mat4 { #[spirv(fragment)] pub fn load( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &[u32], + out: &mut Vec4, + outmat: &mut Mat4, +) { + unsafe { + let buf = ByteAddressableBuffer::from_slice(buf); + *out = buf.load(5); + *outmat = buf.load(5); + } +} + +#[spirv(fragment)] +pub fn load_mut( #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], out: &mut Vec4, outmat: &mut Mat4, ) { unsafe { - let buf = ByteAddressableBuffer::new(buf); + let buf = ByteAddressableBuffer::from_mut_slice(buf); *out = buf.load(5); *outmat = buf.load(5); } @@ -31,7 +44,7 @@ pub fn store( valmat: Mat4, ) { unsafe { - let mut buf = ByteAddressableBuffer::new(buf); + let mut buf = ByteAddressableBuffer::from_mut_slice(buf); buf.store(5, val); buf.store(5, valmat); }