diff --git a/src/thread.rs b/src/thread.rs index 976802c..519262a 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -14,11 +14,7 @@ use std::{fmt, io}; use tracing::trace; /// Mock implementation of `std::thread::JoinHandle`. -pub struct JoinHandle { - result: Arc>>>, - notify: rt::Notify, - thread: Thread, -} +pub struct JoinHandle(JoinHandleInner<'static, T>); /// Mock implementation of `std::thread::Thread`. #[derive(Clone, Debug)] @@ -128,7 +124,7 @@ where F: 'static, T: 'static, { - spawn_internal(f, None, location!()) + JoinHandle(spawn_internal_static(f, None, location!())) } /// Mock implementation of `std::thread::park`. @@ -142,38 +138,6 @@ pub fn park() { rt::park(location!()); } -fn spawn_internal(f: F, name: Option, location: Location) -> JoinHandle -where - F: FnOnce() -> T, - F: 'static, - T: 'static, -{ - let result = Arc::new(Mutex::new(None)); - let notify = rt::Notify::new(true, false); - - let id = { - let name = name.clone(); - let result = result.clone(); - rt::spawn(move || { - rt::execution(|execution| { - init_current(execution, name); - }); - - *result.lock().unwrap() = Some(Ok(f())); - notify.notify(location); - }) - }; - - JoinHandle { - result, - notify, - thread: Thread { - id: ThreadId { id }, - name, - }, - } -} - impl Builder { /// Generates the base configuration for spawning a thread, from which /// configuration methods can be chained. @@ -206,7 +170,27 @@ impl Builder { F: Send + 'static, T: Send + 'static, { - Ok(spawn_internal(f, self.name, location!())) + Ok(JoinHandle(spawn_internal_static(f, self.name, location!()))) + } +} + +impl Builder { + /// Spawns a new scoped thread using the settings set through this `Builder`. + pub fn spawn_scoped<'scope, 'env, F, T>( + self, + scope: &'scope Scope<'scope, 'env>, + f: F, + ) -> io::Result> + where + F: FnOnce() -> T + Send + 'scope, + T: Send + 'scope, + { + Ok(ScopedJoinHandle( + // Safety: the call to this function requires a `&'scope Scope` + // which can only be constructed by `scope()`, which ensures that + // all spawned threads are joined before the `Scope` is destroyed. + unsafe { spawn_internal(f, self.name, Some(scope.data.clone()), location!()) }, + )) } } @@ -214,13 +198,12 @@ impl JoinHandle { /// Waits for the associated thread to finish. #[track_caller] pub fn join(self) -> std::thread::Result { - self.notify.wait(location!()); - self.result.lock().unwrap().take().unwrap() + self.0.join() } /// Gets a handle to the underlying [`Thread`] pub fn thread(&self) -> &Thread { - &self.thread + self.0.thread() } } @@ -301,3 +284,220 @@ impl fmt::Debug for LocalKey { f.pad("LocalKey { .. }") } } + +/// A scope for spawning scoped threads. +/// +/// See [`scope`] for more details. +#[derive(Debug)] +pub struct Scope<'scope, 'env: 'scope> { + data: Arc, + scope: PhantomData<&'scope mut &'scope ()>, + env: PhantomData<&'env mut &'env ()>, +} + +/// An owned permission to join on a scoped thread (block on its termination). +/// +/// See [`Scope::spawn`] for details. +#[derive(Debug)] +pub struct ScopedJoinHandle<'scope, T>(JoinHandleInner<'scope, T>); + +/// Create a scope for spawning scoped threads. +/// +/// Mock implementation of [`std::thread::scope`]. +#[track_caller] +pub fn scope<'env, F, T>(f: F) -> T +where + F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T, +{ + let scope = Scope { + data: Arc::new(ScopeData { + running_threads: Mutex::default(), + main_thread: current(), + }), + env: PhantomData, + scope: PhantomData, + }; + + // Run `f`, but catch panics so we can make sure to wait for all the threads to join. + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(&scope))); + + // Wait until all the threads are finished. This is required to fulfill + // the safety requirements of `spawn_internal`. + let running = loop { + { + let running = scope.data.running_threads.lock().unwrap(); + if running.count == 0 { + break running; + } + } + park(); + }; + + for notify in &running.notify_on_finished { + notify.wait(location!()) + } + + // Throw any panic from `f`, or the return value of `f` if no thread panicked. + match result { + Err(e) => std::panic::resume_unwind(e), + Ok(result) => result, + } +} + +impl<'scope, 'env> Scope<'scope, 'env> { + /// Spawns a new thread within a scope, returning a [`ScopedJoinHandle`] for it. + /// + /// See [`std::thread::Scope`] and [`std::thread::scope`] for details. + pub fn spawn(&'scope self, f: F) -> ScopedJoinHandle<'scope, T> + where + F: FnOnce() -> T + Send + 'scope, + T: Send + 'scope, + { + Builder::new() + .spawn_scoped(self, f) + .expect("failed to spawn thread") + } +} + +impl<'scope, T> ScopedJoinHandle<'scope, T> { + /// Extracts a handle to the underlying thread. + pub fn thread(&self) -> &Thread { + self.0.thread() + } + + /// Waits for the associated thread to finish. + pub fn join(self) -> std::thread::Result { + self.0.join() + } +} + +/// Handle for joining on a thread with a scope. +#[derive(Debug)] +struct JoinHandleInner<'scope, T> { + data: Arc>, + notify: rt::Notify, + thread: Thread, +} + +/// Spawns a thread without a local scope. +fn spawn_internal_static( + f: F, + name: Option, + location: Location, +) -> JoinHandleInner<'static, T> +where + F: FnOnce() -> T, + F: 'static, + T: 'static, +{ + // Safety: the requirements of `spawn_internal` are trivially satisfied + // since there is no `scope`. + unsafe { spawn_internal(f, name, None, location) } +} + +/// Spawns a thread with an optional scope. +/// +/// The caller must ensure that if `scope` is not None, the provided closure +/// finishes before `'scope` ends. +unsafe fn spawn_internal<'scope, F, T>( + f: F, + name: Option, + scope: Option>, + location: Location, +) -> JoinHandleInner<'scope, T> +where + F: FnOnce() -> T, + F: 'scope, + T: 'scope, +{ + let scope_notify = scope + .clone() + .map(|scope| (scope.add_running_thread(), scope)); + let thread_data = Arc::new(ThreadData::new()); + let notify = rt::Notify::new(true, false); + + let id = { + let name = name.clone(); + let thread_data = thread_data.clone(); + let body: Box = Box::new(move || { + rt::execution(|execution| { + init_current(execution, name); + }); + + *thread_data.result.lock().unwrap() = Some(Ok(f())); + notify.notify(location); + + if let Some((notifier, scope)) = scope_notify { + notifier.notify(location!()); + scope.remove_running_thread() + } + }); + rt::spawn(std::mem::transmute::<_, Box>(body)) + }; + + JoinHandleInner { + data: thread_data, + notify, + thread: Thread { + id: ThreadId { id }, + name, + }, + } +} + +/// Data for a running thread. +#[derive(Debug)] +struct ThreadData<'scope, T> { + result: Mutex>>, + _marker: PhantomData>, +} + +impl<'scope, T> ThreadData<'scope, T> { + fn new() -> Self { + Self { + result: Mutex::new(None), + _marker: PhantomData, + } + } +} + +impl<'scope, T> JoinHandleInner<'scope, T> { + fn join(self) -> std::thread::Result { + self.notify.wait(location!()); + self.data.result.lock().unwrap().take().unwrap() + } + + fn thread(&self) -> &Thread { + &self.thread + } +} + +#[derive(Default, Debug)] +struct ScopeThreads { + count: usize, + notify_on_finished: Vec, +} + +#[derive(Debug)] +struct ScopeData { + running_threads: Mutex, + main_thread: Thread, +} + +impl ScopeData { + fn add_running_thread(&self) -> rt::Notify { + let mut running = self.running_threads.lock().unwrap(); + running.count += 1; + let notify = rt::Notify::new(true, false); + running.notify_on_finished.push(notify); + notify + } + + fn remove_running_thread(&self) { + let mut running = self.running_threads.lock().unwrap(); + running.count -= 1; + if running.count == 0 { + self.main_thread.unpark() + } + } +} diff --git a/tests/thread_api.rs b/tests/thread_api.rs index 95405e4..c1ddd7f 100644 --- a/tests/thread_api.rs +++ b/tests/thread_api.rs @@ -123,3 +123,101 @@ fn park_unpark_std() { std::thread::park(); println!("it did not deadlock"); } + +fn incrementer(a: &loom::sync::atomic::AtomicUsize) -> impl FnOnce() + '_ { + move || { + let _ = a.fetch_add(1, loom::sync::atomic::Ordering::Relaxed); + } +} + +#[test] +fn scoped_thread() { + loom::model(|| { + const SPAWN_COUNT: usize = 3; + let a = loom::sync::atomic::AtomicUsize::new(0); + thread::scope(|scope| { + for _i in 0..SPAWN_COUNT { + let _handle = scope.spawn(incrementer(&a)); + } + }); + assert_eq!(a.load(loom::sync::atomic::Ordering::Relaxed), SPAWN_COUNT); + }) +} + +#[test] +fn scoped_thread_builder() { + loom::model(|| { + const SPAWN_COUNT: usize = 3; + let a = loom::sync::atomic::AtomicUsize::new(0); + thread::scope(|scope| { + for _i in 0..SPAWN_COUNT { + thread::Builder::new() + .spawn_scoped(scope, incrementer(&a)) + .unwrap(); + } + }); + assert_eq!(a.load(loom::sync::atomic::Ordering::Relaxed), SPAWN_COUNT); + }) +} + +#[test] +fn scoped_thread_join() { + loom::model(|| { + const JOIN_COUNT: usize = 2; + let a = loom::sync::atomic::AtomicUsize::new(0); + thread::scope(|scope| { + let handles = [(); JOIN_COUNT].map(|()| scope.spawn(incrementer(&a))); + + // Spawn another thread that might increment `a` before the first + // threads finish. + let _other_handle = scope.spawn(incrementer(&a)); + + for h in handles { + h.join().unwrap() + } + let a = a.load(loom::sync::atomic::Ordering::Relaxed); + assert!(a == JOIN_COUNT || a == JOIN_COUNT + 1); + }); + assert_eq!( + a.load(loom::sync::atomic::Ordering::Relaxed), + JOIN_COUNT + 1 + ); + }) +} + +#[test] +fn multiple_scopes() { + loom::model(|| { + let a = loom::sync::atomic::AtomicUsize::new(0); + + thread::scope(|scope| { + let _handle = scope.spawn(incrementer(&a)); + }); + assert_eq!(a.load(loom::sync::atomic::Ordering::Relaxed), 1); + + thread::scope(|scope| { + let _handle = scope.spawn(incrementer(&a)); + }); + assert_eq!(a.load(loom::sync::atomic::Ordering::Relaxed), 2); + }) +} + +#[test] +fn scoped_and_unscoped_threads() { + loom::model(|| { + let a = loom::sync::Arc::new(loom::sync::atomic::AtomicUsize::new(0)); + + let unscoped_handle = thread::scope(|scope| { + let _handle = scope.spawn(incrementer(&a)); + let a = a.clone(); + loom::thread::spawn(move || incrementer(&a)()) + }); + + let v = a.load(loom::sync::atomic::Ordering::Relaxed); + assert!(v == 1 || v == 2, "{}", v); + + unscoped_handle.join().unwrap(); + let v = a.load(loom::sync::atomic::Ordering::Relaxed); + assert_eq!(v, 2); + }) +}