diff --git a/crates/bevy_tasks/Cargo.toml b/crates/bevy_tasks/Cargo.toml index f86fb78d3bc23..1d3b79a96b7c6 100644 --- a/crates/bevy_tasks/Cargo.toml +++ b/crates/bevy_tasks/Cargo.toml @@ -8,6 +8,10 @@ repository = "https://github.com/bevyengine/bevy" license = "MIT OR Apache-2.0" keywords = ["bevy"] +[features] +default = ["tokio"] +tokio = ["dep:tokio"] + [dependencies] futures-lite = "1.4.0" async-executor = "1.3.0" @@ -19,5 +23,8 @@ concurrent-queue = "1.2.2" [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen-futures = "0.4" +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +tokio = { version = "1.22", optional = true, features = ["rt-multi-thread"]} + [dev-dependencies] instant = { version = "0.1", features = ["wasm-bindgen"] } diff --git a/crates/bevy_tasks/src/lib.rs b/crates/bevy_tasks/src/lib.rs index 802f6c267b7cf..3676004445cc7 100644 --- a/crates/bevy_tasks/src/lib.rs +++ b/crates/bevy_tasks/src/lib.rs @@ -4,14 +4,25 @@ mod slice; pub use slice::{ParallelSlice, ParallelSliceMut}; +#[cfg(any(target_arch = "wasm32", not(feature = "tokio")))] mod task; +#[cfg(any(target_arch = "wasm32", not(feature = "tokio")))] pub use task::Task; +#[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))] +mod tokio_task; +#[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))] +pub use tokio_task::Task; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(not(target_arch = "wasm32"), not(feature = "tokio")))] mod task_pool; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(all(not(target_arch = "wasm32"), not(feature = "tokio")))] pub use task_pool::{Scope, TaskPool, TaskPoolBuilder}; +#[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))] +mod tokio_task_pool; +#[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))] +pub use tokio_task_pool::{Scope, TaskPool, TaskPoolBuilder}; + #[cfg(target_arch = "wasm32")] mod single_threaded_task_pool; #[cfg(target_arch = "wasm32")] diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index 099f96e93d006..518b56461d980 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -369,7 +369,7 @@ pub struct Scope<'scope, 'env: 'scope, T> { env: PhantomData<&'env mut &'env ()>, } -impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> { +impl<'scope, 'env, T: Send + 'static> Scope<'scope, 'env, T> { /// Spawns a scoped future onto the thread pool. The scope *must* outlive /// the provided future. The results of the future will be returned as a part of /// [`TaskPool::scope`]'s return value. diff --git a/crates/bevy_tasks/src/tokio_task.rs b/crates/bevy_tasks/src/tokio_task.rs new file mode 100644 index 0000000000000..15ec368e0ca99 --- /dev/null +++ b/crates/bevy_tasks/src/tokio_task.rs @@ -0,0 +1,81 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use futures_lite::FutureExt; +use tokio::task::JoinHandle; + +/// Wraps `async_executor::Task`, a spawned future. +/// +/// Tasks are also futures themselves and yield the output of the spawned future. +/// +/// When a task is dropped, its gets canceled and won't be polled again. To cancel a task a bit +/// more gracefully and wait until it stops running, use the [`cancel()`][Task::cancel()] method. +/// +/// Tasks that panic get immediately canceled. Awaiting a canceled task also causes a panic. +/// Wraps `async_executor::Task` +#[derive(Debug)] +#[must_use = "Tasks are canceled when dropped, use `.detach()` to run them in the background."] +pub struct Task(Option>); + +impl Task { + /// Creates a new task from a given `async_executor::Task` + pub fn new(task: JoinHandle) -> Self { + Self(Some(task)) + } + + /// Detaches the task to let it keep running in the background. See + /// `async_executor::Task::detach` + pub fn detach(mut self) { + drop(self.0.take()); + } + + /// Cancels the task and waits for it to stop running. + /// + /// Returns the task's output if it was completed just before it got canceled, or [`None`] if + /// it didn't complete. + /// + /// While it's possible to simply drop the [`Task`] to cancel it, this is a cleaner way of + /// canceling because it also waits for the task to stop running. + /// + /// See `async_executor::Task::cancel` + pub async fn cancel(mut self) -> Option { + self.0.take()? + .await + .ok() + } + + /// Returns `true` if the current task is finished. + /// + /// + /// Unlike poll, it doesn't resolve the final value, it just checks if the task has finished. + /// Note that in a multithreaded environment, this task can be finished immediately after calling this function. + pub fn is_finished(&self) -> bool { + self.0.as_ref().map(|handle| handle.is_finished()).unwrap_or(true) + } +} + +impl Future for Task { + type Output = T; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(handle) = self.0.as_mut() { + match handle.poll(cx) { + Poll::Ready(Ok(result)) => Poll::Ready(result), + Poll::Ready(Err(err)) => panic!("Task has failed: {}", err), + Poll::Pending => Poll::Pending, + } + } else { + unreachable!("Polling dropped task"); + } + } +} + +impl Drop for Task { + fn drop(&mut self) { + if let Some(handle) = self.0.take() { + handle.abort(); + } + } +} \ No newline at end of file diff --git a/crates/bevy_tasks/src/tokio_task_pool.rs b/crates/bevy_tasks/src/tokio_task_pool.rs new file mode 100644 index 0000000000000..cd252c8446d2b --- /dev/null +++ b/crates/bevy_tasks/src/tokio_task_pool.rs @@ -0,0 +1,548 @@ +use std::{ + future::Future, + marker::PhantomData, + mem, + sync::{atomic::{AtomicUsize, Ordering}, Arc}, + pin::Pin, +}; + +use tokio::task::JoinHandle; +use tokio::runtime::{Runtime, Builder}; +use concurrent_queue::ConcurrentQueue; +use futures_lite::{future, pin}; + +use crate::tokio_task::Task; + +/// Used to create a [`TaskPool`] +#[derive(Debug)] +#[must_use] +pub struct TaskPoolBuilder { + thread_count: Option, + builder: Builder, +} + +impl Default for TaskPoolBuilder { + fn default() -> Self { + Self::new() + } +} + +impl TaskPoolBuilder { + /// Creates a new [`TaskPoolBuilder`] instance + pub fn new() -> Self { + let mut builder = Builder::new_multi_thread(); + builder.enable_all().thread_name(String::from("TaskPool")); + Self { + thread_count: None, + builder, + } + } + + /// Override the number of threads created for the pool. If unset, we default to the number + /// of logical cores of the system + pub fn num_threads(mut self, num_threads: usize) -> Self { + self.thread_count = Some(num_threads); + self.builder.worker_threads(num_threads); + self + } + + /// Override the stack size of the threads created for the pool + pub fn stack_size(mut self, stack_size: usize) -> Self { + self.builder.thread_stack_size(stack_size); + self + } + + /// Override the name of the threads created for the pool. If set, threads will + /// be named ` ()`, i.e. `MyThreadPool (2)` + pub fn thread_name(mut self, thread_name: String) -> Self { + let counter = Arc::new(AtomicUsize::new(0)); + self.builder.thread_name_fn(move || { + let thread_count = counter.fetch_add(1, Ordering::Relaxed); + format!("{} ({})", thread_name, thread_count) + }); + self + } + + /// Creates a new [`TaskPool`] based on the current options. + pub fn build(mut self) -> TaskPool { + TaskPool { + thread_count: self.thread_count.unwrap_or_else(crate::available_parallelism), + runtime: self.builder.build().unwrap(), + } + } +} + +/// A thread pool for executing tasks. Tasks are futures that are being automatically driven by +/// the pool on threads owned by the pool. +#[derive(Debug)] +pub struct TaskPool { + runtime: Runtime, + thread_count: usize, +} + +impl TaskPool { + /// Create a `TaskPool` with the default configuration. + pub fn new() -> Self { + TaskPoolBuilder::new().build() + } + + /// Return the number of threads owned by the task pool + pub fn thread_num(&self) -> usize { + self.thread_count + } + + /// Allows spawning non-`'static` futures on the thread pool. The function takes a callback, + /// passing a scope object into it. The scope object provided to the callback can be used + /// to spawn tasks. This function will await the completion of all tasks before returning. + /// + /// This is similar to `rayon::scope` and `crossbeam::scope` + /// + /// # Example + /// + /// ``` + /// use bevy_tasks::TaskPool; + /// + /// let pool = TaskPool::new(); + /// let mut x = 0; + /// let results = pool.scope(|s| { + /// s.spawn(async { + /// // you can borrow the spawner inside a task and spawn tasks from within the task + /// s.spawn(async { + /// // borrow x and mutate it. + /// x = 2; + /// // return a value from the task + /// 1 + /// }); + /// // return some other value from the first task + /// 0 + /// }); + /// }); + /// + /// // The ordering of results is non-deterministic if you spawn from within tasks as above. + /// // If you're doing this, you'll have to write your code to not depend on the ordering. + /// assert!(results.contains(&0)); + /// assert!(results.contains(&1)); + /// + /// // The ordering is deterministic if you only spawn directly from the closure function. + /// let results = pool.scope(|s| { + /// s.spawn(async { 0 }); + /// s.spawn(async { 1 }); + /// }); + /// assert_eq!(&results[..], &[0, 1]); + /// + /// // You can access x after scope runs, since it was only temporarily borrowed in the scope. + /// assert_eq!(x, 2); + /// ``` + /// + /// # Lifetimes + /// + /// The [`Scope`] object takes two lifetimes: `'scope` and `'env`. + /// + /// The `'scope` lifetime represents the lifetime of the scope. That is the time during + /// which the provided closure and tasks that are spawned into the scope are run. + /// + /// The `'env` lifetime represents the lifetime of whatever is borrowed by the scope. + /// Thus this lifetime must outlive `'scope`. + /// + /// ```compile_fail + /// use bevy_tasks::TaskPool; + /// fn scope_escapes_closure() { + /// let pool = TaskPool::new(); + /// let foo = Box::new(42); + /// pool.scope(|scope| { + /// std::thread::spawn(move || { + /// // UB. This could spawn on the scope after `.scope` returns and the internal Scope is dropped. + /// scope.spawn(async move { + /// assert_eq!(*foo, 42); + /// }); + /// }); + /// }); + /// } + /// ``` + /// + /// ```compile_fail + /// use bevy_tasks::TaskPool; + /// fn cannot_borrow_from_closure() { + /// let pool = TaskPool::new(); + /// pool.scope(|scope| { + /// let x = 1; + /// let y = &x; + /// scope.spawn(async move { + /// assert_eq!(*y, 1); + /// }); + /// }); + /// } + /// + pub fn scope<'env, F, T>(&self, f: F) -> Vec + where + F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>), + T: Send + 'static, + { + // SAFETY: This safety comment applies to all references transmuted to 'env. + // Any futures spawned with these references need to return before this function completes. + // This is guaranteed because we drive all the futures spawned onto the Scope + // to completion in this function. However, rust has no way of knowing this so we + // transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety. + let runtime: &Runtime = &self.runtime; + let runtime: &'env Runtime = unsafe { mem::transmute(runtime) }; + let task_scope_runtime: Runtime = Builder::new_current_thread().build().unwrap(); + let task_scope_runtime: &'env Runtime = + unsafe { mem::transmute(&task_scope_runtime) }; + let spawned: ConcurrentQueue> = ConcurrentQueue::unbounded(); + let spawned_ref: &'env ConcurrentQueue> = + unsafe { mem::transmute(&spawned) }; + + let scope = Scope { + runtime, + task_scope_runtime, + spawned: spawned_ref, + scope: PhantomData, + env: PhantomData, + }; + + let scope_ref: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) }; + + f(scope_ref); + + if spawned.is_empty() { + Vec::new() + } else { + let get_results = async { + let mut results = Vec::with_capacity(spawned_ref.len()); + while let Ok(task) = spawned_ref.pop() { + results.push(task.await.unwrap()); + } + + results + }; + + // Pin the futures on the stack. + pin!(get_results); + + let _guard = task_scope_runtime.enter(); + loop { + if let Some(result) = self.runtime.block_on(future::poll_once(&mut get_results)) { + break result; + } + if let Some(result) = task_scope_runtime.block_on(future::poll_once(&mut get_results)) { + break result; + } + } + } + } + + /// Spawns a static future onto the thread pool. The returned Task is a future. It can also be + /// cancelled and "detached" allowing it to continue running without having to be polled by the + /// end-user. + /// + /// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should be used instead. + pub fn spawn(&self, future: impl Future + Send + 'static) -> Task + where + T: Send + 'static, + { + Task::new(self.runtime.spawn(future)) + } + + /// Spawns a static future on the thread-local async executor for the current thread. The task + /// will run entirely on the thread the task was spawned on. The returned Task is a future. + /// It can also be cancelled and "detached" allowing it to continue running without having + /// to be polled by the end-user. Users should generally prefer to use [`TaskPool::spawn`] + /// instead, unless the provided future is not `Send`. + pub fn spawn_local(&self, future: impl Future + 'static) -> Task + where + T: 'static, + { + Task::new(tokio::task::spawn_local(future)) + } + + /// Runs a function with the local executor. Typically used to tick + /// the local executor on the main thread as it needs to share time with + /// other things. + /// + /// ```rust + /// use bevy_tasks::TaskPool; + /// + /// TaskPool::new().with_local_executor(|local_executor| { + /// local_executor.try_tick(); + /// }); + /// ``` + pub fn with_local_executor(&self, f: F) -> R + where + F: FnOnce(&async_executor::LocalExecutor) -> R, + { + // TODO: Implement this... somehow. + let dummy = async_executor::LocalExecutor::new(); + f(&dummy) + } +} + +impl Default for TaskPool { + fn default() -> Self { + Self::new() + } +} + +/// A `TaskPool` scope for running one or more non-`'static` futures. +/// +/// For more information, see [`TaskPool::scope`]. +#[derive(Debug)] +pub struct Scope<'scope, 'env: 'scope, T> { + runtime: &'scope Runtime, + task_scope_runtime: &'scope Runtime, + spawned: &'scope ConcurrentQueue>, + // make `Scope` invariant over 'scope and 'env + scope: PhantomData<&'scope mut &'scope ()>, + env: PhantomData<&'env mut &'env ()>, +} + +impl<'scope, 'env, T: Send + 'static> Scope<'scope, 'env, T> { + /// Spawns a scoped future onto the thread pool. The scope *must* outlive + /// the provided future. The results of the future will be returned as a part of + /// [`TaskPool::scope`]'s return value. + /// + /// For futures that should run on the thread `scope` is called on [`Scope::spawn_on_scope`] should be used + /// instead. + /// + /// For more information, see [`TaskPool::scope`]. + pub fn spawn + 'scope + Send>(&self, f: Fut) { + let fut: Pin + 'scope + Send>> = Box::pin(f); + let fut: Pin + 'static + Send>> = + unsafe { mem::transmute(fut) }; + let task = self.runtime.spawn(fut); + // ConcurrentQueue only errors when closed or full, but we never + // close and use an unbouded queue, so it is safe to unwrap + if let Err(err) = self.spawned.push(task) { + panic!("Failed to scheudle task on scope: {}", err); + } + } + + /// Spawns a scoped future onto the thread the scope is run on. The scope *must* outlive + /// the provided future. The results of the future will be returned as a part of + /// [`TaskPool::scope`]'s return value. Users should generally prefer to use + /// [`Scope::spawn`] instead, unless the provided future needs to run on the scope's thread. + /// + /// For more information, see [`TaskPool::scope`]. + pub fn spawn_on_scope + 'scope + Send>(&self, f: Fut) { + let fut: Pin + 'scope + Send>> = Box::pin(f); + let fut: Pin + 'static + Send>> = + unsafe { mem::transmute(fut) }; + let task = self.task_scope_runtime.spawn(fut); + // ConcurrentQueue only errors when closed or full, but we never + // close and use an unbouded queue, so it is safe to unwrap + if let Err(err) = self.spawned.push(task) { + panic!("Failed to scheudle task on scope: {}", err); + } + } +} + +impl<'scope, 'env, T> Drop for Scope<'scope, 'env, T> +where + T: 'scope, +{ + fn drop(&mut self) { + future::block_on(async { + while let Ok(task) = self.spawned.pop() { + task.abort(); + } + }); + } +} + +#[cfg(test)] +#[allow(clippy::disallowed_types)] +mod tests { + use super::*; + use std::sync::{ + atomic::{AtomicBool, AtomicI32, Ordering}, + Barrier, + }; + + #[test] + fn test_spawn() { + let pool = TaskPool::new(); + + let foo = Box::new(42); + let foo = &*foo; + + let count = Arc::new(AtomicI32::new(0)); + + let outputs = pool.scope(|scope| { + for _ in 0..100 { + let count_clone = count.clone(); + scope.spawn(async move { + if *foo != 42 { + panic!("not 42!?!?") + } else { + count_clone.fetch_add(1, Ordering::Relaxed); + *foo + } + }); + } + }); + + for output in &outputs { + assert_eq!(*output, 42); + } + + assert_eq!(outputs.len(), 100); + assert_eq!(count.load(Ordering::Relaxed), 100); + } + + #[test] + fn test_mixed_spawn_on_scope_and_spawn() { + let pool = TaskPool::new(); + + let foo = Box::new(42); + let foo = &*foo; + + let local_count = Arc::new(AtomicI32::new(0)); + let non_local_count = Arc::new(AtomicI32::new(0)); + + let outputs = pool.scope(|scope| { + for i in 0..100 { + if i % 2 == 0 { + let count_clone = non_local_count.clone(); + scope.spawn(async move { + if *foo != 42 { + panic!("not 42!?!?") + } else { + count_clone.fetch_add(1, Ordering::Relaxed); + *foo + } + }); + } else { + let count_clone = local_count.clone(); + scope.spawn_on_scope(async move { + if *foo != 42 { + panic!("not 42!?!?") + } else { + count_clone.fetch_add(1, Ordering::Relaxed); + *foo + } + }); + } + } + }); + + for output in &outputs { + assert_eq!(*output, 42); + } + + assert_eq!(outputs.len(), 100); + assert_eq!(local_count.load(Ordering::Relaxed), 50); + assert_eq!(non_local_count.load(Ordering::Relaxed), 50); + } + + #[test] + fn test_thread_locality() { + let pool = Arc::new(TaskPool::new()); + let count = Arc::new(AtomicI32::new(0)); + let barrier = Arc::new(Barrier::new(101)); + let thread_check_failed = Arc::new(AtomicBool::new(false)); + + for _ in 0..100 { + let inner_barrier = barrier.clone(); + let count_clone = count.clone(); + let inner_pool = pool.clone(); + let inner_thread_check_failed = thread_check_failed.clone(); + std::thread::spawn(move || { + inner_pool.scope(|scope| { + let inner_count_clone = count_clone.clone(); + scope.spawn(async move { + inner_count_clone.fetch_add(1, Ordering::Release); + }); + let spawner = std::thread::current().id(); + let inner_count_clone = count_clone.clone(); + scope.spawn_on_scope(async move { + inner_count_clone.fetch_add(1, Ordering::Release); + if std::thread::current().id() != spawner { + // NOTE: This check is using an atomic rather than simply panicing the + // thread to avoid deadlocking the barrier on failure + inner_thread_check_failed.store(true, Ordering::Release); + } + }); + }); + inner_barrier.wait(); + }); + } + barrier.wait(); + assert!(!thread_check_failed.load(Ordering::Acquire)); + assert_eq!(count.load(Ordering::Acquire), 200); + } + + #[test] + fn test_nested_spawn() { + let pool = TaskPool::new(); + + let foo = Box::new(42); + let foo = &*foo; + + let count = Arc::new(AtomicI32::new(0)); + + let outputs: Vec = pool.scope(|scope| { + for _ in 0..10 { + let count_clone = count.clone(); + scope.spawn(async move { + for _ in 0..10 { + let count_clone_clone = count_clone.clone(); + scope.spawn(async move { + if *foo != 42 { + panic!("not 42!?!?") + } else { + count_clone_clone.fetch_add(1, Ordering::Relaxed); + *foo + } + }); + } + *foo + }); + } + }); + + for output in &outputs { + assert_eq!(*output, 42); + } + + // the inner loop runs 100 times and the outer one runs 10. 100 + 10 + assert_eq!(outputs.len(), 110); + assert_eq!(count.load(Ordering::Relaxed), 100); + } + + #[test] + fn test_nested_locality() { + let pool = Arc::new(TaskPool::new()); + let count = Arc::new(AtomicI32::new(0)); + let barrier = Arc::new(Barrier::new(101)); + let thread_check_failed = Arc::new(AtomicBool::new(false)); + + for _ in 0..100 { + let inner_barrier = barrier.clone(); + let count_clone = count.clone(); + let inner_pool = pool.clone(); + let inner_thread_check_failed = thread_check_failed.clone(); + std::thread::spawn(move || { + inner_pool.scope(|scope| { + let spawner = std::thread::current().id(); + let inner_count_clone = count_clone.clone(); + scope.spawn(async move { + inner_count_clone.fetch_add(1, Ordering::Release); + + // spawning on the scope from another thread runs the futures on the scope's thread + scope.spawn_on_scope(async move { + inner_count_clone.fetch_add(1, Ordering::Release); + if std::thread::current().id() != spawner { + // NOTE: This check is using an atomic rather than simply panicing the + // thread to avoid deadlocking the barrier on failure + inner_thread_check_failed.store(true, Ordering::Release); + } + }); + }); + }); + inner_barrier.wait(); + }); + } + barrier.wait(); + assert!(!thread_check_failed.load(Ordering::Acquire)); + assert_eq!(count.load(Ordering::Acquire), 200); + } +}