diff --git a/benches/benches/bevy_ecs/iteration/heavy_compute.rs b/benches/benches/bevy_ecs/iteration/heavy_compute.rs index 9a53092903f48..1692db9ce15d7 100644 --- a/benches/benches/bevy_ecs/iteration/heavy_compute.rs +++ b/benches/benches/bevy_ecs/iteration/heavy_compute.rs @@ -20,7 +20,7 @@ pub fn heavy_compute(c: &mut Criterion) { group.warm_up_time(std::time::Duration::from_millis(500)); group.measurement_time(std::time::Duration::from_secs(4)); group.bench_function("base", |b| { - ComputeTaskPool::get_or_init(TaskPool::default); + ComputeTaskPool::get_or_default(); let mut world = World::default(); diff --git a/benches/benches/bevy_ecs/iteration/par_iter_simple.rs b/benches/benches/bevy_ecs/iteration/par_iter_simple.rs index 76489e33a84a3..89a30ad5652a2 100644 --- a/benches/benches/bevy_ecs/iteration/par_iter_simple.rs +++ b/benches/benches/bevy_ecs/iteration/par_iter_simple.rs @@ -26,7 +26,7 @@ fn insert_if_bit_enabled(entity: &mut EntityWorldMut, i: u16) { impl<'w> Benchmark<'w> { pub fn new(fragment: u16) -> Self { - ComputeTaskPool::get_or_init(TaskPool::default); + ComputeTaskPool::get_or_default(); let mut world = World::new(); diff --git a/crates/bevy_asset/src/processor/mod.rs b/crates/bevy_asset/src/processor/mod.rs index 380b1b2b4fd1a..82d288cc8cef9 100644 --- a/crates/bevy_asset/src/processor/mod.rs +++ b/crates/bevy_asset/src/processor/mod.rs @@ -442,7 +442,7 @@ impl AssetProcessor { #[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))] async fn process_assets_internal<'scope>( &'scope self, - scope: &'scope bevy_tasks::Scope<'scope, '_, ()>, + scope: &'scope bevy_tasks::StaticScope<'scope, '_, ()>, source: &'scope AssetSource, path: PathBuf, ) -> Result<(), AssetReaderError> { diff --git a/crates/bevy_core/src/task_pool_options.rs b/crates/bevy_core/src/task_pool_options.rs index 276902fb499da..7a2a19c3e561b 100644 --- a/crates/bevy_core/src/task_pool_options.rs +++ b/crates/bevy_core/src/task_pool_options.rs @@ -1,5 +1,7 @@ -use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder}; -use bevy_utils::tracing::trace; +use bevy_tasks::{ + AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder, TaskPoolInitializationError, +}; +use bevy_utils::tracing::{trace, warn}; /// Defines a simple way to determine how many threads to use given the number of remaining cores /// and number of total cores @@ -80,6 +82,18 @@ impl Default for TaskPoolOptions { } } +fn handle_initialization_error(name: &str, res: Result<(), TaskPoolInitializationError>) { + match res { + Ok(()) => {} + Err(TaskPoolInitializationError::AlreadyInitialized) => { + warn!("{} already initialized.", name); + } + Err(err) => { + panic!("Error while initializing: {}", err); + } + } +} + impl TaskPoolOptions { /// Create a configuration that forces using the given number of threads. pub fn with_num_threads(thread_count: usize) -> Self { @@ -107,12 +121,14 @@ impl TaskPoolOptions { trace!("IO Threads: {}", io_threads); remaining_threads = remaining_threads.saturating_sub(io_threads); - IoTaskPool::get_or_init(|| { - TaskPoolBuilder::default() - .num_threads(io_threads) - .thread_name("IO Task Pool".to_string()) - .build() - }); + handle_initialization_error( + "IO Task Pool", + IoTaskPool::get().init( + TaskPoolBuilder::default() + .num_threads(io_threads) + .thread_name("IO Task Pool".to_string()), + ), + ); } { @@ -124,12 +140,14 @@ impl TaskPoolOptions { trace!("Async Compute Threads: {}", async_compute_threads); remaining_threads = remaining_threads.saturating_sub(async_compute_threads); - AsyncComputeTaskPool::get_or_init(|| { - TaskPoolBuilder::default() - .num_threads(async_compute_threads) - .thread_name("Async Compute Task Pool".to_string()) - .build() - }); + handle_initialization_error( + "Async Task Pool", + AsyncComputeTaskPool::get().init( + TaskPoolBuilder::default() + .num_threads(async_compute_threads) + .thread_name("Async Compute Task Pool".to_string()), + ), + ); } { @@ -141,12 +159,14 @@ impl TaskPoolOptions { trace!("Compute Threads: {}", compute_threads); - ComputeTaskPool::get_or_init(|| { - TaskPoolBuilder::default() - .num_threads(compute_threads) - .thread_name("Compute Task Pool".to_string()) - .build() - }); + handle_initialization_error( + "Compute Task Pool", + ComputeTaskPool::get().init( + TaskPoolBuilder::default() + .num_threads(compute_threads) + .thread_name("Compute Task Pool".to_string()), + ), + ); } } } diff --git a/crates/bevy_ecs/src/lib.rs b/crates/bevy_ecs/src/lib.rs index 7eb5cf1a3bd1b..f165662487467 100644 --- a/crates/bevy_ecs/src/lib.rs +++ b/crates/bevy_ecs/src/lib.rs @@ -73,7 +73,7 @@ mod tests { system::Resource, world::{EntityRef, Mut, World}, }; - use bevy_tasks::{ComputeTaskPool, TaskPool}; + use bevy_tasks::ComputeTaskPool; use std::num::NonZeroU32; use std::{ any::TypeId, @@ -405,7 +405,7 @@ mod tests { #[test] fn par_for_each_dense() { - ComputeTaskPool::get_or_init(TaskPool::default); + ComputeTaskPool::get_or_default(); let mut world = World::new(); let e1 = world.spawn(A(1)).id(); let e2 = world.spawn(A(2)).id(); @@ -428,7 +428,7 @@ mod tests { #[test] fn par_for_each_sparse() { - ComputeTaskPool::get_or_init(TaskPool::default); + ComputeTaskPool::get_or_default(); let mut world = World::new(); let e1 = world.spawn(SparseStored(1)).id(); let e2 = world.spawn(SparseStored(2)).id(); diff --git a/crates/bevy_ecs/src/query/state.rs b/crates/bevy_ecs/src/query/state.rs index 84fd805fa6ca9..1882ec943d4c9 100644 --- a/crates/bevy_ecs/src/query/state.rs +++ b/crates/bevy_ecs/src/query/state.rs @@ -1312,7 +1312,7 @@ impl QueryState { /// #[derive(Component, PartialEq, Debug)] /// struct A(usize); /// - /// # bevy_tasks::ComputeTaskPool::get_or_init(|| bevy_tasks::TaskPool::new()); + /// # bevy_tasks::ComputeTaskPool::get_or_default(); /// /// let mut world = World::new(); /// diff --git a/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs b/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs index fa9a19058081e..cd516f80dc11c 100644 --- a/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs +++ b/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs @@ -3,7 +3,7 @@ use std::{ sync::{Arc, Mutex, MutexGuard}, }; -use bevy_tasks::{ComputeTaskPool, Scope, TaskPool, ThreadExecutor}; +use bevy_tasks::{ComputeTaskPool, StaticScope, TaskPool, ThreadExecutor}; use bevy_utils::default; use bevy_utils::syncunsafecell::SyncUnsafeCell; #[cfg(feature = "trace")] @@ -132,7 +132,7 @@ pub struct ExecutorState { #[derive(Copy, Clone)] struct Context<'scope, 'env, 'sys> { environment: &'env Environment<'env, 'sys>, - scope: &'scope Scope<'scope, 'env, ()>, + scope: &'scope StaticScope<'scope, 'env, ()>, } impl Default for MultiThreadedExecutor { @@ -218,17 +218,13 @@ impl SystemExecutor for MultiThreadedExecutor { let environment = &Environment::new(self, schedule, world); - ComputeTaskPool::get_or_init(TaskPool::default).scope_with_executor( - false, - thread_executor, - |scope| { - let context = Context { environment, scope }; + ComputeTaskPool::get().scope_with_executor(false, thread_executor, |scope| { + let context = Context { environment, scope }; - // The first tick won't need to process finished systems, but we still need to run the loop in - // tick_executor() in case a system completes while the first tick still holds the mutex. - context.tick_executor(); - }, - ); + // The first tick won't need to process finished systems, but we still need to run the loop in + // tick_executor() in case a system completes while the first tick still holds the mutex. + context.tick_executor(); + }); // End the borrows of self and world in environment by copying out the reference to systems. let systems = environment.systems; diff --git a/crates/bevy_ecs/src/schedule/mod.rs b/crates/bevy_ecs/src/schedule/mod.rs index b38f7adb67923..0cab0eca6c48b 100644 --- a/crates/bevy_ecs/src/schedule/mod.rs +++ b/crates/bevy_ecs/src/schedule/mod.rs @@ -100,12 +100,12 @@ mod tests { #[test] #[cfg(not(miri))] fn parallel_execution() { - use bevy_tasks::{ComputeTaskPool, TaskPool}; + use bevy_tasks::ComputeTaskPool; use std::sync::{Arc, Barrier}; let mut world = World::default(); let mut schedule = Schedule::default(); - let thread_count = ComputeTaskPool::get_or_init(TaskPool::default).thread_num(); + let thread_count = ComputeTaskPool::get_or_default().thread_num(); let barrier = Arc::new(Barrier::new(thread_count)); diff --git a/crates/bevy_tasks/Cargo.toml b/crates/bevy_tasks/Cargo.toml index 98c4edbb8d49d..f9c344ea9ea1a 100644 --- a/crates/bevy_tasks/Cargo.toml +++ b/crates/bevy_tasks/Cargo.toml @@ -10,14 +10,19 @@ keywords = ["bevy"] [features] multi-threaded = ["dep:async-channel", "dep:async-task", "dep:concurrent-queue"] +trace = ["tracing"] [dependencies] futures-lite = "2.0.1" -async-executor = "1.7.2" +async-executor = { git = "https://github.com/james7132/async-executor", branch = "leaked-executor", features = [ + "static", +] } async-channel = { version = "2.2.0", optional = true } async-io = { version = "2.0.0", optional = true } async-task = { version = "4.2.0", optional = true } concurrent-queue = { version = "2.0.0", optional = true } +tracing = { version = "0.1", optional = true } +thiserror = "1.0" [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen-futures = "0.4" diff --git a/crates/bevy_tasks/src/lib.rs b/crates/bevy_tasks/src/lib.rs index 34011532d6b96..d8344c51db82d 100644 --- a/crates/bevy_tasks/src/lib.rs +++ b/crates/bevy_tasks/src/lib.rs @@ -11,15 +11,21 @@ pub use slice::{ParallelSlice, ParallelSliceMut}; mod task; pub use task::Task; +#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))] +mod static_task_pool; #[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))] mod task_pool; #[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))] +pub use static_task_pool::{StaticScope, StaticTaskPool}; +#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))] pub use task_pool::{Scope, TaskPool, TaskPoolBuilder}; #[cfg(any(target_arch = "wasm32", not(feature = "multi-threaded")))] mod single_threaded_task_pool; #[cfg(any(target_arch = "wasm32", not(feature = "multi-threaded")))] -pub use single_threaded_task_pool::{FakeTask, Scope, TaskPool, TaskPoolBuilder, ThreadExecutor}; +pub use single_threaded_task_pool::{ + FakeTask, Scope, StaticScope, StaticTaskPool, TaskPool, TaskPoolBuilder, ThreadExecutor, +}; mod usages; #[cfg(not(target_arch = "wasm32"))] @@ -41,6 +47,7 @@ mod iter; pub use iter::ParallelIterator; pub use futures_lite; +use thiserror::Error; #[allow(missing_docs)] pub mod prelude { @@ -55,6 +62,20 @@ pub mod prelude { use std::num::NonZeroUsize; +/// Potential errors when initializing a [`StaticTaskPool`]. +#[derive(Error, Debug)] +pub enum TaskPoolInitializationError { + /// The task pool was already initialized and cannot be changed after initialization. + #[error("The task pool is already initialized.")] + AlreadyInitialized, + /// The task pool would have been initialized with zero threads. + #[error("The task pool would have been initialized with zero threads.")] + ZeroThreads, + /// Initialization failed to spawn a thread. + #[error("Failed to spawn thread: {0:?}")] + ThreadSpawnError(#[from] std::io::Error), +} + /// Gets the logical CPU core count available to the current process. /// /// This is identical to [`std::thread::available_parallelism`], except diff --git a/crates/bevy_tasks/src/single_threaded_task_pool.rs b/crates/bevy_tasks/src/single_threaded_task_pool.rs index 3a32c9e286211..0d5644287ad2a 100644 --- a/crates/bevy_tasks/src/single_threaded_task_pool.rs +++ b/crates/bevy_tasks/src/single_threaded_task_pool.rs @@ -1,3 +1,4 @@ +use crate::TaskPoolInitializationError; use std::sync::Arc; use std::{cell::RefCell, future::Future, marker::PhantomData, mem, rc::Rc}; @@ -5,6 +6,14 @@ thread_local! { static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = async_executor::LocalExecutor::new(); } +/// A [`TaskPool`] optimized for use in static variables. +pub type StaticTaskPool = TaskPool; + +/// A [`StaticTaskPool`] scope for running one or more non-`'static` futures. +/// +/// For more information, see [`TaskPool::scope`]. +pub type StaticScope<'scope, 'env, T> = Scope<'scope, 'env, T>; + /// Used to create a [`TaskPool`]. #[derive(Debug, Default, Clone)] pub struct TaskPoolBuilder {} @@ -25,8 +34,8 @@ impl<'a> ThreadExecutor<'a> { impl TaskPoolBuilder { /// Creates a new `TaskPoolBuilder` instance - pub fn new() -> Self { - Self::default() + pub const fn new() -> Self { + Self {} } /// No op on the single threaded task pool @@ -45,7 +54,7 @@ impl TaskPoolBuilder { } /// Creates a new [`TaskPool`] - pub fn build(self) -> TaskPool { + pub const fn build(self) -> TaskPool { TaskPool::new_internal() } } @@ -62,15 +71,28 @@ impl TaskPool { } /// Create a `TaskPool` with the default configuration. - pub fn new() -> Self { + pub const fn new() -> Self { TaskPoolBuilder::new().build() } #[allow(unused_variables)] - fn new_internal() -> Self { + const fn new_internal() -> Self { Self {} } + /// Checks if the threads in the task pool have been started or not. This always returns + /// true in single threaded builds. + pub fn is_initialized(&self) -> bool { + true + } + + /// Initializes the task pool with the provided builder. This always is a no-op + /// true in single threaded builds. + #[allow(unused_variables)] + pub fn init(&self, builder: TaskPoolBuilder) -> Result<(), TaskPoolInitializationError> { + Ok(()) + } + /// Return the number of threads owned by the task pool pub fn thread_num(&self) -> usize { 1 diff --git a/crates/bevy_tasks/src/static_task_pool.rs b/crates/bevy_tasks/src/static_task_pool.rs new file mode 100644 index 0000000000000..0da0f482ff20f --- /dev/null +++ b/crates/bevy_tasks/src/static_task_pool.rs @@ -0,0 +1,574 @@ +use crate::{ + block_on, Task, TaskPool, TaskPoolBuilder, TaskPoolInitializationError, ThreadExecutor, + ThreadExecutorTicker, +}; + +use async_executor::StaticExecutor; +use async_task::FallibleTask; +use concurrent_queue::ConcurrentQueue; +use futures_lite::FutureExt; + +use std::future::Future; +use std::marker::PhantomData; +use std::panic::AssertUnwindSafe; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Mutex, +}; +use std::thread::JoinHandle; + +/// A [`TaskPool`] optimized for use in `static` variables. +#[derive(Debug)] +pub struct StaticTaskPool { + executor: StaticExecutor, + threads: Mutex>>, + thread_count: AtomicUsize, +} + +impl StaticTaskPool { + #[allow(clippy::new_without_default)] + pub(crate) const fn new() -> Self { + Self { + executor: StaticExecutor::new(), + threads: Mutex::new(Vec::new()), + thread_count: AtomicUsize::new(0), + } + } + + /// The number of threads active in the task pool. + pub fn thread_num(&self) -> usize { + self.thread_count.load(Ordering::Relaxed) + } + + /// Checks if the threads for the task pool have been started. + pub fn is_initialized(&self) -> bool { + self.thread_num() > 0 + } + + /// Initializes the task pool with the `builder`. + /// + /// Retuns an error if the task pool was already initialized or provided a builder + /// that yields zero threads. + pub fn init( + &'static self, + builder: TaskPoolBuilder, + ) -> Result<(), TaskPoolInitializationError> { + let mut join_handles = self.threads.lock().unwrap(); + + if !join_handles.is_empty() { + // TODO: figure out a way to support reconfiguring/reinitializing StaticTaskPools. + return Err(TaskPoolInitializationError::AlreadyInitialized); + } + + let num_threads = builder + .num_threads + .unwrap_or_else(crate::available_parallelism); + + if num_threads == 0 { + return Err(TaskPoolInitializationError::ZeroThreads); + } + + *join_handles = Vec::with_capacity(num_threads); + for i in 0..num_threads { + let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() { + format!("{thread_name} ({i})") + } else { + format!("TaskPool ({i})") + }; + let mut thread_builder = std::thread::Builder::new().name(thread_name); + + if let Some(stack_size) = builder.stack_size { + thread_builder = thread_builder.stack_size(stack_size); + } + + let on_thread_spawn = builder.on_thread_spawn.clone(); + + let res = thread_builder.spawn(move || { + TaskPool::LOCAL_EXECUTOR.with(|local_executor| { + if let Some(on_thread_spawn) = on_thread_spawn { + on_thread_spawn(); + drop(on_thread_spawn); + } + loop { + let res = std::panic::catch_unwind(|| { + let tick_forever = async move { + loop { + local_executor.tick().await; + } + }; + block_on(self.executor.run(tick_forever)); + }); + if res.is_ok() { + break; + } + } + }); + }); + match res { + Ok(join_handle) => { + join_handles.push(join_handle); + } + Err(join_handle) => { + *join_handles = Vec::new(); + return Err(join_handle.into()); + } + } + } + self.thread_count.store(num_threads, Ordering::Relaxed); + Ok(()) + } + + /// Allows spawning non-`'static` futures on the thread pool. The function takes a callback, + /// passing a scope object into it. The scope object provided to the callback can be used + /// to spawn tasks. This function will await the completion of all tasks before returning. + /// + /// This is similar to [`thread::scope`] and `rayon::scope`. + /// + /// # Example + /// + /// ``` + /// use bevy_tasks::TaskPool; + /// + /// let pool = TaskPool::new(); + /// let mut x = 0; + /// let results = pool.scope(|s| { + /// s.spawn(async { + /// // you can borrow the spawner inside a task and spawn tasks from within the task + /// s.spawn(async { + /// // borrow x and mutate it. + /// x = 2; + /// // return a value from the task + /// 1 + /// }); + /// // return some other value from the first task + /// 0 + /// }); + /// }); + /// + /// // The ordering of results is non-deterministic if you spawn from within tasks as above. + /// // If you're doing this, you'll have to write your code to not depend on the ordering. + /// assert!(results.contains(&0)); + /// assert!(results.contains(&1)); + /// + /// // The ordering is deterministic if you only spawn directly from the closure function. + /// let results = pool.scope(|s| { + /// s.spawn(async { 0 }); + /// s.spawn(async { 1 }); + /// }); + /// assert_eq!(&results[..], &[0, 1]); + /// + /// // You can access x after scope runs, since it was only temporarily borrowed in the scope. + /// assert_eq!(x, 2); + /// ``` + /// + /// # Lifetimes + /// + /// The [`Scope`] object takes two lifetimes: `'scope` and `'env`. + /// + /// The `'scope` lifetime represents the lifetime of the scope. That is the time during + /// which the provided closure and tasks that are spawned into the scope are run. + /// + /// The `'env` lifetime represents the lifetime of whatever is borrowed by the scope. + /// Thus this lifetime must outlive `'scope`. + /// + /// ```compile_fail + /// use bevy_tasks::TaskPool; + /// fn scope_escapes_closure() { + /// let pool = TaskPool::new(); + /// let foo = Box::new(42); + /// pool.scope(|scope| { + /// std::thread::spawn(move || { + /// // UB. This could spawn on the scope after `.scope` returns and the internal Scope is dropped. + /// scope.spawn(async move { + /// assert_eq!(*foo, 42); + /// }); + /// }); + /// }); + /// } + /// ``` + /// + /// ```compile_fail + /// use bevy_tasks::TaskPool; + /// fn cannot_borrow_from_closure() { + /// let pool = TaskPool::new(); + /// pool.scope(|scope| { + /// let x = 1; + /// let y = &x; + /// scope.spawn(async move { + /// assert_eq!(*y, 1); + /// }); + /// }); + /// } + pub fn scope<'env, F, T>(&'static self, f: F) -> Vec + where + F: for<'scope> FnOnce(&'scope StaticScope<'scope, 'env, T>), + T: Send + 'static, + { + TaskPool::THREAD_EXECUTOR.with(|scope_executor| { + self.scope_with_executor_inner(true, scope_executor, scope_executor, f) + }) + } + + /// This allows passing an external executor to spawn tasks on. When you pass an external executor + /// [`Scope::spawn_on_scope`] spawns is then run on the thread that [`ThreadExecutor`] is being ticked on. + /// If [`None`] is passed the scope will use a [`ThreadExecutor`] that is ticked on the current thread. + /// + /// When `tick_task_pool_executor` is set to `true`, the multithreaded task stealing executor is ticked on the scope + /// thread. Disabling this can be useful when finishing the scope is latency sensitive. Pulling tasks from + /// global executor can run tasks unrelated to the scope and delay when the scope returns. + /// + /// See [`Self::scope`] for more details in general about how scopes work. + pub fn scope_with_executor<'env, F, T>( + &'static self, + tick_task_pool_executor: bool, + external_executor: Option<&ThreadExecutor>, + f: F, + ) -> Vec + where + F: for<'scope> FnOnce(&'scope StaticScope<'scope, 'env, T>), + T: Send + 'static, + { + TaskPool::THREAD_EXECUTOR.with(|scope_executor| { + // If a `external_executor` is passed use that. Otherwise get the executor stored + // in the `THREAD_EXECUTOR` thread local. + if let Some(external_executor) = external_executor { + self.scope_with_executor_inner( + tick_task_pool_executor, + external_executor, + scope_executor, + f, + ) + } else { + self.scope_with_executor_inner( + tick_task_pool_executor, + scope_executor, + scope_executor, + f, + ) + } + }) + } + + #[allow(unsafe_code)] + fn scope_with_executor_inner<'env, F, T>( + &'static self, + tick_task_pool_executor: bool, + external_executor: &ThreadExecutor, + scope_executor: &ThreadExecutor, + f: F, + ) -> Vec + where + F: for<'scope> FnOnce(&'scope StaticScope<'scope, 'env, T>), + T: Send + 'static, + { + // SAFETY: This safety comment applies to all references transmuted to 'env. + // Any futures spawned with these references need to return before this function completes. + // This is guaranteed because we drive all the futures spawned onto the Scope + // to completion in this function. However, rust has no way of knowing this so we + // transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety. + // Any usages of the references passed into `Scope` must be accessed through + // the transmuted reference for the rest of this function. + // SAFETY: As above, all futures must complete in this function so we can change the lifetime + let external_executor: &'env ThreadExecutor<'env> = + unsafe { std::mem::transmute(external_executor) }; + // SAFETY: As above, all futures must complete in this function so we can change the lifetime + let scope_executor: &'env ThreadExecutor<'env> = + unsafe { std::mem::transmute(scope_executor) }; + let spawned: ConcurrentQueue>>> = + ConcurrentQueue::unbounded(); + // shadow the variable so that the owned value cannot be used for the rest of the function + // SAFETY: As above, all futures must complete in this function so we can change the lifetime + let spawned: &'env ConcurrentQueue< + FallibleTask>>, + > = unsafe { std::mem::transmute(&spawned) }; + + let scope = StaticScope { + executor: &self.executor, + external_executor, + scope_executor, + spawned, + scope: PhantomData, + env: PhantomData, + }; + + // shadow the variable so that the owned value cannot be used for the rest of the function + // SAFETY: As above, all futures must complete in this function so we can change the lifetime + let scope: &'env StaticScope<'_, 'env, T> = unsafe { std::mem::transmute(&scope) }; + + f(scope); + + if spawned.is_empty() { + Vec::new() + } else { + block_on(async move { + let get_results = async { + let mut results = Vec::with_capacity(spawned.len()); + while let Ok(task) = spawned.pop() { + if let Some(res) = task.await { + match res { + Ok(res) => results.push(res), + Err(payload) => std::panic::resume_unwind(payload), + } + } else { + panic!("Failed to catch panic!"); + } + } + results + }; + + let tick_task_pool_executor = tick_task_pool_executor || self.thread_num() == 0; + + // we get this from a thread local so we should always be on the scope executors thread. + // note: it is possible `scope_executor` and `external_executor` is the same executor, + // in that case, we should only tick one of them, otherwise, it may cause deadlock. + let scope_ticker = scope_executor.ticker().unwrap(); + let external_ticker = if !external_executor.is_same(scope_executor) { + external_executor.ticker() + } else { + None + }; + + match (external_ticker, tick_task_pool_executor) { + (Some(external_ticker), true) => { + Self::execute_global_external_scope( + &self.executor, + external_ticker, + scope_ticker, + get_results, + ) + .await + } + (Some(external_ticker), false) => { + Self::execute_external_scope(external_ticker, scope_ticker, get_results) + .await + } + // either external_executor is none or it is same as scope_executor + (None, true) => { + Self::execute_global_scope(&self.executor, scope_ticker, get_results).await + } + (None, false) => Self::execute_scope(scope_ticker, get_results).await, + } + }) + } + } + + #[inline] + async fn execute_global_external_scope<'scope, 'ticker, T>( + executor: &'static StaticExecutor, + external_ticker: ThreadExecutorTicker<'scope, 'ticker>, + scope_ticker: ThreadExecutorTicker<'scope, 'ticker>, + get_results: impl Future>, + ) -> Vec { + // we restart the executors if a task errors. if a scoped + // task errors it will panic the scope on the call to get_results + let execute_forever = async move { + loop { + let tick_forever = async { + loop { + external_ticker.tick().or(scope_ticker.tick()).await; + } + }; + // we don't care if it errors. If a scoped task errors it will propagate + // to get_results + let _result = AssertUnwindSafe(executor.run(tick_forever)) + .catch_unwind() + .await + .is_ok(); + } + }; + execute_forever.or(get_results).await + } + + #[inline] + async fn execute_external_scope<'scope, 'ticker, T>( + external_ticker: ThreadExecutorTicker<'scope, 'ticker>, + scope_ticker: ThreadExecutorTicker<'scope, 'ticker>, + get_results: impl Future>, + ) -> Vec { + let execute_forever = async { + loop { + let tick_forever = async { + loop { + external_ticker.tick().or(scope_ticker.tick()).await; + } + }; + let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok(); + } + }; + execute_forever.or(get_results).await + } + + #[inline] + async fn execute_global_scope<'scope, 'ticker, T>( + executor: &'static StaticExecutor, + scope_ticker: ThreadExecutorTicker<'scope, 'ticker>, + get_results: impl Future>, + ) -> Vec { + let execute_forever = async { + loop { + let tick_forever = async { + loop { + scope_ticker.tick().await; + } + }; + let _result = AssertUnwindSafe(executor.run(tick_forever)) + .catch_unwind() + .await + .is_ok(); + } + }; + execute_forever.or(get_results).await + } + + #[inline] + async fn execute_scope<'scope, 'ticker, T>( + scope_ticker: ThreadExecutorTicker<'scope, 'ticker>, + get_results: impl Future>, + ) -> Vec { + let execute_forever = async { + loop { + let tick_forever = async { + loop { + scope_ticker.tick().await; + } + }; + let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok(); + } + }; + execute_forever.or(get_results).await + } + + /// Spawns a static future onto the thread pool. The returned [`Task`] is a + /// future that can be polled for the result. It can also be canceled and + /// "detached", allowing the task to continue running even if dropped. In + /// any case, the pool will execute the task even without polling by the + /// end-user. + /// + /// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should + /// be used instead. + pub fn spawn(&'static self, future: impl Future + Send + 'static) -> Task + where + T: Send + 'static, + { + Task::new(self.executor.spawn(future)) + } + + /// Spawns a static future on the thread-local async executor for the + /// current thread. The task will run entirely on the thread the task was + /// spawned on. + /// + /// The returned [`Task`] is a future that can be polled for the + /// result. It can also be canceled and "detached", allowing the task to + /// continue running even if dropped. In any case, the pool will execute the + /// task even without polling by the end-user. + /// + /// Users should generally prefer to use [`TaskPool::spawn`] instead, + /// unless the provided future is not `Send`. + pub fn spawn_local(&self, future: impl Future + 'static) -> Task + where + T: 'static, + { + Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future))) + } + + /// Runs a function with the local executor. Typically used to tick + /// the local executor on the main thread as it needs to share time with + /// other things. + /// + /// ``` + /// use bevy_tasks::TaskPool; + /// + /// TaskPool::new().with_local_executor(|local_executor| { + /// local_executor.try_tick(); + /// }); + /// ``` + pub fn with_local_executor(&self, f: F) -> R + where + F: FnOnce(&async_executor::LocalExecutor) -> R, + { + TaskPool::LOCAL_EXECUTOR.with(f) + } +} + +/// A [`StaticTaskPool`] scope for running one or more non-`'static` futures. +/// +/// For more information, see [`TaskPool::scope`]. +#[derive(Debug)] +pub struct StaticScope<'scope, 'env: 'scope, T> { + executor: &'static StaticExecutor, + external_executor: &'scope ThreadExecutor<'scope>, + scope_executor: &'scope ThreadExecutor<'scope>, + spawned: &'scope ConcurrentQueue>>>, + // make `Scope` invariant over 'scope and 'env + scope: PhantomData<&'scope mut &'scope ()>, + env: PhantomData<&'env mut &'env ()>, +} + +impl<'scope, 'env, T: Send + 'static> StaticScope<'scope, 'env, T> { + /// Spawns a scoped future onto the thread pool. The scope *must* outlive + /// the provided future. The results of the future will be returned as a part of + /// [`TaskPool::scope`]'s return value. + /// + /// For futures that should run on the thread `scope` is called on [`Scope::spawn_on_scope`] should be used + /// instead. + /// + /// For more information, see [`TaskPool::scope`]. + #[allow(unsafe_code)] + pub fn spawn + 'scope + Send>(&self, f: Fut) { + // SAFETY: T lasts for the full 'static lifetime. + let task = unsafe { + self.executor + .spawn_scoped(AssertUnwindSafe(f).catch_unwind()) + .fallible() + }; + // ConcurrentQueue only errors when closed or full, but we never + // close and use an unbounded queue, so it is safe to unwrap + self.spawned.push(task).unwrap(); + } + + /// Spawns a scoped future onto the thread the scope is run on. The scope *must* outlive + /// the provided future. The results of the future will be returned as a part of + /// [`TaskPool::scope`]'s return value. Users should generally prefer to use + /// [`Scope::spawn`] instead, unless the provided future needs to run on the scope's thread. + /// + /// For more information, see [`TaskPool::scope`]. + pub fn spawn_on_scope + 'scope + Send>(&self, f: Fut) { + let task = self + .scope_executor + .spawn(AssertUnwindSafe(f).catch_unwind()) + .fallible(); + // ConcurrentQueue only errors when closed or full, but we never + // close and use an unbounded queue, so it is safe to unwrap + self.spawned.push(task).unwrap(); + } + + /// Spawns a scoped future onto the thread of the external thread executor. + /// This is typically the main thread. The scope *must* outlive + /// the provided future. The results of the future will be returned as a part of + /// [`TaskPool::scope`]'s return value. Users should generally prefer to use + /// [`Scope::spawn`] instead, unless the provided future needs to run on the external thread. + /// + /// For more information, see [`TaskPool::scope`]. + pub fn spawn_on_external + 'scope + Send>(&self, f: Fut) { + let task = self + .external_executor + .spawn(AssertUnwindSafe(f).catch_unwind()) + .fallible(); + // ConcurrentQueue only errors when closed or full, but we never + // close and use an unbounded queue, so it is safe to unwrap + self.spawned.push(task).unwrap(); + } +} + +impl<'scope, 'env, T> Drop for StaticScope<'scope, 'env, T> +where + T: 'scope, +{ + fn drop(&mut self) { + block_on(async { + while let Ok(task) = self.spawned.pop() { + task.cancel().await; + } + }); + } +} diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index 1e58f128ca2ec..1db203e5794b8 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -33,15 +33,15 @@ impl Drop for CallOnDrop { pub struct TaskPoolBuilder { /// If set, we'll set up the thread pool to use at most `num_threads` threads. /// Otherwise use the logical core count of the system - num_threads: Option, + pub(crate) num_threads: Option, /// If set, we'll use the given stack size rather than the system default - stack_size: Option, + pub(crate) stack_size: Option, /// Allows customizing the name of the threads - helpful for debugging. If set, threads will /// be named ` ()`, i.e. `"MyThreadPool (2)"`. - thread_name: Option, + pub(crate) thread_name: Option, - on_thread_spawn: Option>, - on_thread_destroy: Option>, + pub(crate) on_thread_spawn: Option>, + pub(crate) on_thread_destroy: Option>, } impl TaskPoolBuilder { @@ -116,8 +116,8 @@ pub struct TaskPool { impl TaskPool { thread_local! { - static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = const { async_executor::LocalExecutor::new() }; - static THREAD_EXECUTOR: Arc> = Arc::new(ThreadExecutor::new()); + pub(crate) static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = const { async_executor::LocalExecutor::new() }; + pub(crate) static THREAD_EXECUTOR: Arc> = Arc::new(ThreadExecutor::new()); } /// Each thread should only create one `ThreadExecutor`, otherwise, there are good chances they will deadlock diff --git a/crates/bevy_tasks/src/usages.rs b/crates/bevy_tasks/src/usages.rs index fda3092b8ebc8..895a2f2877df9 100644 --- a/crates/bevy_tasks/src/usages.rs +++ b/crates/bevy_tasks/src/usages.rs @@ -1,48 +1,34 @@ -use super::TaskPool; -use std::{ops::Deref, sync::OnceLock}; +use crate::{StaticTaskPool, TaskPoolBuilder, TaskPoolInitializationError}; macro_rules! taskpool { ($(#[$attr:meta])* ($static:ident, $type:ident)) => { - static $static: OnceLock<$type> = OnceLock::new(); + static $static: $type = $type(StaticTaskPool::new()); $(#[$attr])* #[derive(Debug)] - pub struct $type(TaskPool); + pub struct $type(StaticTaskPool); impl $type { - #[doc = concat!(" Gets the global [`", stringify!($type), "`] instance, or initializes it with `f`.")] - pub fn get_or_init(f: impl FnOnce() -> TaskPool) -> &'static Self { - $static.get_or_init(|| Self(f())) - } - - #[doc = concat!(" Attempts to get the global [`", stringify!($type), "`] instance, \ - or returns `None` if it is not initialized.")] - pub fn try_get() -> Option<&'static Self> { - $static.get() - } - #[doc = concat!(" Gets the global [`", stringify!($type), "`] instance.")] - #[doc = ""] - #[doc = " # Panics"] - #[doc = " Panics if the global instance has not been initialized yet."] - pub fn get() -> &'static Self { - $static.get().expect( - concat!( - "The ", - stringify!($type), - " has not been initialized yet. Please call ", - stringify!($type), - "::get_or_init beforehand." - ) - ) + pub fn get() -> &'static StaticTaskPool { + &$static.0 } - } - impl Deref for $type { - type Target = TaskPool; + /// Gets the global instance, or initializes it with the provided builder if + /// it hasn't already been initialized. + pub fn get_or_init(builder: TaskPoolBuilder) -> &'static StaticTaskPool { + let pool = &$static.0; + match pool.init(builder) { + Ok(()) => pool, + Err(TaskPoolInitializationError::AlreadyInitialized) => pool, + Err(err) => panic!("Error while initializing task pool: {}", err), + } + } - fn deref(&self) -> &Self::Target { - &self.0 + /// Gets the global instance, or initializes it with the default configuration if + /// it hasn't already been initialized. + pub fn get_or_default() -> &'static StaticTaskPool { + Self::get_or_init(Default::default()) } } }; @@ -83,23 +69,18 @@ taskpool! { #[cfg(not(target_arch = "wasm32"))] pub fn tick_global_task_pools_on_main_thread() { COMPUTE_TASK_POOL - .get() - .unwrap() + .0 .with_local_executor(|compute_local_executor| { ASYNC_COMPUTE_TASK_POOL - .get() - .unwrap() + .0 .with_local_executor(|async_local_executor| { - IO_TASK_POOL - .get() - .unwrap() - .with_local_executor(|io_local_executor| { - for _ in 0..100 { - compute_local_executor.try_tick(); - async_local_executor.try_tick(); - io_local_executor.try_tick(); - } - }); + IO_TASK_POOL.0.with_local_executor(|io_local_executor| { + for _ in 0..100 { + compute_local_executor.try_tick(); + async_local_executor.try_tick(); + io_local_executor.try_tick(); + } + }); }); }); } diff --git a/crates/bevy_transform/src/systems.rs b/crates/bevy_transform/src/systems.rs index 401a32cb22cc4..59b046ee4f0a9 100644 --- a/crates/bevy_transform/src/systems.rs +++ b/crates/bevy_transform/src/systems.rs @@ -186,7 +186,7 @@ mod test { use bevy_ecs::prelude::*; use bevy_ecs::world::CommandQueue; use bevy_math::{vec3, Vec3}; - use bevy_tasks::{ComputeTaskPool, TaskPool}; + use bevy_tasks::ComputeTaskPool; use crate::systems::*; use crate::TransformBundle; @@ -194,7 +194,7 @@ mod test { #[test] fn correct_parent_removed() { - ComputeTaskPool::get_or_init(TaskPool::default); + ComputeTaskPool::get_or_default(); let mut world = World::default(); let offset_global_transform = |offset| GlobalTransform::from(Transform::from_xyz(offset, offset, offset)); @@ -249,7 +249,7 @@ mod test { #[test] fn did_propagate() { - ComputeTaskPool::get_or_init(TaskPool::default); + ComputeTaskPool::get_or_default(); let mut world = World::default(); let mut schedule = Schedule::default(); @@ -327,7 +327,7 @@ mod test { #[test] fn correct_children() { - ComputeTaskPool::get_or_init(TaskPool::default); + ComputeTaskPool::get_or_default(); let mut world = World::default(); let mut schedule = Schedule::default(); @@ -405,7 +405,7 @@ mod test { #[test] fn correct_transforms_when_no_children() { let mut app = App::new(); - ComputeTaskPool::get_or_init(TaskPool::default); + ComputeTaskPool::get_or_default(); app.add_systems(Update, (sync_simple_transforms, propagate_transforms)); @@ -450,7 +450,7 @@ mod test { #[test] #[should_panic] fn panic_when_hierarchy_cycle() { - ComputeTaskPool::get_or_init(TaskPool::default); + ComputeTaskPool::get_or_default(); // We cannot directly edit Parent and Children, so we use a temp world to break // the hierarchy's invariants. let mut temp = World::new();