diff --git a/benches/benches/bevy_ecs/iteration/batched_compute.rs b/benches/benches/bevy_ecs/iteration/batched_compute.rs new file mode 100644 index 0000000000000..8a9a46bc48253 --- /dev/null +++ b/benches/benches/bevy_ecs/iteration/batched_compute.rs @@ -0,0 +1,238 @@ +use bevy_ecs::prelude::*; +use core::arch::x86_64::*; +use glam::*; +use rand::prelude::*; + +use criterion::BenchmarkId; +use criterion::Criterion; + +#[derive(Component, Copy, Clone, Default)] +struct Position(Vec3); + +#[derive(Component, Copy, Clone, Default)] +#[repr(transparent)] +struct Health(f32); + +//A hyperplane describing solid geometry, (x,y,z) = n with d such that nx + d = 0 +#[derive(Component, Copy, Clone, Default)] +struct Wall(Vec3, f32); + +struct Benchmark(World); + +fn rnd_vec3(rng: &mut ThreadRng) -> Vec3 { + let x1 = rng.gen_range(-16.0..=16.0); + let x2 = rng.gen_range(-16.0..=16.0); + let x3 = rng.gen_range(-16.0..=16.0); + + Vec3::new(x1, x2, x3) +} + +fn rnd_wall(rng: &mut ThreadRng) -> Wall { + let d = rng.gen_range(-16.0..=16.0); + + Wall(rnd_vec3(rng).normalize_or_zero(), d) +} + +// AoS to SoA data layout conversion for x86 AVX. +// This code has been adapted from: +// https://www.intel.com/content/dam/develop/external/us/en/documents/normvec-181650.pdf +#[inline(always)] +// This example is written in a way that benefits from inlined data layout conversion. +fn aos_to_soa_83(aos_inner: &[Vec3; 8]) -> [__m256; 3] { + unsafe { + //# SAFETY: Vec3 is repr(C) for x86_64 + let mx0 = _mm_loadu_ps((aos_inner as *const Vec3 as *const f32).offset(0)); + let mx1 = _mm_loadu_ps((aos_inner as *const Vec3 as *const f32).offset(4)); + let mx2 = _mm_loadu_ps((aos_inner as *const Vec3 as *const f32).offset(8)); + let mx3 = _mm_loadu_ps((aos_inner as *const Vec3 as *const f32).offset(12)); + let mx4 = _mm_loadu_ps((aos_inner as *const Vec3 as *const f32).offset(16)); + let mx5 = _mm_loadu_ps((aos_inner as *const Vec3 as *const f32).offset(20)); + + let mut m03 = _mm256_castps128_ps256(mx0); // load lower halves + let mut m14 = _mm256_castps128_ps256(mx1); + let mut m25 = _mm256_castps128_ps256(mx2); + m03 = _mm256_insertf128_ps(m03, mx3, 1); // load upper halves + m14 = _mm256_insertf128_ps(m14, mx4, 1); + m25 = _mm256_insertf128_ps(m25, mx5, 1); + + let xy = _mm256_shuffle_ps::<0b10011110>(m14, m25); // upper x's and y's + let yz = _mm256_shuffle_ps::<0b01001001>(m03, m14); // lower y's and z's + let x = _mm256_shuffle_ps::<0b10001100>(m03, xy); + let y = _mm256_shuffle_ps::<0b11011000>(yz, xy); + let z = _mm256_shuffle_ps::<0b11001101>(yz, m25); + [x, y, z] + } +} + +impl Benchmark { + fn new(size: i32) -> Benchmark { + let mut world = World::new(); + + let mut rng = rand::thread_rng(); + + world.spawn_batch((0..size).map(|_| (Position(rnd_vec3(&mut rng)), Health(100.0)))); + world.spawn_batch((0..(2_i32.pow(12) - 1)).map(|_| (rnd_wall(&mut rng)))); + + Self(world) + } + + fn scalar(mut pos_healths: Query<(&Position, &mut Health)>, walls: Query<&Wall>) { + pos_healths.for_each_mut(|(position, mut health)| { + // This forms the scalar path: it behaves just like `for_each_mut`. + + // Optional: disable change detection for more performance. + let health = &mut health.bypass_change_detection().0; + + // Test each (Position,Health) against each Wall. + walls.for_each(|wall| { + let plane = wall.0; + + // Test which side of the wall we are on + let dotproj = plane.dot(position.0); + + // Test against the Wall's displacement/discriminant value + if dotproj < wall.1 { + //Ouch! Take damage! + *health -= 1.0; + } + }); + }); + } + + // Perform collision detection against a set of Walls, forming a convex polygon. + // Each entity has a Position and some Health (initialized to 100.0). + // If the position of an entity is found to be outside of a Wall, decrement its "health" by 1.0. + // The effect is cumulative based on the number of walls. + // An entity entirely inside the convex polygon will have its health remain unchanged. + fn batched_avx(mut pos_healths: Query<(&Position, &mut Health)>, walls: Query<&Wall>) { + // Conceptually, this system is executed using two loops: the outer "batched" loop receiving + // batches of 8 Positions and Health components at a time, and the inner loop iterating over + // the Walls. + + // There's more than one way to vectorize this system -- this example may not be optimal. + pos_healths.for_each_mut_batched::<8>( + |(position, mut health)| { + // This forms the scalar path: it behaves just like `for_each_mut`. + + // Optional: disable change detection for more performance. + let health = &mut health.bypass_change_detection().0; + + // Test each (Position,Health) against each Wall. + walls.for_each(|wall| { + let plane = wall.0; + + // Test which side of the wall we are on + let dotproj = plane.dot(position.0); + + // Test against the Wall's displacement/discriminant value + if dotproj < wall.1 { + //Ouch! Take damage! + *health -= 1.0; + } + }); + }, + |(positions, mut healths)| { + // This forms the vector path: the closure receives a batch of + // 8 Positions and 8 Healths as arrays. + + // Optional: disable change detection for more performance. + let healths = healths.bypass_change_detection(); + + // Treat the Health batch as a batch of 8 f32s. + unsafe { + // # SAFETY: Health is repr(transprent)! + let healths_raw = healths as *mut Health as *mut f32; + let mut healths = _mm256_loadu_ps(healths_raw); + + // NOTE: array::map optimizes poorly -- it is recommended to unpack your arrays + // manually as shown to avoid spurious copies which will impact your performance. + let [p0, p1, p2, p3, p4, p5, p6, p7] = positions; + + // Perform data layout conversion from AoS to SoA. + // ps_x will receive all of the X components of the positions, + // ps_y will receive all of the Y components + // and ps_z will receive all of the Z's. + let [ps_x, ps_y, ps_z] = + aos_to_soa_83(&[p0.0, p1.0, p2.0, p3.0, p4.0, p5.0, p6.0, p7.0]); + + // Iterate over each wall without batching. + walls.for_each(|wall| { + // Test each wall against all 8 positions at once. The "broadcast" intrinsic + // helps us achieve this by duplicating the Wall's X coordinate over an entire + // vector register, e.g., [X X ... X]. The same goes for the Wall's Y and Z + // coordinates. + + // This is the exact same formula as implemented in the scalar path, but + // modified to be calculated in parallel across each lane. + + // Multiply all of the X coordinates of each Position against Wall's Normal X + let xs_dot = _mm256_mul_ps(ps_x, _mm256_broadcast_ss(&wall.0.x)); + // Multiply all of the Y coordinates of each Position against Wall's Normal Y + let ys_dot = _mm256_mul_ps(ps_y, _mm256_broadcast_ss(&wall.0.y)); + // Multiply all of the Z coordinates of each Position against Wall's Normal Z + let zs_dot = _mm256_mul_ps(ps_z, _mm256_broadcast_ss(&wall.0.z)); + + // Now add them together: the result is a vector register containing the dot + // product of each Position against the Wall's Normal vector. + let dotprojs = _mm256_add_ps(_mm256_add_ps(xs_dot, ys_dot), zs_dot); + + // Take the Wall's discriminant/displacement value and broadcast it like before. + let wall_d = _mm256_broadcast_ss(&wall.1); + + // Compare each dot product against the Wall's discriminant, using the + // "Less Than" relation as we did in the scalar code. + // The result will be be either -1 or zero *as an integer*. + let cmp = _mm256_cmp_ps::<_CMP_LT_OS>(dotprojs, wall_d); + + // Convert the integer values back to f32 values (-1.0 or 0.0). + // These form the damage values for each entity. + let damages = _mm256_cvtepi32_ps(_mm256_castps_si256(cmp)); //-1.0 or 0.0 + + // Update the healths of each entity being processed with the results of the + // collision detection. + healths = _mm256_add_ps(healths, damages); + }); + // Now that all Walls have been processed, write the final updated Health values + // for this batch of entities back to main memory. + _mm256_storeu_ps(healths_raw, healths); + } + }, + ); + } +} + +pub fn batched_compute(c: &mut Criterion) { + let mut group = c.benchmark_group("batched_compute"); + group.warm_up_time(std::time::Duration::from_secs(1)); + group.measurement_time(std::time::Duration::from_secs(9)); + + for exp in 14..17 { + let size = 2_i32.pow(exp) - 1; //Ensure scalar path gets run too (incomplete batch at end) + + group.bench_with_input( + BenchmarkId::new("autovectorized", size), + &size, + |b, &size| { + let Benchmark(mut world) = Benchmark::new(size); + + let mut system = IntoSystem::into_system(Benchmark::scalar); + system.initialize(&mut world); + system.update_archetype_component_access(&world); + + b.iter(move || system.run((), &mut world)); + }, + ); + + group.bench_with_input(BenchmarkId::new("batched_avx", size), &size, |b, &size| { + let Benchmark(mut world) = Benchmark::new(size); + + let mut system = IntoSystem::into_system(Benchmark::batched_avx); + system.initialize(&mut world); + system.update_archetype_component_access(&world); + + b.iter(move || system.run((), &mut world)); + }); + } + + group.finish(); +} diff --git a/benches/benches/bevy_ecs/iteration/mod.rs b/benches/benches/bevy_ecs/iteration/mod.rs index e3ed6a6afeabe..c0a1fddf155c0 100644 --- a/benches/benches/bevy_ecs/iteration/mod.rs +++ b/benches/benches/bevy_ecs/iteration/mod.rs @@ -1,5 +1,8 @@ use criterion::*; +#[cfg(target_feature = "avx")] +mod batched_compute; + mod heavy_compute; mod iter_frag; mod iter_frag_foreach; @@ -19,8 +22,21 @@ mod iter_simple_system; mod iter_simple_wide; mod iter_simple_wide_sparse_set; +#[cfg(target_feature = "avx")] +use batched_compute::batched_compute; + use heavy_compute::*; +#[cfg(target_feature = "avx")] +criterion_group!( + iterations_benches, + iter_frag, + iter_frag_sparse, + iter_simple, + heavy_compute, + batched_compute, +); +#[cfg(not(target_feature = "avx"))] criterion_group!( iterations_benches, iter_frag, diff --git a/crates/bevy_ecs/Cargo.toml b/crates/bevy_ecs/Cargo.toml index 15c8dd2bb3e51..789ba50842188 100644 --- a/crates/bevy_ecs/Cargo.toml +++ b/crates/bevy_ecs/Cargo.toml @@ -30,6 +30,7 @@ serde = { version = "1", features = ["derive"] } [dev-dependencies] rand = "0.8" +bevy_math = { path = "../bevy_math", version = "0.9.0-dev" } [[example]] name = "events" diff --git a/crates/bevy_ecs/src/archetype.rs b/crates/bevy_ecs/src/archetype.rs index 777dd0d2e3fcd..692641eee9931 100644 --- a/crates/bevy_ecs/src/archetype.rs +++ b/crates/bevy_ecs/src/archetype.rs @@ -23,7 +23,10 @@ use crate::{ bundle::BundleId, component::{ComponentId, StorageType}, entity::{Entity, EntityLocation}, - storage::{ImmutableSparseSet, SparseArray, SparseSet, SparseSetIndex, TableId, TableRow}, + storage::{ + aligned_vec::SimdAlignedVec, ImmutableSparseSet, SparseArray, SparseSet, SparseSetIndex, + TableId, TableRow, + }, }; use std::{ collections::HashMap, @@ -297,7 +300,7 @@ pub struct Archetype { id: ArchetypeId, table_id: TableId, edges: Edges, - entities: Vec, + entities: SimdAlignedVec, components: ImmutableSparseSet, } @@ -333,7 +336,7 @@ impl Archetype { Self { id, table_id, - entities: Vec::new(), + entities: SimdAlignedVec::new(), components: components.into_immutable(), edges: Default::default(), } diff --git a/crates/bevy_ecs/src/change_detection.rs b/crates/bevy_ecs/src/change_detection.rs index 3802af55dd349..34b9bdf004edc 100644 --- a/crates/bevy_ecs/src/change_detection.rs +++ b/crates/bevy_ecs/src/change_detection.rs @@ -2,10 +2,10 @@ use crate::{ component::{Tick, TickCells}, - ptr::PtrMut, + ptr::{Batch, Ptr, PtrMut, UnsafeCellDeref}, system::Resource, }; -use bevy_ptr::{Ptr, UnsafeCellDeref}; +use core::marker::PhantomData; use std::ops::{Deref, DerefMut}; /// The (arbitrarily chosen) minimum number of world tick increments between `check_tick` scans. @@ -352,6 +352,12 @@ impl<'a> TicksMut<'a> { } } } +pub(crate) struct TicksBatch<'a, const N: usize> { + pub(crate) added_ticks: &'a mut Batch, + pub(crate) changed_ticks: &'a mut Batch, + pub(crate) last_change_tick: u32, + pub(crate) change_tick: u32, +} impl<'a> From> for Ticks<'a> { fn from(ticks: TicksMut<'a>) -> Self { @@ -580,6 +586,120 @@ change_detection_mut_impl!(Mut<'a, T>, T,); impl_methods!(Mut<'a, T>, T,); impl_debug!(Mut<'a, T>,); +/// Unique mutable borrow of an entity's component (batched version). +/// Each batch changes in unison: a batch has changed if any of its elements have changed. +pub struct MutBatch<'a, T, const N: usize> { + pub(crate) value: &'a mut Batch, + pub(crate) ticks: TicksBatch<'a, N>, + pub(crate) _marker: PhantomData, +} + +impl<'a, T, const N: usize> DetectChanges for MutBatch<'a, T, N> { + #[inline] + fn is_added(&self) -> bool { + self.ticks + .added_ticks + .iter() + .any(|x| x.is_older_than(self.ticks.last_change_tick, self.ticks.change_tick)) + } + + #[inline] + fn is_changed(&self) -> bool { + self.ticks + .changed_ticks + .iter() + .any(|x| x.is_older_than(self.ticks.last_change_tick, self.ticks.change_tick)) + } + + #[inline] + fn last_changed(&self) -> u32 { + self.ticks.last_change_tick + } +} + +impl<'a, T, const N: usize> DetectChangesMut for MutBatch<'a, T, N> { + type Inner = Batch; + + #[inline] + fn set_changed(&mut self) { + for ticks in self.ticks.changed_ticks.iter_mut() { + ticks.set_changed(self.ticks.change_tick); + } + } + + fn set_last_changed(&mut self, last_change_tick: u32) { + self.ticks.last_change_tick = last_change_tick; + } + + fn bypass_change_detection(&mut self) -> &mut Self::Inner { + self.value + } + + #[inline] + fn set_if_neq(&mut self, value: Target) + where + Self: Deref + DerefMut, + Target: PartialEq, + { + // This dereference is immutable, so does not trigger change detection + if *::deref(self) != value { + // `DerefMut` usage triggers change detection + *::deref_mut(self) = value; + } + } +} + +impl<'a, T, const N: usize> Deref for MutBatch<'a, T, N> { + type Target = Batch; + + #[inline] + fn deref(&self) -> &Self::Target { + self.value + } +} + +impl<'a, T, const N: usize> DerefMut for MutBatch<'a, T, N> { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + self.set_changed(); + self.value + } +} + +impl<'a, T, const N: usize> AsRef> for MutBatch<'a, T, N> { + #[inline] + fn as_ref(&self) -> &Batch { + self.deref() + } +} + +impl<'a, T, const N: usize> AsMut> for MutBatch<'a, T, N> { + #[inline] + fn as_mut(&mut self) -> &mut Batch { + self.deref_mut() + } +} + +impl<'a, T, const N: usize> MutBatch<'a, T, N> { + /// Consume `self` and return a mutable reference to the + /// contained value while marking `self` as "changed". + #[inline] + pub fn into_inner(mut self) -> &'a mut Batch { + self.set_changed(); + self.value + } +} + +impl<'a, T, const N: usize> std::fmt::Debug for MutBatch<'a, T, N> +where + Batch: std::fmt::Debug, + T: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple(stringify!($name)).field(&self.value).finish() + } +} + /// Unique mutable borrow of resources or an entity's component. /// /// Similar to [`Mut`], but not generic over the component type, instead diff --git a/crates/bevy_ecs/src/query/batch.rs b/crates/bevy_ecs/src/query/batch.rs new file mode 100644 index 0000000000000..fc0ab1b992819 --- /dev/null +++ b/crates/bevy_ecs/src/query/batch.rs @@ -0,0 +1,234 @@ +use crate::{ + change_detection::{MutBatch, TicksBatch}, + component::{Component, Tick}, + entity::Entity, + ptr::Batch, + query::{AnyOf, ChangeTrackers, DebugCheckedUnwrap, WorldQuery}, + storage::TableRow, +}; + +use bevy_ecs_macros::all_tuples; + +use core::marker::PhantomData; + +/// The item type returned when a [`WorldQuery`] is iterated over in a batched fashion +pub type QueryBatch<'w, Q, const N: usize> = >::BatchItem<'w>; + +/// The read-only variant of the item type returned when a [`WorldQuery`] is iterated over in a batched fashion +pub type ROQueryBatch<'w, Q, const N: usize> = QueryBatch<'w, ::ReadOnly, N>; + +/// An extension of [`WorldQuery`] for batched queries. +pub trait WorldQueryBatch: WorldQuery { + type BatchItem<'w>; + + /// Retrieve a batch of size `N` from the current table. + /// # Safety + /// + /// `table_row_start` is a valid table row index for the current table + /// `table_row_start` + `N` is a valid table row index for the current table + /// `table_row_start` is a multiple of `N` + /// + /// Must always be called _after_ [`WorldQuery::set_table`]. + unsafe fn fetch_batched<'w>( + fetch: &mut ::Fetch<'w>, + entity_batch: &'w Batch, + table_row_start: TableRow, + len: usize, + ) -> Self::BatchItem<'w>; +} + +impl WorldQueryBatch for Entity { + type BatchItem<'w> = &'w Batch; + + #[inline] + unsafe fn fetch_batched<'w>( + _fetch: &mut ::Fetch<'w>, + entity_batch: &'w Batch, + _table_row_start: TableRow, + _len: usize, + ) -> Self::BatchItem<'w> { + entity_batch + } +} + +impl WorldQueryBatch for &T { + type BatchItem<'w> = &'w Batch; + + #[inline] + unsafe fn fetch_batched<'w>( + fetch: &mut ::Fetch<'w>, + _entity_batch: &'w Batch, + table_row_start: TableRow, + len: usize, + ) -> Self::BatchItem<'w> { + //TODO: when generalized const expresions are stable, want the following: + //gcd::euclid_usize(ptr::MAX_SIMD_ALIGNMENT, N * core::mem::size_of::()); + + let components = fetch.table_components.debug_checked_unwrap(); + + components.get_batch_deref::(table_row_start.index(), len) + } +} + +impl<'__w, T: Component, const N: usize> WorldQueryBatch for &'__w mut T { + type BatchItem<'w> = MutBatch<'w, T, N>; + + #[inline] + unsafe fn fetch_batched<'w>( + fetch: &mut ::Fetch<'w>, + _entity_batch: &'w Batch, + table_row_start: TableRow, + len: usize, + ) -> Self::BatchItem<'w> { + let (table_components, added_ticks, changed_ticks) = + fetch.table_data.debug_checked_unwrap(); + + MutBatch:: { + value: table_components.get_batch_deref_mut::(table_row_start.index(), len), + ticks: TicksBatch { + // SAFETY: [table_row_start..+batch.len()] is in range + added_ticks: added_ticks.get_batch_deref_mut::(table_row_start.index(), len), + changed_ticks: changed_ticks.get_batch_deref_mut::(table_row_start.index(), len), + change_tick: fetch.change_tick, + last_change_tick: fetch.last_change_tick, + }, + _marker: PhantomData, + } + } +} + +impl, const N: usize> WorldQueryBatch for Option { + type BatchItem<'w> = Option>; + + #[inline] + unsafe fn fetch_batched<'w>( + fetch: &mut ::Fetch<'w>, + entity_batch: &'w Batch, + table_row_start: TableRow, + len: usize, + ) -> Self::BatchItem<'w> { + if fetch.matches { + Some(T::fetch_batched( + &mut fetch.fetch, + entity_batch, + table_row_start, + len, + )) + } else { + None + } + } +} + +/// A batch of [`ChangeTrackers`]. This is used when performing queries with Change Trackers using the +/// [`Query::for_each_mut_batched`](crate::system::Query::for_each_mut_batched) and [`Query::for_each_batched`](crate::system::Query::for_each_batched) functions. +#[derive(Clone)] +pub struct ChangeTrackersBatch<'a, T, const N: usize> { + pub(crate) added_ticks: &'a Batch, + pub(crate) changed_ticks: &'a Batch, + pub(crate) last_change_tick: u32, + pub(crate) change_tick: u32, + marker: PhantomData, +} + +impl<'a, T: Component, const N: usize> ChangeTrackersBatch<'a, T, N> { + /// Returns true if this component has been added since the last execution of this system. + #[inline] + pub fn is_added(&self) -> bool { + self.added_ticks + .iter() + .any(|x| x.is_older_than(self.last_change_tick, self.change_tick)) + } + + /// Returns true if this component has been changed since the last execution of this system. + #[inline] + pub fn is_changed(&self) -> bool { + self.changed_ticks + .iter() + .any(|x| x.is_older_than(self.last_change_tick, self.change_tick)) + } +} + +impl WorldQueryBatch for ChangeTrackers { + type BatchItem<'w> = ChangeTrackersBatch<'w, T, N>; + + #[inline] + unsafe fn fetch_batched<'w>( + fetch: &mut ::Fetch<'w>, + _entity_batch: &'w Batch, + table_row_start: TableRow, + len: usize, + ) -> Self::BatchItem<'w> { + ChangeTrackersBatch { + added_ticks: { + let table_ticks = fetch.table_added.debug_checked_unwrap(); + + table_ticks.get_batch_deref::(table_row_start.index(), len) + }, + changed_ticks: { + let table_ticks = fetch.table_changed.debug_checked_unwrap(); + + table_ticks.get_batch_deref::(table_row_start.index(), len) + }, + marker: PhantomData, + last_change_tick: fetch.last_change_tick, + change_tick: fetch.change_tick, + } + } +} + +macro_rules! impl_tuple_fetch_batched { + ($(($name: ident, $state: ident)),*) => { + #[allow(unused_variables)] + #[allow(non_snake_case)] + #[allow(clippy::unused_unit)] + impl),*> WorldQueryBatch for ($($name,)*) + { + type BatchItem<'w> = ($($name::BatchItem<'w>,)*); + + #[inline] + unsafe fn fetch_batched<'w>( + _fetch: &mut ::Fetch<'w>, + _entity_batch: &'w Batch, + _table_row_start: TableRow, + _len: usize, + ) -> Self::BatchItem<'w> + { + let ($($name,)*) = _fetch; + ($($name::fetch_batched($name, _entity_batch, _table_row_start, _len),)*) + } + } + }; +} + +macro_rules! impl_anytuple_fetch_batched { + ($(($name: ident, $state: ident)),*) => { + #[allow(unused_variables)] + #[allow(non_snake_case)] + #[allow(clippy::unused_unit)] + impl),*> WorldQueryBatch for AnyOf<($($name,)*)> + { + type BatchItem<'w> = ($(Option<$name::BatchItem<'w>>,)*); + + #[inline] + unsafe fn fetch_batched<'w>( + _fetch: &mut ::Fetch<'w>, + _entity_batch: &'w Batch, + _table_row_start: TableRow, + _len: usize, + ) -> >::BatchItem<'w> + { + let ($($name,)*) = _fetch; + + ($( + $name.1.then(|| $name::fetch_batched(&mut $name.0, _entity_batch, _table_row_start, _len)), + )*) + + } + } + + }; +} + +all_tuples!(impl_tuple_fetch_batched, 0, 15, F, S); +all_tuples!(impl_anytuple_fetch_batched, 0, 15, F, S); diff --git a/crates/bevy_ecs/src/query/fetch.rs b/crates/bevy_ecs/src/query/fetch.rs index 5733ffd262829..adb3e94d95d2c 100644 --- a/crates/bevy_ecs/src/query/fetch.rs +++ b/crates/bevy_ecs/src/query/fetch.rs @@ -3,14 +3,15 @@ use crate::{ change_detection::{Ticks, TicksMut}, component::{Component, ComponentId, ComponentStorage, ComponentTicks, StorageType, Tick}, entity::Entity, + ptr::{ThinSlicePtr, UnsafeCellDeref}, query::{Access, DebugCheckedUnwrap, FilteredAccess}, storage::{ComponentSparseSet, Table, TableRow}, world::{Mut, Ref, World}, }; + use bevy_ecs_macros::all_tuples; pub use bevy_ecs_macros::WorldQuery; -use bevy_ptr::{ThinSlicePtr, UnsafeCellDeref}; -use std::{cell::UnsafeCell, marker::PhantomData}; +use core::{cell::UnsafeCell, marker::PhantomData}; /// Types that can be fetched from a [`World`] using a [`Query`]. /// @@ -518,9 +519,9 @@ unsafe impl ReadOnlyWorldQuery for Entity {} #[doc(hidden)] pub struct ReadFetch<'w, T> { // T::Storage = TableStorage - table_components: Option>>, + pub(crate) table_components: Option>>, // T::Storage = SparseStorage - sparse_set: Option<&'w ComponentSparseSet>, + pub(crate) sparse_set: Option<&'w ComponentSparseSet>, } /// SAFETY: `Self` is the same as `Self::ReadOnly` @@ -817,16 +818,15 @@ unsafe impl<'__w, T: Component> ReadOnlyWorldQuery for Ref<'__w, T> {} #[doc(hidden)] pub struct WriteFetch<'w, T> { // T::Storage = TableStorage - table_data: Option<( + pub(crate) table_data: Option<( ThinSlicePtr<'w, UnsafeCell>, ThinSlicePtr<'w, UnsafeCell>, ThinSlicePtr<'w, UnsafeCell>, )>, // T::Storage = SparseStorage - sparse_set: Option<&'w ComponentSparseSet>, - - last_change_tick: u32, - change_tick: u32, + pub(crate) sparse_set: Option<&'w ComponentSparseSet>, + pub(crate) last_change_tick: u32, + pub(crate) change_tick: u32, } /// SAFETY: access of `&T` is a subset of `&mut T` @@ -978,8 +978,8 @@ unsafe impl<'__w, T: Component> WorldQuery for &'__w mut T { #[doc(hidden)] pub struct OptionFetch<'w, T: WorldQuery> { - fetch: T::Fetch<'w>, - matches: bool, + pub(crate) fetch: T::Fetch<'w>, + pub(crate) matches: bool, } // SAFETY: defers to soundness of `T: WorldQuery` impl @@ -1119,7 +1119,7 @@ pub struct ChangeTrackers { pub(crate) component_ticks: ComponentTicks, pub(crate) last_change_tick: u32, pub(crate) change_tick: u32, - marker: PhantomData, + pub(crate) marker: PhantomData, } impl Clone for ChangeTrackers { @@ -1161,14 +1161,14 @@ impl ChangeTrackers { #[doc(hidden)] pub struct ChangeTrackersFetch<'w, T> { // T::Storage = TableStorage - table_added: Option>>, - table_changed: Option>>, + pub(crate) table_added: Option>>, + pub(crate) table_changed: Option>>, // T::Storage = SparseStorage - sparse_set: Option<&'w ComponentSparseSet>, + pub(crate) sparse_set: Option<&'w ComponentSparseSet>, - marker: PhantomData, - last_change_tick: u32, - change_tick: u32, + pub(crate) marker: PhantomData, + pub(crate) last_change_tick: u32, + pub(crate) change_tick: u32, } // SAFETY: `ROQueryFetch` is the same as `QueryFetch` @@ -1418,7 +1418,6 @@ macro_rules! impl_tuple_fetch { /// SAFETY: each item in the tuple is read only unsafe impl<$($name: ReadOnlyWorldQuery),*> ReadOnlyWorldQuery for ($($name,)*) {} - }; } diff --git a/crates/bevy_ecs/src/query/filter.rs b/crates/bevy_ecs/src/query/filter.rs index a067acbd89932..760deaa21837d 100644 --- a/crates/bevy_ecs/src/query/filter.rs +++ b/crates/bevy_ecs/src/query/filter.rs @@ -2,12 +2,12 @@ use crate::{ archetype::{Archetype, ArchetypeComponentId}, component::{Component, ComponentId, ComponentStorage, StorageType, Tick}, entity::Entity, + ptr::{ThinSlicePtr, UnsafeCellDeref}, query::{Access, DebugCheckedUnwrap, FilteredAccess, WorldQuery}, storage::{Column, ComponentSparseSet, Table, TableRow}, world::World, }; use bevy_ecs_macros::all_tuples; -use bevy_ptr::{ThinSlicePtr, UnsafeCellDeref}; use std::{cell::UnsafeCell, marker::PhantomData}; use super::ReadOnlyWorldQuery; diff --git a/crates/bevy_ecs/src/query/mod.rs b/crates/bevy_ecs/src/query/mod.rs index 1b3e1f4d08bca..2e28b55191653 100644 --- a/crates/bevy_ecs/src/query/mod.rs +++ b/crates/bevy_ecs/src/query/mod.rs @@ -1,4 +1,5 @@ mod access; +mod batch; mod fetch; mod filter; mod iter; @@ -6,6 +7,7 @@ mod par_iter; mod state; pub use access::*; +pub use batch::*; pub use fetch::*; pub use filter::*; pub use iter::*; @@ -77,6 +79,9 @@ mod tests { #[derive(Component, Debug, Eq, PartialEq, Clone, Copy)] struct D(usize); + #[derive(Component)] + struct E; + #[derive(Component, Debug, Eq, PartialEq, Clone, Copy)] #[component(storage = "SparseSet")] struct Sparse(usize); @@ -712,6 +717,121 @@ mod tests { } } + #[test] + fn batched_queries() { + let mut world = World::new(); + + world.spawn_batch( + (0..127) + .into_iter() + .map(|i| (A(4 * i), B(4 * i + 1), C(4 * i + 2), D(4 * i + 3))), + ); + + fn system_compute(mut q: Query<(&mut A, &B, &C, &D)>) { + let mut scalar_counter = 0; + let mut batch_counter = 0; + + q.for_each_mut_batched::<4>( + |(mut a, b, c, d)| { + assert_eq!(a.ticks.added.tick, 1); + assert_eq!(a.ticks.changed.tick, 1); + + a.0 += b.0 + c.0 + d.0; + scalar_counter += 1; + + assert_eq!(a.ticks.added.tick, 1); + assert_eq!(a.ticks.changed.tick, 2); + }, + |(mut a, b, c, d)| { + for tick in a.ticks.added_ticks.iter() { + assert_eq!(tick.tick, 1); + } + + for tick in a.ticks.changed_ticks.iter() { + assert_eq!(tick.tick, 1); + } + + assert_eq!( + *a, + [ + A(4 * batch_counter), + A(4 * (batch_counter + 1)), + A(4 * (batch_counter + 2)), + A(4 * (batch_counter + 3)) + ] + ); + + for (i, mut a_elem) in a.iter_mut().enumerate() { + a_elem.0 += b[i].0 + c[i].0 + d[i].0; + } + + for tick in a.ticks.added_ticks.iter() { + assert_eq!(tick.tick, 1); + } + + for tick in a.ticks.changed_ticks.iter() { + assert_eq!(tick.tick, 2); + } + + batch_counter += 4; + }, + ); + + assert_eq!(scalar_counter, 3); + assert_eq!(batch_counter, 124); + } + fn system_check(mut q: Query<&A>) { + let mut scalar_counter = 0; + let mut batch_counter = 0; + + q.for_each_batched::<4>( + |a| { + assert_eq!(*a, A(1990 + 16 * scalar_counter)); + + scalar_counter += 1; + }, + |a| { + assert_eq!( + *a, + [ + A(16 * batch_counter + 6), + A(16 * (batch_counter + 1) + 6), + A(16 * (batch_counter + 2) + 6), + A(16 * (batch_counter + 3) + 6) + ] + ); + + batch_counter += 4; + }, + ); + } + + world.increment_change_tick(); + + let mut system_compute = IntoSystem::into_system(system_compute); + system_compute.initialize(&mut world); + system_compute.run((), &mut world); + + let mut system_check = IntoSystem::into_system(system_check); + system_check.initialize(&mut world); + system_check.run((), &mut world); + } + + #[test] + fn batched_queries_zst() { + let mut world = World::new(); + + world.spawn_batch((0..127).into_iter().map(|_| E)); + + fn system_compute(mut q: Query<&mut E>) { + q.for_each_mut_batched::<4>(|mut e| *e = E, |mut e| e[1] = E); + } + + let mut system_compute = IntoSystem::into_system(system_compute); + system_compute.initialize(&mut world); + system_compute.run((), &mut world); + } + #[test] fn mut_to_immut_query_methods_have_immut_item() { #[derive(Component)] diff --git a/crates/bevy_ecs/src/query/state.rs b/crates/bevy_ecs/src/query/state.rs index b7b163a2f5340..052d0dfc806c6 100644 --- a/crates/bevy_ecs/src/query/state.rs +++ b/crates/bevy_ecs/src/query/state.rs @@ -10,13 +10,17 @@ use crate::{ storage::{TableId, TableRow}, world::{World, WorldId}, }; +use bevy_ptr::ThinSlicePtr; use bevy_tasks::ComputeTaskPool; #[cfg(feature = "trace")] use bevy_utils::tracing::Instrument; use fixedbitset::FixedBitSet; use std::{borrow::Borrow, fmt, mem::MaybeUninit}; -use super::{NopWorldQuery, QueryManyIter, ROQueryItem, ReadOnlyWorldQuery}; +use super::{ + NopWorldQuery, QueryBatch, QueryItem, QueryManyIter, ROQueryBatch, ROQueryItem, + ReadOnlyWorldQuery, WorldQueryBatch, +}; /// Provides scoped access to a [`World`] state according to a given [`WorldQuery`] and query filter. #[repr(C)] @@ -778,6 +782,29 @@ impl QueryState { } } + /// A read-only version of [`for_each_mut_batched`](Self::for_each_mut_batched). Detailed docs can be found there regarding how to use this function. + #[inline] + pub fn for_each_batched<'w, const N: usize>( + &'w mut self, + world: &'w mut World, + func: impl FnMut(ROQueryItem<'w, Q>), + func_batch: impl FnMut(ROQueryBatch<'w, Q, N>), + ) where + ::ReadOnly: WorldQueryBatch, + { + // SAFETY: query has unique world access + unsafe { + self.update_archetypes(world); + self.as_readonly().for_each_unchecked_manual_batched( + world, + func, + func_batch, + world.last_change_tick(), + world.read_change_tick(), + ); + } + } + /// Runs `func` on each query result for the given [`World`]. This is faster than the equivalent /// `iter_mut()` method, but cannot be chained like a normal [`Iterator`]. #[inline] @@ -789,6 +816,40 @@ impl QueryState { self.for_each_unchecked_manual(world, func, world.last_change_tick(), change_tick); } } + /// This is a batched version of [`for_each_mut`](Self::for_each_mut) that accepts a batch size `N`. + /// The advantage of using batching in queries is that it enables SIMD acceleration of your code to help you meet your performance goals. + /// This function accepts two arguments, `func`, and `func_batch` which represent the "scalar" and "vector" (or "batched") paths of your code respectively. + /// + /// ## Usage: + /// + /// * `N` must be a power of 2 + /// * `func` functions exactly as does in [`for_each_mut`](Self::for_each_mut) -- it receives "scalar" (non-batched) components. + /// * `func_batch` receives [`Batch`](bevy_ptr::Batch)es of `N` components. + /// + /// In other words, `func_batch` composes the "fast path" of your query, and `func` is the "slow path". + /// + /// See [`Query::for_each_mut_batched`](crate::system::Query::for_each_mut_batched) for a complete example of how to use this function. + #[inline] + pub fn for_each_mut_batched<'w, const N: usize>( + &'w mut self, + world: &'w mut World, + func: impl FnMut(QueryItem<'w, Q>), + func_batch: impl FnMut(QueryBatch<'w, Q, N>), + ) where + Q: WorldQueryBatch, + { + // SAFETY: query has unique world access + unsafe { + self.update_archetypes(world); + self.for_each_unchecked_manual_batched( + world, + func, + func_batch, + world.last_change_tick(), + world.read_change_tick(), + ); + } + } /// Runs `func` on each query result for the given [`World`]. This is faster than the equivalent /// iter() method, but cannot be chained like a normal [`Iterator`]. @@ -911,6 +972,138 @@ impl QueryState { } } + // TODO: when generic const expressions are stable, can encode batch alignment using this simple formula: + // gcd::euclid_usize(crate::ptr::batch::MAX_SIMD_ALIGNMENT, N * core::mem::size_of::()); + // This will only benefit architectures that encode alignment in opcodes or have operand alignment restrictions (SSE4.2 and below mainly) + #[inline] + pub(crate) unsafe fn for_each_unchecked_manual_batched< + 'w, + const N: usize, + FN: FnMut(QueryItem<'w, Q>), + FnBatch: FnMut(QueryBatch<'w, Q, N>), + >( + &self, + world: &'w World, + mut func: FN, + mut func_batch: FnBatch, + last_change_tick: u32, + change_tick: u32, + ) where + Q: WorldQueryBatch, + { + // NOTE: If you are changing query iteration code, remember to update the following places, where relevant: + // QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual + let mut fetch = Q::init_fetch(world, &self.fetch_state, last_change_tick, change_tick); + let mut filter = F::init_fetch(world, &self.filter_state, last_change_tick, change_tick); + + // Can't use this because it captures a mutable reference to fetch and filter + + let serial_portion = |entities: ThinSlicePtr<'w, Entity>, + fetch: &mut Q::Fetch<'w>, + filter: &mut F::Fetch<'w>, + func: &mut FN, + range| { + for table_index in range { + let entity = entities.get(table_index); + let row = TableRow::new(table_index); + if !F::filter_fetch(filter, *entity, row) { + continue; + } + let item = Q::fetch(fetch, *entity, row); + func(item); + } + }; + + let tables = &world.storages().tables; + if Q::IS_DENSE && F::IS_DENSE { + for table_id in &self.matched_table_ids { + let table = &tables[*table_id]; + let entities = ThinSlicePtr::from(table.entities()); + Q::set_table(&mut fetch, &self.fetch_state, table); + F::set_table(&mut filter, &self.filter_state, table); + + let mut table_index = 0; + + let batch_end = table.batchable_region_end::(); + + while table_index < batch_end { + // TODO PERF: since both the Query and the Filter are dense, can this be precomputed? + // NOTE: if F = (), this optimizes right out, so don't worry about performance in that case. + let mut unbatchable = None; + for i in 0..N { + let table_row = table_index + i; + let entity = entities.get(table_row); + + if !F::filter_fetch(&mut filter, *entity, TableRow::new(table_row)) { + // Cannot do a full batch, fallback to scalar. + // Already checked the filter against everything up until now. + // Therefore, do an *unchecked* serial portion. + for p in table_index..table_row { + let row = TableRow::new(p); + let entity = entities.get(row.index()); + let item = Q::fetch(&mut fetch, *entity, row); + func(item); + } + + // Handle the rest after + unbatchable = Some(table_row..table_index + N); + break; + } + } + + if let Some(rest) = unbatchable { + serial_portion(entities, &mut fetch, &mut filter, &mut func, rest); + } else { + // TODO PERF: assume likely/hot path + let row = TableRow::new(table_index); + let entity_batch = entities.get_batch(row.index(), table.entity_count()); + + let batch = + Q::fetch_batched(&mut fetch, entity_batch, row, table.entity_count()); + func_batch(batch); + } + + table_index += N; + } + + //EPILOGUE: + serial_portion( + entities, + &mut fetch, + &mut filter, + &mut func, + batch_end..table.entity_count(), + ); + } + } else { + // TODO: accelerate with batching, but first need to figure out if it's worth trying to batch sparse queries + let archetypes = &world.archetypes; + for archetype_id in &self.matched_archetype_ids { + let archetype = archetypes.get(*archetype_id).debug_checked_unwrap(); + let table = tables.get(archetype.table_id()).debug_checked_unwrap(); + Q::set_archetype(&mut fetch, &self.fetch_state, archetype, table); + F::set_archetype(&mut filter, &self.filter_state, archetype, table); + + let entities = archetype.entities(); + for idx in 0..archetype.len() { + let archetype_entity = entities.get_unchecked(idx); + if !F::filter_fetch( + &mut filter, + archetype_entity.entity(), + archetype_entity.table_row(), + ) { + continue; + } + func(Q::fetch( + &mut fetch, + archetype_entity.entity(), + archetype_entity.table_row(), + )); + } + } + } + } + /// Runs `func` on each query result in parallel for the given [`World`], where the last change and /// the current change tick are given. This is faster than the equivalent /// iter() method, but cannot be chained like a normal [`Iterator`]. diff --git a/crates/bevy_ecs/src/storage/aligned_vec.rs b/crates/bevy_ecs/src/storage/aligned_vec.rs new file mode 100644 index 0000000000000..e748ef27b1b9b --- /dev/null +++ b/crates/bevy_ecs/src/storage/aligned_vec.rs @@ -0,0 +1,250 @@ +use core::alloc::Layout; +use core::borrow::{Borrow, BorrowMut}; +use core::marker::PhantomData; +use core::mem::needs_drop; +use core::ops::{Deref, DerefMut}; + +use core::cmp; +use core::slice::SliceIndex; + +use crate::ptr::OwningPtr; + +use super::blob_vec::BlobVec; + +/// A vector whose internal buffer is aligned to `MAX_SIMD_ALIGNMENT`. +/// Intended to support SIMD use cases. Aligning the data to `MAX_SIMD_ALIGNMENT` +/// allows for best-case alignment on accesses, which helps performance when using batched +/// queries. +/// +/// Used to densely store homogeneous ECS data whose type is known at compile time. +/// Built on `BlobVec`. It is not intended to be a drop-in replacement for Vec at this time. + +/* +NOTE: AlignedVec is ONLY implemented in terms of BlobVec because the Allocator API is not stable yet. +Once the Allocator API is stable, one could easily define AlignedVec as being a Vec with an allocator +that provides MAX_SIMD_ALIGNMENT as a guarantee, and remove almost all of the code in this file: + + type AlignedVec = Vec; + +As it stands, AlignedVec is a stand-in to provide just enough functionality to work for bevy_ecs. +*/ +pub(crate) struct SimdAlignedVec { + vec: BlobVec, + _marker: PhantomData, +} + +impl Default for SimdAlignedVec { + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Debug for SimdAlignedVec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AlignedVec") + .field("vec", &self.vec) + .finish() + } +} + +impl SimdAlignedVec { + // SAFETY: The pointer points to a valid value of type `T` and it is safe to drop this value. + unsafe fn drop_ptr(x: OwningPtr<'_>) { + x.drop_as::(); + } + + pub fn with_capacity(capacity: usize) -> SimdAlignedVec { + Self { + // SAFETY: + // `drop` accurately reflects whether the contents of this Vec need to be dropped, and correctly performs the drop operation. + vec: unsafe { + BlobVec::new( + Layout::new::(), + needs_drop::().then_some(Self::drop_ptr as _), + capacity, + ) + }, + _marker: PhantomData, + } + } + + pub fn new() -> SimdAlignedVec { + Self::with_capacity(0) //Ensure a starting power-of-two capacity (for non-ZSTs) + } + + #[inline] + pub fn len(&self) -> usize { + self.vec.len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.vec.len() == 0 + } + + #[inline] + pub fn capacity(&self) -> usize { + self.vec.capacity() + } + + /// # Safety + /// It is the caller's responsibility to ensure that `index` is < self.len() + #[inline] + pub unsafe fn get_unchecked(&self, index: usize) -> &>::Output { + debug_assert!(index < self.len()); + + self.vec.get_unchecked(index).deref() + } + + /// # Safety + /// It is the caller's responsibility to ensure that `index` is < self.len() + #[inline] + pub unsafe fn get_unchecked_mut( + &mut self, + index: usize, + ) -> &mut >::Output { + debug_assert!(index < self.len()); + + self.vec.get_unchecked_mut(index).deref_mut() + } + + //This function attempts to keep the same semantics as Vec's swap_remove function + pub fn swap_remove(&mut self, index: usize) -> T { + #[cold] + #[inline(never)] + fn assert_failed(index: usize, len: usize) -> ! { + panic!("swap_remove index (is {index}) should be < len (is {len})"); + } + let len = self.len(); + if index >= len { + assert_failed(index, len); + } + + // SAFETY: + // The index is guaranteed to be in bounds by this point. + unsafe { self.vec.swap_remove_and_forget_unchecked(index).read() } + } + + pub fn push(&mut self, value: T) { + // SAFETY: + // value is a valid owned instance of T, therefore it is safe to call push with it + OwningPtr::make(value, |ptr| unsafe { + self.vec.push(ptr); + }); + } + + pub fn reserve_exact(&mut self, additional: usize) { + self.vec.reserve_exact(additional); + } + + // From RawVec soruce code, for compatibility + const MIN_NON_ZERO_CAP: usize = if core::mem::size_of::() == 1 { + 8 + } else if core::mem::size_of::() <= 1024 { + 4 + } else { + 1 + }; + + //This function attempts to keep the same semantics as Vec's reserve function + pub fn reserve(&mut self, additional: usize) { + if core::mem::size_of::() == 0 { + // Since we return a capacity of `usize::MAX` when `elem_size` is + // 0, getting to here necessarily means the `AlignedVec` is overfull. + panic!("AlignedVec capacity overflow") + } + + // Nothing we can really do about these checks, sadly. + let required_cap = self.vec.len().checked_add(additional); + + if let Some(cap) = required_cap { + // This guarantees exponential growth. The doubling cannot overflow + // because `cap <= isize::MAX` and the type of `cap` is `usize`. + let cap = cmp::max(self.vec.capacity() * 2, cap); + let cap = cmp::max(Self::MIN_NON_ZERO_CAP, cap); + + self.reserve_exact(cap - self.vec.len()); + } else { + panic!("AlignedVec capacity overflow") + } + } + + pub fn clear(&mut self) { + self.vec.clear(); + } +} + +impl Borrow<[T]> for SimdAlignedVec { + fn borrow(&self) -> &[T] { + self + } +} + +impl BorrowMut<[T]> for SimdAlignedVec { + fn borrow_mut(&mut self) -> &mut [T] { + self + } +} + +impl AsRef<[T]> for SimdAlignedVec { + fn as_ref(&self) -> &[T] { + self + } +} + +impl AsMut<[T]> for SimdAlignedVec { + fn as_mut(&mut self) -> &mut [T] { + self + } +} + +impl Deref for SimdAlignedVec { + type Target = [T]; + + #[inline] + fn deref(&self) -> &[T] { + // SAFETY: + // The vector represents an array of T with appropriate alignment. + // The vector is borrowed with an shared reference, guaranteeing only other shared references exist. + // Therefore, it is safe to provide a shared reference to its contents. + unsafe { + std::slice::from_raw_parts(self.vec.get_ptr().as_ptr() as *const T, self.vec.len()) + } + } +} + +impl DerefMut for SimdAlignedVec { + #[inline] + fn deref_mut(&mut self) -> &mut [T] { + // SAFETY: + // The vector represents an array of T with appropriate alignment. + // The vector is borrowed with a mutable reference, guaranteeing uniqueness. + // Therefore, it is safe to provide a mutable reference to its contents. + unsafe { + core::slice::from_raw_parts_mut( + self.vec.get_ptr_mut().as_ptr() as *mut T, + self.vec.len(), + ) + } + } +} + +impl<'a, T> IntoIterator for &'a mut SimdAlignedVec { + type Item = <&'a mut [T] as IntoIterator>::Item; + + type IntoIter = <&'a mut [T] as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.as_mut().iter_mut() + } +} + +impl<'a, T> IntoIterator for &'a SimdAlignedVec { + type Item = <&'a [T] as IntoIterator>::Item; + + type IntoIter = <&'a [T] as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.as_ref().iter() + } +} diff --git a/crates/bevy_ecs/src/storage/blob_vec.rs b/crates/bevy_ecs/src/storage/blob_vec.rs index 49b370b7b8e3a..d5ed7cb073688 100644 --- a/crates/bevy_ecs/src/storage/blob_vec.rs +++ b/crates/bevy_ecs/src/storage/blob_vec.rs @@ -7,6 +7,27 @@ use std::{ use bevy_ptr::{OwningPtr, Ptr, PtrMut}; use bevy_utils::OnDrop; +/// The maximum SIMD alignment for a given target. +/// `MAX_SIMD_ALIGNMENT` is 64 for the following reasons: +/// 1. This ensures that table columns are aligned to cache lines on x86 +/// 2. 64 is the maximum alignment required to use all instructions on all known CPU architectures. +/// This simplifies greatly handling cross platform alignment on a case by case basis; by aligning to the worst case, we align for all cases +/// 3. The overhead of aligning columns to 64 bytes is very small as columns will in general be much larger than this +pub const MAX_SIMD_ALIGNMENT: usize = 64; //Must be greater than zero! + +/* +If this is a problem, this can be replaced with code that looks something like the following: + + #[cfg(all(any(target_feature = "avx"), not(target_feature = "avx512f")))] + pub const MAX_SIMD_ALIGNMENT: usize = 32; + + #[cfg(any(target_feature = "avx512f"))] + pub const MAX_SIMD_ALIGNMENT: usize = 64; + + //All platforms get 16-byte alignment on tables guaranteed. + #[cfg(not(any(target_feature = "avx512f")))] + pub const MAX_SIMD_ALIGNMENT: usize = 16; +*/ /// A flat, type-erased data storage type /// @@ -48,10 +69,9 @@ impl BlobVec { capacity: usize, ) -> BlobVec { let align = NonZeroUsize::new(item_layout.align()).expect("alignment must be > 0"); - let data = bevy_ptr::dangling_with_align(align); if item_layout.size() == 0 { BlobVec { - data, + data: bevy_ptr::dangling_with_align(align), capacity: usize::MAX, len: 0, item_layout, @@ -59,7 +79,9 @@ impl BlobVec { } } else { let mut blob_vec = BlobVec { - data, + data: bevy_ptr::dangling_with_align( + align.max(NonZeroUsize::new_unchecked(MAX_SIMD_ALIGNMENT)), // Want internal array of BlobVec to be at least aligned to MAX_SIMD_ALIGNMENT for performance. + ), capacity: 0, len: 0, item_layout, @@ -328,7 +350,7 @@ impl BlobVec { unsafe { PtrMut::new(self.data) } } - /// Get a reference to the entire [`BlobVec`] as if it were an array with elements of type `T` + /// Get a reference to the entire [`BlobVec`] as if it were an array with elements of type `T`. /// /// # Safety /// The type `T` must be the type of the items in this [`BlobVec`]. @@ -378,6 +400,9 @@ impl Drop for BlobVec { fn array_layout(layout: &Layout, n: usize) -> Option { let (array_layout, offset) = repeat_layout(layout, n)?; debug_assert_eq!(layout.size(), offset); + + //Note: NEEDED for batching. This is the layout of the array itself, not the layout of its elements. + let array_layout = array_layout.align_to(MAX_SIMD_ALIGNMENT).unwrap(); Some(array_layout) } diff --git a/crates/bevy_ecs/src/storage/mod.rs b/crates/bevy_ecs/src/storage/mod.rs index b2fab3fdcb590..085a9a842cc54 100644 --- a/crates/bevy_ecs/src/storage/mod.rs +++ b/crates/bevy_ecs/src/storage/mod.rs @@ -1,5 +1,6 @@ //! Storage layouts for ECS data. +pub(super) mod aligned_vec; mod blob_vec; mod resource; mod sparse_set; diff --git a/crates/bevy_ecs/src/storage/sparse_set.rs b/crates/bevy_ecs/src/storage/sparse_set.rs index d145122c33456..099f3bdbe4bb7 100644 --- a/crates/bevy_ecs/src/storage/sparse_set.rs +++ b/crates/bevy_ecs/src/storage/sparse_set.rs @@ -378,7 +378,7 @@ impl Default for SparseSet { } impl SparseSet { - pub const fn new() -> Self { + pub fn new() -> Self { Self { dense: Vec::new(), indices: Vec::new(), diff --git a/crates/bevy_ecs/src/storage/table.rs b/crates/bevy_ecs/src/storage/table.rs index fe2fd6cd783ba..c7ac0678a7fd3 100644 --- a/crates/bevy_ecs/src/storage/table.rs +++ b/crates/bevy_ecs/src/storage/table.rs @@ -1,10 +1,10 @@ use crate::{ component::{ComponentId, ComponentInfo, ComponentTicks, Components, Tick, TickCells}, entity::Entity, + ptr::{OwningPtr, Ptr, PtrMut, UnsafeCellDeref}, query::DebugCheckedUnwrap, - storage::{blob_vec::BlobVec, ImmutableSparseSet, SparseSet}, + storage::{aligned_vec::SimdAlignedVec, blob_vec::BlobVec, ImmutableSparseSet, SparseSet}, }; -use bevy_ptr::{OwningPtr, Ptr, PtrMut, UnsafeCellDeref}; use bevy_utils::HashMap; use std::alloc::Layout; use std::{ @@ -89,8 +89,8 @@ impl TableRow { #[derive(Debug)] pub struct Column { data: BlobVec, - added_ticks: Vec>, - changed_ticks: Vec>, + added_ticks: SimdAlignedVec>, + changed_ticks: SimdAlignedVec>, } impl Column { @@ -99,8 +99,8 @@ impl Column { Column { // SAFETY: component_info.drop() is valid for the types that will be inserted. data: unsafe { BlobVec::new(component_info.layout(), component_info.drop(), capacity) }, - added_ticks: Vec::with_capacity(capacity), - changed_ticks: Vec::with_capacity(capacity), + added_ticks: SimdAlignedVec::with_capacity(capacity), + changed_ticks: SimdAlignedVec::with_capacity(capacity), } } @@ -408,7 +408,7 @@ impl TableBuilder { pub fn build(self) -> Table { Table { columns: self.columns.into_immutable(), - entities: Vec::with_capacity(self.capacity), + entities: SimdAlignedVec::with_capacity(self.capacity), } } } @@ -427,7 +427,7 @@ impl TableBuilder { /// [`World`]: crate::world::World pub struct Table { columns: ImmutableSparseSet, - entities: Vec, + entities: SimdAlignedVec, } impl Table { @@ -612,6 +612,13 @@ impl Table { self.entities.is_empty() } + #[inline] + /// Returns the end index, exclusive, of the batchable region of this table with batch size `N`. + /// For example, if N = 8, and the table has 12 rows, then this would return `8`. + pub fn batchable_region_end(&self) -> usize { + (self.entity_count() / N) * N + } + pub(crate) fn check_change_ticks(&mut self, change_tick: u32) { for column in self.columns.values_mut() { column.check_change_ticks(change_tick); diff --git a/crates/bevy_ecs/src/system/query.rs b/crates/bevy_ecs/src/system/query.rs index 4d3a12ed49bc7..1560339379f9f 100644 --- a/crates/bevy_ecs/src/system/query.rs +++ b/crates/bevy_ecs/src/system/query.rs @@ -2,8 +2,9 @@ use crate::{ component::Component, entity::Entity, query::{ - BatchingStrategy, QueryCombinationIter, QueryEntityError, QueryIter, QueryManyIter, - QueryParIter, QuerySingleError, QueryState, ROQueryItem, ReadOnlyWorldQuery, WorldQuery, + BatchingStrategy, QueryBatch, QueryCombinationIter, QueryEntityError, QueryItem, QueryIter, + QueryManyIter, QueryParIter, QuerySingleError, QueryState, ROQueryBatch, ROQueryItem, + ReadOnlyWorldQuery, WorldQuery, WorldQueryBatch, }, world::{Mut, World}, }; @@ -696,6 +697,28 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> Query<'w, 's, Q, F> { }; } + /// See [`QueryState::for_each_batched`](QueryState::for_each_batched) for how to use this function. + #[inline] + pub fn for_each_batched<'a, const N: usize>( + &'a mut self, + func: impl FnMut(ROQueryItem<'a, Q>), + func_batch: impl FnMut(ROQueryBatch<'a, Q, N>), + ) where + ::ReadOnly: WorldQueryBatch, + { + // SAFETY: system runs without conflicts with other systems. same-system queries have runtime + // borrow checks when they conflict + unsafe { + self.state.as_readonly().for_each_unchecked_manual_batched( + self.world, + func, + func_batch, + self.last_change_tick, + self.change_tick, + ); + }; + } + /// Runs `f` on each query item. /// /// # Example @@ -734,6 +757,271 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> Query<'w, 's, Q, F> { }; } + /// This is a "batched" version of [`for_each_mut`](Self::for_each_mut) that accepts a batch size `N`, which should be a power of two. + /// The advantage of using batching in queries is that it enables SIMD acceleration (vectorization) of your code to help you meet your performance goals. + /// This function accepts two arguments, `func`, and `func_batch` which represent the "scalar" and "vector" (or "batched") paths of your code respectively. + /// Each "batch" contains `N` query results, in order. **Consider enabling AVX if you are on x86 when using this API**. + /// + /// # A very brief introduction to SIMD + /// + /// SIMD, or Single Instruction, Multiple Data, is a paradigm that allows a single instruction to operate on multiple datums in parallel. + /// It is most commonly seen in "vector" instruction set extensions such as AVX and NEON, where it is possible to, for example, add + /// two arrays of `[f32; 4]` together in a single instruction. When used appropriately, SIMD is a very powerful tool that can greatly accelerate certain types of workloads. + /// An introductory treatment of SIMD can be found [on Wikipedia](https://en.wikipedia.org/wiki/Single_instruction,_multiple_data) for interested readers. + /// + /// [Vectorization](https://stackoverflow.com/questions/1422149/what-is-vectorization) is an informal term to describe optimizing code to leverage these SIMD instruction sets. + /// + /// # When should I consider batching for my query? + /// + /// The first thing you should consider is if you are meeting your performance goals. Batching a query is fundamentally an optimization, and if your application is meeting performance requirements + /// already, then (other than for your own entertainment) you won't get much benefit out of batching. If you are having performance problems though, the next step is to + /// use a [profiler](https://nnethercote.github.io/perf-book/profiling.html) to determine the running characteristics of your code. + /// If, after profiling your code, you have determined that a substantial amount of time is being processing a query, and it's hindering your performance goals, + /// then it might be worth it to consider batching to meet them. + /// + /// One of the main tradeoffs with batching your queries is that there will be an increased complexity from maintaining both code paths: `func` and `func_batch` + /// semantically should be doing the same thing, and it should always be possible to interchange them without visible program effects. + /// + /// # Getting maximum performance for your application + /// + /// Bevy aims to provide best-in-class performance for architectures that do not encode alignment into SIMD instructions. This includes (but is not limited to) AVX and onwards for x86, + /// ARM 64-bit, RISC-V, and `WebAssembly`g. The majority of architectures created since 2010 have this property. It is important to note that although these architectures + /// do not encode alignment in the instruction set itself, they still benefit from memory operands being naturally aligned. + /// + /// On other instruction sets, code generation may be worse than it could be due to alignment information not being statically available in batches. + /// For example, 32-bit ARM NEON instructions encode an alignment hint that is not present in the 64-bit ARM versions, + /// and this hint will be set to "unaligned" even if the data itself is aligned. Whether this incurs a performance penalty is implementation defined. + /// + /// **As a result, it is recommended, if you are on x86, to enable at minimum AVX support for maximum performance when executing batched queries**. + /// It's a good idea to enable in general, too. SSE4.2 and below will have slightly worse performance as more unaligned loads will be produced, + /// with work being done in registers, since it requires memory operands to be aligned whereas AVX relaxes this restriction. + /// + /// To enable AVX support for your application, add "-C target-feature=+avx" to your `RUSTFLAGS`. See the [Rust docs](https://doc.rust-lang.org/cargo/reference/config.html) + /// for details on how to set this as a default for your project. + /// + /// When the `generic_const_exprs` feature of Rust is stable, Bevy will be able to encode the alignment of the batch into the batch itself and provide maximum performance + /// on architectures that encode alignment into SIMD instruction opcodes as well. + /// + /// # What kinds of queries make sense to batch? + /// + /// Usually math related ones. Anything involving floats is a possible candidate. Depending on your component layout, you may need to perform a data layout conversion + /// to batch the query optimally. This Wikipedia page on ["array of struct" and "struct of array" layouts](https://en.wikipedia.org/wiki/AoS_and_SoA) is a good starter on + /// this topic, as is this [Intel blog post](https://www.intel.com/content/www/us/en/developer/articles/technical/memory-layout-transformations.html). + /// + /// Vectorizing code can be a very deep subject to get into. + /// Sometimes it can be very straightfoward to accomplish what you want to do, and other times it takes a bit of playing around to make your problem fit the SIMD model. + /// + /// # Will batching always make my queries faster? + /// + /// Unfortunately it will not. A suboptimally written batched query will probably perform worse than a straightforward `for_each_mut` query. Data layout conversion, + /// for example, carries overhead that may not always be worth it. Fortunately, your profiler can help you identify these situations. + /// + /// Think of batching as a tool in your performance toolbox rather than the preferred way of writing your queries. + /// + /// # What kinds of queries are batched right now? + /// + /// Currently, only "Dense" queries are actually batched; other queries will only use `func` and never call `func_batch`. This will improve + /// in the future. + /// + /// # Usage: + /// + /// * `N` should be a power of 2, and ideally be a multiple of your SIMD vector size. + /// * `func_batch` receives [`Batch`](bevy_ptr::Batch)es of `N` components. + /// * `func` functions exactly as does in [`for_each_mut`](Self::for_each_mut) -- it receives "scalar" (non-batched) components. + /// + /// In other words, `func_batch` composes the "fast path" of your query, and `func` is the "slow path". + /// + /// In general, when using this function, be mindful of the types of filters being used with your query, as these can fragment your batches + /// and cause the scalar path to be taken more often. + /// + /// **Note**: It is well known that [`array::map`](https://doc.rust-lang.org/std/primitive.array.html#method.map) optimizes poorly at the moment. + /// Avoid using it until the upstream issues are resolved: [#86912](https://github.com/rust-lang/rust/issues/86912) and [#102202](https://github.com/rust-lang/rust/issues/102202). + /// Manually unpack your batches in the meantime for optimal codegen. + /// + /// **Note**: It is always valid for the implementation of this function to only call `func`. Currently, batching is only supported for "Dense" queries. + /// Calling this function on any other query type will result in only the slow path being executed (e.g., queries with Sparse components.) + /// More query types may become batchable in the future. + /// + /// **Note**: Although this function provides the groundwork for writing performance-portable SIMD-accelerated queries, you will still need to take into account + /// your target architecture's capabilities. The batch size will likely need to be tuned for your application, for example. + /// When SIMD becomes stabilized in Rust, it will be possible to write code that is generic over the batch width, but some degree of tuning will likely always be + /// necessary. Think of this as a tool at your disposal to meet your performance goals. + /// + /// The following is an example of using batching to accelerate a simplified collision detection system. It is written using x86 AVX intrinsics, since `std::simd` is not stable + /// yet. You can, of course, use `std::simd` in your own code if you prefer, or adapt this example to other instruction sets. + /// ``` + /// # use bevy_ecs::prelude::*; + /// # use bevy_math::Vec3; + /// use core::arch::x86_64::*; + /// + /// #[derive(Component)] + /// struct Position(Vec3); + /// + /// #[derive(Component)] + /// #[repr(transparent)] + /// struct Health(f32); + /// + /// // A plane describing solid geometry, (x,y,z) = n with d such that nx + d = 0 + /// #[derive(Component)] + /// struct Wall(Vec3, f32); + /// + /// // AoS to SoA data layout conversion for x86 AVX. + /// // This code has been adapted from: + /// // https://www.intel.com/content/dam/develop/external/us/en/documents/normvec-181650.pdf + /// #[inline(always)] + /// // This example is written in a way that benefits from inlined data layout conversion. + /// fn aos_to_soa_83(aos_inner: &[Vec3; 8]) -> [__m256; 3] { + /// unsafe { + /// //# SAFETY: Vec3 is repr(C) for x86_64 + /// let mx0 = _mm_loadu_ps((aos_inner as *const Vec3 as *const f32).offset(0)); + /// let mx1 = _mm_loadu_ps((aos_inner as *const Vec3 as *const f32).offset(4)); + /// let mx2 = _mm_loadu_ps((aos_inner as *const Vec3 as *const f32).offset(8)); + /// let mx3 = _mm_loadu_ps((aos_inner as *const Vec3 as *const f32).offset(12)); + /// let mx4 = _mm_loadu_ps((aos_inner as *const Vec3 as *const f32).offset(16)); + /// let mx5 = _mm_loadu_ps((aos_inner as *const Vec3 as *const f32).offset(20)); + /// + /// let mut m03 = _mm256_castps128_ps256(mx0); // load lower halves + /// let mut m14 = _mm256_castps128_ps256(mx1); + /// let mut m25 = _mm256_castps128_ps256(mx2); + /// m03 = _mm256_insertf128_ps(m03, mx3, 1); // load upper halves + /// m14 = _mm256_insertf128_ps(m14, mx4, 1); + /// m25 = _mm256_insertf128_ps(m25, mx5, 1); + /// + /// let xy = _mm256_shuffle_ps::<0b10011110>(m14, m25); // upper x's and y's + /// let yz = _mm256_shuffle_ps::<0b01001001>(m03, m14); // lower y's and z's + /// let x = _mm256_shuffle_ps::<0b10001100>(m03, xy); + /// let y = _mm256_shuffle_ps::<0b11011000>(yz, xy); + /// let z = _mm256_shuffle_ps::<0b11001101>(yz, m25); + /// [x, y, z] + /// } + ///} + /// + /// // Perform collision detection against a set of Walls, forming a convex polygon. + /// // Each entity has a Position and some Health (initialized to 100.0). + /// // If the position of an entity is found to be outside of a Wall, decrement its "health" by 1.0. + /// // The effect is cumulative based on the number of walls. + /// // An entity entirely inside the convex polygon will have its health remain unchanged. + /// fn batched_collision_detection_system(mut pos_healths: Query<(&Position, &mut Health)>, + /// walls: Query<&Wall>) { + /// + /// // Conceptually, this system is executed using two loops: the outer "batched" loop receiving + /// // batches of 8 Positions and Health components at a time, and the inner loop iterating over + /// // the Walls. + /// + /// // There's more than one way to vectorize this system -- this example may not be optimal. + /// pos_healths.for_each_mut_batched::<8>( + /// |(position, mut health)| { + /// // This forms the scalar path: it behaves just like `for_each_mut`. + /// + /// // Optional: disable change detection for more performance. + /// let health = &mut health.bypass_change_detection().0; + /// + /// // Test each (Position,Health) against each Wall. + /// walls.for_each(|wall| { + /// let plane = wall.0; + /// + /// // Test which side of the wall we are on + /// let dotproj = plane.dot(position.0); + /// + /// // Test against the Wall's displacement/discriminant value + /// if dotproj < wall.1 { + /// //Ouch! Take damage! + /// *health -= 1.0; + /// } + /// }); + /// }, + /// |(positions, mut healths)| { + /// // This forms the vector path: the closure receives a batch of + /// // 8 Positions and 8 Healths as arrays. + /// + /// // Optional: disable change detection for more performance. + /// let healths = healths.bypass_change_detection(); + /// + /// // Treat the Health batch as a batch of 8 f32s. + /// unsafe { + /// // # SAFETY: Health is repr(transprent)! + /// let healths_raw = healths as *mut Health as *mut f32; + /// let mut healths = _mm256_loadu_ps(healths_raw); + /// + /// // NOTE: array::map optimizes poorly -- it is recommended to unpack your arrays + /// // manually as shown to avoid spurious copies which will impact your performance. + /// let [p0, p1, p2, p3, p4, p5, p6, p7] = positions; + /// + /// // Perform data layout conversion from AoS to SoA. + /// // ps_x will receive all of the X components of the positions, + /// // ps_y will receive all of the Y components + /// // and ps_z will receive all of the Z's. + /// let [ps_x, ps_y, ps_z] = + /// aos_to_soa_83(&[p0.0, p1.0, p2.0, p3.0, p4.0, p5.0, p6.0, p7.0]); + /// + /// // Iterate over each wall without batching. + /// walls.for_each(|wall| { + /// // Test each wall against all 8 positions at once. The "broadcast" intrinsic + /// // helps us achieve this by duplicating the Wall's X coordinate over an entire + /// // vector register, e.g., [X X ... X]. The same goes for the Wall's Y and Z + /// // coordinates. + /// + /// // This is the exact same formula as implemented in the scalar path, but + /// // modified to be calculated in parallel across each lane. + /// + /// // Multiply all of the X coordinates of each Position against Wall's Normal X + /// let xs_dot = _mm256_mul_ps(ps_x, _mm256_broadcast_ss(&wall.0.x)); + /// // Multiply all of the Y coordinates of each Position against Wall's Normal Y + /// let ys_dot = _mm256_mul_ps(ps_y, _mm256_broadcast_ss(&wall.0.y)); + /// // Multiply all of the Z coordinates of each Position against Wall's Normal Z + /// let zs_dot = _mm256_mul_ps(ps_z, _mm256_broadcast_ss(&wall.0.z)); + /// + /// // Now add them together: the result is a vector register containing the dot + /// // product of each Position against the Wall's Normal vector. + /// let dotprojs = _mm256_add_ps(_mm256_add_ps(xs_dot, ys_dot), zs_dot); + /// + /// // Take the Wall's discriminant/displacement value and broadcast it like before. + /// let wall_d = _mm256_broadcast_ss(&wall.1); + /// + /// // Compare each dot product against the Wall's discriminant, using the + /// // "Less Than" relation as we did in the scalar code. + /// // The result will be be either -1 or zero *as an integer*. + /// let cmp = _mm256_cmp_ps::<_CMP_LT_OS>(dotprojs, wall_d); + /// + /// // Convert the integer values back to f32 values (-1.0 or 0.0). + /// // These form the damage values for each entity. + /// let damages = _mm256_cvtepi32_ps(_mm256_castps_si256(cmp)); //-1.0 or 0.0 + /// + /// // Update the healths of each entity being processed with the results of the + /// // collision detection. + /// healths = _mm256_add_ps(healths, damages); + /// }); + /// // Now that all Walls have been processed, write the final updated Health values + /// // for this batch of entities back to main memory. + /// _mm256_storeu_ps(healths_raw, healths); + /// } + /// }, + /// ); + /// } + /// + /// # bevy_ecs::system::assert_is_system(batched_collision_detection_system); + /// ``` + #[inline] + pub fn for_each_mut_batched<'a, const N: usize>( + &'a mut self, + func: impl FnMut(QueryItem<'a, Q>), + func_batch: impl FnMut(QueryBatch<'a, Q, N>), + ) where + Q: WorldQueryBatch, + { + // SAFETY: system runs without conflicts with other systems. same-system queries have runtime + // borrow checks when they conflict + unsafe { + self.state.for_each_unchecked_manual_batched( + self.world, + func, + func_batch, + self.last_change_tick, + self.change_tick, + ); + }; + } + /// Returns a parallel iterator over the query results for the given [`World`]. /// /// This can only be called for read-only queries, see [`par_iter_mut`] for write-queries. diff --git a/crates/bevy_ecs/src/world/world_cell.rs b/crates/bevy_ecs/src/world/world_cell.rs index 73dabe387c0e3..a2d1cd97d9612 100644 --- a/crates/bevy_ecs/src/world/world_cell.rs +++ b/crates/bevy_ecs/src/world/world_cell.rs @@ -38,7 +38,7 @@ impl Default for ArchetypeComponentAccess { const UNIQUE_ACCESS: usize = 0; const BASE_ACCESS: usize = 1; impl ArchetypeComponentAccess { - const fn new() -> Self { + fn new() -> Self { Self { access: SparseSet::new(), } diff --git a/crates/bevy_ptr/src/lib.rs b/crates/bevy_ptr/src/lib.rs index e992cec575c2b..a8fcb7394b173 100644 --- a/crates/bevy_ptr/src/lib.rs +++ b/crates/bevy_ptr/src/lib.rs @@ -52,6 +52,11 @@ pub struct Ptr<'a, A: IsAligned = Aligned>(NonNull, PhantomData<(&'a u8, A)> /// the metadata and able to point to data that does not correspond to a Rust type. pub struct PtrMut<'a, A: IsAligned = Aligned>(NonNull, PhantomData<(&'a mut u8, A)>); +/// An `N`-sized batch of `T` components. The batched query interface makes use of this type. +/// In the future, Batch may have additional alignment information when the `generic_const_exprs` +/// language feature of Rust is stable. +pub type Batch = [T; N]; + /// Type-erased Box-like pointer to some unknown type chosen when constructing this type. /// Conceptually represents ownership of whatever data is being pointed to and so is /// responsible for calling its `Drop` impl. This pointer is _not_ responsible for freeing @@ -344,7 +349,7 @@ impl<'a> OwningPtr<'a, Unaligned> { } } -/// Conceptually equivalent to `&'a [T]` but with length information cut out for performance reasons +/// Conceptually equivalent to `&'a [T]` but with length information cut out for performance reasons. pub struct ThinSlicePtr<'a, T> { ptr: NonNull, #[cfg(debug_assertions)] @@ -353,17 +358,124 @@ pub struct ThinSlicePtr<'a, T> { } impl<'a, T> ThinSlicePtr<'a, T> { + /// # Safety + /// The contents of the slice returned by this function must never be accessed #[inline] + pub unsafe fn dangling() -> Self { + let item_layout = core::alloc::Layout::new::(); + + let dangling = NonNull::new(item_layout.align() as *mut T).unwrap(); + + Self { + ptr: dangling, + #[cfg(debug_assertions)] + len: 0, + _marker: PhantomData, + } + } + /// Indexes the slice without doing bounds checks /// /// # Safety /// `index` must be in-bounds. + #[inline] pub unsafe fn get(self, index: usize) -> &'a T { #[cfg(debug_assertions)] debug_assert!(index < self.len); &*self.ptr.as_ptr().add(index) } + + /// # Safety + /// `index` must be in bounds + /// `index + len` must be in bounds + #[inline] + pub unsafe fn get_slice(self, index: usize, len: usize) -> &'a [T] { + core::slice::from_raw_parts(self.ptr.as_ptr().add(index), len) + } + + /// Indexes the slice without doing bounds checks with a batch size of `N`. + /// + /// # Safety + /// `index` must be in-bounds. + /// `index` must be a multiple of `N`. + #[inline] + unsafe fn get_batch_raw(self, index: usize, _len: usize) -> *const Batch { + #[cfg(debug_assertions)] + debug_assert!(index + N < self.len); + #[cfg(debug_assertions)] + debug_assert_eq!(_len, self.len); + #[cfg(debug_assertions)] + debug_assert_eq!(index % N, 0); + + let off_ptr = self.ptr.as_ptr().add(index); + + // NOTE: ZSTs may cause this "slice" to point into nothingness. + // This sounds dangerous, but won't cause harm as nothing + // will actually access anything "in the slice". + // This is consistent with the semantics of Rust slices. + + // TODO: when pointer_is_aligned is standardized, we can just use ptr::is_aligned() + #[cfg(debug_assertions)] + debug_assert_eq!(off_ptr as usize % core::mem::align_of::>(), 0); + + //SAFETY: off_ptr is not null + off_ptr as *const Batch + } + + /// Indexes the slice without doing bounds checks with a batch size of N. + /// + /// # Safety + /// `index` must be in-bounds. + #[inline] + pub unsafe fn get_batch(self, index: usize, len: usize) -> &'a Batch { + &(*self.get_batch_raw(index, len)) + } +} + +impl<'a, T> ThinSlicePtr<'a, UnsafeCell> { + /// Indexes the slice without doing bounds checks with a batch size of `N`. + /// The semantics are like `UnsafeCell` -- you must ensure the aliasing constraints are met. + /// + /// # Safety + /// `index` must be in-bounds. + /// `index` must be a multiple of `N`. + /// No other references exist to the batch of size `N` at `index` + #[inline] + pub unsafe fn get_batch_deref_mut( + self, + index: usize, + len: usize, + ) -> &'a mut Batch { + &mut *(self.as_deref().get_batch_raw::(index, len) as *mut Batch) + } + + /// Indexes the slice without doing bounds checks with a batch size of `N`. + /// The semantics are like `UnsafeCell` -- you must ensure the aliasing constraints are met. + /// + /// # Safety + /// `index` must be in-bounds. + /// `index` must be a multiple of `N`. + /// No mutable references exist to the batch of size `N` at `index` + #[inline] + pub unsafe fn get_batch_deref( + self, + index: usize, + len: usize, + ) -> &'a Batch { + &*(self.as_deref().get_batch_raw::(index, len)) + } + + /// Get an immutable view of this `ThinSlicePtr`'s contents. Note that this is not a reference type. + #[inline] + pub fn as_deref(self) -> ThinSlicePtr<'a, T> { + ThinSlicePtr::<'a, T> { + ptr: self.ptr.cast::(), + #[cfg(debug_assertions)] + len: self.len, + _marker: PhantomData, + } + } } impl<'a, T> Clone for ThinSlicePtr<'a, T> { @@ -399,6 +511,16 @@ pub fn dangling_with_align(align: NonZeroUsize) -> NonNull { debug_assert!(align.is_power_of_two(), "Alignment must be power of two."); // SAFETY: The pointer will not be null, since it was created // from the address of a `NonZeroUsize`. + + /*NOTE: Dangling pointers still need to be well aligned for the type when using slices (even though they are 0-length). + This is important for [`SimdAlignedVec`] and any function that would return a slice view of this BlobVec. + + Since neither strict_provenance nor alloc_layout_extra is stable, there is no way to construct a NonNull::dangling() + pointer from `item_layout` without using a pointer cast. This requires `-Zmiri-permissive-provenance` when testing, + otherwise Miri will issue a warning. + + TODO: Rewrite this when strict_provenance or alloc_layout_extra is stable. + */ unsafe { NonNull::new_unchecked(align.get() as *mut u8) } }