diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9816c46a947..4370783d2d4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,8 +1,8 @@ on: push: - branches: ["master", "tokio-*.x"] + branches: [ "master", "tokio-*.x" ] pull_request: - branches: ["master", "tokio-*.x"] + branches: [ "master", "tokio-*.x" ] name: CI @@ -107,7 +107,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Install cargo-nextest uses: taiki-e/install-action@v2 with: @@ -139,7 +139,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Install cargo-nextest uses: taiki-e/install-action@v2 with: @@ -169,7 +169,7 @@ jobs: - name: Install Rust ${{ env.rust_nightly }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_nightly }} + toolchain: ${{ env.rust_nightly }} - name: Install cargo-nextest uses: taiki-e/install-action@v2 with: @@ -197,7 +197,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Install cargo-hack uses: taiki-e/install-action@v2 with: @@ -237,7 +237,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Enable parking_lot send_guard feature # Inserts the line "plsend = ["parking_lot/send_guard"]" right after [features] @@ -256,7 +256,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: 1.82 + toolchain: 1.82 - name: Install Valgrind uses: taiki-e/install-action@valgrind @@ -295,7 +295,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Install cargo-nextest uses: taiki-e/install-action@v2 @@ -329,7 +329,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Install cargo-nextest uses: taiki-e/install-action@v2 @@ -363,7 +363,7 @@ jobs: - name: Install Rust ${{ env.rust_stable }} uses: dtolnay/rust-toolchain@stable with: - toolchain: ${{ env.rust_stable }} + toolchain: ${{ env.rust_stable }} - name: Install cargo-nextest uses: taiki-e/install-action@v2 with: @@ -842,10 +842,10 @@ jobs: toolchain: ${{ env.rust_stable }} - uses: Swatinem/rust-cache@v2 - name: build --cfg loom - run: cargo test --no-run --lib --features full + run: cargo test --no-run --lib --release --features full working-directory: tokio env: - RUSTFLAGS: --cfg loom --cfg tokio_unstable -Dwarnings + RUSTFLAGS: --cfg loom --cfg tokio_unstable -Dwarnings -Cdebug-assertions check-readme: name: Check README diff --git a/tokio/src/lib.rs b/tokio/src/lib.rs index 6b0f48bd105..cf287b0dac8 100644 --- a/tokio/src/lib.rs +++ b/tokio/src/lib.rs @@ -351,10 +351,7 @@ //! - [`task::Builder`] //! - Some methods on [`task::JoinSet`] //! - [`runtime::RuntimeMetrics`] -//! - [`runtime::Builder::on_task_spawn`] -//! - [`runtime::Builder::on_task_terminate`] //! - [`runtime::Builder::unhandled_panic`] -//! - [`runtime::TaskMeta`] //! //! This flag enables **unstable** features. The public API of these features //! may break in 1.x releases. To enable these features, the `--cfg diff --git a/tokio/src/runtime/blocking/pool.rs b/tokio/src/runtime/blocking/pool.rs index 23180dc5245..990a1fd4a7b 100644 --- a/tokio/src/runtime/blocking/pool.rs +++ b/tokio/src/runtime/blocking/pool.rs @@ -375,10 +375,15 @@ impl Spawner { F: FnOnce() -> R + Send + 'static, R: Send + 'static, { + // let parent = with_c let id = task::Id::next(); let fut = blocking_task::>(BlockingTask::new(func), spawn_meta, id.as_u64()); + #[cfg(tokio_unstable)] + let (task, handle) = task::unowned(fut, BlockingSchedule::new(rt), id, None); + + #[cfg(not(tokio_unstable))] let (task, handle) = task::unowned(fut, BlockingSchedule::new(rt), id); let spawned = self.spawn_task(Task::new(task, is_mandatory), rt); diff --git a/tokio/src/runtime/blocking/schedule.rs b/tokio/src/runtime/blocking/schedule.rs index 0e97c5aeaf4..aad4da4f3f4 100644 --- a/tokio/src/runtime/blocking/schedule.rs +++ b/tokio/src/runtime/blocking/schedule.rs @@ -1,7 +1,9 @@ #[cfg(feature = "test-util")] use crate::runtime::scheduler; -use crate::runtime::task::{self, Task, TaskHarnessScheduleHooks}; +use crate::runtime::task::{self, Task}; use crate::runtime::Handle; +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; /// `task::Schedule` implementation that does nothing (except some bookkeeping /// in test-util builds). This is unique to the blocking scheduler as tasks @@ -12,7 +14,8 @@ use crate::runtime::Handle; pub(crate) struct BlockingSchedule { #[cfg(feature = "test-util")] handle: Handle, - hooks: TaskHarnessScheduleHooks, + #[cfg(tokio_unstable)] + hooks_factory: OptionalTaskHooksFactory, } impl BlockingSchedule { @@ -31,9 +34,8 @@ impl BlockingSchedule { BlockingSchedule { #[cfg(feature = "test-util")] handle: handle.clone(), - hooks: TaskHarnessScheduleHooks { - task_terminate_callback: handle.inner.hooks().task_terminate_callback.clone(), - }, + #[cfg(tokio_unstable)] + hooks_factory: handle.inner.hooks_factory(), } } } @@ -58,9 +60,13 @@ impl task::Schedule for BlockingSchedule { unreachable!(); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: self.hooks.task_terminate_callback.clone(), - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + self.hooks_factory.clone() + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + self.hooks_factory.as_ref().map(AsRef::as_ref) } } diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index 994fcfa5c73..4fbe5565dca 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -1,15 +1,19 @@ #![cfg_attr(loom, allow(unused_imports))] +use crate::runtime::blocking::BlockingPool; use crate::runtime::handle::Handle; -use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime, TaskCallback}; +use crate::runtime::scheduler::CurrentThread; +use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime}; #[cfg(tokio_unstable)] -use crate::runtime::{metrics::HistogramConfiguration, LocalOptions, LocalRuntime, TaskMeta}; +use crate::runtime::{ + metrics::HistogramConfiguration, LocalOptions, LocalRuntime, OptionalTaskHooksFactory, + TaskHookHarnessFactory, +}; use crate::util::rand::{RngSeed, RngSeedGenerator}; - -use crate::runtime::blocking::BlockingPool; -use crate::runtime::scheduler::CurrentThread; use std::fmt; use std::io; +#[cfg(tokio_unstable)] +use std::sync::Arc; use std::thread::ThreadId; use std::time::Duration; @@ -85,19 +89,8 @@ pub struct Builder { /// To run after each thread is unparked. pub(super) after_unpark: Option, - /// To run before each task is spawned. - pub(super) before_spawn: Option, - - /// To run before each poll #[cfg(tokio_unstable)] - pub(super) before_poll: Option, - - /// To run after each poll - #[cfg(tokio_unstable)] - pub(super) after_poll: Option, - - /// To run after each task is terminated. - pub(super) after_termination: Option, + pub(super) task_hook_harness_factory: OptionalTaskHooksFactory, /// Customizable keep alive timeout for `BlockingPool` pub(super) keep_alive: Option, @@ -287,13 +280,8 @@ impl Builder { before_park: None, after_unpark: None, - before_spawn: None, - after_termination: None, - #[cfg(tokio_unstable)] - before_poll: None, - #[cfg(tokio_unstable)] - after_poll: None, + task_hook_harness_factory: None, keep_alive: None, @@ -685,188 +673,19 @@ impl Builder { self } - /// Executes function `f` just before a task is spawned. - /// - /// `f` is called within the Tokio context, so functions like - /// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being - /// invoked immediately. - /// - /// This can be used for bookkeeping or monitoring purposes. - /// - /// Note: There can only be one spawn callback for a runtime; calling this function more - /// than once replaces the last callback defined, rather than adding to it. - /// - /// This *does not* support [`LocalSet`](crate::task::LocalSet) at this time. - /// - /// **Note**: This is an [unstable API][unstable]. The public API of this type - /// may break in 1.x releases. See [the documentation on unstable - /// features][unstable] for details. - /// - /// [unstable]: crate#unstable-features - /// - /// # Examples - /// - /// ``` - /// # use tokio::runtime; - /// # pub fn main() { - /// let runtime = runtime::Builder::new_current_thread() - /// .on_task_spawn(|_| { - /// println!("spawning task"); - /// }) - /// .build() - /// .unwrap(); - /// - /// runtime.block_on(async { - /// tokio::task::spawn(std::future::ready(())); - /// - /// for _ in 0..64 { - /// tokio::task::yield_now().await; - /// } - /// }) - /// # } - /// ``` - #[cfg(all(not(loom), tokio_unstable))] - #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] - pub fn on_task_spawn(&mut self, f: F) -> &mut Self - where - F: Fn(&TaskMeta<'_>) + Send + Sync + 'static, - { - self.before_spawn = Some(std::sync::Arc::new(f)); - self - } - - /// Executes function `f` just before a task is polled - /// - /// `f` is called within the Tokio context, so functions like - /// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being - /// invoked immediately. - /// - /// **Note**: This is an [unstable API][unstable]. The public API of this type - /// may break in 1.x releases. See [the documentation on unstable - /// features][unstable] for details. - /// - /// [unstable]: crate#unstable-features - /// - /// # Examples - /// - /// ``` - /// # use std::sync::{atomic::AtomicUsize, Arc}; - /// # use tokio::task::yield_now; - /// # pub fn main() { - /// let poll_start_counter = Arc::new(AtomicUsize::new(0)); - /// let poll_start = poll_start_counter.clone(); - /// let rt = tokio::runtime::Builder::new_multi_thread() - /// .enable_all() - /// .on_before_task_poll(move |meta| { - /// println!("task {} is about to be polled", meta.id()) - /// }) - /// .build() - /// .unwrap(); - /// let task = rt.spawn(async { - /// yield_now().await; - /// }); - /// let _ = rt.block_on(task); - /// - /// # } - /// ``` - #[cfg(tokio_unstable)] - pub fn on_before_task_poll(&mut self, f: F) -> &mut Self - where - F: Fn(&TaskMeta<'_>) + Send + Sync + 'static, - { - self.before_poll = Some(std::sync::Arc::new(f)); - self - } - - /// Executes function `f` just after a task is polled - /// - /// `f` is called within the Tokio context, so functions like - /// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being - /// invoked immediately. - /// - /// **Note**: This is an [unstable API][unstable]. The public API of this type - /// may break in 1.x releases. See [the documentation on unstable - /// features][unstable] for details. - /// - /// [unstable]: crate#unstable-features - /// - /// # Examples - /// - /// ``` - /// # use std::sync::{atomic::AtomicUsize, Arc}; - /// # use tokio::task::yield_now; - /// # pub fn main() { - /// let poll_stop_counter = Arc::new(AtomicUsize::new(0)); - /// let poll_stop = poll_stop_counter.clone(); - /// let rt = tokio::runtime::Builder::new_multi_thread() - /// .enable_all() - /// .on_after_task_poll(move |meta| { - /// println!("task {} completed polling", meta.id()); - /// }) - /// .build() - /// .unwrap(); - /// let task = rt.spawn(async { - /// yield_now().await; - /// }); - /// let _ = rt.block_on(task); - /// - /// # } - /// ``` - #[cfg(tokio_unstable)] - pub fn on_after_task_poll(&mut self, f: F) -> &mut Self - where - F: Fn(&TaskMeta<'_>) + Send + Sync + 'static, - { - self.after_poll = Some(std::sync::Arc::new(f)); - self - } - - /// Executes function `f` just after a task is terminated. - /// - /// `f` is called within the Tokio context, so functions like - /// [`tokio::spawn`](crate::spawn) can be called. - /// - /// This can be used for bookkeeping or monitoring purposes. - /// - /// Note: There can only be one task termination callback for a runtime; calling this - /// function more than once replaces the last callback defined, rather than adding to it. + /// Factory method for producing "fallback" task hook harnesses. /// - /// This *does not* support [`LocalSet`](crate::task::LocalSet) at this time. - /// - /// **Note**: This is an [unstable API][unstable]. The public API of this type - /// may break in 1.x releases. See [the documentation on unstable - /// features][unstable] for details. - /// - /// [unstable]: crate#unstable-features - /// - /// # Examples - /// - /// ``` - /// # use tokio::runtime; - /// # pub fn main() { - /// let runtime = runtime::Builder::new_current_thread() - /// .on_task_terminate(|_| { - /// println!("killing task"); - /// }) - /// .build() - /// .unwrap(); - /// - /// runtime.block_on(async { - /// tokio::task::spawn(std::future::ready(())); - /// - /// for _ in 0..64 { - /// tokio::task::yield_now().await; - /// } - /// }) - /// # } - /// ``` + /// The order of operations for assigning the hook harness for a task are as follows: + /// 1. [`crate::task::spawn_with_hooks`], if used. + /// 2. [`crate::runtime::task_hooks::TaskHookHarnessFactory`], if it returns something other than [Option::None]. + /// 3. This function. #[cfg(all(not(loom), tokio_unstable))] #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] - pub fn on_task_terminate(&mut self, f: F) -> &mut Self + pub fn hook_harness_factory(&mut self, hooks: T) -> &mut Self where - F: Fn(&TaskMeta<'_>) + Send + Sync + 'static, + T: TaskHookHarnessFactory + Send + Sync + 'static, { - self.after_termination = Some(std::sync::Arc::new(f)); + self.task_hook_harness_factory = Some(Arc::new(hooks)); self } @@ -1475,12 +1294,8 @@ impl Builder { Config { before_park: self.before_park.clone(), after_unpark: self.after_unpark.clone(), - before_spawn: self.before_spawn.clone(), - #[cfg(tokio_unstable)] - before_poll: self.before_poll.clone(), #[cfg(tokio_unstable)] - after_poll: self.after_poll.clone(), - after_termination: self.after_termination.clone(), + task_hook_factory: self.task_hook_harness_factory.clone(), global_queue_interval: self.global_queue_interval, event_interval: self.event_interval, #[cfg(tokio_unstable)] @@ -1628,12 +1443,8 @@ cfg_rt_multi_thread! { Config { before_park: self.before_park.clone(), after_unpark: self.after_unpark.clone(), - before_spawn: self.before_spawn.clone(), - #[cfg(tokio_unstable)] - before_poll: self.before_poll.clone(), #[cfg(tokio_unstable)] - after_poll: self.after_poll.clone(), - after_termination: self.after_termination.clone(), + task_hook_factory: self.task_hook_harness_factory.clone(), global_queue_interval: self.global_queue_interval, event_interval: self.event_interval, #[cfg(tokio_unstable)] diff --git a/tokio/src/runtime/config.rs b/tokio/src/runtime/config.rs index b79df96e1e2..8537adc5dcd 100644 --- a/tokio/src/runtime/config.rs +++ b/tokio/src/runtime/config.rs @@ -2,7 +2,10 @@ any(not(all(tokio_unstable, feature = "full")), target_family = "wasm"), allow(dead_code) )] -use crate::runtime::{Callback, TaskCallback}; + +use crate::runtime::Callback; +#[cfg(tokio_unstable)] +use crate::runtime::OptionalTaskHooksFactory; use crate::util::RngSeedGenerator; pub(crate) struct Config { @@ -18,19 +21,9 @@ pub(crate) struct Config { /// Callback for a worker unparking itself pub(crate) after_unpark: Option, - /// To run before each task is spawned. - pub(crate) before_spawn: Option, - - /// To run after each task is terminated. - pub(crate) after_termination: Option, - - /// To run before each poll - #[cfg(tokio_unstable)] - pub(crate) before_poll: Option, - - /// To run after each poll + /// Called on task spawn to generate the attached task hook harness. #[cfg(tokio_unstable)] - pub(crate) after_poll: Option, + pub(crate) task_hook_factory: OptionalTaskHooksFactory, /// The multi-threaded scheduler includes a per-worker LIFO slot used to /// store the last scheduled task. This can improve certain usage patterns, diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index e8f17bb374a..c0fcc64aa15 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -1,10 +1,14 @@ +#[cfg(all(feature = "rt", tokio_unstable))] +use crate::loom::cell::UnsafeCell; use crate::loom::thread::AccessError; +#[cfg(all(feature = "rt", tokio_unstable))] +use crate::runtime::{OptionalTaskHooksMut, OptionalTaskHooksWeak, TaskHookHarness}; use crate::task::coop; - -use std::cell::Cell; - #[cfg(any(feature = "rt", feature = "macros", feature = "time"))] use crate::util::rand::FastRand; +use std::cell::Cell; +#[cfg(all(feature = "rt", tokio_unstable))] +use std::ptr::NonNull; cfg_rt! { mod blocking; @@ -49,6 +53,10 @@ struct Context { #[cfg(feature = "rt")] current_task_id: Cell>, + /// Tracks the current set of task hooks, + #[cfg(all(feature = "rt", tokio_unstable))] + current_task_hooks: OptionalTaskHooksWeak, + /// Tracks if the current thread is currently driving a runtime. /// Note, that if this is set to "entered", the current scheduler /// handle may not reference the runtime currently executing. This @@ -92,6 +100,9 @@ tokio_thread_local! { #[cfg(feature = "rt")] current_task_id: Cell::new(None), + #[cfg(all(feature = "rt", tokio_unstable))] + current_task_hooks: UnsafeCell::new(None), + // Tracks if the current thread is currently driving a runtime. // Note, that if this is set to "entered", the current scheduler // handle may not reference the runtime currently executing. This @@ -139,6 +150,16 @@ pub(crate) fn budget(f: impl FnOnce(&Cell) -> R) -> Result>) -> Result { + CONTEXT.try_with(|ctx| { + ctx.current_task_hooks.with_mut(|x| { + unsafe { + *x = hooks; + } + }) + })?; + + Ok(SetTaskHooksGuard) + } + + #[track_caller] + #[cfg(tokio_unstable)] + pub(super) fn clear_task_hooks() -> Result<(), AccessError> { + CONTEXT.try_with(|ctx| { + ctx.current_task_hooks.with_mut(|x| { + unsafe { + *x = None; + } + }) + })?; + + Ok(()) + } + + #[track_caller] + #[cfg(tokio_unstable)] + pub(super) fn with_task_hooks(f: impl FnOnce(OptionalTaskHooksMut<'_>) -> R) -> Result { + CONTEXT.try_with(|ctx| { + ctx.current_task_hooks.with_mut(|ptr| { + let hooks = unsafe { &mut *ptr }; + unsafe { + f(hooks.as_mut().map(|x| x.as_mut())) + } + }) + }) + } + #[track_caller] pub(crate) fn defer(waker: &Waker) { with_scheduler(|maybe_scheduler| { diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 7aaba2ff243..7ca76578dd9 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -1,5 +1,5 @@ #[cfg(tokio_unstable)] -use crate::runtime; +use crate::runtime::{self, OptionalTaskHooks}; use crate::runtime::{context, scheduler, RuntimeFlavor, RuntimeMetrics}; /// Handle to the runtime. @@ -191,6 +191,13 @@ impl Handle { F::Output: Send + 'static, { let fut_size = mem::size_of::(); + #[cfg(tokio_unstable)] + return if fut_size > BOX_FUTURE_THRESHOLD { + self.spawn_named(Box::pin(future), SpawnMeta::new_unnamed(fut_size), None) + } else { + self.spawn_named(future, SpawnMeta::new_unnamed(fut_size), None) + }; + #[cfg(not(tokio_unstable))] if fut_size > BOX_FUTURE_THRESHOLD { self.spawn_named(Box::pin(future), SpawnMeta::new_unnamed(fut_size)) } else { @@ -329,7 +336,12 @@ impl Handle { } #[track_caller] - pub(crate) fn spawn_named(&self, future: F, _meta: SpawnMeta<'_>) -> JoinHandle + pub(crate) fn spawn_named( + &self, + future: F, + _meta: SpawnMeta<'_>, + #[cfg(tokio_unstable)] parent: OptionalTaskHooks, + ) -> JoinHandle where F: Future + Send + 'static, F::Output: Send + 'static, @@ -345,6 +357,9 @@ impl Handle { let future = super::task::trace::Trace::root(future); #[cfg(all(tokio_unstable, feature = "tracing"))] let future = crate::util::trace::task(future, "task", _meta, id.as_u64()); + #[cfg(tokio_unstable)] + return self.inner.spawn(future, id, parent); + #[cfg(not(tokio_unstable))] self.inner.spawn(future, id) } @@ -354,6 +369,7 @@ impl Handle { &self, future: F, _meta: SpawnMeta<'_>, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, ) -> JoinHandle where F: Future + 'static, @@ -370,6 +386,9 @@ impl Handle { let future = super::task::trace::Trace::root(future); #[cfg(all(tokio_unstable, feature = "tracing"))] let future = crate::util::trace::task(future, "task", _meta, id.as_u64()); + #[cfg(tokio_unstable)] + return self.inner.spawn_local(future, id, hooks_override); + #[cfg(not(tokio_unstable))] self.inner.spawn_local(future, id) } diff --git a/tokio/src/runtime/local_runtime/runtime.rs b/tokio/src/runtime/local_runtime/runtime.rs index 358a771956b..11fdc097b17 100644 --- a/tokio/src/runtime/local_runtime/runtime.rs +++ b/tokio/src/runtime/local_runtime/runtime.rs @@ -155,9 +155,9 @@ impl LocalRuntime { // safety: spawn_local can only be called from `LocalRuntime`, which this is unsafe { if std::mem::size_of::() > BOX_FUTURE_THRESHOLD { - self.handle.spawn_local_named(Box::pin(future), meta) + self.handle.spawn_local_named(Box::pin(future), meta, None) } else { - self.handle.spawn_local_named(future, meta) + self.handle.spawn_local_named(future, meta, None) } } } diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 78a0114f48e..026bdd7ef68 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -380,13 +380,10 @@ cfg_rt! { pub use dump::Dump; } - mod task_hooks; - pub(crate) use task_hooks::{TaskHooks, TaskCallback}; cfg_unstable! { - pub use task_hooks::TaskMeta; + mod task_hooks; + pub use task_hooks::*; } - #[cfg(not(tokio_unstable))] - pub(crate) use task_hooks::TaskMeta; mod handle; pub use handle::{EnterGuard, Handle, TryCurrentError}; diff --git a/tokio/src/runtime/runtime.rs b/tokio/src/runtime/runtime.rs index 2f2b07d322c..9af21d31249 100644 --- a/tokio/src/runtime/runtime.rs +++ b/tokio/src/runtime/runtime.rs @@ -233,6 +233,15 @@ impl Runtime { F::Output: Send + 'static, { let fut_size = mem::size_of::(); + #[cfg(tokio_unstable)] + return if fut_size > BOX_FUTURE_THRESHOLD { + self.handle + .spawn_named(Box::pin(future), SpawnMeta::new_unnamed(fut_size), None) + } else { + self.handle + .spawn_named(future, SpawnMeta::new_unnamed(fut_size), None) + }; + #[cfg(not(tokio_unstable))] if fut_size > BOX_FUTURE_THRESHOLD { self.handle .spawn_named(Box::pin(future), SpawnMeta::new_unnamed(fut_size)) diff --git a/tokio/src/runtime/scheduler/current_thread/mod.rs b/tokio/src/runtime/scheduler/current_thread/mod.rs index 72e12fae895..0f018ada1c4 100644 --- a/tokio/src/runtime/scheduler/current_thread/mod.rs +++ b/tokio/src/runtime/scheduler/current_thread/mod.rs @@ -1,17 +1,19 @@ use crate::loom::sync::atomic::AtomicBool; use crate::loom::sync::Arc; +#[cfg(tokio_unstable)] +use crate::runtime::context::with_task_hooks; use crate::runtime::driver::{self, Driver}; use crate::runtime::scheduler::{self, Defer, Inject}; -use crate::runtime::task::{ - self, JoinHandle, OwnedTasks, Schedule, Task, TaskHarnessScheduleHooks, -}; +use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task}; +use crate::runtime::{blocking, context, Config, MetricsBatch, SchedulerMetrics, WorkerMetrics}; +#[cfg(tokio_unstable)] use crate::runtime::{ - blocking, context, Config, MetricsBatch, SchedulerMetrics, TaskHooks, TaskMeta, WorkerMetrics, + OnChildTaskSpawnContext, OnTopLevelTaskSpawnContext, OptionalTaskHooks, + OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef, }; use crate::sync::notify::Notify; use crate::util::atomic_cell::AtomicCell; use crate::util::{waker_ref, RngSeedGenerator, Wake, WakerRef}; - use std::cell::RefCell; use std::collections::VecDeque; use std::future::{poll_fn, Future}; @@ -20,7 +22,7 @@ use std::task::Poll::{Pending, Ready}; use std::task::Waker; use std::thread::ThreadId; use std::time::Duration; -use std::{fmt, thread}; +use std::{fmt, panic, thread}; /// Executes tasks on the current thread pub(crate) struct CurrentThread { @@ -47,7 +49,8 @@ pub(crate) struct Handle { pub(crate) seed_generator: RngSeedGenerator, /// User-supplied hooks to invoke for things - pub(crate) task_hooks: TaskHooks, + #[cfg(tokio_unstable)] + pub(crate) task_hooks: OptionalTaskHooksFactory, /// If this is a `LocalRuntime`, flags the owning thread ID. pub(crate) local_tid: Option, @@ -142,14 +145,8 @@ impl CurrentThread { .unwrap_or(DEFAULT_GLOBAL_QUEUE_INTERVAL); let handle = Arc::new(Handle { - task_hooks: TaskHooks { - task_spawn_callback: config.before_spawn.clone(), - task_terminate_callback: config.after_termination.clone(), - #[cfg(tokio_unstable)] - before_poll_callback: config.before_poll.clone(), - #[cfg(tokio_unstable)] - after_poll_callback: config.after_poll.clone(), - }, + #[cfg(tokio_unstable)] + task_hooks: config.task_hook_factory.clone(), shared: Shared { inject: Inject::new(), owned: OwnedTasks::new(1), @@ -448,19 +445,65 @@ impl Handle { pub(crate) fn spawn( me: &Arc, future: F, - id: crate::runtime::task::Id, + id: task::Id, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, ) -> JoinHandle where F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - let (handle, notified) = me.shared.owned.bind(future, me.clone(), id); - - me.task_hooks.spawn(&TaskMeta { - id, - _phantom: Default::default(), + // preference order for hook selection: + // 1. "hook override" - comes from builder + // 2. parent task's hook + // 3. runtime hook factory + #[cfg(tokio_unstable)] + let hooks = hooks_override.or_else(|| { + with_task_hooks(|parent| { + parent + .map(|parent| { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + parent + .on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks + })) { + r + } else { + None + } + }) + .flatten() + }) + .ok() + .flatten() + .or_else(|| { + if let Some(hooks) = me.hooks_factory_ref() { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks + .on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks + })) { + r + } else { + None + } + } else { + None + } + }) }); + #[cfg(tokio_unstable)] + let (handle, notified) = me.shared.owned.bind(future, me.clone(), id, hooks); + + #[cfg(not(tokio_unstable))] + let (handle, notified) = me.shared.owned.bind(future, me.clone(), id); + if let Some(notified) = notified { me.schedule(notified); } @@ -477,18 +520,66 @@ impl Handle { pub(crate) unsafe fn spawn_local( me: &Arc, future: F, - id: crate::runtime::task::Id, + id: task::Id, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, ) -> JoinHandle where F: crate::future::Future + 'static, F::Output: 'static, { - let (handle, notified) = me.shared.owned.bind_local(future, me.clone(), id); + // preference order for hook selection: + // 1. "hook override" - comes from builder + // 2. parent task's hook + // 3. runtime hook factory + #[cfg(tokio_unstable)] + let hooks = hooks_override.or_else(|| { + with_task_hooks(|parent| { + parent + .map(|parent| { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + parent + .on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks + })) { + r + } else { + None + } + }) + .flatten() + }) + .ok() + .flatten() + .or_else(|| { + if let Some(hooks) = me.hooks_factory_ref() { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks + .on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks + })) { + r + } else { + None + } + } else { + None + } + }) + }); - me.task_hooks.spawn(&TaskMeta { + let (handle, notified) = me.shared.owned.bind_local( + future, + me.clone(), id, - _phantom: Default::default(), - }); + #[cfg(tokio_unstable)] + hooks, + ); if let Some(notified) = notified { me.schedule(notified); @@ -654,10 +745,14 @@ impl Schedule for Arc { }); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: self.task_hooks.task_terminate_callback.clone(), - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + self.task_hooks.clone() + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + self.task_hooks.as_ref().map(AsRef::as_ref) } cfg_unstable! { @@ -770,17 +865,8 @@ impl CoreGuard<'_> { let task = context.handle.shared.owned.assert_owner(task); - #[cfg(tokio_unstable)] - let task_id = task.task_id(); - let (c, ()) = context.run_task(core, || { - #[cfg(tokio_unstable)] - context.handle.task_hooks.poll_start_callback(task_id); - task.run(); - - #[cfg(tokio_unstable)] - context.handle.task_hooks.poll_stop_callback(task_id); }); core = c; diff --git a/tokio/src/runtime/scheduler/mod.rs b/tokio/src/runtime/scheduler/mod.rs index 8241b57c1de..6b2b2cf645a 100644 --- a/tokio/src/runtime/scheduler/mod.rs +++ b/tokio/src/runtime/scheduler/mod.rs @@ -8,8 +8,6 @@ cfg_rt! { pub(crate) mod inject; pub(crate) use inject::Inject; - use crate::runtime::TaskHooks; - use crate::runtime::WorkerMetrics; } @@ -25,6 +23,10 @@ cfg_rt_multi_thread! { } use crate::runtime::driver; +#[cfg(all(feature = "rt", tokio_unstable))] +use crate::runtime::task::Schedule; +#[cfg(all(feature = "rt", tokio_unstable))] +use crate::runtime::{OptionalTaskHooks, OptionalTaskHooksFactory}; #[derive(Debug, Clone)] pub(crate) enum Handle { @@ -117,11 +119,24 @@ cfg_rt! { } } - pub(crate) fn spawn(&self, future: F, id: Id) -> JoinHandle + pub(crate) fn spawn(&self, + future: F, + id: Id, + #[cfg(tokio_unstable)] + hooks_override: OptionalTaskHooks + ) -> JoinHandle where F: Future + Send + 'static, F::Output: Send + 'static, { + #[cfg(tokio_unstable)] + return match self { + Handle::CurrentThread(h) => current_thread::Handle::spawn(h, future, id, hooks_override), + + #[cfg(feature = "rt-multi-thread")] + Handle::MultiThread(h) => multi_thread::Handle::spawn(h, future, id, hooks_override), + }; + #[cfg(not(tokio_unstable))] match self { Handle::CurrentThread(h) => current_thread::Handle::spawn(h, future, id), @@ -136,12 +151,15 @@ cfg_rt! { /// This should only be called in `LocalRuntime` if the runtime has been verified to be owned /// by the current thread. #[allow(irrefutable_let_patterns)] - pub(crate) unsafe fn spawn_local(&self, future: F, id: Id) -> JoinHandle + pub(crate) unsafe fn spawn_local(&self, future: F, id: Id, #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks) -> JoinHandle where F: Future + 'static, F::Output: 'static, { if let Handle::CurrentThread(h) = self { + #[cfg(tokio_unstable)] + return current_thread::Handle::spawn_local(h, future, id, hooks_override); + #[cfg(not(tokio_unstable))] current_thread::Handle::spawn_local(h, future, id) } else { panic!("Only current_thread and LocalSet have spawn_local internals implemented") @@ -169,12 +187,9 @@ cfg_rt! { } } - pub(crate) fn hooks(&self) -> &TaskHooks { - match self { - Handle::CurrentThread(h) => &h.task_hooks, - #[cfg(feature = "rt-multi-thread")] - Handle::MultiThread(h) => &h.task_hooks, - } + #[cfg(tokio_unstable)] + pub(crate) fn hooks_factory(&self) -> OptionalTaskHooksFactory { + match_flavor!(self, Handle(h) => h.hooks_factory()) } } diff --git a/tokio/src/runtime/scheduler/multi_thread/handle.rs b/tokio/src/runtime/scheduler/multi_thread/handle.rs index 4075713c979..030910d30f3 100644 --- a/tokio/src/runtime/scheduler/multi_thread/handle.rs +++ b/tokio/src/runtime/scheduler/multi_thread/handle.rs @@ -1,14 +1,20 @@ use crate::future::Future; use crate::loom::sync::Arc; +#[cfg(tokio_unstable)] +use crate::runtime::context::with_task_hooks; use crate::runtime::scheduler::multi_thread::worker; +#[cfg(tokio_unstable)] +use crate::runtime::task::Schedule; use crate::runtime::{ blocking, driver, task::{self, JoinHandle}, - TaskHooks, TaskMeta, }; +#[cfg(tokio_unstable)] +use crate::runtime::{OnChildTaskSpawnContext, OnTopLevelTaskSpawnContext, OptionalTaskHooks}; use crate::util::RngSeedGenerator; - use std::fmt; +#[cfg(tokio_unstable)] +use std::panic; mod metrics; @@ -29,18 +35,24 @@ pub(crate) struct Handle { /// Current random number generator seed pub(crate) seed_generator: RngSeedGenerator, - - /// User-supplied hooks to invoke for things - pub(crate) task_hooks: TaskHooks, } impl Handle { /// Spawns a future onto the thread pool - pub(crate) fn spawn(me: &Arc, future: F, id: task::Id) -> JoinHandle + pub(crate) fn spawn( + me: &Arc, + future: F, + id: task::Id, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, + ) -> JoinHandle where F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { + #[cfg(tokio_unstable)] + return Self::bind_new_task(me, future, id, hooks_override); + + #[cfg(not(tokio_unstable))] Self::bind_new_task(me, future, id) } @@ -48,17 +60,69 @@ impl Handle { self.close(); } - pub(super) fn bind_new_task(me: &Arc, future: T, id: task::Id) -> JoinHandle + pub(super) fn bind_new_task( + me: &Arc, + future: T, + id: task::Id, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, + ) -> JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, { - let (handle, notified) = me.shared.owned.bind(future, me.clone(), id); + // preference order for hook selection: + // 1. "hook override" - comes from builder + // 2. parent task's hook + // 3. runtime hook factory + #[cfg(tokio_unstable)] + let hooks = hooks_override.or_else(|| { + with_task_hooks(|parent| { + parent + .map(|parent| { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + parent + .on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks + })) { + r + } else { + None + } + }) + .flatten() + }) + .ok() + .flatten() + .or_else(|| { + if let Some(hooks) = me.hooks_factory_ref() { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks + .on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks + })) { + r + } else { + None + } + } else { + None + } + }) + }); - me.task_hooks.spawn(&TaskMeta { + let (handle, notified) = me.shared.owned.bind( + future, + me.clone(), id, - _phantom: Default::default(), - }); + #[cfg(tokio_unstable)] + hooks, + ); me.schedule_option_task_without_yield(notified); diff --git a/tokio/src/runtime/scheduler/multi_thread/worker.rs b/tokio/src/runtime/scheduler/multi_thread/worker.rs index e33b9baea2c..2bd2438f378 100644 --- a/tokio/src/runtime/scheduler/multi_thread/worker.rs +++ b/tokio/src/runtime/scheduler/multi_thread/worker.rs @@ -58,13 +58,15 @@ use crate::loom::sync::{Arc, Mutex}; use crate::runtime; +use crate::runtime::context; use crate::runtime::scheduler::multi_thread::{ idle, queue, Counters, Handle, Idle, Overflow, Parker, Stats, TraceStatus, Unparker, }; use crate::runtime::scheduler::{inject, Defer, Lock}; -use crate::runtime::task::{OwnedTasks, TaskHarnessScheduleHooks}; +use crate::runtime::task::OwnedTasks; use crate::runtime::{blocking, driver, scheduler, task, Config, SchedulerMetrics, WorkerMetrics}; -use crate::runtime::{context, TaskHooks}; +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; use crate::task::coop; use crate::util::atomic_cell::AtomicCell; use crate::util::rand::{FastRand, RngSeedGenerator}; @@ -281,7 +283,6 @@ pub(super) fn create( let remotes_len = remotes.len(); let handle = Arc::new(Handle { - task_hooks: TaskHooks::from_config(&config), shared: Shared { remotes: remotes.into_boxed_slice(), inject, @@ -570,9 +571,6 @@ impl Context { } fn run_task(&self, task: Notified, mut core: Box) -> RunResult { - #[cfg(tokio_unstable)] - let task_id = task.task_id(); - let task = self.worker.handle.shared.owned.assert_owner(task); // Make sure the worker is not in the **searching** state. This enables @@ -592,16 +590,8 @@ impl Context { // Run the task coop::budget(|| { - // Unlike the poll time above, poll start callback is attached to the task id, - // so it is tightly associated with the actual poll invocation. - #[cfg(tokio_unstable)] - self.worker.handle.task_hooks.poll_start_callback(task_id); - task.run(); - #[cfg(tokio_unstable)] - self.worker.handle.task_hooks.poll_stop_callback(task_id); - let mut lifo_polls = 0; // As long as there is budget remaining and a task exists in the @@ -665,16 +655,7 @@ impl Context { *self.core.borrow_mut() = Some(core); let task = self.worker.handle.shared.owned.assert_owner(task); - #[cfg(tokio_unstable)] - let task_id = task.task_id(); - - #[cfg(tokio_unstable)] - self.worker.handle.task_hooks.poll_start_callback(task_id); - task.run(); - - #[cfg(tokio_unstable)] - self.worker.handle.task_hooks.poll_stop_callback(task_id); } }) } @@ -1063,10 +1044,18 @@ impl task::Schedule for Arc { self.schedule_task(task, false); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: self.task_hooks.task_terminate_callback.clone(), - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + self.shared.config.task_hook_factory.clone() + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + self.shared + .config + .task_hook_factory + .as_ref() + .map(AsRef::as_ref) } fn yield_now(&self, task: Notified) { diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index 5d3ca0e00c9..7f122581c29 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -14,7 +14,9 @@ use crate::loom::cell::UnsafeCell; use crate::runtime::context; use crate::runtime::task::raw::{self, Vtable}; use crate::runtime::task::state::State; -use crate::runtime::task::{Id, Schedule, TaskHarnessScheduleHooks}; +use crate::runtime::task::{Id, Schedule}; +#[cfg(tokio_unstable)] +use crate::runtime::OptionalTaskHooks; use crate::util::linked_list; use std::num::NonZeroU64; @@ -186,7 +188,8 @@ pub(super) struct Trailer { /// Consumer task waiting on completion of this task. pub(super) waker: UnsafeCell>, /// Optional hooks needed in the harness. - pub(super) hooks: TaskHarnessScheduleHooks, + #[cfg(tokio_unstable)] + pub(super) hooks: UnsafeCell, } generate_addr_of_methods! { @@ -208,7 +211,13 @@ pub(super) enum Stage { impl Cell { /// Allocates a new task cell, containing the header, trailer, and core /// structures. - pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box> { + pub(super) fn new( + future: T, + scheduler: S, + state: State, + task_id: Id, + #[cfg(tokio_unstable)] hooks: OptionalTaskHooks, + ) -> Box> { // Separated into a non-generic function to reduce LLVM codegen fn new_header( state: State, @@ -229,7 +238,13 @@ impl Cell { let tracing_id = future.id(); let vtable = raw::vtable::(); let result = Box::new(Cell { - trailer: Trailer::new(scheduler.hooks()), + #[cfg(tokio_unstable)] + trailer: Trailer::new( + #[cfg(tokio_unstable)] + hooks, + ), + #[cfg(not(tokio_unstable))] + trailer: Trailer::new(), header: new_header( state, vtable, @@ -462,11 +477,12 @@ impl Header { } impl Trailer { - fn new(hooks: TaskHarnessScheduleHooks) -> Self { + fn new(#[cfg(tokio_unstable)] hooks: OptionalTaskHooks) -> Self { Trailer { waker: UnsafeCell::new(None), owned: linked_list::Pointers::new(), - hooks, + #[cfg(tokio_unstable)] + hooks: UnsafeCell::new(hooks), } } diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 9bf73b74fbf..f039731a51c 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -1,10 +1,12 @@ use crate::future::Future; +#[cfg(tokio_unstable)] +use crate::runtime::context::with_task_hooks; use crate::runtime::task::core::{Cell, Core, Header, Trailer}; use crate::runtime::task::state::{Snapshot, State}; use crate::runtime::task::waker::waker_ref; use crate::runtime::task::{Id, JoinError, Notified, RawTask, Schedule, Task}; - -use crate::runtime::TaskMeta; +#[cfg(tokio_unstable)] +use crate::runtime::{AfterTaskPollContext, OnTaskTerminateContext}; use std::any::Any; use std::mem; use std::mem::ManuallyDrop; @@ -150,8 +152,21 @@ where /// All necessary state checks and transitions are performed. /// Panics raised while polling the future are handled. pub(super) fn poll(self) { + let res = self.poll_inner(); + + #[cfg(tokio_unstable)] + let _ = with_task_hooks(|t| { + if let Some(hooks) = t { + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks.after_poll(&mut AfterTaskPollContext { + _phantom: Default::default(), + }) + })); + } + }); + // We pass our ref-count to `poll_inner`. - match self.poll_inner() { + match res { PollFuture::Notified => { // The `poll_inner` call has given us two ref-counts back. // We give one of them to a new task and call `yield_now`. @@ -367,14 +382,16 @@ where // // We call this in a separate block so that it runs after the task appears to have // completed and will still run if the destructor panics. - if let Some(f) = self.trailer().hooks.task_terminate_callback.as_ref() { - let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { - f(&TaskMeta { - id: self.core().task_id, - _phantom: Default::default(), - }) - })); - } + #[cfg(tokio_unstable)] + let _ = with_task_hooks(|t| { + if let Some(hooks) = t { + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks.on_task_terminate(&mut OnTaskTerminateContext { + _phantom: Default::default(), + }) + })); + } + }); // The task has completed execution and will no longer be scheduled. let num_release = self.release(); diff --git a/tokio/src/runtime/task/list.rs b/tokio/src/runtime/task/list.rs index 54bfc01aafb..91e89dd4ffa 100644 --- a/tokio/src/runtime/task/list.rs +++ b/tokio/src/runtime/task/list.rs @@ -13,9 +13,10 @@ use crate::util::linked_list::{Link, LinkedList}; use crate::util::sharded_list; use crate::loom::sync::atomic::{AtomicBool, Ordering}; +#[cfg(tokio_unstable)] +use crate::runtime::OptionalTaskHooks; use std::marker::PhantomData; use std::num::NonZeroU64; - // The id from the module below is used to verify whether a given task is stored // in this OwnedTasks, or some other task. The counter starts at one so we can // use `None` for tasks not owned by any list. @@ -91,13 +92,20 @@ impl OwnedTasks { task: T, scheduler: S, id: super::Id, + #[cfg(tokio_unstable)] hooks: OptionalTaskHooks, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + Send + 'static, T::Output: Send + 'static, { - let (task, notified, join) = super::new_task(task, scheduler, id); + let (task, notified, join) = super::new_task( + task, + scheduler, + id, + #[cfg(tokio_unstable)] + hooks, + ); let notified = unsafe { self.bind_inner(task, notified) }; (join, notified) } @@ -111,13 +119,20 @@ impl OwnedTasks { task: T, scheduler: S, id: super::Id, + #[cfg(tokio_unstable)] parent: OptionalTaskHooks, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + 'static, T::Output: 'static, { - let (task, notified, join) = super::new_task(task, scheduler, id); + let (task, notified, join) = super::new_task( + task, + scheduler, + id, + #[cfg(tokio_unstable)] + parent, + ); let notified = unsafe { self.bind_inner(task, notified) }; (join, notified) } @@ -258,12 +273,16 @@ impl LocalOwnedTasks { task: T, scheduler: S, id: super::Id, + #[cfg(tokio_unstable)] parent: OptionalTaskHooks, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + 'static, T::Output: 'static, { + #[cfg(tokio_unstable)] + let (task, notified, join) = super::new_task(task, scheduler, id, parent); + #[cfg(not(tokio_unstable))] let (task, notified, join) = super::new_task(task, scheduler, id); unsafe { diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs index 7d314c3b176..7cf2f1e98f7 100644 --- a/tokio/src/runtime/task/mod.rs +++ b/tokio/src/runtime/task/mod.rs @@ -221,10 +221,10 @@ cfg_taskdump! { } use crate::future::Future; +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooks, OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; use crate::util::linked_list; use crate::util::sharded_list; - -use crate::runtime::TaskCallback; use std::marker::PhantomData; use std::ptr::NonNull; use std::{fmt, mem}; @@ -256,13 +256,6 @@ pub(crate) struct LocalNotified { _not_send: PhantomData<*const ()>, } -impl LocalNotified { - #[cfg(tokio_unstable)] - pub(crate) fn task_id(&self) -> Id { - self.task.id() - } -} - /// A task that is not owned by any `OwnedTasks`. Used for blocking tasks. /// This type holds two ref-counts. pub(crate) struct UnownedTask { @@ -277,12 +270,6 @@ unsafe impl Sync for UnownedTask {} /// Task result sent back. pub(crate) type Result = std::result::Result; -/// Hooks for scheduling tasks which are needed in the task harness. -#[derive(Clone)] -pub(crate) struct TaskHarnessScheduleHooks { - pub(crate) task_terminate_callback: Option, -} - pub(crate) trait Schedule: Sync + Sized + 'static { /// The task has completed work and is ready to be released. The scheduler /// should release it immediately and return it. The task module will batch @@ -294,7 +281,11 @@ pub(crate) trait Schedule: Sync + Sized + 'static { /// Schedule the task fn schedule(&self, task: Notified); - fn hooks(&self) -> TaskHarnessScheduleHooks; + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory; + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_>; /// Schedule the task to run in the near future, yielding the thread to /// other tasks. @@ -317,13 +308,19 @@ cfg_rt! { task: T, scheduler: S, id: Id, + #[cfg(tokio_unstable)] + hooks: OptionalTaskHooks ) -> (Task, Notified, JoinHandle) where S: Schedule, T: Future + 'static, T::Output: 'static, { + #[cfg(tokio_unstable)] + let raw = RawTask::new::(task, scheduler, id, hooks); + #[cfg(not(tokio_unstable))] let raw = RawTask::new::(task, scheduler, id); + let task = Task { raw, _p: PhantomData, @@ -341,12 +338,16 @@ cfg_rt! { /// only when the task is not going to be stored in an `OwnedTasks` list. /// /// Currently only blocking tasks use this method. - pub(crate) fn unowned(task: T, scheduler: S, id: Id) -> (UnownedTask, JoinHandle) + pub(crate) fn unowned(task: T, scheduler: S, id: Id, #[cfg(tokio_unstable)] hooks: OptionalTaskHooks) -> (UnownedTask, JoinHandle) where S: Schedule, T: Send + Future + 'static, T::Output: Send + 'static, { + #[cfg(tokio_unstable)] + let (task, notified, join) = new_task(task, scheduler, id, hooks); + + #[cfg(not(tokio_unstable))] let (task, notified, join) = new_task(task, scheduler, id); // This transfers the ref-count of task and notified into an UnownedTask. @@ -459,6 +460,7 @@ impl LocalNotified { /// Runs the task. pub(crate) fn run(self) { let raw = self.task.raw; + mem::forget(self); raw.poll(); } diff --git a/tokio/src/runtime/task/raw.rs b/tokio/src/runtime/task/raw.rs index 6699551f3ec..90b3477ecc8 100644 --- a/tokio/src/runtime/task/raw.rs +++ b/tokio/src/runtime/task/raw.rs @@ -1,7 +1,12 @@ use crate::future::Future; +#[cfg(tokio_unstable)] +use crate::runtime::context::set_task_hooks; use crate::runtime::task::core::{Core, Trailer}; use crate::runtime::task::{Cell, Harness, Header, Id, Schedule, State}; - +#[cfg(tokio_unstable)] +use crate::runtime::{BeforeTaskPollContext, OptionalTaskHooks, TaskHookHarness}; +#[cfg(tokio_unstable)] +use std::panic; use std::ptr::NonNull; use std::task::{Poll, Waker}; @@ -157,12 +162,24 @@ const fn get_id_offset( } impl RawTask { - pub(super) fn new(task: T, scheduler: S, id: Id) -> RawTask + pub(super) fn new( + task: T, + scheduler: S, + id: Id, + #[cfg(tokio_unstable)] hooks: OptionalTaskHooks, + ) -> RawTask where T: Future, S: Schedule, { - let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new(), id)); + let ptr = Box::into_raw(Cell::<_, S>::new( + task, + scheduler, + State::new(), + id, + #[cfg(tokio_unstable)] + hooks, + )); let ptr = unsafe { NonNull::new_unchecked(ptr.cast()) }; RawTask { ptr } @@ -197,8 +214,27 @@ impl RawTask { /// Safety: mutual exclusion is required to call this function. pub(crate) fn poll(self) { - let vtable = self.header().vtable; - unsafe { (vtable.poll)(self.ptr) } + #[cfg(tokio_unstable)] + let _guard = self.trailer().hooks.with_mut(|ptr| unsafe { + ptr.as_mut().and_then(|x| { + x.as_mut().map(|x| { + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + x.before_poll(&mut BeforeTaskPollContext { + _phantom: Default::default(), + }) + })); + + set_task_hooks(NonNull::new( + (&mut **x) as *mut (dyn TaskHookHarness + Send + Sync + 'static), + )) + }) + }) + }); + + unsafe { + let vtable = self.header().vtable; + (vtable.poll)(self.ptr); + } } pub(super) fn schedule(self) { diff --git a/tokio/src/runtime/task_hooks.rs b/tokio/src/runtime/task_hooks.rs deleted file mode 100644 index 13865ed515d..00000000000 --- a/tokio/src/runtime/task_hooks.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::marker::PhantomData; - -use super::Config; - -impl TaskHooks { - pub(crate) fn spawn(&self, meta: &TaskMeta<'_>) { - if let Some(f) = self.task_spawn_callback.as_ref() { - f(meta) - } - } - - #[allow(dead_code)] - pub(crate) fn from_config(config: &Config) -> Self { - Self { - task_spawn_callback: config.before_spawn.clone(), - task_terminate_callback: config.after_termination.clone(), - #[cfg(tokio_unstable)] - before_poll_callback: config.before_poll.clone(), - #[cfg(tokio_unstable)] - after_poll_callback: config.after_poll.clone(), - } - } - - #[cfg(tokio_unstable)] - #[inline] - pub(crate) fn poll_start_callback(&self, id: super::task::Id) { - if let Some(poll_start) = &self.before_poll_callback { - (poll_start)(&TaskMeta { - id, - _phantom: std::marker::PhantomData, - }) - } - } - - #[cfg(tokio_unstable)] - #[inline] - pub(crate) fn poll_stop_callback(&self, id: super::task::Id) { - if let Some(poll_stop) = &self.after_poll_callback { - (poll_stop)(&TaskMeta { - id, - _phantom: std::marker::PhantomData, - }) - } - } -} - -#[derive(Clone)] -pub(crate) struct TaskHooks { - pub(crate) task_spawn_callback: Option, - pub(crate) task_terminate_callback: Option, - #[cfg(tokio_unstable)] - pub(crate) before_poll_callback: Option, - #[cfg(tokio_unstable)] - pub(crate) after_poll_callback: Option, -} - -/// Task metadata supplied to user-provided hooks for task events. -/// -/// **Note**: This is an [unstable API][unstable]. The public API of this type -/// may break in 1.x releases. See [the documentation on unstable -/// features][unstable] for details. -/// -/// [unstable]: crate#unstable-features -#[allow(missing_debug_implementations)] -#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] -pub struct TaskMeta<'a> { - /// The opaque ID of the task. - pub(crate) id: super::task::Id, - pub(crate) _phantom: PhantomData<&'a ()>, -} - -impl<'a> TaskMeta<'a> { - /// Return the opaque ID of the task. - #[cfg_attr(not(tokio_unstable), allow(unreachable_pub, dead_code))] - pub fn id(&self) -> super::task::Id { - self.id - } -} - -/// Runs on specific task-related events -pub(crate) type TaskCallback = std::sync::Arc) + Send + Sync>; diff --git a/tokio/src/runtime/task_hooks/mod.rs b/tokio/src/runtime/task_hooks/mod.rs new file mode 100644 index 00000000000..51b70cf5a91 --- /dev/null +++ b/tokio/src/runtime/task_hooks/mod.rs @@ -0,0 +1,156 @@ +use super::task; +use crate::loom::cell::UnsafeCell; +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::sync::Arc; + +/// A factory which produces new [`TaskHookHarness`] objects for tasks which either have been +/// spawned in "detached mode" via [`crate::task::spawn_with_hooks`], or which were spawned from outside the runtime or +/// from another context where no [`TaskHookHarness`] was present. +pub trait TaskHookHarnessFactory { + /// Runs a hook which may produce a new [`TaskHookHarness`] object which the runtime will attach to a given task. + fn on_top_level_spawn(&self, ctx: &mut OnTopLevelTaskSpawnContext<'_>) + -> OnTopLevelSpawnAction; +} + +/// Trait for user-provided "harness" objects which are attached to tasks and provide hook +/// implementations. +#[allow(unused_variables)] +pub trait TaskHookHarness { + /// Pre-poll task hook which runs arbitrary user logic. + fn before_poll(&mut self, ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction { + BeforeTaskPollAction::default() + } + + /// Post-poll task hook which runs arbitrary user logic. + fn after_poll(&mut self, ctx: &mut AfterTaskPollContext<'_>) -> AfterTaskPollAction { + AfterTaskPollAction::default() + } + + /// Task hook which runs when this task spawns a child, unless that child is explicitly spawned + /// detached from the parent. + /// + /// This hook creates a harness for the child, or detaches the child from any instrumentation. + fn on_child_spawn(&mut self, ctx: &mut OnChildTaskSpawnContext<'_>) -> OnChildSpawnAction { + OnChildSpawnAction::default() + } + + /// Task hook which runs on task termination. + fn on_task_terminate(&mut self, ctx: &mut OnTaskTerminateContext<'_>) -> OnTaskTerminateAction { + OnTaskTerminateAction::default() + } +} + +pub(crate) type OptionalTaskHooksFactory = + Option>; +pub(crate) type OptionalTaskHooks = Option>; + +pub(crate) type OptionalTaskHooksWeak = + UnsafeCell>>; + +pub(crate) type OptionalTaskHooksMut<'a> = + Option<&'a mut (dyn TaskHookHarness + Send + Sync + 'static)>; +pub(crate) type OptionalTaskHooksFactoryRef<'a> = + Option<&'a (dyn TaskHookHarnessFactory + Send + Sync + 'static)>; + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct OnTopLevelTaskSpawnContext<'a> { + pub(crate) id: task::Id, + pub(crate) _phantom: PhantomData<&'a ()>, +} + +impl<'a> OnTopLevelTaskSpawnContext<'a> { + /// Returns the ID of the task. + pub fn id(&self) -> task::Id { + self.id + } +} + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct OnChildTaskSpawnContext<'a> { + pub(crate) id: task::Id, + pub(crate) _phantom: PhantomData<&'a ()>, +} + +impl<'a> OnChildTaskSpawnContext<'a> { + /// Returns the ID of the task. + pub fn id(&self) -> task::Id { + self.id + } +} + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct OnTaskTerminateContext<'a> { + pub(crate) _phantom: PhantomData<&'a ()>, +} + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct BeforeTaskPollContext<'a> { + pub(crate) _phantom: PhantomData<&'a ()>, +} + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct AfterTaskPollContext<'a> { + pub(crate) _phantom: PhantomData<&'a ()>, +} + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct OnTopLevelSpawnAction { + pub(crate) hooks: Option>, +} + +impl OnTopLevelSpawnAction { + /// Pass in a set of task hooks for the task. + pub fn set_hooks(&mut self, hooks: T) -> &mut Self + where + T: TaskHookHarness + Send + Sync + 'static, + { + self.hooks = Some(Box::new(hooks)); + self + } +} + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct OnChildSpawnAction { + pub(crate) hooks: Option>, +} + +impl OnChildSpawnAction { + /// Pass in a set of task hooks for the child task. + pub fn set_hooks(&mut self, hooks: T) -> &mut Self + where + T: TaskHookHarness + Send + Sync + 'static, + { + self.hooks = Some(Box::new(hooks)); + self + } +} + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct OnTaskTerminateAction {} + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct BeforeTaskPollAction {} + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct AfterTaskPollAction {} diff --git a/tokio/src/runtime/tests/mod.rs b/tokio/src/runtime/tests/mod.rs index 6fcf8a2ec09..6115c6c429b 100644 --- a/tokio/src/runtime/tests/mod.rs +++ b/tokio/src/runtime/tests/mod.rs @@ -6,7 +6,9 @@ use self::noop_scheduler::NoopSchedule; use self::unowned_wrapper::unowned; mod noop_scheduler { - use crate::runtime::task::{self, Task, TaskHarnessScheduleHooks}; + use crate::runtime::task::{self, Task}; + #[cfg(tokio_unstable)] + use crate::runtime::{OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; /// `task::Schedule` implementation that does nothing, for testing. pub(crate) struct NoopSchedule; @@ -20,10 +22,14 @@ mod noop_scheduler { unreachable!(); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: None, - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + None + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + None } } } @@ -41,6 +47,9 @@ mod unowned_wrapper { use tracing::Instrument; let span = tracing::trace_span!("test_span"); let task = task.instrument(span); + #[cfg(tokio_unstable)] + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next(), None); + #[cfg(not(tokio_unstable))] let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next()); (task.into_notified(), handle) } @@ -51,6 +60,9 @@ mod unowned_wrapper { T: std::future::Future + Send + 'static, T::Output: Send + 'static, { + #[cfg(tokio_unstable)] + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next(), None); + #[cfg(not(tokio_unstable))] let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next()); (task.into_notified(), handle) } diff --git a/tokio/src/runtime/tests/queue.rs b/tokio/src/runtime/tests/queue.rs index 9047f4ad7af..b44e8992fd9 100644 --- a/tokio/src/runtime/tests/queue.rs +++ b/tokio/src/runtime/tests/queue.rs @@ -1,5 +1,4 @@ use crate::runtime::scheduler::multi_thread::{queue, Stats}; - use std::cell::RefCell; use std::thread; use std::time::Duration; diff --git a/tokio/src/runtime/tests/task.rs b/tokio/src/runtime/tests/task.rs index ea48b8e5199..4cf0de69cf0 100644 --- a/tokio/src/runtime/tests/task.rs +++ b/tokio/src/runtime/tests/task.rs @@ -1,8 +1,7 @@ -use crate::runtime::task::{ - self, unowned, Id, JoinHandle, OwnedTasks, Schedule, Task, TaskHarnessScheduleHooks, -}; +use crate::runtime::task::{self, unowned, Id, JoinHandle, OwnedTasks, Schedule, Task}; use crate::runtime::tests::NoopSchedule; - +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; use std::collections::VecDeque; use std::future::Future; use std::sync::atomic::{AtomicBool, Ordering}; @@ -447,9 +446,13 @@ impl Schedule for Runtime { self.0.core.try_lock().unwrap().queue.push_back(task); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: None, - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + None + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + None } } diff --git a/tokio/src/task/builder.rs b/tokio/src/task/builder.rs index 6053352a01c..c34a2ca462e 100644 --- a/tokio/src/task/builder.rs +++ b/tokio/src/task/builder.rs @@ -44,8 +44,12 @@ use std::{future::Future, io, mem}; /// loop { /// let (socket, _) = listener.accept().await?; /// -/// tokio::task::Builder::new() -/// .name("tcp connection handler") +/// let mut builder = tokio::task::Builder::new(); +/// +/// builder +/// .name("tcp connection handler"); +/// +/// builder /// .spawn(async move { /// // Process each socket concurrently. /// process(socket).await @@ -71,8 +75,9 @@ impl<'a> Builder<'a> { } /// Assigns a name to the task which will be spawned. - pub fn name(&self, name: &'a str) -> Self { - Self { name: Some(name) } + pub fn name(&mut self, name: &'a str) -> &mut Self { + self.name = Some(name); + self } /// Spawns a task with this builder's settings on the current runtime. @@ -91,9 +96,9 @@ impl<'a> Builder<'a> { { let fut_size = mem::size_of::(); Ok(if fut_size > BOX_FUTURE_THRESHOLD { - super::spawn::spawn_inner(Box::pin(future), SpawnMeta::new(self.name, fut_size)) + super::spawn::spawn_inner(Box::pin(future), SpawnMeta::new(self.name, fut_size), None) } else { - super::spawn::spawn_inner(future, SpawnMeta::new(self.name, fut_size)) + super::spawn::spawn_inner(future, SpawnMeta::new(self.name, fut_size), None) }) } @@ -112,9 +117,9 @@ impl<'a> Builder<'a> { { let fut_size = mem::size_of::(); Ok(if fut_size > BOX_FUTURE_THRESHOLD { - handle.spawn_named(Box::pin(future), SpawnMeta::new(self.name, fut_size)) + handle.spawn_named(Box::pin(future), SpawnMeta::new(self.name, fut_size), None) } else { - handle.spawn_named(future, SpawnMeta::new(self.name, fut_size)) + handle.spawn_named(future, SpawnMeta::new(self.name, fut_size), None) }) } @@ -140,9 +145,13 @@ impl<'a> Builder<'a> { { let fut_size = mem::size_of::(); Ok(if fut_size > BOX_FUTURE_THRESHOLD { - super::local::spawn_local_inner(Box::pin(future), SpawnMeta::new(self.name, fut_size)) + super::local::spawn_local_inner( + Box::pin(future), + SpawnMeta::new(self.name, fut_size), + None, + ) } else { - super::local::spawn_local_inner(future, SpawnMeta::new(self.name, fut_size)) + super::local::spawn_local_inner(future, SpawnMeta::new(self.name, fut_size), None) }) } diff --git a/tokio/src/task/join_set.rs b/tokio/src/task/join_set.rs index a156719a067..2c43b9c989d 100644 --- a/tokio/src/task/join_set.rs +++ b/tokio/src/task/join_set.rs @@ -641,9 +641,13 @@ where #[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "tracing"))))] impl<'a, T: 'static> Builder<'a, T> { /// Assigns a name to the task which will be spawned. - pub fn name(self, name: &'a str) -> Self { - let builder = self.builder.name(name); - Self { builder, ..self } + pub fn name(mut self, name: &'a str) -> Self { + self.builder.name(name); + + Self { + builder: self.builder, + ..self + } } /// Spawn the provided task with this builder's settings and store it in the diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 95bd6404bec..f100e6dc2ca 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -1,9 +1,11 @@ //! Runs `!Send` futures on the current thread. use crate::loom::cell::UnsafeCell; use crate::loom::sync::{Arc, Mutex}; +use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task}; #[cfg(tokio_unstable)] -use crate::runtime; -use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task, TaskHarnessScheduleHooks}; +use crate::runtime::{ + self, OptionalTaskHooks, OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef, +}; use crate::runtime::{context, ThreadId, BOX_FUTURE_THRESHOLD}; use crate::sync::AtomicWaker; use crate::util::trace::SpawnMeta; @@ -371,6 +373,13 @@ cfg_rt! { F::Output: 'static, { let fut_size = std::mem::size_of::(); + #[cfg(tokio_unstable)] + if fut_size > BOX_FUTURE_THRESHOLD { + spawn_local_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size), None) + } else { + spawn_local_inner(future, SpawnMeta::new_unnamed(fut_size), None) + } + #[cfg(not(tokio_unstable))] if fut_size > BOX_FUTURE_THRESHOLD { spawn_local_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size)) } else { @@ -380,7 +389,7 @@ cfg_rt! { #[track_caller] - pub(super) fn spawn_local_inner(future: F, meta: SpawnMeta<'_>) -> JoinHandle + pub(super) fn spawn_local_inner(future: F, meta: SpawnMeta<'_>, #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks) -> JoinHandle where F: Future + 'static, F::Output: 'static { @@ -412,6 +421,9 @@ cfg_rt! { let task = crate::util::trace::task(future, "task", meta, id.as_u64()); // safety: we have verified that this is a `LocalRuntime` owned by the current thread + #[cfg(tokio_unstable)] + unsafe { handle.spawn_local(task, id, hooks_override) } + #[cfg(not(tokio_unstable))] unsafe { handle.spawn_local(task, id) } } else { match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) { @@ -1004,6 +1016,15 @@ impl Context { let future = crate::util::trace::task(future, "local", meta, id.as_u64()); // Safety: called from the thread that owns the `LocalSet` + #[cfg(tokio_unstable)] + let (handle, notified) = { + self.shared.local_state.assert_called_from_owner_thread(); + self.shared + .local_state + .owned + .bind(future, self.shared.clone(), id, None) + }; + #[cfg(not(tokio_unstable))] let (handle, notified) = { self.shared.local_state.assert_called_from_owner_thread(); self.shared @@ -1117,11 +1138,15 @@ impl task::Schedule for Arc { Shared::schedule(self, task); } - // localset does not currently support task hooks - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: None, - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + None + } + + // localset does not support task hooks + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + None } cfg_unstable! { diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index f0c6f71c15a..601793f2f99 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -311,6 +311,10 @@ cfg_rt! { pub use crate::runtime::task::{Id, id, try_id}; + cfg_unstable! { + pub use spawn::spawn_with_hooks; + } + cfg_trace! { mod builder; pub use builder::Builder; diff --git a/tokio/src/task/spawn.rs b/tokio/src/task/spawn.rs index 7c748226121..ad1fc6eb19c 100644 --- a/tokio/src/task/spawn.rs +++ b/tokio/src/task/spawn.rs @@ -1,4 +1,6 @@ use crate::runtime::BOX_FUTURE_THRESHOLD; +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooks, TaskHookHarness}; use crate::task::JoinHandle; use crate::util::trace::SpawnMeta; @@ -169,6 +171,13 @@ cfg_rt! { F::Output: Send + 'static, { let fut_size = std::mem::size_of::(); + #[cfg(tokio_unstable)] + if fut_size > BOX_FUTURE_THRESHOLD { + spawn_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size), None) + } else { + spawn_inner(future, SpawnMeta::new_unnamed(fut_size), None) + } + #[cfg(not(tokio_unstable))] if fut_size > BOX_FUTURE_THRESHOLD { spawn_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size)) } else { @@ -176,8 +185,26 @@ cfg_rt! { } } + /// Spawn a future with a custom set of task hooks + #[track_caller] + #[cfg(tokio_unstable)] + pub fn spawn_with_hooks(future: F, hooks: T) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + T: TaskHookHarness + Send + Sync + 'static, + { + let fut_size = std::mem::size_of::(); + + if fut_size > BOX_FUTURE_THRESHOLD { + spawn_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size), Some(Box::new(hooks))) + } else { + spawn_inner(future, SpawnMeta::new_unnamed(fut_size), Some(Box::new(hooks))) + } + } + #[track_caller] - pub(super) fn spawn_inner(future: T, meta: SpawnMeta<'_>) -> JoinHandle + pub(super) fn spawn_inner(future: T, meta: SpawnMeta<'_>, #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks) -> JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, @@ -199,6 +226,13 @@ cfg_rt! { let id = task::Id::next(); let task = crate::util::trace::task(future, "task", meta, id.as_u64()); + #[cfg(tokio_unstable)] + return match context::with_current(|handle| handle.spawn(task, id, hooks_override)) { + Ok(join_handle) => join_handle, + Err(e) => panic!("{}", e), + }; + + #[cfg(not(tokio_unstable))] match context::with_current(|handle| handle.spawn(task, id)) { Ok(join_handle) => join_handle, Err(e) => panic!("{}", e), diff --git a/tokio/tests/rt_poll_callbacks.rs b/tokio/tests/rt_poll_callbacks.rs deleted file mode 100644 index 8ccff385772..00000000000 --- a/tokio/tests/rt_poll_callbacks.rs +++ /dev/null @@ -1,128 +0,0 @@ -#![allow(unknown_lints, unexpected_cfgs)] -#![cfg(tokio_unstable)] - -use std::sync::{atomic::AtomicUsize, Arc, Mutex}; - -use tokio::task::yield_now; - -#[cfg(not(target_os = "wasi"))] -#[test] -fn callbacks_fire_multi_thread() { - let poll_start_counter = Arc::new(AtomicUsize::new(0)); - let poll_stop_counter = Arc::new(AtomicUsize::new(0)); - let poll_start = poll_start_counter.clone(); - let poll_stop = poll_stop_counter.clone(); - - let before_task_poll_callback_task_id: Arc>> = - Arc::new(Mutex::new(None)); - let after_task_poll_callback_task_id: Arc>> = - Arc::new(Mutex::new(None)); - - let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id); - let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id); - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .on_before_task_poll(move |task_meta| { - before_task_poll_callback_task_id_ref - .lock() - .unwrap() - .replace(task_meta.id()); - poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - }) - .on_after_task_poll(move |task_meta| { - after_task_poll_callback_task_id_ref - .lock() - .unwrap() - .replace(task_meta.id()); - poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - }) - .build() - .unwrap(); - let task = rt.spawn(async { - yield_now().await; - yield_now().await; - yield_now().await; - }); - - let spawned_task_id = task.id(); - - rt.block_on(task).expect("task should succeed"); - // We need to drop the runtime to guarantee the workers have exited (and thus called the callback) - drop(rt); - - assert_eq!( - before_task_poll_callback_task_id.lock().unwrap().unwrap(), - spawned_task_id - ); - assert_eq!( - after_task_poll_callback_task_id.lock().unwrap().unwrap(), - spawned_task_id - ); - let actual_count = 4; - assert_eq!( - poll_start.load(std::sync::atomic::Ordering::Relaxed), - actual_count, - "unexpected number of poll starts" - ); - assert_eq!( - poll_stop.load(std::sync::atomic::Ordering::Relaxed), - actual_count, - "unexpected number of poll stops" - ); -} - -#[test] -fn callbacks_fire_current_thread() { - let poll_start_counter = Arc::new(AtomicUsize::new(0)); - let poll_stop_counter = Arc::new(AtomicUsize::new(0)); - let poll_start = poll_start_counter.clone(); - let poll_stop = poll_stop_counter.clone(); - - let before_task_poll_callback_task_id: Arc>> = - Arc::new(Mutex::new(None)); - let after_task_poll_callback_task_id: Arc>> = - Arc::new(Mutex::new(None)); - - let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id); - let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id); - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .on_before_task_poll(move |task_meta| { - before_task_poll_callback_task_id_ref - .lock() - .unwrap() - .replace(task_meta.id()); - poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - }) - .on_after_task_poll(move |task_meta| { - after_task_poll_callback_task_id_ref - .lock() - .unwrap() - .replace(task_meta.id()); - poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - }) - .build() - .unwrap(); - - let task = rt.spawn(async { - yield_now().await; - yield_now().await; - yield_now().await; - }); - - let spawned_task_id = task.id(); - - let _ = rt.block_on(task); - drop(rt); - - assert_eq!( - before_task_poll_callback_task_id.lock().unwrap().unwrap(), - spawned_task_id - ); - assert_eq!( - after_task_poll_callback_task_id.lock().unwrap().unwrap(), - spawned_task_id - ); - assert_eq!(poll_start.load(std::sync::atomic::Ordering::Relaxed), 4); - assert_eq!(poll_stop.load(std::sync::atomic::Ordering::Relaxed), 4); -} diff --git a/tokio/tests/task_builder.rs b/tokio/tests/task_builder.rs index c700f229f9f..63cd9d925f1 100644 --- a/tokio/tests/task_builder.rs +++ b/tokio/tests/task_builder.rs @@ -8,22 +8,22 @@ use tokio::{ #[test] async fn spawn_with_name() { - let result = Builder::new() - .name("name") - .spawn(async { "task executed" }) - .unwrap() - .await; + let mut b = Builder::new(); + + b.name("name"); + + let result = b.spawn(async { "task executed" }).unwrap().await; assert_eq!(result.unwrap(), "task executed"); } #[test] async fn spawn_blocking_with_name() { - let result = Builder::new() - .name("name") - .spawn_blocking(|| "task executed") - .unwrap() - .await; + let mut b = Builder::new(); + + b.name("name"); + + let result = b.spawn_blocking(|| "task executed").unwrap().await; assert_eq!(result.unwrap(), "task executed"); } @@ -33,11 +33,11 @@ async fn spawn_local_with_name() { let unsend_data = Rc::new("task executed"); let result = LocalSet::new() .run_until(async move { - Builder::new() - .name("name") - .spawn_local(async move { unsend_data }) - .unwrap() - .await + let mut b = Builder::new(); + + b.name("name"); + + b.spawn_local(async move { unsend_data }).unwrap().await }) .await; diff --git a/tokio/tests/task_hooks.rs b/tokio/tests/task_hooks.rs index 185b9126cca..e2127388e58 100644 --- a/tokio/tests/task_hooks.rs +++ b/tokio/tests/task_hooks.rs @@ -1,75 +1,471 @@ -#![warn(rust_2018_idioms)] -#![cfg(all(feature = "full", tokio_unstable, target_has_atomic = "64"))] +#![cfg(all( + feature = "full", + tokio_unstable, + target_has_atomic = "64", + not(target_arch = "wasm32") +))] -use std::collections::HashSet; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use tokio::runtime; +use tokio::runtime::{ + AfterTaskPollAction, AfterTaskPollContext, BeforeTaskPollAction, BeforeTaskPollContext, + OnChildSpawnAction, OnChildTaskSpawnContext, OnTaskTerminateAction, OnTaskTerminateContext, + OnTopLevelSpawnAction, OnTopLevelTaskSpawnContext, TaskHookHarness, TaskHookHarnessFactory, +}; -use tokio::runtime::Builder; +#[test] +fn runtime_default_factory() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + + run_runtime_default_factory(ct); + run_runtime_default_factory(mt); +} + +#[test] +fn parent_child_chaining() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + + run_parent_child_chaining(ct); + run_parent_child_chaining(mt); +} + +#[test] +fn before_poll() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + + run_before_poll(ct); + run_before_poll(mt); +} + +#[test] +fn after_poll() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + + run_after_poll(ct); + run_after_poll(mt); +} + +#[test] +fn terminate() { + let ct = runtime::Builder::new_current_thread(); + + run_terminate(ct); +} + +#[test] +fn hook_switching() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + + run_hook_switching(ct); + run_hook_switching(mt); +} -const TASKS: usize = 8; -const ITERATIONS: usize = 64; -/// Assert that the spawn task hook always fires when set. #[test] -fn spawn_task_hook_fires() { - let count = Arc::new(AtomicUsize::new(0)); - let count2 = Arc::clone(&count); +fn override_hooks() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + + run_override(ct); + run_override(mt); +} - let ids = Arc::new(Mutex::new(HashSet::new())); - let ids2 = Arc::clone(&ids); +fn run_runtime_default_factory(mut builder: runtime::Builder) { + struct TestFactory { + counter: Arc, + } - let runtime = Builder::new_current_thread() - .on_task_spawn(move |data| { - ids2.lock().unwrap().insert(data.id()); + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> OnTopLevelSpawnAction { + self.counter.fetch_add(1, Ordering::SeqCst); - count2.fetch_add(1, Ordering::SeqCst); + Default::default() + } + } + + let counter = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + counter: counter.clone(), }) .build() .unwrap(); - for _ in 0..TASKS { - runtime.spawn(std::future::pending::<()>()); + rt.spawn(async {}); + + assert_eq!(counter.load(Ordering::SeqCst), 1); + + let handle = rt.handle(); + + handle.spawn(async {}); + + assert_eq!(counter.load(Ordering::SeqCst), 2); + + rt.block_on(async {}); + + assert_eq!(counter.load(Ordering::SeqCst), 2); + + rt.block_on(async { tokio::spawn(async {}) }); + + assert_eq!(counter.load(Ordering::SeqCst), 3); + + // block on a future which spawns a future and waits for it, which in turn spawns another future + // + // this checks that stuff works from on-worker within a multithreaded runtime + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}) }).await }); + + assert_eq!(counter.load(Ordering::SeqCst), 5); +} + +fn run_parent_child_chaining(mut builder: runtime::Builder) { + struct TestFactory { + parent_spawns: Arc, + child_spawns: Arc, } - let count_realized = count.load(Ordering::SeqCst); - assert_eq!( - TASKS, count_realized, - "Total number of spawned task hook invocations was incorrect, expected {TASKS}, got {}", - count_realized - ); + struct TestHooks { + spawns: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> OnTopLevelSpawnAction { + self.parent_spawns.fetch_add(1, Ordering::SeqCst); + + let mut a = OnTopLevelSpawnAction::default(); - let count_ids_realized = ids.lock().unwrap().len(); + a.set_hooks(TestHooks { + spawns: self.child_spawns.clone(), + }); + + a + } + } - assert_eq!( - TASKS, count_ids_realized, - "Total number of spawned task hook invocations was incorrect, expected {TASKS}, got {}", - count_realized - ); + impl TaskHookHarness for TestHooks { + fn on_child_spawn(&mut self, _ctx: &mut OnChildTaskSpawnContext<'_>) -> OnChildSpawnAction { + self.spawns.fetch_add(1, Ordering::SeqCst); + + let mut a = OnChildSpawnAction::default(); + + a.set_hooks(Self { + spawns: self.spawns.clone(), + }); + + a + } + } + + let parent_spawns = Arc::new(AtomicUsize::new(0)); + let child_spawns = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + parent_spawns: parent_spawns.clone(), + child_spawns: child_spawns.clone(), + }) + .build() + .unwrap(); + + rt.spawn(async {}); + + assert_eq!(parent_spawns.load(Ordering::SeqCst), 1); + assert_eq!(child_spawns.load(Ordering::SeqCst), 0); + + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}) }).await }); + + assert_eq!(parent_spawns.load(Ordering::SeqCst), 2); + assert_eq!(child_spawns.load(Ordering::SeqCst), 1); } -/// Assert that the terminate task hook always fires when set. -#[test] -fn terminate_task_hook_fires() { - let count = Arc::new(AtomicUsize::new(0)); - let count2 = Arc::clone(&count); +fn run_before_poll(mut builder: runtime::Builder) { + struct TestFactory { + polls: Arc, + } + + struct TestHooks { + polls: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> OnTopLevelSpawnAction { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { + polls: self.polls.clone(), + }); + + a + } + } + + impl TaskHookHarness for TestHooks { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction { + self.polls.fetch_add(1, Ordering::SeqCst); + + Default::default() + } + } + + let polls = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + polls: polls.clone(), + }) + .build() + .unwrap(); + + rt.block_on(async {}); + assert_eq!(polls.load(Ordering::SeqCst), 0); + + let _ = rt.block_on(async { tokio::spawn(async {}).await }); + assert_eq!(polls.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}).await }).await }); + assert_eq!(polls.load(Ordering::SeqCst), 4); +} - let runtime = Builder::new_current_thread() - .on_task_terminate(move |_data| { - count2.fetch_add(1, Ordering::SeqCst); +fn run_after_poll(mut builder: runtime::Builder) { + struct TestFactory { + polls: Arc, + } + + struct TestHooks { + polls: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> OnTopLevelSpawnAction { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { + polls: self.polls.clone(), + }); + + a + } + } + + impl TaskHookHarness for TestHooks { + fn after_poll(&mut self, _ctx: &mut AfterTaskPollContext<'_>) -> AfterTaskPollAction { + self.polls.fetch_add(1, Ordering::SeqCst); + + Default::default() + } + } + + let polls = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + polls: polls.clone(), + }) + .build() + .unwrap(); + + rt.block_on(async {}); + assert_eq!(polls.load(Ordering::SeqCst), 0); + + let _ = rt.block_on(async { tokio::spawn(async {}).await }); + assert_eq!(polls.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}).await }).await }); + assert_eq!(polls.load(Ordering::SeqCst), 4); +} + +fn run_terminate(mut builder: runtime::Builder) { + struct TestFactory { + terminations: Arc, + } + + struct TestHooks { + terminations: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> OnTopLevelSpawnAction { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { + terminations: self.terminations.clone(), + }); + + a + } + } + + impl TaskHookHarness for TestHooks { + fn on_task_terminate( + &mut self, + _ctx: &mut OnTaskTerminateContext<'_>, + ) -> OnTaskTerminateAction { + self.terminations.fetch_add(1, Ordering::SeqCst); + + Default::default() + } + } + + let terminations = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + terminations: terminations.clone(), + }) + .build() + .unwrap(); + + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}).await }).await }); + + assert_eq!(terminations.load(Ordering::SeqCst), 2); +} + +fn run_hook_switching(mut builder: runtime::Builder) { + struct TestFactory { + next_id: Arc, + flag: Arc, + } + + struct TestHooks { + id: usize, + flag: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> OnTopLevelSpawnAction { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { + id: self.next_id.fetch_add(1, Ordering::SeqCst), + flag: self.flag.clone(), + }); + + a + } + } + + impl TaskHookHarness for TestHooks { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction { + self.flag.store(self.id, Ordering::SeqCst); + + Default::default() + } + } + + let polls = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + next_id: Arc::new(Default::default()), + flag: polls.clone(), }) .build() .unwrap(); - for _ in 0..TASKS { - runtime.spawn(std::future::ready(())); + let _ = rt.block_on(async { tokio::spawn(async {}).await }); + assert_eq!(polls.load(Ordering::SeqCst), 0); + + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}).await }).await }); + assert_eq!(polls.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { tokio::spawn(async {}).await }); + assert_eq!(polls.load(Ordering::SeqCst), 3); +} + +fn run_override(mut builder: runtime::Builder) { + struct TestFactory { + counter: Arc, + } + + struct TestHooks { + counter: Arc, } - runtime.block_on(async { - // tick the runtime a bunch to close out tasks - for _ in 0..ITERATIONS { - tokio::task::yield_now().await; + impl TaskHookHarness for TestHooks { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction { + self.counter.fetch_add(1, Ordering::SeqCst); + + Default::default() + } + + fn on_child_spawn(&mut self, _ctx: &mut OnChildTaskSpawnContext<'_>) -> OnChildSpawnAction { + let mut a = OnChildSpawnAction::default(); + + a.set_hooks(Self { + counter: self.counter.clone(), + }); + + a } + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> OnTopLevelSpawnAction { + self.counter.fetch_add(1, Ordering::SeqCst); + + Default::default() + } + } + + let factory_counter = Arc::new(AtomicUsize::new(0)); + let builder_counter = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + counter: factory_counter.clone(), + }) + .build() + .unwrap(); + + rt.spawn(async {}); + + assert_eq!(factory_counter.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { + tokio::task::spawn_with_hooks( + async {}, + TestHooks { + counter: builder_counter.clone(), + }, + ) + .await + }); + + assert_eq!(factory_counter.load(Ordering::SeqCst), 1); + assert_eq!(builder_counter.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { + let counter = builder_counter.clone(); + tokio::spawn(async { tokio::task::spawn_with_hooks(async {}, TestHooks { counter }).await }) + .await }); - assert_eq!(TASKS, count.load(Ordering::SeqCst)); + assert_eq!(factory_counter.load(Ordering::SeqCst), 2); + assert_eq!(builder_counter.load(Ordering::SeqCst), 2); } diff --git a/tokio/tests/tracing_task.rs b/tokio/tests/tracing_task.rs index a9317bf5b12..f2adf573a9d 100644 --- a/tokio/tests/tracing_task.rs +++ b/tokio/tests/tracing_task.rs @@ -64,9 +64,11 @@ async fn task_builder_name_recorded() { { let _guard = tracing::subscriber::set_default(subscriber); - task::Builder::new() - .name("test-task") - .spawn(futures::future::ready(())) + let mut b = task::Builder::new(); + + b.name("test-task"); + + b.spawn(futures::future::ready(())) .unwrap() .await .expect("failed to await join handle");