diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 94fbd3f13f3b..019ef333397b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -491,6 +491,16 @@ jobs: - name: Run cargo test run: cargo xtask --deny-warnings --backend ${{ matrix.backend }} test ${{ matrix.package }} + loom-tests: + name: rtic-sync loom tests + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Run cargo test + run: RUSTFLAGS="--cfg loom" cargo test -p rtic-sync --release --lib + # Build documentation, check links docs: name: build docs diff --git a/rtic-sync/CHANGELOG.md b/rtic-sync/CHANGELOG.md index 6638a58efde7..bc10371699dc 100644 --- a/rtic-sync/CHANGELOG.md +++ b/rtic-sync/CHANGELOG.md @@ -7,6 +7,10 @@ For each category, _Added_, _Changed_, _Fixed_ add new entries at the top! ## [Unreleased] +### Changed + +- Add `loom` support. + ## v1.3.2 - 2025-03-16 ### Fixed diff --git a/rtic-sync/Cargo.toml b/rtic-sync/Cargo.toml index 60d8be2a8394..c78356ccc39b 100644 --- a/rtic-sync/Cargo.toml +++ b/rtic-sync/Cargo.toml @@ -25,15 +25,23 @@ portable-atomic = { version = "1", default-features = false } embedded-hal = { version = "1.0.0" } embedded-hal-async = { version = "1.0.0" } embedded-hal-bus = { version = "0.2.0", features = ["async"] } - defmt-03 = { package = "defmt", version = "0.3", optional = true } [dev-dependencies] cassette = "0.3.0" static_cell = "2.1.0" -tokio = { version = "1", features = ["rt", "macros", "time"] } + +[target.'cfg(not(loom))'.dev-dependencies] +tokio = { version = "1", features = ["rt", "macros", "time"], default-features = false } [features] default = [] -testing = ["critical-section/std", "rtic-common/testing"] +testing = ["rtic-common/testing"] defmt-03 = ["dep:defmt-03", "embedded-hal/defmt-03", "embedded-hal-async/defmt-03", "embedded-hal-bus/defmt-03"] + +[lints.rust] +unexpected_cfgs = { level = "allow", check-cfg = ['cfg(loom)'] } + +[target.'cfg(loom)'.dependencies] +loom = { version = "0.7.2", features = [ "futures" ] } +critical-section = { version = "1", features = [ "restore-state-bool" ] } diff --git a/rtic-sync/src/arbiter.rs b/rtic-sync/src/arbiter.rs index 768e2000c98b..60559dffab86 100644 --- a/rtic-sync/src/arbiter.rs +++ b/rtic-sync/src/arbiter.rs @@ -381,6 +381,7 @@ pub mod i2c { } } +#[cfg(not(loom))] #[cfg(test)] mod tests { use super::*; diff --git a/rtic-sync/src/channel.rs b/rtic-sync/src/channel.rs index d3a64b6f1c2b..516948018405 100644 --- a/rtic-sync/src/channel.rs +++ b/rtic-sync/src/channel.rs @@ -1,138 +1,17 @@ //! An async aware MPSC channel that can be used on no-alloc systems. -use core::{ - cell::UnsafeCell, - future::poll_fn, - mem::MaybeUninit, - pin::Pin, - ptr, - sync::atomic::{fence, Ordering}, - task::{Poll, Waker}, -}; -#[doc(hidden)] -pub use critical_section; -use heapless::Deque; -use rtic_common::{ - dropper::OnDrop, wait_queue::DoublyLinkedList, wait_queue::Link, - waker_registration::CriticalSectionWakerRegistration as WakerRegistration, -}; - -#[cfg(feature = "defmt-03")] -use crate::defmt; - -type WaitQueueData = (Waker, SlotPtr); -type WaitQueue = DoublyLinkedList; - -/// An MPSC channel for use in no-alloc systems. `N` sets the size of the queue. -/// -/// This channel uses critical sections, however there are extremely small and all `memcpy` -/// operations of `T` are done without critical sections. -pub struct Channel { - // Here are all indexes that are not used in `slots` and ready to be allocated. - freeq: UnsafeCell>, - // Here are wakers and indexes to slots that are ready to be dequeued by the receiver. - readyq: UnsafeCell>, - // Waker for the receiver. - receiver_waker: WakerRegistration, - // Storage for N `T`s, so we don't memcpy around a lot of `T`s. - slots: [UnsafeCell>; N], - // If there is no room in the queue a `Sender`s can wait for there to be place in the queue. - wait_queue: WaitQueue, - // Keep track of the receiver. - receiver_dropped: UnsafeCell, - // Keep track of the number of senders. - num_senders: UnsafeCell, -} - -unsafe impl Send for Channel {} - -unsafe impl Sync for Channel {} - -struct UnsafeAccess<'a, const N: usize> { - freeq: &'a mut Deque, - readyq: &'a mut Deque, - receiver_dropped: &'a mut bool, - num_senders: &'a mut usize, -} - -impl Default for Channel { - fn default() -> Self { - Self::new() - } -} +#[allow(clippy::module_inception)] +mod channel; +pub use channel::Channel; -impl Channel { - const _CHECK: () = assert!(N < 256, "This queue support a maximum of 255 entries"); - - /// Create a new channel. - pub const fn new() -> Self { - Self { - freeq: UnsafeCell::new(Deque::new()), - readyq: UnsafeCell::new(Deque::new()), - receiver_waker: WakerRegistration::new(), - slots: [const { UnsafeCell::new(MaybeUninit::uninit()) }; N], - wait_queue: WaitQueue::new(), - receiver_dropped: UnsafeCell::new(false), - num_senders: UnsafeCell::new(0), - } - } - - /// Split the queue into a `Sender`/`Receiver` pair. - pub fn split(&mut self) -> (Sender<'_, T, N>, Receiver<'_, T, N>) { - // Fill free queue - for idx in 0..N as u8 { - assert!(!self.freeq.get_mut().is_full()); - - // SAFETY: This safe as the loop goes from 0 to the capacity of the underlying queue. - unsafe { - self.freeq.get_mut().push_back_unchecked(idx); - } - } - - assert!(self.freeq.get_mut().is_full()); - - // There is now 1 sender - *self.num_senders.get_mut() = 1; - - (Sender(self), Receiver(self)) - } - - fn access<'a>(&'a self, _cs: critical_section::CriticalSection) -> UnsafeAccess<'a, N> { - // SAFETY: This is safe as are in a critical section. - unsafe { - UnsafeAccess { - freeq: &mut *self.freeq.get(), - readyq: &mut *self.readyq.get(), - receiver_dropped: &mut *self.receiver_dropped.get(), - num_senders: &mut *self.num_senders.get(), - } - } - } +mod sender; +pub use sender::{Sender, TrySendError}; - /// Return free slot `slot` to the channel. - /// - /// This will do one of two things: - /// 1. If there are any waiting `send`-ers, wake the longest-waiting one and hand it `slot`. - /// 2. else, insert `slot` into `self.freeq`. - /// - /// SAFETY: `slot` must be a `u8` that is obtained by dequeueing from [`Self::readyq`]. - unsafe fn return_free_slot(&self, slot: u8) { - critical_section::with(|cs| { - fence(Ordering::SeqCst); +mod receiver; +pub use receiver::{ReceiveError, Receiver}; - // If someone is waiting in the `wait_queue`, wake the first one up & hand it the free slot. - if let Some((wait_head, mut freeq_slot)) = self.wait_queue.pop() { - // SAFETY: `freeq_slot` is valid for writes: we are in a critical - // section & the `SlotPtr` lives for at least the duration of the wait queue link. - unsafe { freeq_slot.replace(Some(slot), cs) }; - wait_head.wake(); - } else { - assert!(!self.access(cs).freeq.is_full()); - unsafe { self.access(cs).freeq.push_back_unchecked(slot) } - } - }) - } -} +#[doc(hidden)] +pub use critical_section; /// Creates a split channel with `'static` lifetime. #[macro_export] @@ -159,580 +38,3 @@ macro_rules! make_channel { } }}; } - -// -------- Sender - -/// Error state for when the receiver has been dropped. -#[cfg_attr(feature = "defmt-03", derive(defmt::Format))] -pub struct NoReceiver(pub T); - -/// Errors that 'try_send` can have. -#[cfg_attr(feature = "defmt-03", derive(defmt::Format))] -pub enum TrySendError { - /// Error state for when the receiver has been dropped. - NoReceiver(T), - /// Error state when the queue is full. - Full(T), -} - -impl core::fmt::Debug for NoReceiver -where - T: core::fmt::Debug, -{ - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "NoReceiver({:?})", self.0) - } -} - -impl core::fmt::Debug for TrySendError -where - T: core::fmt::Debug, -{ - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - TrySendError::NoReceiver(v) => write!(f, "NoReceiver({v:?})"), - TrySendError::Full(v) => write!(f, "Full({v:?})"), - } - } -} - -impl PartialEq for TrySendError -where - T: PartialEq, -{ - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (TrySendError::NoReceiver(v1), TrySendError::NoReceiver(v2)) => v1.eq(v2), - (TrySendError::NoReceiver(_), TrySendError::Full(_)) => false, - (TrySendError::Full(_), TrySendError::NoReceiver(_)) => false, - (TrySendError::Full(v1), TrySendError::Full(v2)) => v1.eq(v2), - } - } -} - -/// A `Sender` can send to the channel and can be cloned. -pub struct Sender<'a, T, const N: usize>(&'a Channel); - -unsafe impl Send for Sender<'_, T, N> {} - -/// This is needed to make the async closure in `send` accept that we "share" -/// the link possible between threads. -#[derive(Clone)] -struct LinkPtr(*mut Option>); - -impl LinkPtr { - /// This will dereference the pointer stored within and give out an `&mut`. - unsafe fn get(&mut self) -> &mut Option> { - &mut *self.0 - } -} - -unsafe impl Send for LinkPtr {} - -unsafe impl Sync for LinkPtr {} - -/// This is needed to make the async closure in `send` accept that we "share" -/// the link possible between threads. -#[derive(Clone)] -struct SlotPtr(*mut Option); - -impl SlotPtr { - /// Replace the value of this slot with `new_value`, and return - /// the old value. - /// - /// SAFETY: the pointer in this `SlotPtr` must be valid for writes. - unsafe fn replace( - &mut self, - new_value: Option, - _cs: critical_section::CriticalSection, - ) -> Option { - // SAFETY: we are in a critical section. - unsafe { core::ptr::replace(self.0, new_value) } - } -} - -unsafe impl Send for SlotPtr {} - -unsafe impl Sync for SlotPtr {} - -impl core::fmt::Debug for Sender<'_, T, N> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "Sender") - } -} - -#[cfg(feature = "defmt-03")] -impl defmt::Format for Sender<'_, T, N> { - fn format(&self, f: defmt::Formatter) { - defmt::write!(f, "Sender",) - } -} - -impl Sender<'_, T, N> { - #[inline(always)] - fn send_footer(&mut self, idx: u8, val: T) { - // Write the value to the slots, note; this memcpy is not under a critical section. - unsafe { - ptr::write( - self.0.slots.get_unchecked(idx as usize).get() as *mut T, - val, - ) - } - - // Write the value into the ready queue. - critical_section::with(|cs| { - assert!(!self.0.access(cs).readyq.is_full()); - unsafe { self.0.access(cs).readyq.push_back_unchecked(idx) } - }); - - fence(Ordering::SeqCst); - - // If there is a receiver waker, wake it. - self.0.receiver_waker.wake(); - } - - /// Try to send a value, non-blocking. If the channel is full this will return an error. - pub fn try_send(&mut self, val: T) -> Result<(), TrySendError> { - // If the wait queue is not empty, we can't try to push into the queue. - if !self.0.wait_queue.is_empty() { - return Err(TrySendError::Full(val)); - } - - // No receiver available. - if self.is_closed() { - return Err(TrySendError::NoReceiver(val)); - } - - let idx = - if let Some(idx) = critical_section::with(|cs| self.0.access(cs).freeq.pop_front()) { - idx - } else { - return Err(TrySendError::Full(val)); - }; - - self.send_footer(idx, val); - - Ok(()) - } - - /// Send a value. If there is no place left in the queue this will wait until there is. - /// If the receiver does not exist this will return an error. - pub async fn send(&mut self, val: T) -> Result<(), NoReceiver> { - let mut free_slot_ptr: Option = None; - let mut link_ptr: Option> = None; - - // Make this future `Drop`-safe. - // SAFETY(link_ptr): Shadow the original definition of `link_ptr` so we can't abuse it. - let mut link_ptr = LinkPtr(core::ptr::addr_of_mut!(link_ptr)); - // SAFETY(freed_slot): Shadow the original definition of `free_slot_ptr` so we can't abuse it. - let mut free_slot_ptr = SlotPtr(core::ptr::addr_of_mut!(free_slot_ptr)); - - let mut link_ptr2 = link_ptr.clone(); - let mut free_slot_ptr2 = free_slot_ptr.clone(); - let dropper = OnDrop::new(|| { - // SAFETY: We only run this closure and dereference the pointer if we have - // exited the `poll_fn` below in the `drop(dropper)` call. The other dereference - // of this pointer is in the `poll_fn`. - if let Some(link) = unsafe { link_ptr2.get() } { - link.remove_from_list(&self.0.wait_queue); - } - - // Return our potentially-unused free slot. - // Potentially unnecessary c-s because our link was already popped, so there - // is no way for anything else to access the free slot ptr. Gotta think - // about this a bit more... - critical_section::with(|cs| { - if let Some(freed_slot) = unsafe { free_slot_ptr2.replace(None, cs) } { - // SAFETY: freed slot is passed to us from `return_free_slot`, which either - // directly (through `try_recv`), or indirectly (through another `return_free_slot`) - // comes from `readyq`. - unsafe { self.0.return_free_slot(freed_slot) }; - } - }); - }); - - let idx = poll_fn(|cx| { - // Do all this in one critical section, else there can be race conditions - critical_section::with(|cs| { - if self.is_closed() { - return Poll::Ready(Err(())); - } - - let wq_empty = self.0.wait_queue.is_empty(); - let freeq_empty = self.0.access(cs).freeq.is_empty(); - - // SAFETY: This pointer is only dereferenced here and on drop of the future - // which happens outside this `poll_fn`'s stack frame. - let link = unsafe { link_ptr.get() }; - - // We are already in the wait queue. - if let Some(link) = link { - if link.is_popped() { - // SAFETY: `free_slot_ptr` is valid for writes until the end of this future. - let slot = unsafe { free_slot_ptr.replace(None, cs) }; - - // If our link is popped, then: - // 1. We were popped by `return_free_lot` and provided us with a slot. - // 2. We were popped by `Receiver::drop` and it did not provide us with a slot, and the channel is closed. - if let Some(slot) = slot { - Poll::Ready(Ok(slot)) - } else { - Poll::Ready(Err(())) - } - } else { - Poll::Pending - } - } - // We are not in the wait queue, but others are, or there is currently no free - // slot available. - else if !wq_empty || freeq_empty { - // Place the link in the wait queue. - let link_ref = - link.insert(Link::new((cx.waker().clone(), free_slot_ptr.clone()))); - - // SAFETY(new_unchecked): The address to the link is stable as it is defined - // outside this stack frame. - // SAFETY(push): `link_ref` lifetime comes from `link_ptr` and `free_slot_ptr` that - // are shadowed and we make sure in `dropper` that the link is removed from the queue - // before dropping `link_ptr` AND `dropper` makes sure that the shadowed - // `ptr`s live until the end of the stack frame. - unsafe { self.0.wait_queue.push(Pin::new_unchecked(link_ref)) }; - - Poll::Pending - } - // We are not in the wait queue, no one else is waiting, and there is a free slot available. - else { - assert!(!self.0.access(cs).freeq.is_empty()); - let slot = unsafe { self.0.access(cs).freeq.pop_back_unchecked() }; - Poll::Ready(Ok(slot)) - } - }) - }) - .await; - - // Make sure the link is removed from the queue. - drop(dropper); - - if let Ok(idx) = idx { - self.send_footer(idx, val); - - Ok(()) - } else { - Err(NoReceiver(val)) - } - } - - /// Returns true if there is no `Receiver`s. - pub fn is_closed(&self) -> bool { - critical_section::with(|cs| *self.0.access(cs).receiver_dropped) - } - - /// Is the queue full. - pub fn is_full(&self) -> bool { - critical_section::with(|cs| self.0.access(cs).freeq.is_empty()) - } - - /// Is the queue empty. - pub fn is_empty(&self) -> bool { - critical_section::with(|cs| self.0.access(cs).freeq.is_full()) - } -} - -impl Drop for Sender<'_, T, N> { - fn drop(&mut self) { - // Count down the reference counter - let num_senders = critical_section::with(|cs| { - *self.0.access(cs).num_senders -= 1; - - *self.0.access(cs).num_senders - }); - - // If there are no senders, wake the receiver to do error handling. - if num_senders == 0 { - self.0.receiver_waker.wake(); - } - } -} - -impl Clone for Sender<'_, T, N> { - fn clone(&self) -> Self { - // Count up the reference counter - critical_section::with(|cs| *self.0.access(cs).num_senders += 1); - - Self(self.0) - } -} - -// -------- Receiver - -/// A receiver of the channel. There can only be one receiver at any time. -pub struct Receiver<'a, T, const N: usize>(&'a Channel); - -unsafe impl Send for Receiver<'_, T, N> {} - -impl core::fmt::Debug for Receiver<'_, T, N> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "Receiver") - } -} - -#[cfg(feature = "defmt-03")] -impl defmt::Format for Receiver<'_, T, N> { - fn format(&self, f: defmt::Formatter) { - defmt::write!(f, "Receiver",) - } -} - -/// Possible receive errors. -#[cfg_attr(feature = "defmt-03", derive(defmt::Format))] -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum ReceiveError { - /// Error state for when all senders has been dropped. - NoSender, - /// Error state for when the queue is empty. - Empty, -} - -impl Receiver<'_, T, N> { - /// Receives a value if there is one in the channel, non-blocking. - pub fn try_recv(&mut self) -> Result { - // Try to get a ready slot. - let ready_slot = critical_section::with(|cs| self.0.access(cs).readyq.pop_front()); - - if let Some(rs) = ready_slot { - // Read the value from the slots, note; this memcpy is not under a critical section. - let r = unsafe { ptr::read(self.0.slots.get_unchecked(rs as usize).get() as *const T) }; - - // Return the index to the free queue after we've read the value. - // SAFETY: `rs` comes directly from `readyq`. - unsafe { self.0.return_free_slot(rs) }; - - Ok(r) - } else if self.is_closed() { - Err(ReceiveError::NoSender) - } else { - Err(ReceiveError::Empty) - } - } - - /// Receives a value, waiting if the queue is empty. - /// If all senders are dropped this will error with `NoSender`. - pub async fn recv(&mut self) -> Result { - // There was nothing in the queue, setup the waiting. - poll_fn(|cx| { - // Register waker. - // TODO: Should it happen here or after the if? This might cause a spurious wake. - self.0.receiver_waker.register(cx.waker()); - - // Try to dequeue. - match self.try_recv() { - Ok(val) => { - return Poll::Ready(Ok(val)); - } - Err(ReceiveError::NoSender) => { - return Poll::Ready(Err(ReceiveError::NoSender)); - } - _ => {} - } - - Poll::Pending - }) - .await - } - - /// Returns true if there are no `Sender`s. - pub fn is_closed(&self) -> bool { - critical_section::with(|cs| *self.0.access(cs).num_senders == 0) - } - - /// Is the queue full. - pub fn is_full(&self) -> bool { - critical_section::with(|cs| self.0.access(cs).readyq.is_full()) - } - - /// Is the queue empty. - pub fn is_empty(&self) -> bool { - critical_section::with(|cs| self.0.access(cs).readyq.is_empty()) - } -} - -impl Drop for Receiver<'_, T, N> { - fn drop(&mut self) { - // Mark the receiver as dropped and wake all waiters - critical_section::with(|cs| *self.0.access(cs).receiver_dropped = true); - - while let Some((waker, _)) = self.0.wait_queue.pop() { - waker.wake(); - } - } -} - -#[cfg(test)] -mod tests { - use cassette::Cassette; - - use super::*; - - #[test] - fn empty() { - let (mut s, mut r) = make_channel!(u32, 10); - - assert!(s.is_empty()); - assert!(r.is_empty()); - - s.try_send(1).unwrap(); - - assert!(!s.is_empty()); - assert!(!r.is_empty()); - - r.try_recv().unwrap(); - - assert!(s.is_empty()); - assert!(r.is_empty()); - } - - #[test] - fn full() { - let (mut s, mut r) = make_channel!(u32, 3); - - for _ in 0..3 { - assert!(!s.is_full()); - assert!(!r.is_full()); - - s.try_send(1).unwrap(); - } - - assert!(s.is_full()); - assert!(r.is_full()); - - for _ in 0..3 { - r.try_recv().unwrap(); - - assert!(!s.is_full()); - assert!(!r.is_full()); - } - } - - #[test] - fn send_recieve() { - let (mut s, mut r) = make_channel!(u32, 10); - - for i in 0..10 { - s.try_send(i).unwrap(); - } - - assert_eq!(s.try_send(11), Err(TrySendError::Full(11))); - - for i in 0..10 { - assert_eq!(r.try_recv().unwrap(), i); - } - - assert_eq!(r.try_recv(), Err(ReceiveError::Empty)); - } - - #[test] - fn closed_recv() { - let (s, mut r) = make_channel!(u32, 10); - - drop(s); - - assert!(r.is_closed()); - - assert_eq!(r.try_recv(), Err(ReceiveError::NoSender)); - } - - #[test] - fn closed_sender() { - let (mut s, r) = make_channel!(u32, 10); - - drop(r); - - assert!(s.is_closed()); - - assert_eq!(s.try_send(11), Err(TrySendError::NoReceiver(11))); - } - - #[tokio::test] - async fn stress_channel() { - const NUM_RUNS: usize = 1_000; - const QUEUE_SIZE: usize = 10; - - let (s, mut r) = make_channel!(u32, QUEUE_SIZE); - let mut v = std::vec::Vec::new(); - - for i in 0..NUM_RUNS { - let mut s = s.clone(); - - v.push(tokio::spawn(async move { - s.send(i as _).await.unwrap(); - })); - } - - let mut map = std::collections::BTreeSet::new(); - - for _ in 0..NUM_RUNS { - map.insert(r.recv().await.unwrap()); - } - - assert_eq!(map.len(), NUM_RUNS); - - for v in v { - v.await.unwrap(); - } - } - - fn make() { - let _ = make_channel!(u32, 10); - } - - #[test] - #[should_panic] - fn double_make_channel() { - make(); - make(); - } - - #[test] - fn tuple_channel() { - let _ = make_channel!((i32, u32), 10); - } - - fn freeq(channel: &Channel, f: F) -> R - where - F: FnOnce(&mut Deque) -> R, - { - critical_section::with(|cs| f(channel.access(cs).freeq)) - } - - #[test] - fn dropping_waked_send_returns_freeq_item() { - let (mut tx, mut rx) = make_channel!(u8, 1); - - tx.try_send(0).unwrap(); - assert!(freeq(&rx.0, |q| q.is_empty())); - - // Running this in a separate thread scope to ensure that `pinned_future` is dropped fully. - // - // Calling drop explicitly gets hairy because dropping things behind a `Pin` is not easy. - std::thread::scope(|scope| { - scope.spawn(|| { - let pinned_future = core::pin::pin!(tx.send(1)); - let mut future = Cassette::new(pinned_future); - - future.poll_on(); - - assert!(freeq(&rx.0, |q| q.is_empty())); - assert!(!rx.0.wait_queue.is_empty()); - - assert_eq!(rx.try_recv(), Ok(0)); - - assert!(freeq(&rx.0, |q| q.is_empty())); - }); - }); - - assert!(!freeq(&rx.0, |q| q.is_empty())); - - // Make sure that rx & tx are alive until here for good measure. - drop((tx, rx)); - } -} diff --git a/rtic-sync/src/channel/channel.rs b/rtic-sync/src/channel/channel.rs new file mode 100644 index 000000000000..f2b52321c345 --- /dev/null +++ b/rtic-sync/src/channel/channel.rs @@ -0,0 +1,596 @@ +use core::{ + mem::MaybeUninit, + pin::Pin, + ptr, + sync::atomic::{fence, Ordering}, + task::Waker, +}; + +use heapless::Deque; +use rtic_common::{ + wait_queue::{DoublyLinkedList, Link}, + waker_registration::CriticalSectionWakerRegistration as WakerRegistration, +}; + +use super::{Receiver, Sender}; + +use crate::unsafecell::UnsafeCell; + +pub(crate) type WaitQueueData = (Waker, FreeSlotPtr); +pub(crate) type WaitQueue = DoublyLinkedList; + +macro_rules! cs_access { + ($name:ident, $field:ident, $type:ty) => { + /// Access the value mutably. + /// + /// SAFETY: this function must not be called recursively within `f`. + unsafe fn $name(&self, _cs: critical_section::CriticalSection, f: F) -> R + where + F: FnOnce(&mut $type) -> R, + { + self.$field.with_mut(|v| { + let v = unsafe { &mut *v }; + f(v) + }) + } + }; +} + +/// A free slot. +#[derive(Debug)] +pub(crate) struct FreeSlot(u8); + +/// A pointer to a free slot. +/// +/// This struct exists to enforce lifetime/safety requirements, and to ensure +/// that [`FreeSlot`]s can only be created/updated by this module. +#[derive(Clone)] +pub(crate) struct FreeSlotPtr(*mut Option); + +impl FreeSlotPtr { + /// SAFETY: `inner` must be valid until the [`Link`] containing this [`FreeSlotPtr`] is popped. + /// Additionally, this [`FreeSlotPtr`] must have exclusive access to the data pointed to by + /// `inner`. + pub unsafe fn new(inner: *mut Option) -> Self { + Self(inner) + } + + /// Replace the value of this slot with `new_value`, and return + /// the old value. + /// + /// SAFETY: the pointer in this [`FreeSlotPtr`] must be valid for writes. + pub(crate) unsafe fn take( + &mut self, + cs: critical_section::CriticalSection, + ) -> Option { + self.replace(None, cs) + } + + /// Replace the value of this slot with `new_value`, and return + /// the old value. + /// + /// SAFETY: the pointer in this [`FreeSlotPtr`] must be valid for writes, and `new_value` must + /// be obtained from `freeq`. + unsafe fn replace( + &mut self, + new_value: Option, + _cs: critical_section::CriticalSection, + ) -> Option { + // SAFETY: we are in a critical section. + unsafe { core::ptr::replace(self.0, new_value) } + } +} + +unsafe impl Send for FreeSlotPtr {} + +unsafe impl Sync for FreeSlotPtr {} + +/// An MPSC channel for use in no-alloc systems. `N` sets the size of the queue. +/// +/// This channel uses critical sections, however there are extremely small and all `memcpy` +/// operations of `T` are done without critical sections. +pub struct Channel { + // Here are all indexes that are not used in `slots` and ready to be allocated. + freeq: UnsafeCell>, + // Here are wakers and indexes to slots that are ready to be dequeued by the receiver. + readyq: UnsafeCell>, + // Waker for the receiver. + receiver_waker: WakerRegistration, + // Storage for N `T`s, so we don't memcpy around a lot of `T`s. + slots: [UnsafeCell>; N], + // If there is no room in the queue a `Sender`s can wait for there to be place in the queue. + wait_queue: WaitQueue, + // Keep track of the receiver. + receiver_dropped: UnsafeCell, + // Keep track of the number of senders. + num_senders: UnsafeCell, +} + +unsafe impl Send for Channel {} + +unsafe impl Sync for Channel {} + +impl Default for Channel { + fn default() -> Self { + Self::new() + } +} + +impl Channel { + const _CHECK: () = assert!(N < 256, "This queue support a maximum of 255 entries"); + + /// Create a new channel. + #[cfg(not(loom))] + pub const fn new() -> Self { + Self { + freeq: UnsafeCell::new(Deque::new()), + readyq: UnsafeCell::new(Deque::new()), + receiver_waker: WakerRegistration::new(), + slots: [const { UnsafeCell::new(MaybeUninit::uninit()) }; N], + wait_queue: WaitQueue::new(), + receiver_dropped: UnsafeCell::new(false), + num_senders: UnsafeCell::new(0), + } + } + + /// Create a new channel. + #[cfg(loom)] + pub fn new() -> Self { + Self { + freeq: UnsafeCell::new(Deque::new()), + readyq: UnsafeCell::new(Deque::new()), + receiver_waker: WakerRegistration::new(), + slots: core::array::from_fn(|_| UnsafeCell::new(MaybeUninit::uninit())), + wait_queue: WaitQueue::new(), + receiver_dropped: UnsafeCell::new(false), + num_senders: UnsafeCell::new(0), + } + } + + /// Split the queue into a `Sender`/`Receiver` pair. + pub fn split(&mut self) -> (Sender<'_, T, N>, Receiver<'_, T, N>) { + // SAFETY: we have exclusive access to `self`. + let freeq = self.freeq.get_mut(); + let freeq = unsafe { freeq.deref() }; + + // Fill free queue + for idx in 0..N as u8 { + assert!(!freeq.is_full()); + + // SAFETY: This safe as the loop goes from 0 to the capacity of the underlying queue. + unsafe { + freeq.push_back_unchecked(idx); + } + } + + assert!(freeq.is_full()); + + // There is now 1 sender + // SAFETY: we have exclusive access to `self`. + unsafe { *self.num_senders.get_mut().deref() = 1 }; + + (Sender(self), Receiver(self)) + } + + cs_access!(access_freeq, freeq, Deque); + cs_access!(access_readyq, readyq, Deque); + cs_access!(access_receiver_dropped, receiver_dropped, bool); + cs_access!(access_num_senders, num_senders, usize); + + /// SAFETY: this function must not be called recursively in `f`. + pub(crate) unsafe fn freeq(&self, f: F) -> R + where + F: FnOnce(&Deque) -> R, + { + critical_section::with(|cs| self.access_freeq(cs, |v| f(&v))) + } + + /// SAFETY: this function must not be called recursively in `f`. + pub(crate) unsafe fn readyq(&self, f: F) -> R + where + F: FnOnce(&Deque) -> R, + { + critical_section::with(|cs| self.access_readyq(cs, |v| f(&v))) + } + + pub(crate) fn num_senders(&self) -> usize { + critical_section::with(|cs| unsafe { + // SAFETY: `self.access_num_senders` is not called recursively. + self.access_num_senders(cs, |v| *v) + }) + } + + pub(crate) fn receiver_dropped(&self) -> bool { + critical_section::with(|cs| unsafe { + // SAFETY: `self.receiver_dropped` is not called recursively. + self.access_receiver_dropped(cs, |v| *v) + }) + } + + /// Return free slot `slot` to the channel. + /// + /// This will do one of two things: + /// 1. If there are any waiting `send`-ers, wake the longest-waiting one and hand it `slot`. + /// 2. else, insert `slot` into the free queue. + /// + /// SAFETY: `slot` must be obtained from this exact channel instance. + pub(crate) unsafe fn return_free_slot(&self, slot: FreeSlot) { + critical_section::with(|cs| { + fence(Ordering::SeqCst); + + // If a sender is waiting in the `wait_queue`, wake the first one up & hand it the free slot. + if let Some((wait_head, mut freeq_slot)) = self.wait_queue.pop() { + // SAFETY: `freeq_slot` is valid for writes: we are in a critical + // section & the `FreeSlotPtr` lives for at least the duration of the wait queue link. + unsafe { freeq_slot.replace(Some(slot), cs) }; + wait_head.wake(); + } else { + // SAFETY: `self.freeq` is not called recursively. + unsafe { + self.access_freeq(cs, |freeq| { + assert!(!freeq.is_full()); + // SAFETY: `freeq` is not full. + freeq.push_back_unchecked(slot.0); + }); + } + } + }); + } + + /// Send a value using the given `slot` in this channel. + /// + /// SAFETY: `slot` must be obtained from this exact channel instance. + #[inline(always)] + pub(crate) unsafe fn send_value(&self, slot: FreeSlot, val: T) { + let slot = slot.0; + + // Write the value to the slots, note; this memcpy is not under a critical section. + unsafe { + let first_element = self.slots.get_unchecked(slot as usize).get_mut(); + let ptr = first_element.deref().as_mut_ptr(); + ptr::write(ptr, val) + } + + // Write the value into the ready queue. + critical_section::with(|cs| { + // SAFETY: `self.readyq` is not called recursively. + unsafe { + self.access_readyq(cs, |readyq| { + assert!(!readyq.is_full()); + // SAFETY: ready is not full. + readyq.push_back_unchecked(slot); + }); + } + }); + + fence(Ordering::SeqCst); + + // If there is a receiver waker, wake it. + self.receiver_waker.wake(); + } + + /// Pop the value of a ready slot to make it available to a receiver. + /// + /// Internally, this function does these things: + /// 1. Pop a ready slot from the ready queue. + /// 2. If available, read the data from the backing slot storage. + /// 3. If available, return the now-free slot to the free queue. + pub(crate) fn receive_value(&self) -> Option { + let ready_slot = critical_section::with(|cs| unsafe { + // SAFETY: `self.readyq` is not called recursively. + self.access_readyq(cs, |q| q.pop_front()) + }); + + if let Some(rs) = ready_slot { + let r = unsafe { + let first_element = self.slots.get_unchecked(rs as usize).get_mut(); + let ptr = first_element.deref().as_ptr(); + ptr::read(ptr) + }; + + // Return the index to the free queue after we've read the value. + // SAFETY: `rs` is now a free slot obtained from this channel. + unsafe { self.return_free_slot(FreeSlot(rs)) }; + + Some(r) + } else { + None + } + } + + /// Register a new waiter in the wait queue. + /// + /// SAFETY: `link` must be valid until it is popped. + pub(crate) unsafe fn push_wait_queue(&self, link: Pin<&Link>) { + self.wait_queue.push(link); + } + + pub(crate) fn remove_from_wait_queue(&self, link: &Link) { + link.remove_from_list(&self.wait_queue); + } + + /// Pop a free slot. + pub(crate) fn pop_free_slot(&self) -> Option { + let slot = critical_section::with(|cs| unsafe { + // SAFETY: `self.freeq` is not called recursively. + self.access_freeq(cs, |q| q.pop_front()) + }); + slot.map(FreeSlot) + } + + pub(crate) fn drop_receiver(&self) { + // Mark the receiver as dropped and wake all waiters + critical_section::with(|cs| unsafe { + // SAFTEY: `self.receiver_dropped` is not called recursively. + self.access_receiver_dropped(cs, |v| *v = true) + }); + + while let Some((waker, _)) = self.wait_queue.pop() { + waker.wake(); + } + } + + pub(crate) fn register_receiver_waker(&self, waker: &Waker) { + self.receiver_waker.register(waker); + } + + pub(crate) fn drop_sender(&self) { + // Count down the reference counter + let num_senders = critical_section::with(|cs| unsafe { + // SAFETY: `self.num_senders` is not called recursively. + self.access_num_senders(cs, |s| { + *s -= 1; + *s + }) + }); + + // If there are no senders, wake the receiver to do error handling. + if num_senders == 0 { + self.receiver_waker.wake(); + } + } + + pub(crate) fn clone_sender(&self) { + // Count up the reference counter + critical_section::with(|cs| unsafe { + // SAFETY: `self.num_senders` is not called recursively. + self.access_num_senders(cs, |v| *v += 1) + }); + } +} + +#[cfg(test)] +#[cfg(not(loom))] +mod tests { + use crate::{ + channel::{ReceiveError, TrySendError}, + make_channel, + }; + use cassette::Cassette; + use heapless::Deque; + + use super::Channel; + + #[test] + fn empty() { + let (mut s, mut r) = make_channel!(u32, 10); + + assert!(s.is_empty()); + assert!(r.is_empty()); + + s.try_send(1).unwrap(); + + assert!(!s.is_empty()); + assert!(!r.is_empty()); + + r.try_recv().unwrap(); + + assert!(s.is_empty()); + assert!(r.is_empty()); + } + + #[test] + fn full() { + let (mut s, mut r) = make_channel!(u32, 3); + + for _ in 0..3 { + assert!(!s.is_full()); + assert!(!r.is_full()); + + s.try_send(1).unwrap(); + } + + assert!(s.is_full()); + assert!(r.is_full()); + + for _ in 0..3 { + r.try_recv().unwrap(); + + assert!(!s.is_full()); + assert!(!r.is_full()); + } + } + + #[test] + fn send_recieve() { + let (mut s, mut r) = make_channel!(u32, 10); + + for i in 0..10 { + s.try_send(i).unwrap(); + } + + assert_eq!(s.try_send(11), Err(TrySendError::Full(11))); + + for i in 0..10 { + assert_eq!(r.try_recv().unwrap(), i); + } + + assert_eq!(r.try_recv(), Err(ReceiveError::Empty)); + } + + #[test] + fn closed_recv() { + let (s, mut r) = make_channel!(u32, 10); + + drop(s); + + assert!(r.is_closed()); + + assert_eq!(r.try_recv(), Err(ReceiveError::NoSender)); + } + + #[test] + fn closed_sender() { + let (mut s, r) = make_channel!(u32, 10); + + drop(r); + + assert!(s.is_closed()); + + assert_eq!(s.try_send(11), Err(TrySendError::NoReceiver(11))); + } + + fn make() { + let _ = make_channel!(u32, 10); + } + + #[test] + #[should_panic] + fn double_make_channel() { + make(); + make(); + } + + #[test] + fn tuple_channel() { + let _ = make_channel!((i32, u32), 10); + } + + fn freeq(channel: &Channel, f: F) -> R + where + F: FnOnce(&mut Deque) -> R, + { + critical_section::with(|cs| unsafe { channel.access_freeq(cs, f) }) + } + + #[test] + fn dropping_waked_send_returns_freeq_item() { + let (mut tx, mut rx) = make_channel!(u8, 1); + + tx.try_send(0).unwrap(); + assert!(freeq(&rx.0, |q| q.is_empty())); + + // Running this in a separate thread scope to ensure that `pinned_future` is dropped fully. + // + // Calling drop explicitly gets hairy because dropping things behind a `Pin` is not easy. + std::thread::scope(|scope| { + scope.spawn(|| { + let pinned_future = core::pin::pin!(tx.send(1)); + let mut future = Cassette::new(pinned_future); + + future.poll_on(); + + assert!(freeq(&rx.0, |q| q.is_empty())); + assert!(!rx.0.wait_queue.is_empty()); + + assert_eq!(rx.try_recv(), Ok(0)); + + assert!(freeq(&rx.0, |q| q.is_empty())); + }); + }); + + assert!(!freeq(&rx.0, |q| q.is_empty())); + + // Make sure that rx & tx are alive until here for good measure. + drop((tx, rx)); + } +} + +#[cfg(not(loom))] +#[cfg(test)] +mod tokio_tests { + use crate::make_channel; + + #[tokio::test] + async fn stress_channel() { + const NUM_RUNS: usize = 1_000; + const QUEUE_SIZE: usize = 10; + + let (s, mut r) = make_channel!(u32, QUEUE_SIZE); + let mut v = std::vec::Vec::new(); + + for i in 0..NUM_RUNS { + let mut s = s.clone(); + + v.push(tokio::spawn(async move { + s.send(i as _).await.unwrap(); + })); + } + + let mut map = std::collections::BTreeSet::new(); + + for _ in 0..NUM_RUNS { + map.insert(r.recv().await.unwrap()); + } + + assert_eq!(map.len(), NUM_RUNS); + + for v in v { + v.await.unwrap(); + } + } +} + +#[cfg(test)] +#[cfg(loom)] +mod loom_test { + use cassette::Cassette; + use loom::thread; + + #[macro_export] + #[allow(missing_docs)] + macro_rules! make_loom_channel { + ($type:ty, $size:expr) => {{ + let channel: crate::channel::Channel<$type, $size> = super::Channel::new(); + let boxed = Box::new(channel); + let boxed = Box::leak(boxed); + + // SAFETY: This is safe as we hide the static mut from others to access it. + // Only this point is where the mutable access happens. + boxed.split() + }}; + } + + // This test tests the following scenarios: + // 1. Receiver is dropped while concurrent senders are waiting to send. + // 2. Concurrent senders are competing for the same free slot. + #[test] + pub fn concurrent_send_while_full_and_drop() { + loom::model(|| { + let (mut tx, mut rx) = make_loom_channel!([u8; 20], 1); + let mut cloned = tx.clone(); + + tx.try_send([1; 20]).unwrap(); + + let handle1 = thread::spawn(move || { + let future = std::pin::pin!(tx.send([1; 20])); + let mut future = Cassette::new(future); + if future.poll_on().is_none() { + future.poll_on(); + } + }); + + rx.try_recv().ok(); + + let future = std::pin::pin!(cloned.send([1; 20])); + let mut future = Cassette::new(future); + if future.poll_on().is_none() { + future.poll_on(); + } + + drop(rx); + + handle1.join().unwrap(); + }); + } +} diff --git a/rtic-sync/src/channel/receiver.rs b/rtic-sync/src/channel/receiver.rs new file mode 100644 index 000000000000..53ec671c311b --- /dev/null +++ b/rtic-sync/src/channel/receiver.rs @@ -0,0 +1,98 @@ +use core::{future::poll_fn, task::Poll}; + +use super::Channel; + +#[cfg(feature = "defmt-03")] +use crate::defmt; + +/// Possible receive errors. +#[cfg_attr(feature = "defmt-03", derive(defmt::Format))] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ReceiveError { + /// Error state for when all senders has been dropped. + NoSender, + /// Error state for when the queue is empty. + Empty, +} + +/// A receiver of the channel. There can only be one receiver at any time. +pub struct Receiver<'a, T, const N: usize>(pub(crate) &'a Channel); + +unsafe impl Send for Receiver<'_, T, N> {} + +impl core::fmt::Debug for Receiver<'_, T, N> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "Receiver") + } +} + +#[cfg(feature = "defmt-03")] +impl defmt::Format for Receiver<'_, T, N> { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "Receiver",) + } +} + +impl Receiver<'_, T, N> { + /// Receives a value if there is one in the channel, non-blocking. + pub fn try_recv(&mut self) -> Result { + // Try to get a ready slot. + let ready_slot = self.0.receive_value(); + + if let Some(value) = ready_slot { + Ok(value) + } else if self.is_closed() { + Err(ReceiveError::NoSender) + } else { + Err(ReceiveError::Empty) + } + } + + /// Receives a value, waiting if the queue is empty. + /// If all senders are dropped this will error with `NoSender`. + pub async fn recv(&mut self) -> Result { + // There was nothing in the queue, setup the waiting. + poll_fn(|cx| { + // Register waker. + // TODO: Should it happen here or after the if? This might cause a spurious wake. + self.0.register_receiver_waker(cx.waker()); + + // Try to dequeue. + match self.try_recv() { + Ok(val) => { + return Poll::Ready(Ok(val)); + } + Err(ReceiveError::NoSender) => { + return Poll::Ready(Err(ReceiveError::NoSender)); + } + _ => {} + } + + Poll::Pending + }) + .await + } + + /// Returns true if there are no `Sender`s. + pub fn is_closed(&self) -> bool { + self.0.num_senders() == 0 + } + + /// Is the queue full. + pub fn is_full(&self) -> bool { + // SAFETY: `self.0.readyq` is not called recursively. + unsafe { self.0.readyq(|q| q.is_full()) } + } + + /// Is the queue empty. + pub fn is_empty(&self) -> bool { + // SAFETY: `self.0.readyq` is not called recursively. + unsafe { self.0.readyq(|q| q.is_empty()) } + } +} + +impl Drop for Receiver<'_, T, N> { + fn drop(&mut self) { + self.0.drop_receiver(); + } +} diff --git a/rtic-sync/src/channel/sender.rs b/rtic-sync/src/channel/sender.rs new file mode 100644 index 000000000000..858628a6ed25 --- /dev/null +++ b/rtic-sync/src/channel/sender.rs @@ -0,0 +1,251 @@ +use core::{future::poll_fn, pin::Pin, task::Poll}; + +use rtic_common::{dropper::OnDrop, wait_queue::Link}; + +use super::{ + channel::{FreeSlot, FreeSlotPtr, WaitQueueData}, + Channel, +}; + +#[cfg(feature = "defmt-03")] +use crate::defmt; + +/// This is needed to make the async closure in `send` accept that we "share" +/// the link possible between threads. +#[derive(Clone)] +struct LinkPtr(*mut Option>); + +impl LinkPtr { + /// This will dereference the pointer stored within and give out an `&mut`. + unsafe fn get(&mut self) -> &mut Option> { + &mut *self.0 + } +} + +unsafe impl Send for LinkPtr {} + +unsafe impl Sync for LinkPtr {} + +/// Error state for when the receiver has been dropped. +#[cfg_attr(feature = "defmt-03", derive(defmt::Format))] +pub struct NoReceiver(pub T); + +/// Errors that 'try_send` can have. +#[cfg_attr(feature = "defmt-03", derive(defmt::Format))] +pub enum TrySendError { + /// Error state for when the receiver has been dropped. + NoReceiver(T), + /// Error state when the queue is full. + Full(T), +} + +impl core::fmt::Debug for NoReceiver +where + T: core::fmt::Debug, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "NoReceiver({:?})", self.0) + } +} + +impl core::fmt::Debug for TrySendError +where + T: core::fmt::Debug, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + TrySendError::NoReceiver(v) => write!(f, "NoReceiver({v:?})"), + TrySendError::Full(v) => write!(f, "Full({v:?})"), + } + } +} + +impl PartialEq for TrySendError +where + T: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (TrySendError::NoReceiver(v1), TrySendError::NoReceiver(v2)) => v1.eq(v2), + (TrySendError::NoReceiver(_), TrySendError::Full(_)) => false, + (TrySendError::Full(_), TrySendError::NoReceiver(_)) => false, + (TrySendError::Full(v1), TrySendError::Full(v2)) => v1.eq(v2), + } + } +} + +/// A `Sender` can send to the channel and can be cloned. +pub struct Sender<'a, T, const N: usize>(pub(crate) &'a Channel); + +unsafe impl Send for Sender<'_, T, N> {} + +impl core::fmt::Debug for Sender<'_, T, N> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "Sender") + } +} + +#[cfg(feature = "defmt-03")] +impl defmt::Format for Sender<'_, T, N> { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "Sender",) + } +} + +impl Sender<'_, T, N> { + /// Try to send a value, non-blocking. If the channel is full this will return an error. + pub fn try_send(&mut self, val: T) -> Result<(), TrySendError> { + // If the wait queue is not empty, we can't try to push into the queue. + // TODO: this no longer seems necessary: freeq items are sent directly to + // queueing `send`s. + // if !self.0.wait_queue.is_empty() { + // return Err(TrySendError::Full(val)); + // } + + // No receiver available. + if self.is_closed() { + return Err(TrySendError::NoReceiver(val)); + } + + let idx = if let Some(idx) = self.0.pop_free_slot() { + idx + } else { + return Err(TrySendError::Full(val)); + }; + + unsafe { self.0.send_value(idx, val) }; + + Ok(()) + } + + /// Send a value. If there is no place left in the queue this will wait until there is. + /// If the receiver does not exist this will return an error. + pub async fn send(&mut self, val: T) -> Result<(), NoReceiver> { + let mut free_slot_ptr: Option = None; + let mut link_ptr: Option> = None; + + // Make this future `Drop`-safe. + // SAFETY(link_ptr): Shadow the original definition of `link_ptr` so we can't abuse it. + let mut link_ptr = LinkPtr(core::ptr::addr_of_mut!(link_ptr)); + // SAFETY(new): `free_slot_ptr` is alive until at least after `link_ptr` is popped. + let mut free_slot_ptr = unsafe { FreeSlotPtr::new(core::ptr::addr_of_mut!(free_slot_ptr)) }; + + let mut link_ptr2 = link_ptr.clone(); + let mut free_slot_ptr2 = free_slot_ptr.clone(); + let dropper = OnDrop::new(|| { + // SAFETY: We only run this closure and dereference the pointer if we have + // exited the `poll_fn` below in the `drop(dropper)` call. The other dereference + // of this pointer is in the `poll_fn`. + if let Some(link) = unsafe { link_ptr2.get() } { + self.0.remove_from_wait_queue(link); + } + + // Return our potentially-unused free slot. + // Potentially unnecessary c-s because our link was already popped, so there + // is no way for anything else to access the free slot ptr. Gotta think + // about this a bit more... + critical_section::with(|cs| { + if let Some(freed_slot) = unsafe { free_slot_ptr2.take(cs) } { + // SAFETY: `freed_slot` is a free slot in our referenced channel. + unsafe { self.0.return_free_slot(freed_slot) }; + } + }); + }); + + let idx = poll_fn(|cx| { + // Do all this in one critical section, else there can be race conditions + critical_section::with(|cs| { + if self.is_closed() { + return Poll::Ready(Err(())); + } + + // SAFETY: This pointer is only dereferenced here and on drop of the future + // which happens outside this `poll_fn`'s stack frame. + let link = unsafe { link_ptr.get() }; + + // We are already in the wait queue. + if let Some(link) = link { + if link.is_popped() { + // SAFETY: `free_slot_ptr` is valid for writes until the end of this future. + let slot = unsafe { free_slot_ptr.take(cs) }; + + // If our link is popped, then: + // 1. We were popped by `return_free_lot` and provided us with a slot. + // 2. We were popped by `drop_receiver` and it did not provide us with a slot, and the channel is closed. + if let Some(slot) = slot { + Poll::Ready(Ok(slot)) + } else { + Poll::Ready(Err(())) + } + } else { + Poll::Pending + } + } + // A free slot is available. + else if let Some(free_slot) = self.0.pop_free_slot() { + Poll::Ready(Ok(free_slot)) + } + // We are not in the wait queue, and no free slot is available. + else { + // Place the link in the wait queue. + let link_ref = + link.insert(Link::new((cx.waker().clone(), free_slot_ptr.clone()))); + + // SAFETY(new_unchecked): The address to the link is stable as it is defined + // outside this stack frame. + // SAFETY(push): `link_ref` lifetime comes from `link_ptr` and `free_slot_ptr` that + // are shadowed and we make sure in `dropper` that the link is removed from the queue + // before dropping `link_ptr` AND `dropper` makes sure that the shadowed + // `ptr`s live until the end of the stack frame. + unsafe { self.0.push_wait_queue(Pin::new_unchecked(link_ref)) }; + + Poll::Pending + } + }) + }) + .await; + + // Make sure the link is removed from the queue. + drop(dropper); + + if let Ok(slot) = idx { + // SAFETY: `slot` is provided through a `SlotPtr` or comes from `pop_free_slot`. + unsafe { self.0.send_value(slot, val) }; + + Ok(()) + } else { + Err(NoReceiver(val)) + } + } + + /// Returns true if there is no `Receiver`s. + pub fn is_closed(&self) -> bool { + self.0.receiver_dropped() + } + + /// Is the queue full. + pub fn is_full(&self) -> bool { + // SAFETY: `self.0.freeq` is not called recursively. + unsafe { self.0.freeq(|q| q.is_empty()) } + } + + /// Is the queue empty. + pub fn is_empty(&self) -> bool { + // SAFETY: `self.0.freeq` is not called recursively. + unsafe { self.0.freeq(|q| q.is_full()) } + } +} + +impl Drop for Sender<'_, T, N> { + fn drop(&mut self) { + self.0.drop_sender(); + } +} + +impl Clone for Sender<'_, T, N> { + fn clone(&self) -> Self { + self.0.clone_sender(); + + Self(self.0) + } +} diff --git a/rtic-sync/src/lib.rs b/rtic-sync/src/lib.rs index f8845888ed58..c2f323f04484 100644 --- a/rtic-sync/src/lib.rs +++ b/rtic-sync/src/lib.rs @@ -1,6 +1,6 @@ //! Synchronization primitives for asynchronous contexts. -#![no_std] +#![cfg_attr(not(loom), no_std)] #![deny(missing_docs)] #[cfg(feature = "defmt-03")] @@ -11,6 +11,11 @@ pub mod channel; pub use portable_atomic; pub mod signal; +mod unsafecell; + #[cfg(test)] #[macro_use] extern crate std; + +#[cfg(loom)] +mod loom_cs; diff --git a/rtic-sync/src/loom_cs.rs b/rtic-sync/src/loom_cs.rs new file mode 100644 index 000000000000..3291f52ff9db --- /dev/null +++ b/rtic-sync/src/loom_cs.rs @@ -0,0 +1,69 @@ +//! A loom-based implementation of CriticalSection, effectively copied from the critical_section::std module. + +use core::cell::RefCell; +use core::mem::MaybeUninit; + +use loom::cell::Cell; +use loom::sync::{Mutex, MutexGuard}; + +loom::lazy_static! { + static ref GLOBAL_MUTEX: Mutex<()> = Mutex::new(()); + // This is initialized if a thread has acquired the CS, uninitialized otherwise. + static ref GLOBAL_GUARD: RefCell>> = RefCell::new(MaybeUninit::uninit()); +} + +loom::thread_local!(static IS_LOCKED: Cell = Cell::new(false)); + +struct StdCriticalSection; +critical_section::set_impl!(StdCriticalSection); + +unsafe impl critical_section::Impl for StdCriticalSection { + unsafe fn acquire() -> bool { + // Allow reentrancy by checking thread local state + IS_LOCKED.with(|l| { + if l.get() { + // CS already acquired in the current thread. + return true; + } + + // Note: it is fine to set this flag *before* acquiring the mutex because it's thread local. + // No other thread can see its value, there's no potential for races. + // This way, we hold the mutex for slightly less time. + l.set(true); + + // Not acquired in the current thread, acquire it. + let guard = match GLOBAL_MUTEX.lock() { + Ok(guard) => guard, + Err(err) => { + // Ignore poison on the global mutex in case a panic occurred + // while the mutex was held. + err.into_inner() + } + }; + GLOBAL_GUARD.borrow_mut().write(guard); + + false + }) + } + + unsafe fn release(nested_cs: bool) { + if !nested_cs { + // SAFETY: As per the acquire/release safety contract, release can only be called + // if the critical section is acquired in the current thread, + // in which case we know the GLOBAL_GUARD is initialized. + // + // We have to `assume_init_read` then drop instead of `assume_init_drop` because: + // - drop requires exclusive access (&mut) to the contents + // - mutex guard drop first unlocks the mutex, then returns. In between those, there's a brief + // moment where the mutex is unlocked but a `&mut` to the contents exists. + // - During this moment, another thread can go and use GLOBAL_GUARD, causing `&mut` aliasing. + #[allow(let_underscore_lock)] + let _ = GLOBAL_GUARD.borrow_mut().assume_init_read(); + + // Note: it is fine to clear this flag *after* releasing the mutex because it's thread local. + // No other thread can see its value, there's no potential for races. + // This way, we hold the mutex for slightly less time. + IS_LOCKED.with(|l| l.set(false)); + } + } +} diff --git a/rtic-sync/src/signal.rs b/rtic-sync/src/signal.rs index afe49bdb04f3..3ab41eff1c95 100644 --- a/rtic-sync/src/signal.rs +++ b/rtic-sync/src/signal.rs @@ -168,10 +168,10 @@ macro_rules! make_signal { } #[cfg(test)] +#[cfg(not(loom))] mod tests { - use static_cell::StaticCell; - use super::*; + use static_cell::StaticCell; #[test] fn empty() { diff --git a/rtic-sync/src/unsafecell.rs b/rtic-sync/src/unsafecell.rs new file mode 100644 index 000000000000..e1774f8fa16d --- /dev/null +++ b/rtic-sync/src/unsafecell.rs @@ -0,0 +1,43 @@ +//! Compat layer for [`core::cell::UnsafeCell`] and `loom::cell::UnsafeCell`. + +#[cfg(loom)] +pub use loom::cell::UnsafeCell; + +#[cfg(not(loom))] +pub use core::UnsafeCell; + +#[cfg(not(loom))] +mod core { + /// An [`core::cell::UnsafeCell`] wrapper that provides compatibility with + /// loom's UnsafeCell. + #[derive(Debug)] + pub struct UnsafeCell(core::cell::UnsafeCell); + + impl UnsafeCell { + /// Create a new `UnsafeCell`. + pub const fn new(data: T) -> UnsafeCell { + UnsafeCell(core::cell::UnsafeCell::new(data)) + } + + /// Access the contents of the `UnsafeCell` through a mut pointer. + pub fn get_mut(&self) -> MutPtr { + MutPtr(self.0.get()) + } + + pub unsafe fn with_mut(&self, f: F) -> R + where + F: FnOnce(*mut T) -> R, + { + f(self.0.get()) + } + } + + pub struct MutPtr(*mut T); + + impl MutPtr { + #[allow(clippy::mut_from_ref)] + pub unsafe fn deref(&self) -> &mut T { + &mut *self.0 + } + } +}