diff --git a/crates/bevy_ecs/macros/src/component.rs b/crates/bevy_ecs/macros/src/component.rs index f9f1ad1fc67b1..8bd444304ecc7 100644 --- a/crates/bevy_ecs/macros/src/component.rs +++ b/crates/bevy_ecs/macros/src/component.rs @@ -1,7 +1,16 @@ use proc_macro::TokenStream; use proc_macro2::{Span, TokenStream as TokenStream2}; -use quote::quote; -use syn::{parse_macro_input, parse_quote, DeriveInput, ExprPath, Ident, LitStr, Path, Result}; +use quote::{quote, ToTokens}; +use std::collections::HashSet; +use syn::{ + parenthesized, + parse::Parse, + parse_macro_input, parse_quote, + punctuated::Punctuated, + spanned::Spanned, + token::{Comma, Paren}, + DeriveInput, ExprPath, Ident, LitStr, Path, Result, +}; pub fn derive_event(input: TokenStream) -> TokenStream { let mut ast = parse_macro_input!(input as DeriveInput); @@ -66,12 +75,55 @@ pub fn derive_component(input: TokenStream) -> TokenStream { .predicates .push(parse_quote! { Self: Send + Sync + 'static }); + let requires = &attrs.requires; + let mut register_required = Vec::with_capacity(attrs.requires.iter().len()); + let mut register_recursive_requires = Vec::with_capacity(attrs.requires.iter().len()); + if let Some(requires) = requires { + for require in requires { + let ident = &require.path; + register_recursive_requires.push(quote! { + <#ident as Component>::register_required_components(components, storages, required_components); + }); + if let Some(func) = &require.func { + register_required.push(quote! { + required_components.register(components, storages, || { let x: #ident = #func().into(); x }); + }); + } else { + register_required.push(quote! { + required_components.register(components, storages, <#ident as Default>::default); + }); + } + } + } let struct_name = &ast.ident; let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl(); + let required_component_docs = attrs.requires.map(|r| { + let paths = r + .iter() + .map(|r| format!("[`{}`]", r.path.to_token_stream())) + .collect::>() + .join(", "); + let doc = format!("Required Components: {paths}. \n\n A component's Required Components are inserted whenever it is inserted. Note that this will also insert the required components _of_ the required components, recursively, in depth-first order."); + quote! { + #[doc = #doc] + } + }); + + // This puts `register_required` before `register_recursive_requires` to ensure that the constructors of _all_ top + // level components are initialized first, giving them precedence over recursively defined constructors for the same component type TokenStream::from(quote! { + #required_component_docs impl #impl_generics #bevy_ecs_path::component::Component for #struct_name #type_generics #where_clause { const STORAGE_TYPE: #bevy_ecs_path::component::StorageType = #storage; + fn register_required_components( + components: &mut #bevy_ecs_path::component::Components, + storages: &mut #bevy_ecs_path::storage::Storages, + required_components: &mut #bevy_ecs_path::component::RequiredComponents + ) { + #(#register_required)* + #(#register_recursive_requires)* + } #[allow(unused_variables)] fn register_component_hooks(hooks: &mut #bevy_ecs_path::component::ComponentHooks) { @@ -86,6 +138,8 @@ pub fn derive_component(input: TokenStream) -> TokenStream { pub const COMPONENT: &str = "component"; pub const STORAGE: &str = "storage"; +pub const REQUIRE: &str = "require"; + pub const ON_ADD: &str = "on_add"; pub const ON_INSERT: &str = "on_insert"; pub const ON_REPLACE: &str = "on_replace"; @@ -93,6 +147,7 @@ pub const ON_REMOVE: &str = "on_remove"; struct Attrs { storage: StorageTy, + requires: Option>, on_add: Option, on_insert: Option, on_replace: Option, @@ -105,6 +160,11 @@ enum StorageTy { SparseSet, } +struct Require { + path: Path, + func: Option, +} + // values for `storage` attribute const TABLE: &str = "Table"; const SPARSE_SET: &str = "SparseSet"; @@ -116,42 +176,77 @@ fn parse_component_attr(ast: &DeriveInput) -> Result { on_insert: None, on_replace: None, on_remove: None, + requires: None, }; - for meta in ast.attrs.iter().filter(|a| a.path().is_ident(COMPONENT)) { - meta.parse_nested_meta(|nested| { - if nested.path.is_ident(STORAGE) { - attrs.storage = match nested.value()?.parse::()?.value() { - s if s == TABLE => StorageTy::Table, - s if s == SPARSE_SET => StorageTy::SparseSet, - s => { - return Err(nested.error(format!( - "Invalid storage type `{s}`, expected '{TABLE}' or '{SPARSE_SET}'.", - ))); - } - }; - Ok(()) - } else if nested.path.is_ident(ON_ADD) { - attrs.on_add = Some(nested.value()?.parse::()?); - Ok(()) - } else if nested.path.is_ident(ON_INSERT) { - attrs.on_insert = Some(nested.value()?.parse::()?); - Ok(()) - } else if nested.path.is_ident(ON_REPLACE) { - attrs.on_replace = Some(nested.value()?.parse::()?); - Ok(()) - } else if nested.path.is_ident(ON_REMOVE) { - attrs.on_remove = Some(nested.value()?.parse::()?); - Ok(()) + let mut require_paths = HashSet::new(); + for attr in ast.attrs.iter() { + if attr.path().is_ident(COMPONENT) { + attr.parse_nested_meta(|nested| { + if nested.path.is_ident(STORAGE) { + attrs.storage = match nested.value()?.parse::()?.value() { + s if s == TABLE => StorageTy::Table, + s if s == SPARSE_SET => StorageTy::SparseSet, + s => { + return Err(nested.error(format!( + "Invalid storage type `{s}`, expected '{TABLE}' or '{SPARSE_SET}'.", + ))); + } + }; + Ok(()) + } else if nested.path.is_ident(ON_ADD) { + attrs.on_add = Some(nested.value()?.parse::()?); + Ok(()) + } else if nested.path.is_ident(ON_INSERT) { + attrs.on_insert = Some(nested.value()?.parse::()?); + Ok(()) + } else if nested.path.is_ident(ON_REPLACE) { + attrs.on_replace = Some(nested.value()?.parse::()?); + Ok(()) + } else if nested.path.is_ident(ON_REMOVE) { + attrs.on_remove = Some(nested.value()?.parse::()?); + Ok(()) + } else { + Err(nested.error("Unsupported attribute")) + } + })?; + } else if attr.path().is_ident(REQUIRE) { + let punctuated = + attr.parse_args_with(Punctuated::::parse_terminated)?; + for require in punctuated.iter() { + if !require_paths.insert(require.path.to_token_stream().to_string()) { + return Err(syn::Error::new( + require.path.span(), + "Duplicate required components are not allowed.", + )); + } + } + if let Some(current) = &mut attrs.requires { + current.extend(punctuated); } else { - Err(nested.error("Unsupported attribute")) + attrs.requires = Some(punctuated); } - })?; + } } Ok(attrs) } +impl Parse for Require { + fn parse(input: syn::parse::ParseStream) -> Result { + let path = input.parse::()?; + let func = if input.peek(Paren) { + let content; + parenthesized!(content in input); + let func = content.parse::()?; + Some(func) + } else { + None + }; + Ok(Require { path, func }) + } +} + fn storage_path(bevy_ecs_path: &Path, ty: StorageTy) -> TokenStream2 { let storage_type = match ty { StorageTy::Table => Ident::new("Table", Span::call_site()), diff --git a/crates/bevy_ecs/macros/src/lib.rs b/crates/bevy_ecs/macros/src/lib.rs index 60e8027756731..b5624cee24bfe 100644 --- a/crates/bevy_ecs/macros/src/lib.rs +++ b/crates/bevy_ecs/macros/src/lib.rs @@ -77,6 +77,7 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream { let mut field_get_component_ids = Vec::new(); let mut field_get_components = Vec::new(); let mut field_from_components = Vec::new(); + let mut field_required_components = Vec::new(); for (((i, field_type), field_kind), field) in field_type .iter() .enumerate() @@ -88,6 +89,9 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream { field_component_ids.push(quote! { <#field_type as #ecs_path::bundle::Bundle>::component_ids(components, storages, &mut *ids); }); + field_required_components.push(quote! { + <#field_type as #ecs_path::bundle::Bundle>::register_required_components(components, storages, required_components); + }); field_get_component_ids.push(quote! { <#field_type as #ecs_path::bundle::Bundle>::get_component_ids(components, &mut *ids); }); @@ -153,6 +157,14 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream { #(#field_from_components)* } } + + fn register_required_components( + components: &mut #ecs_path::component::Components, + storages: &mut #ecs_path::storage::Storages, + required_components: &mut #ecs_path::component::RequiredComponents + ){ + #(#field_required_components)* + } } impl #impl_generics #ecs_path::bundle::DynamicBundle for #struct_name #ty_generics #where_clause { @@ -527,7 +539,7 @@ pub fn derive_resource(input: TokenStream) -> TokenStream { component::derive_resource(input) } -#[proc_macro_derive(Component, attributes(component))] +#[proc_macro_derive(Component, attributes(component, require))] pub fn derive_component(input: TokenStream) -> TokenStream { component::derive_component(input) } diff --git a/crates/bevy_ecs/src/archetype.rs b/crates/bevy_ecs/src/archetype.rs index baafd514db30c..b7cad4e389512 100644 --- a/crates/bevy_ecs/src/archetype.rs +++ b/crates/bevy_ecs/src/archetype.rs @@ -21,7 +21,7 @@ use crate::{ bundle::BundleId, - component::{ComponentId, Components, StorageType}, + component::{ComponentId, Components, RequiredComponentConstructor, StorageType}, entity::{Entity, EntityLocation}, observer::Observers, storage::{ImmutableSparseSet, SparseArray, SparseSet, SparseSetIndex, TableId, TableRow}, @@ -123,10 +123,31 @@ pub(crate) struct AddBundle { /// For each component iterated in the same order as the source [`Bundle`](crate::bundle::Bundle), /// indicate if the component is newly added to the target archetype or if it already existed pub bundle_status: Vec, + /// The set of additional required components that must be initialized immediately when adding this Bundle. + /// + /// The initial values are determined based on the provided constructor, falling back to the `Default` trait if none is given. + pub required_components: Vec, + /// The components added by this bundle. This includes any Required Components that are inserted when adding this bundle. pub added: Vec, + /// The components that were explicitly contributed by this bundle, but already existed in the archetype. This _does not_ include any + /// Required Components. pub existing: Vec, } +impl AddBundle { + pub(crate) fn iter_inserted(&self) -> impl Iterator + '_ { + self.added.iter().chain(self.existing.iter()).copied() + } + + pub(crate) fn iter_added(&self) -> impl Iterator + '_ { + self.added.iter().copied() + } + + pub(crate) fn iter_existing(&self) -> impl Iterator + '_ { + self.existing.iter().copied() + } +} + /// This trait is used to report the status of [`Bundle`](crate::bundle::Bundle) components /// being added to a given entity, relative to that entity's original archetype. /// See [`crate::bundle::BundleInfo::write_components`] for more info. @@ -208,6 +229,7 @@ impl Edges { bundle_id: BundleId, archetype_id: ArchetypeId, bundle_status: Vec, + required_components: Vec, added: Vec, existing: Vec, ) { @@ -216,6 +238,7 @@ impl Edges { AddBundle { archetype_id, bundle_status, + required_components, added, existing, }, diff --git a/crates/bevy_ecs/src/bundle.rs b/crates/bevy_ecs/src/bundle.rs index 6d3f04949a6c7..8311040bb3275 100644 --- a/crates/bevy_ecs/src/bundle.rs +++ b/crates/bevy_ecs/src/bundle.rs @@ -2,8 +2,6 @@ //! //! This module contains the [`Bundle`] trait and some other helper types. -use std::any::TypeId; - pub use bevy_ecs_macros::Bundle; use crate::{ @@ -11,7 +9,10 @@ use crate::{ AddBundle, Archetype, ArchetypeId, Archetypes, BundleComponentStatus, ComponentStatus, SpawnBundleStatus, }, - component::{Component, ComponentId, Components, StorageType, Tick}, + component::{ + Component, ComponentId, Components, RequiredComponentConstructor, RequiredComponents, + StorageType, Tick, + }, entity::{Entities, Entity, EntityLocation}, observer::Observers, prelude::World, @@ -19,12 +20,11 @@ use crate::{ storage::{SparseSetIndex, SparseSets, Storages, Table, TableRow}, world::{unsafe_world_cell::UnsafeWorldCell, ON_ADD, ON_INSERT, ON_REPLACE}, }; - use bevy_ptr::{ConstNonNull, OwningPtr}; use bevy_utils::{all_tuples, HashMap, HashSet, TypeIdMap}; #[cfg(feature = "track_change_detection")] use std::panic::Location; -use std::ptr::NonNull; +use std::{any::TypeId, ptr::NonNull}; /// The `Bundle` trait enables insertion and removal of [`Component`]s from an entity. /// @@ -174,6 +174,13 @@ pub unsafe trait Bundle: DynamicBundle + Send + Sync + 'static { // Ensure that the `OwningPtr` is used correctly F: for<'a> FnMut(&'a mut T) -> OwningPtr<'a>, Self: Sized; + + /// Registers components that are required by the components in this [`Bundle`]. + fn register_required_components( + _components: &mut Components, + _storages: &mut Storages, + _required_components: &mut RequiredComponents, + ); } /// The parts from [`Bundle`] that don't require statically knowing the components of the bundle. @@ -212,6 +219,14 @@ unsafe impl Bundle for C { unsafe { ptr.read() } } + fn register_required_components( + components: &mut Components, + storages: &mut Storages, + required_components: &mut RequiredComponents, + ) { + ::register_required_components(components, storages, required_components); + } + fn get_component_ids(components: &Components, ids: &mut impl FnMut(Option)) { ids(components.get_id(TypeId::of::())); } @@ -255,6 +270,14 @@ macro_rules! tuple_impl { // https://doc.rust-lang.org/reference/expressions.html#evaluation-order-of-operands unsafe { ($(<$name as Bundle>::from_components(ctx, func),)*) } } + + fn register_required_components( + _components: &mut Components, + _storages: &mut Storages, + _required_components: &mut RequiredComponents, + ) { + $(<$name as Bundle>::register_required_components(_components, _storages, _required_components);)* + } } impl<$($name: Bundle),*> DynamicBundle for ($($name,)*) { @@ -321,10 +344,17 @@ pub(crate) enum InsertMode { /// [`World`]: crate::world::World pub struct BundleInfo { id: BundleId, - // SAFETY: Every ID in this list must be valid within the World that owns the BundleInfo, - // must have its storage initialized (i.e. columns created in tables, sparse set created), - // and must be in the same order as the source bundle type writes its components in. + /// The list of all components contributed by the bundle (including Required Components). This is in + /// the order `[EXPLICIT_COMPONENTS][REQUIRED_COMPONENTS]` + /// + /// # Safety + /// Every ID in this list must be valid within the World that owns the [`BundleInfo`], + /// must have its storage initialized (i.e. columns created in tables, sparse set created), + /// and the range (0..`explicit_components_len`) must be in the same order as the source bundle + /// type writes its components in. component_ids: Vec, + required_components: Vec, + explicit_components_len: usize, } impl BundleInfo { @@ -338,7 +368,7 @@ impl BundleInfo { unsafe fn new( bundle_type_name: &'static str, components: &Components, - component_ids: Vec, + mut component_ids: Vec, id: BundleId, ) -> BundleInfo { let mut deduped = component_ids.clone(); @@ -367,11 +397,35 @@ impl BundleInfo { panic!("Bundle {bundle_type_name} has duplicate components: {names}"); } + let explicit_components_len = component_ids.len(); + let mut required_components = RequiredComponents::default(); + for component_id in component_ids.iter().copied() { + // SAFETY: caller has verified that all ids are valid + let info = unsafe { components.get_info_unchecked(component_id) }; + required_components.merge(info.required_components()); + } + required_components.remove_explicit_components(&component_ids); + let required_components = required_components + .0 + .into_iter() + .map(|(component_id, v)| { + // This adds required components to the component_ids list _after_ using that list to remove explicitly provided + // components. This ordering is important! + component_ids.push(component_id); + v + }) + .collect(); + // SAFETY: The caller ensures that component_ids: // - is valid for the associated world // - has had its storage initialized // - is in the same order as the source bundle type - BundleInfo { id, component_ids } + BundleInfo { + id, + component_ids, + required_components, + explicit_components_len, + } } /// Returns a value identifying the associated [`Bundle`] type. @@ -380,16 +434,49 @@ impl BundleInfo { self.id } - /// Returns the [ID](ComponentId) of each component stored in this bundle. + /// Returns the [ID](ComponentId) of each component explicitly defined in this bundle (ex: Required Components are excluded). + /// + /// For all components contributed by this bundle (including Required Components), see [`BundleInfo::contributed_components`] + #[inline] + pub fn explicit_components(&self) -> &[ComponentId] { + &self.component_ids[0..self.explicit_components_len] + } + + /// Returns the [ID](ComponentId) of each Required Component needed by this bundle. This _does not include_ Required Components that are + /// explicitly provided by the bundle. #[inline] - pub fn components(&self) -> &[ComponentId] { + pub fn required_components(&self) -> &[ComponentId] { + &self.component_ids[self.explicit_components_len..] + } + + /// Returns the [ID](ComponentId) of each component contributed by this bundle. This includes Required Components. + /// + /// For only components explicitly defined in this bundle, see [`BundleInfo::explicit_components`] + #[inline] + pub fn contributed_components(&self) -> &[ComponentId] { &self.component_ids } - /// Returns an iterator over the [ID](ComponentId) of each component stored in this bundle. + /// Returns an iterator over the [ID](ComponentId) of each component explicitly defined in this bundle (ex: this excludes Required Components). + + /// To iterate all components contributed by this bundle (including Required Components), see [`BundleInfo::iter_contributed_components`] + #[inline] + pub fn iter_explicit_components(&self) -> impl Iterator + '_ { + self.explicit_components().iter().copied() + } + + /// Returns an iterator over the [ID](ComponentId) of each component contributed by this bundle. This includes Required Components. + /// + /// To iterate only components explicitly defined in this bundle, see [`BundleInfo::iter_explicit_components`] #[inline] - pub fn iter_components(&self) -> impl Iterator + '_ { - self.component_ids.iter().cloned() + pub fn iter_contributed_components(&self) -> impl Iterator + '_ { + self.component_ids.iter().copied() + } + + /// Returns an iterator over the [ID](ComponentId) of each Required Component needed by this bundle. This _does not include_ Required Components that are + /// explicitly provided by the bundle. + pub fn iter_required_components(&self) -> impl Iterator + '_ { + self.required_components().iter().copied() } /// This writes components from a given [`Bundle`] to the given entity. @@ -410,11 +497,12 @@ impl BundleInfo { /// `entity`, `bundle` must match this [`BundleInfo`]'s type #[inline] #[allow(clippy::too_many_arguments)] - unsafe fn write_components( + unsafe fn write_components<'a, T: DynamicBundle, S: BundleComponentStatus>( &self, table: &mut Table, sparse_sets: &mut SparseSets, bundle_component_status: &S, + required_components: impl Iterator, entity: Entity, table_row: TableRow, change_tick: Tick, @@ -475,6 +563,74 @@ impl BundleInfo { } bundle_component += 1; }); + + for required_component in required_components { + required_component.initialize( + table, + sparse_sets, + change_tick, + table_row, + entity, + #[cfg(feature = "track_change_detection")] + caller, + ); + } + } + + /// Internal method to initialize a required component from an [`OwningPtr`]. This should ultimately be called + /// in the context of [`BundleInfo::write_components`], via [`RequiredComponentConstructor::initialize`]. + /// + /// # Safety + /// + /// `component_ptr` must point to a required component value that matches the given `component_id`. The `storage_type` must match + /// the type associated with `component_id`. The `entity` and `table_row` must correspond to an entity with an uninitialized + /// component matching `component_id`. + /// + /// This method _should not_ be called outside of [`BundleInfo::write_components`]. + /// For more information, read the [`BundleInfo::write_components`] safety docs. + /// This function inherits the safety requirements defined there. + #[allow(clippy::too_many_arguments)] + pub(crate) unsafe fn initialize_required_component( + table: &mut Table, + sparse_sets: &mut SparseSets, + change_tick: Tick, + table_row: TableRow, + entity: Entity, + component_id: ComponentId, + storage_type: StorageType, + component_ptr: OwningPtr, + #[cfg(feature = "track_change_detection")] caller: &'static Location<'static>, + ) { + { + match storage_type { + StorageType::Table => { + let column = + // SAFETY: If component_id is in required_components, BundleInfo::new requires that + // the target table contains the component. + unsafe { table.get_column_mut(component_id).debug_checked_unwrap() }; + column.initialize( + table_row, + component_ptr, + change_tick, + #[cfg(feature = "track_change_detection")] + caller, + ); + } + StorageType::SparseSet => { + let sparse_set = + // SAFETY: If component_id is in required_components, BundleInfo::new requires that + // a sparse set exists for the component. + unsafe { sparse_sets.get_mut(component_id).debug_checked_unwrap() }; + sparse_set.insert( + entity, + component_ptr, + change_tick, + #[cfg(feature = "track_change_detection")] + caller, + ); + } + } + } } /// Adds a bundle to the given archetype and returns the resulting archetype. This could be the @@ -495,15 +651,16 @@ impl BundleInfo { } let mut new_table_components = Vec::new(); let mut new_sparse_set_components = Vec::new(); - let mut bundle_status = Vec::with_capacity(self.component_ids.len()); + let mut bundle_status = Vec::with_capacity(self.explicit_components_len); + let mut added_required_components = Vec::new(); let mut added = Vec::new(); - let mut mutated = Vec::new(); + let mut existing = Vec::new(); let current_archetype = &mut archetypes[archetype_id]; - for component_id in self.component_ids.iter().cloned() { + for component_id in self.iter_explicit_components() { if current_archetype.contains(component_id) { bundle_status.push(ComponentStatus::Existing); - mutated.push(component_id); + existing.push(component_id); } else { bundle_status.push(ComponentStatus::Added); added.push(component_id); @@ -516,10 +673,34 @@ impl BundleInfo { } } + for (index, component_id) in self.iter_required_components().enumerate() { + if !current_archetype.contains(component_id) { + added_required_components.push(self.required_components[index].clone()); + added.push(component_id); + // SAFETY: component_id exists + let component_info = unsafe { components.get_info_unchecked(component_id) }; + match component_info.storage_type() { + StorageType::Table => { + new_table_components.push(component_id); + } + StorageType::SparseSet => { + new_sparse_set_components.push(component_id); + } + } + } + } + if new_table_components.is_empty() && new_sparse_set_components.is_empty() { let edges = current_archetype.edges_mut(); // the archetype does not change when we add this bundle - edges.insert_add_bundle(self.id, archetype_id, bundle_status, added, mutated); + edges.insert_add_bundle( + self.id, + archetype_id, + bundle_status, + added_required_components, + added, + existing, + ); archetype_id } else { let table_id; @@ -568,8 +749,9 @@ impl BundleInfo { self.id, new_archetype_id, bundle_status, + added_required_components, added, - mutated, + existing, ); new_archetype_id } @@ -706,7 +888,7 @@ impl<'w> BundleInserter<'w> { location: EntityLocation, bundle: T, insert_mode: InsertMode, - #[cfg(feature = "track_change_detection")] caller: &'static core::panic::Location<'static>, + #[cfg(feature = "track_change_detection")] caller: &'static Location<'static>, ) -> EntityLocation { let bundle_info = self.bundle_info.as_ref(); let add_bundle = self.add_bundle.as_ref(); @@ -720,13 +902,13 @@ impl<'w> BundleInserter<'w> { let mut deferred_world = self.world.into_deferred(); if insert_mode == InsertMode::Replace { - deferred_world.trigger_on_replace( - archetype, - entity, - add_bundle.existing.iter().copied(), - ); + deferred_world.trigger_on_replace(archetype, entity, add_bundle.iter_existing()); if archetype.has_replace_observer() { - deferred_world.trigger_observers(ON_REPLACE, entity, &add_bundle.existing); + deferred_world.trigger_observers( + ON_REPLACE, + entity, + add_bundle.iter_existing(), + ); } } } @@ -747,6 +929,7 @@ impl<'w> BundleInserter<'w> { table, sparse_sets, add_bundle, + add_bundle.required_components.iter(), entity, location.table_row, self.change_tick, @@ -788,6 +971,7 @@ impl<'w> BundleInserter<'w> { table, sparse_sets, add_bundle, + add_bundle.required_components.iter(), entity, result.table_row, self.change_tick, @@ -870,6 +1054,7 @@ impl<'w> BundleInserter<'w> { new_table, sparse_sets, add_bundle, + add_bundle.required_components.iter(), entity, move_result.new_row, self.change_tick, @@ -890,9 +1075,9 @@ impl<'w> BundleInserter<'w> { // SAFETY: All components in the bundle are guaranteed to exist in the World // as they must be initialized before creating the BundleInfo. unsafe { - deferred_world.trigger_on_add(new_archetype, entity, add_bundle.added.iter().cloned()); + deferred_world.trigger_on_add(new_archetype, entity, add_bundle.iter_added()); if new_archetype.has_add_observer() { - deferred_world.trigger_observers(ON_ADD, entity, &add_bundle.added); + deferred_world.trigger_observers(ON_ADD, entity, add_bundle.iter_added()); } match insert_mode { InsertMode::Replace => { @@ -900,13 +1085,13 @@ impl<'w> BundleInserter<'w> { deferred_world.trigger_on_insert( new_archetype, entity, - bundle_info.iter_components(), + add_bundle.iter_inserted(), ); if new_archetype.has_insert_observer() { deferred_world.trigger_observers( ON_INSERT, entity, - bundle_info.components(), + add_bundle.iter_inserted(), ); } } @@ -916,10 +1101,14 @@ impl<'w> BundleInserter<'w> { deferred_world.trigger_on_insert( new_archetype, entity, - add_bundle.added.iter().cloned(), + add_bundle.iter_added(), ); if new_archetype.has_insert_observer() { - deferred_world.trigger_observers(ON_INSERT, entity, &add_bundle.added); + deferred_world.trigger_observers( + ON_INSERT, + entity, + add_bundle.iter_added(), + ); } } } @@ -1017,6 +1206,7 @@ impl<'w> BundleSpawner<'w> { table, sparse_sets, &SpawnBundleStatus, + bundle_info.required_components.iter(), entity, table_row, self.change_tick, @@ -1036,13 +1226,29 @@ impl<'w> BundleSpawner<'w> { // SAFETY: All components in the bundle are guaranteed to exist in the World // as they must be initialized before creating the BundleInfo. unsafe { - deferred_world.trigger_on_add(archetype, entity, bundle_info.iter_components()); + deferred_world.trigger_on_add( + archetype, + entity, + bundle_info.iter_contributed_components(), + ); if archetype.has_add_observer() { - deferred_world.trigger_observers(ON_ADD, entity, bundle_info.components()); + deferred_world.trigger_observers( + ON_ADD, + entity, + bundle_info.iter_contributed_components(), + ); } - deferred_world.trigger_on_insert(archetype, entity, bundle_info.iter_components()); + deferred_world.trigger_on_insert( + archetype, + entity, + bundle_info.iter_contributed_components(), + ); if archetype.has_insert_observer() { - deferred_world.trigger_observers(ON_INSERT, entity, bundle_info.components()); + deferred_world.trigger_observers( + ON_INSERT, + entity, + bundle_info.iter_contributed_components(), + ); } }; @@ -1125,7 +1331,7 @@ impl Bundles { ) -> BundleId { let bundle_infos = &mut self.bundle_infos; let id = *self.bundle_ids.entry(TypeId::of::()).or_insert_with(|| { - let mut component_ids = Vec::new(); + let mut component_ids= Vec::new(); T::component_ids(components, storages, &mut |id| component_ids.push(id)); let id = BundleId(bundle_infos.len()); let bundle_info = diff --git a/crates/bevy_ecs/src/component.rs b/crates/bevy_ecs/src/component.rs index 161db19b9a0e9..53e64e7e2c1dc 100644 --- a/crates/bevy_ecs/src/component.rs +++ b/crates/bevy_ecs/src/component.rs @@ -3,9 +3,10 @@ use crate::{ self as bevy_ecs, archetype::ArchetypeFlags, + bundle::BundleInfo, change_detection::MAX_CHANGE_AGE, entity::Entity, - storage::{SparseSetIndex, Storages}, + storage::{SparseSetIndex, SparseSets, Storages, Table, TableRow}, system::{Local, Resource, SystemParam}, world::{DeferredWorld, FromWorld, World}, }; @@ -13,15 +14,18 @@ pub use bevy_ecs_macros::Component; use bevy_ptr::{OwningPtr, UnsafeCellDeref}; #[cfg(feature = "bevy_reflect")] use bevy_reflect::Reflect; -use bevy_utils::TypeIdMap; -use std::cell::UnsafeCell; +use bevy_utils::{HashMap, TypeIdMap}; +#[cfg(feature = "track_change_detection")] +use std::panic::Location; use std::{ alloc::Layout, any::{Any, TypeId}, borrow::Cow, marker::PhantomData, mem::needs_drop, + sync::Arc, }; +use std::{cell::UnsafeCell, fmt::Debug}; /// A data type that can be used to store data for an [entity]. /// @@ -93,6 +97,141 @@ use std::{ /// [`Table`]: crate::storage::Table /// [`SparseSet`]: crate::storage::SparseSet /// +/// # Required Components +/// +/// Components can specify Required Components. If some [`Component`] `A` requires [`Component`] `B`, then when `A` is inserted, +/// `B` will _also_ be initialized and inserted (if it was not manually specified). +/// +/// The [`Default`] constructor will be used to initialize the component, by default: +/// +/// ``` +/// # use bevy_ecs::prelude::*; +/// #[derive(Component)] +/// #[require(B)] +/// struct A; +/// +/// #[derive(Component, Default, PartialEq, Eq, Debug)] +/// struct B(usize); +/// +/// # let mut world = World::default(); +/// // This will implicitly also insert B with the Default constructor +/// let id = world.spawn(A).id(); +/// assert_eq!(&B(0), world.entity(id).get::().unwrap()); +/// +/// // This will _not_ implicitly insert B, because it was already provided +/// world.spawn((A, B(11))); +/// ``` +/// +/// Components can have more than one required component: +/// +/// ``` +/// # use bevy_ecs::prelude::*; +/// #[derive(Component)] +/// #[require(B, C)] +/// struct A; +/// +/// #[derive(Component, Default, PartialEq, Eq, Debug)] +/// #[require(C)] +/// struct B(usize); +/// +/// #[derive(Component, Default, PartialEq, Eq, Debug)] +/// struct C(u32); +/// +/// # let mut world = World::default(); +/// // This will implicitly also insert B and C with their Default constructors +/// let id = world.spawn(A).id(); +/// assert_eq!(&B(0), world.entity(id).get::().unwrap()); +/// assert_eq!(&C(0), world.entity(id).get::().unwrap()); +/// ``` +/// +/// You can also define a custom constructor: +/// +/// ``` +/// # use bevy_ecs::prelude::*; +/// #[derive(Component)] +/// #[require(B(init_b))] +/// struct A; +/// +/// #[derive(Component, PartialEq, Eq, Debug)] +/// struct B(usize); +/// +/// fn init_b() -> B { +/// B(10) +/// } +/// +/// # let mut world = World::default(); +/// // This will implicitly also insert B with the init_b() constructor +/// let id = world.spawn(A).id(); +/// assert_eq!(&B(10), world.entity(id).get::().unwrap()); +/// ``` +/// +/// Required components are _recursive_. This means, if a Required Component has required components, +/// those components will _also_ be inserted if they are missing: +/// +/// ``` +/// # use bevy_ecs::prelude::*; +/// #[derive(Component)] +/// #[require(B)] +/// struct A; +/// +/// #[derive(Component, Default, PartialEq, Eq, Debug)] +/// #[require(C)] +/// struct B(usize); +/// +/// #[derive(Component, Default, PartialEq, Eq, Debug)] +/// struct C(u32); +/// +/// # let mut world = World::default(); +/// // This will implicitly also insert B and C with their Default constructors +/// let id = world.spawn(A).id(); +/// assert_eq!(&B(0), world.entity(id).get::().unwrap()); +/// assert_eq!(&C(0), world.entity(id).get::().unwrap()); +/// ``` +/// +/// Note that cycles in the "component require tree" will result in stack overflows when attempting to +/// insert a component. +/// +/// This "multiple inheritance" pattern does mean that it is possible to have duplicate requires for a given type +/// at different levels of the inheritance tree: +/// +/// ``` +/// # use bevy_ecs::prelude::*; +/// #[derive(Component)] +/// struct X(usize); +/// +/// #[derive(Component, Default)] +/// #[require(X(x1))] +/// struct Y; +/// +/// fn x1() -> X { +/// X(1) +/// } +/// +/// #[derive(Component)] +/// #[require( +/// Y, +/// X(x2), +/// )] +/// struct Z; +/// +/// fn x2() -> X { +/// X(2) +/// } +/// +/// # let mut world = World::default(); +/// // In this case, the x2 constructor is used for X +/// let id = world.spawn(Z).id(); +/// assert_eq!(2, world.entity(id).get::().unwrap().0); +/// ``` +/// +/// In general, this shouldn't happen often, but when it does the algorithm is simple and predictable: +/// 1. Use all of the constructors (including default constructors) directly defined in the spawned component's require list +/// 2. In the order the requires are defined in `#[require()]`, recursively visit the require list of each of the components in the list (this is a depth Depth First Search). When a constructor is found, it will only be used if one has not already been found. +/// +/// From a user perspective, just think about this as the following: +/// 1. Specifying a required component constructor for Foo directly on a spawned component Bar will result in that constructor being used (and overriding existing constructors lower in the inheritance tree). This is the classic "inheritance override" behavior people expect. +/// 2. For cases where "multiple inheritance" results in constructor clashes, Components should be listed in "importance order". List a component earlier in the requirement list to initialize its inheritance tree earlier. +/// /// # Adding component's hooks /// /// See [`ComponentHooks`] for a detailed explanation of component's hooks. @@ -198,6 +337,14 @@ pub trait Component: Send + Sync + 'static { /// Called when registering this component, allowing mutable access to its [`ComponentHooks`]. fn register_component_hooks(_hooks: &mut ComponentHooks) {} + + /// Registers required components. + fn register_required_components( + _components: &mut Components, + _storages: &mut Storages, + _required_components: &mut RequiredComponents, + ) { + } } /// The storage used for a specific component type. @@ -408,6 +555,7 @@ pub struct ComponentInfo { id: ComponentId, descriptor: ComponentDescriptor, hooks: ComponentHooks, + required_components: RequiredComponents, } impl ComponentInfo { @@ -466,7 +614,8 @@ impl ComponentInfo { ComponentInfo { id, descriptor, - hooks: ComponentHooks::default(), + hooks: Default::default(), + required_components: Default::default(), } } @@ -491,6 +640,12 @@ impl ComponentInfo { pub fn hooks(&self) -> &ComponentHooks { &self.hooks } + + /// Retrieves the [`RequiredComponents`] collection, which contains all required components (and their constructors) + /// needed by this component. This includes _recursive_ required components. + pub fn required_components(&self) -> &RequiredComponents { + &self.required_components + } } /// A value which uniquely identifies the type of a [`Component`] or [`Resource`] within a @@ -570,7 +725,7 @@ pub struct ComponentDescriptor { } // We need to ignore the `drop` field in our `Debug` impl -impl std::fmt::Debug for ComponentDescriptor { +impl Debug for ComponentDescriptor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ComponentDescriptor") .field("name", &self.name) @@ -692,22 +847,32 @@ impl Components { /// * [`Components::init_component_with_descriptor()`] #[inline] pub fn init_component(&mut self, storages: &mut Storages) -> ComponentId { - let type_id = TypeId::of::(); - - let Components { - indices, - components, - .. - } = self; - *indices.entry(type_id).or_insert_with(|| { - let index = Components::init_component_inner( + let mut registered = false; + let id = { + let Components { + indices, components, - storages, - ComponentDescriptor::new::(), - ); - T::register_component_hooks(&mut components[index.index()].hooks); - index - }) + .. + } = self; + let type_id = TypeId::of::(); + *indices.entry(type_id).or_insert_with(|| { + let id = Components::init_component_inner( + components, + storages, + ComponentDescriptor::new::(), + ); + registered = true; + id + }) + }; + if registered { + let mut required_components = RequiredComponents::default(); + T::register_required_components(self, storages, &mut required_components); + let info = &mut self.components[id.index()]; + T::register_component_hooks(&mut info.hooks); + info.required_components = required_components; + } + id } /// Initializes a component described by `descriptor`. @@ -1133,3 +1298,150 @@ impl FromWorld for InitComponentId { } } } + +/// A Required Component constructor. See [`Component`] for details. +#[cfg(feature = "track_change_detection")] +#[derive(Clone)] +pub struct RequiredComponentConstructor( + pub Arc)>, +); + +/// A Required Component constructor. See [`Component`] for details. +#[cfg(not(feature = "track_change_detection"))] +#[derive(Clone)] +pub struct RequiredComponentConstructor( + pub Arc, +); + +impl RequiredComponentConstructor { + /// # Safety + /// This is intended to only be called in the context of [`BundleInfo::write_components`] to initialized required components. + /// Calling it _anywhere else_ should be considered unsafe. + /// + /// `table_row` and `entity` must correspond to a valid entity that currently needs a component initialized via the constructor stored + /// on this [`RequiredComponentConstructor`]. The stored constructor must correspond to a component on `entity` that needs initialization. + /// `table` and `sparse_sets` must correspond to storages on a world where `entity` needs this required component initialized. + /// + /// Again, don't call this anywhere but [`BundleInfo::write_components`]. + pub(crate) unsafe fn initialize( + &self, + table: &mut Table, + sparse_sets: &mut SparseSets, + change_tick: Tick, + table_row: TableRow, + entity: Entity, + #[cfg(feature = "track_change_detection")] caller: &'static Location<'static>, + ) { + (self.0)( + table, + sparse_sets, + change_tick, + table_row, + entity, + #[cfg(feature = "track_change_detection")] + caller, + ); + } +} + +/// The collection of metadata for components that are required for a given component. +/// +/// For more information, see the "Required Components" section of [`Component`]. +#[derive(Default, Clone)] +pub struct RequiredComponents(pub(crate) HashMap); + +impl Debug for RequiredComponents { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("RequiredComponents") + .field(&self.0.keys()) + .finish() + } +} + +impl RequiredComponents { + /// Registers a required component. If the component is already registered, the new registration + /// passed in the arguments will be ignored. + /// + /// # Safety + /// + /// `component_id` must match the type initialized by `constructor`. + /// `constructor` _must_ initialize a component for `component_id` in such a way that + /// matches the storage type of the component. It must only use the given `table_row` or `Entity` to + /// initialize the storage for `component_id` corresponding to the given entity. + pub unsafe fn register_dynamic( + &mut self, + component_id: ComponentId, + constructor: RequiredComponentConstructor, + ) { + self.0.entry(component_id).or_insert(constructor); + } + + /// Registers a required component. If the component is already registered, the new registration + /// passed in the arguments will be ignored. + pub fn register( + &mut self, + components: &mut Components, + storages: &mut Storages, + constructor: fn() -> C, + ) { + let component_id = components.init_component::(storages); + let erased: RequiredComponentConstructor = RequiredComponentConstructor(Arc::new( + move |table, + sparse_sets, + change_tick, + table_row, + entity, + #[cfg(feature = "track_change_detection")] caller| { + OwningPtr::make(constructor(), |ptr| { + // SAFETY: This will only be called in the context of `BundleInfo::write_components`, which will + // pass in a valid table_row and entity requiring a C constructor + // C::STORAGE_TYPE is the storage type associated with `component_id` / `C` + // `ptr` points to valid `C` data, which matches the type associated with `component_id` + unsafe { + BundleInfo::initialize_required_component( + table, + sparse_sets, + change_tick, + table_row, + entity, + component_id, + C::STORAGE_TYPE, + ptr, + #[cfg(feature = "track_change_detection")] + caller, + ); + } + }); + }, + )); + // SAFETY: + // `component_id` matches the type initialized by the `erased` constructor above. + // `erased` initializes a component for `component_id` in such a way that + // matches the storage type of the component. It only uses the given `table_row` or `Entity` to + // initialize the storage corresponding to the given entity. + unsafe { self.register_dynamic(component_id, erased) }; + } + + /// Iterates the ids of all required components. This includes recursive required components. + pub fn iter_ids(&self) -> impl Iterator + '_ { + self.0.keys().copied() + } + + /// Removes components that are explicitly provided in a given [`Bundle`]. These components should + /// be logically treated as normal components, not "required components". + /// + /// [`Bundle`]: crate::bundle::Bundle + pub(crate) fn remove_explicit_components(&mut self, components: &[ComponentId]) { + for component in components { + self.0.remove(component); + } + } + + // Merges `required_components` into this collection. This only inserts a required component + // if it _did not already exist_. + pub(crate) fn merge(&mut self, required_components: &RequiredComponents) { + for (id, constructor) in &required_components.0 { + self.0.entry(*id).or_insert_with(|| constructor.clone()); + } + } +} diff --git a/crates/bevy_ecs/src/entity/mod.rs b/crates/bevy_ecs/src/entity/mod.rs index 4250058ed8cda..0c5a2ab939476 100644 --- a/crates/bevy_ecs/src/entity/mod.rs +++ b/crates/bevy_ecs/src/entity/mod.rs @@ -144,7 +144,7 @@ type IdCursor = isize; /// [SemVer]: https://semver.org/ #[derive(Clone, Copy)] #[cfg_attr(feature = "bevy_reflect", derive(Reflect))] -#[cfg_attr(feature = "bevy_reflect", reflect_value(Hash, PartialEq))] +#[cfg_attr(feature = "bevy_reflect", reflect_value(Hash, PartialEq, Debug))] #[cfg_attr( all(feature = "bevy_reflect", feature = "serialize"), reflect_value(Serialize, Deserialize) diff --git a/crates/bevy_ecs/src/lib.rs b/crates/bevy_ecs/src/lib.rs index a08a62053e107..09e86f7711cca 100644 --- a/crates/bevy_ecs/src/lib.rs +++ b/crates/bevy_ecs/src/lib.rs @@ -1793,6 +1793,235 @@ mod tests { ); } + #[test] + fn required_components() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component)] + #[require(Z(new_z))] + struct Y { + value: String, + } + + #[derive(Component)] + struct Z(u32); + + impl Default for Y { + fn default() -> Self { + Self { + value: "hello".to_string(), + } + } + } + + fn new_z() -> Z { + Z(7) + } + + let mut world = World::new(); + let id = world.spawn(X).id(); + assert_eq!( + "hello", + world.entity(id).get::().unwrap().value, + "Y should have the default value" + ); + assert_eq!( + 7, + world.entity(id).get::().unwrap().0, + "Z should have the value provided by the constructor defined in Y" + ); + + let id = world + .spawn(( + X, + Y { + value: "foo".to_string(), + }, + )) + .id(); + assert_eq!( + "foo", + world.entity(id).get::().unwrap().value, + "Y should have the manually provided value" + ); + assert_eq!( + 7, + world.entity(id).get::().unwrap().0, + "Z should have the value provided by the constructor defined in Y" + ); + + let id = world.spawn((X, Z(8))).id(); + assert_eq!( + "hello", + world.entity(id).get::().unwrap().value, + "Y should have the default value" + ); + assert_eq!( + 8, + world.entity(id).get::().unwrap().0, + "Z should have the manually provided value" + ); + } + + #[test] + fn generic_required_components() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y { + value: T, + } + + let mut world = World::new(); + let id = world.spawn(X).id(); + assert_eq!( + 0, + world.entity(id).get::>().unwrap().value, + "Y should have the default value" + ); + } + + #[test] + fn required_components_spawn_nonexistent_hooks() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y; + + #[derive(Resource)] + struct A(usize); + + #[derive(Resource)] + struct I(usize); + + let mut world = World::new(); + world.insert_resource(A(0)); + world.insert_resource(I(0)); + world + .register_component_hooks::() + .on_add(|mut world, _, _| world.resource_mut::().0 += 1) + .on_insert(|mut world, _, _| world.resource_mut::().0 += 1); + + // Spawn entity and ensure Y was added + assert!(world.spawn(X).contains::()); + + assert_eq!(world.resource::().0, 1); + assert_eq!(world.resource::().0, 1); + } + + #[test] + fn required_components_insert_existing_hooks() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y; + + #[derive(Resource)] + struct A(usize); + + #[derive(Resource)] + struct I(usize); + + let mut world = World::new(); + world.insert_resource(A(0)); + world.insert_resource(I(0)); + world + .register_component_hooks::() + .on_add(|mut world, _, _| world.resource_mut::().0 += 1) + .on_insert(|mut world, _, _| world.resource_mut::().0 += 1); + + // Spawn entity and ensure Y was added + assert!(world.spawn_empty().insert(X).contains::()); + + assert_eq!(world.resource::().0, 1); + assert_eq!(world.resource::().0, 1); + } + + #[test] + fn required_components_take_leaves_required() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y; + + let mut world = World::new(); + let e = world.spawn(X).id(); + let _ = world.entity_mut(e).take::().unwrap(); + assert!(world.entity_mut(e).contains::()); + } + + #[test] + fn required_components_retain_keeps_required() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y; + + #[derive(Component, Default)] + struct Z; + + let mut world = World::new(); + let e = world.spawn((X, Z)).id(); + world.entity_mut(e).retain::(); + assert!(world.entity_mut(e).contains::()); + assert!(world.entity_mut(e).contains::()); + assert!(!world.entity_mut(e).contains::()); + } + + #[test] + fn required_components_spawn_then_insert_no_overwrite() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y(usize); + + let mut world = World::new(); + let id = world.spawn((X, Y(10))).id(); + world.entity_mut(id).insert(X); + + assert_eq!( + 10, + world.entity(id).get::().unwrap().0, + "Y should still have the manually provided value" + ); + } + + #[test] + fn dynamic_required_components() { + #[derive(Component)] + #[require(Y)] + struct X; + + #[derive(Component, Default)] + struct Y; + + let mut world = World::new(); + let x_id = world.init_component::(); + + let mut e = world.spawn_empty(); + + // SAFETY: x_id is a valid component id + bevy_ptr::OwningPtr::make(X, |ptr| unsafe { + e.insert_by_id(x_id, ptr); + }); + + assert!(e.contains::()); + } + // These structs are primarily compilation tests to test the derive macros. Because they are // never constructed, we have to manually silence the `dead_code` lint. #[allow(dead_code)] diff --git a/crates/bevy_ecs/src/system/mod.rs b/crates/bevy_ecs/src/system/mod.rs index bfb2dcd396cce..4eb8301d9b84a 100644 --- a/crates/bevy_ecs/src/system/mod.rs +++ b/crates/bevy_ecs/src/system/mod.rs @@ -1063,7 +1063,7 @@ mod tests { .get_id(TypeId::of::<(W, W)>()) .expect("Bundle used to spawn entity should exist"); let bundle_info = bundles.get(bundle_id).unwrap(); - let mut bundle_components = bundle_info.components().to_vec(); + let mut bundle_components = bundle_info.contributed_components().to_vec(); bundle_components.sort(); for component_id in &bundle_components { assert!( diff --git a/crates/bevy_ecs/src/world/deferred_world.rs b/crates/bevy_ecs/src/world/deferred_world.rs index 90205b2ed4360..f84b22a407310 100644 --- a/crates/bevy_ecs/src/world/deferred_world.rs +++ b/crates/bevy_ecs/src/world/deferred_world.rs @@ -370,13 +370,13 @@ impl<'w> DeferredWorld<'w> { &mut self, event: ComponentId, entity: Entity, - components: &[ComponentId], + components: impl Iterator, ) { Observers::invoke::<_>( self.reborrow(), event, entity, - components.iter().copied(), + components, &mut (), &mut false, ); diff --git a/crates/bevy_ecs/src/world/entity_ref.rs b/crates/bevy_ecs/src/world/entity_ref.rs index 0aa28241557e7..5adfb2d302021 100644 --- a/crates/bevy_ecs/src/world/entity_ref.rs +++ b/crates/bevy_ecs/src/world/entity_ref.rs @@ -956,7 +956,7 @@ impl<'w> EntityWorldMut<'w> { let removed_components = &mut world.removed_components; let entity = self.entity; - let mut bundle_components = bundle_info.iter_components(); + let mut bundle_components = bundle_info.iter_explicit_components(); // SAFETY: bundle components are iterated in order, which guarantees that the component type // matches let result = unsafe { @@ -1131,7 +1131,7 @@ impl<'w> EntityWorldMut<'w> { } let old_archetype = &world.archetypes[location.archetype_id]; - for component_id in bundle_info.iter_components() { + for component_id in bundle_info.iter_explicit_components() { if old_archetype.contains(component_id) { world.removed_components.send(component_id, entity); @@ -1180,7 +1180,7 @@ impl<'w> EntityWorldMut<'w> { self } - /// Removes any components except those in the [`Bundle`] from the entity. + /// Removes any components except those in the [`Bundle`] (and its Required Components) from the entity. /// /// See [`EntityCommands::retain`](crate::system::EntityCommands::retain) for more details. pub fn retain(&mut self) -> &mut Self { @@ -1194,9 +1194,10 @@ impl<'w> EntityWorldMut<'w> { let old_location = self.location; let old_archetype = &mut archetypes[old_location.archetype_id]; + // PERF: this could be stored in an Archetype Edge let to_remove = &old_archetype .components() - .filter(|c| !retained_bundle_info.components().contains(c)) + .filter(|c| !retained_bundle_info.contributed_components().contains(c)) .collect::>(); let remove_bundle = self.world.bundles.init_dynamic_info(components, to_remove); @@ -1261,19 +1262,11 @@ impl<'w> EntityWorldMut<'w> { unsafe { deferred_world.trigger_on_replace(archetype, self.entity, archetype.components()); if archetype.has_replace_observer() { - deferred_world.trigger_observers( - ON_REPLACE, - self.entity, - &archetype.components().collect::>(), - ); + deferred_world.trigger_observers(ON_REPLACE, self.entity, archetype.components()); } deferred_world.trigger_on_remove(archetype, self.entity, archetype.components()); if archetype.has_remove_observer() { - deferred_world.trigger_observers( - ON_REMOVE, - self.entity, - &archetype.components().collect::>(), - ); + deferred_world.trigger_observers(ON_REMOVE, self.entity, archetype.components()); } } @@ -1484,13 +1477,17 @@ unsafe fn trigger_on_replace_and_on_remove_hooks_and_observers( entity: Entity, bundle_info: &BundleInfo, ) { - deferred_world.trigger_on_replace(archetype, entity, bundle_info.iter_components()); + deferred_world.trigger_on_replace(archetype, entity, bundle_info.iter_explicit_components()); if archetype.has_replace_observer() { - deferred_world.trigger_observers(ON_REPLACE, entity, bundle_info.components()); + deferred_world.trigger_observers( + ON_REPLACE, + entity, + bundle_info.iter_explicit_components(), + ); } - deferred_world.trigger_on_remove(archetype, entity, bundle_info.iter_components()); + deferred_world.trigger_on_remove(archetype, entity, bundle_info.iter_explicit_components()); if archetype.has_remove_observer() { - deferred_world.trigger_observers(ON_REMOVE, entity, bundle_info.components()); + deferred_world.trigger_observers(ON_REMOVE, entity, bundle_info.iter_explicit_components()); } } @@ -2423,7 +2420,7 @@ unsafe fn remove_bundle_from_archetype( let current_archetype = &mut archetypes[archetype_id]; let mut removed_table_components = Vec::new(); let mut removed_sparse_set_components = Vec::new(); - for component_id in bundle_info.components().iter().cloned() { + for component_id in bundle_info.iter_explicit_components() { if current_archetype.contains(component_id) { // SAFETY: bundle components were already initialized by bundles.get_info let component_info = unsafe { components.get_info_unchecked(component_id) };