diff --git a/Cargo.lock b/Cargo.lock index b0d5d47173..824a5a51ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2672,6 +2672,7 @@ name = "guestmem" version = "0.0.0" dependencies = [ "inspect", + "minircu", "pal_event", "sparse_mmap", "thiserror 2.0.12", @@ -4207,6 +4208,19 @@ dependencies = [ name = "minimal_rt_build" version = "0.0.0" +[[package]] +name = "minircu" +version = "0.0.0" +dependencies = [ + "event-listener", + "libc", + "pal_async", + "parking_lot", + "test_with_tracing", + "tracelimit", + "windows-sys 0.59.0", +] + [[package]] name = "miniz_oxide" version = "0.8.8" @@ -7965,6 +7979,7 @@ dependencies = [ "libc", "memory_range", "mesh", + "minircu", "pal", "pal_async", "pal_uring", diff --git a/Cargo.toml b/Cargo.toml index 4af0881c48..0a9d0a4884 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -123,6 +123,7 @@ mesh_remote = { path = "support/mesh/mesh_remote" } mesh_rpc = { path = "support/mesh/mesh_rpc" } mesh_worker = { path = "support/mesh/mesh_worker" } mesh_tracing = { path = "support/mesh_tracing" } +minircu = { path = "support/minircu" } open_enum = { path = "support/open_enum" } openssl_kdf = { path = "support/openssl_kdf" } openssl_crypto_only = { path = "support/openssl_crypto_only" } diff --git a/openhcl/underhill_mem/Cargo.toml b/openhcl/underhill_mem/Cargo.toml index 7051608bd4..e84deaf5ab 100644 --- a/openhcl/underhill_mem/Cargo.toml +++ b/openhcl/underhill_mem/Cargo.toml @@ -7,7 +7,7 @@ edition.workspace = true rust-version.workspace = true [target.'cfg(target_os = "linux")'.dependencies] -guestmem.workspace = true +guestmem = { workspace = true, features = ["bitmap"] } hcl.workspace = true hv1_structs.workspace = true hvdef.workspace = true diff --git a/openhcl/underhill_mem/src/lib.rs b/openhcl/underhill_mem/src/lib.rs index c15d5cae20..8a41d8eb11 100644 --- a/openhcl/underhill_mem/src/lib.rs +++ b/openhcl/underhill_mem/src/lib.rs @@ -532,7 +532,14 @@ impl ProtectIsolatedMemory for HardwareIsolatedMemoryProtector { clear_bitmap.update_bitmap(range, false); } - // TODO SNP: flush concurrent accessors. + // There may be other threads concurrently accessing these pages. We + // cannot change the page visibility state until these threads have + // stopped those accesses. Flush the RCU domain that `guestmem` uses in + // order to flush any threads accessing the pages. After this, we are + // guaranteed no threads are accessing these pages (unless the pages are + // also locked), since no bitmap currently allows access. + guestmem::rcu().synchronize_blocking(); + if let IsolationType::Snp = self.acceptor.isolation { // We need to ensure that the guest TLB has been fully flushed since // the unaccept operation is not guaranteed to do so in hardware, diff --git a/openhcl/virt_mshv_vtl/Cargo.toml b/openhcl/virt_mshv_vtl/Cargo.toml index 2163afddb6..5512c7baa8 100644 --- a/openhcl/virt_mshv_vtl/Cargo.toml +++ b/openhcl/virt_mshv_vtl/Cargo.toml @@ -33,6 +33,7 @@ x86emu.workspace = true inspect_counters.workspace = true inspect = { workspace = true, features = ["std"] } mesh.workspace = true +minircu.workspace = true pal_async.workspace = true pal_uring.workspace = true pal.workspace = true diff --git a/openhcl/virt_mshv_vtl/src/processor/mod.rs b/openhcl/virt_mshv_vtl/src/processor/mod.rs index 3d79c73986..aa39ae294d 100644 --- a/openhcl/virt_mshv_vtl/src/processor/mod.rs +++ b/openhcl/virt_mshv_vtl/src/processor/mod.rs @@ -860,6 +860,10 @@ impl<'p, T: Backing> Processor for UhProcessor<'p, T> { .into(); } + // Quiesce RCU before running the VP to avoid having to synchronize with + // this CPU during memory protection updates. + minircu::global().quiesce(); + T::run_vp(self, dev, &mut stop).await?; self.kernel_returns += 1; } diff --git a/support/minircu/Cargo.toml b/support/minircu/Cargo.toml new file mode 100644 index 0000000000..fa4feef006 --- /dev/null +++ b/support/minircu/Cargo.toml @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +[package] +name = "minircu" +rust-version.workspace = true +edition.workspace = true + +[dependencies] +event-listener.workspace = true +parking_lot.workspace = true +tracelimit.workspace = true + +[target.'cfg(target_os = "linux")'.dependencies] +libc.workspace = true + +[target.'cfg(windows)'.dependencies] +windows-sys = { workspace = true, features = ["Win32_System_Threading"] } + +[dev-dependencies] +pal_async.workspace = true +test_with_tracing.workspace = true + +[lints] +workspace = true diff --git a/support/minircu/src/lib.rs b/support/minircu/src/lib.rs new file mode 100644 index 0000000000..8d2774e522 --- /dev/null +++ b/support/minircu/src/lib.rs @@ -0,0 +1,661 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Minimal RCU (Read-Copy-Update) implementation +//! +//! This crate provides a minimal Read-Copy-Update (RCU) synchronization +//! mechanism specifically designed for OpenVMM use cases. RCU is a +//! synchronization technique that allows multiple readers to access shared data +//! concurrently with writers by ensuring that writers create new versions of +//! data while readers continue using old versions. +//! +//! This is similar to a reader-writer lock except that readers never wait: +//! writers publish the new version of the data and then wait for all readers to +//! finish using the old version before freeing it. This allows for very low +//! overhead on the read side, as readers do not need to acquire locks. +//! +//! ## Usage +//! +//! Basic usage with the global domain: +//! +//! ```rust +//! // Execute code in a read-side critical section +//! let result = minircu::global().run(|| { +//! // Access shared data safely here. +//! 42 +//! }); +//! +//! // Wait for all current readers to finish their critical sections. +//! // This is typically called by writers after updating data. +//! minircu::global().synchronize_blocking(); +//! ``` +//! +//! ## Quiescing +//! +//! To optimize synchronization, threads can explicitly quiesce when it is not +//! expected to enter a critical section for a while. The RCU domain can skip +//! issuing a memory barrier when all threads are quiesced. +//! +//! ```rust +//! use minircu::global; +//! +//! // Mark the current thread as quiesced. +//! global().quiesce(); +//! ``` +//! +//! ## Asynchronous Support +//! +//! The crate provides async-compatible methods for quiescing and +//! synchronization: +//! +//! ```rust +//! use minircu::global; +//! +//! async fn example() { +//! // Quiesce whenever future returns Poll::Pending +//! global().quiesce_on_pending(async { +//! loop { +//! // Async code here. +//! global().run(|| { +//! // Access shared data safely here. +//! }); +//! } +//! }).await; +//! +//! // Asynchronous synchronization +//! global().synchronize(|duration| async move { +//! // This should be a sleep call, e.g. using tokio::time::sleep. +//! std::future::pending().await +//! }).await; +//! } +//! ``` +//! +//! ## Gotchas +//! +//! * Avoid blocking or long-running operations in critical sections as they can +//! delay writers or cause deadlocks. +//! * Never call [`synchronize`](RcuDomain::synchronize) or +//! [`synchronize_blocking`](RcuDomain::synchronize_blocking) from within a critical +//! section (will panic). +//! * For best performance, ensure all threads in your process call `quiesce` +//! when a thread is going to sleep or block. +//! +//! ## Implementation Notes +//! +//! On Windows and Linux, the read-side critical section avoids any processor +//! memory barriers. It achieves this by having the write side broadcast a +//! memory barrier to all threads in the process when needed for +//! synchronization, via the `membarrier` syscall on Linux and +//! `FlushProcessWriteBuffers` on Windows. +//! +//! On other platforms, which do not support this functionality, the read-side +//! critical section uses a memory fence. This makes the read side more +//! expensive on these platforms, but it is still cheaper than a mutex or +//! reader-writer lock. + +// UNSAFETY: needed to access TLS from a remote thread and to call platform APIs +// for issuing process-wide memory barriers. +#![expect(unsafe_code)] + +/// Provides the environment-specific `membarrier` and `access_fence` +/// implementations. +#[cfg_attr(target_os = "linux", path = "linux.rs")] +#[cfg_attr(windows, path = "windows.rs")] +#[cfg_attr(not(any(windows, target_os = "linux")), path = "other.rs")] +mod sys; + +use event_listener::Event; +use event_listener::Listener; +use parking_lot::Mutex; +use std::cell::Cell; +use std::future::Future; +use std::future::poll_fn; +use std::ops::Deref; +use std::pin::pin; +use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering::Acquire; +use std::sync::atomic::Ordering::Relaxed; +use std::sync::atomic::Ordering::Release; +use std::sync::atomic::Ordering::SeqCst; +use std::sync::atomic::fence; +use std::task::Poll; +use std::thread::LocalKey; +use std::thread::Thread; +use std::time::Duration; +use std::time::Instant; + +/// Defines a new RCU domain, which can be synchronized with separately from +/// other domains. +/// +/// Usually you just want to use [`global`], the global domain. +/// +/// Don't export this until we have a use case. We may want to make `quiesce` +/// apply to all domains, or something like that. +macro_rules! define_rcu_domain { + ($(#[$a:meta])* $vis:vis $name:ident) => { + $(#[$a])* + $vis const fn $name() -> $crate::RcuDomain { + static DATA: $crate::RcuData = $crate::RcuData::new(); + thread_local! { + static TLS: $crate::ThreadData = const { $crate::ThreadData::new() }; + } + $crate::RcuDomain::new(&TLS, &DATA) + } + }; +} + +define_rcu_domain! { + /// The global RCU domain. + pub global +} + +/// An RCU synchronization domain. +#[derive(Copy, Clone)] +pub struct RcuDomain { + tls: &'static LocalKey, + data: &'static RcuData, +} + +impl std::fmt::Debug for RcuDomain { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Self { tls: _, data } = self; + f.debug_struct("RcuDomain").field("data", data).finish() + } +} + +/// Domain-global RCU state. +#[doc(hidden)] +#[derive(Debug)] +pub struct RcuData { + /// The threads that have registered with this domain. + threads: Mutex>, + /// The current sequence number. + seq: AtomicU64, + /// The event that is signaled when a thread exits a critical section and + /// there has been a sequence number update. + event: Event, + /// The number of membarriers issued. + membarriers: AtomicU64, +} + +/// The entry in the thread list for a registered thread. +#[derive(Debug)] +struct ThreadEntry { + /// The pointer to the sequence number for this thread. The [`ThreadData`] + /// TLS destructor will remove this entry, so this is safe to dereference. + seq_ptr: TlsRef, + /// The last sequence number that a synchronizer can know this thread has + /// observed, without issuing membarriers or looking at the thread's TLS + /// data. + observed_seq: u64, + /// The thread that this entry is for. Used for debugging and tracing. + thread: Thread, +} + +/// A pointer representing a valid reference to a value. +struct TlsRef(*const T); + +impl Deref for TlsRef { + type Target = T; + + fn deref(&self) -> &Self::Target { + // SAFETY: This is known to point to valid TLS data for its lifetime, since the TLS + // drop implementation will remove this entry from the list. + unsafe { &*self.0 } + } +} + +// SAFETY: Since this represents a reference to T, it is `Send` if `&T` is +// `Send`. +unsafe impl Send for TlsRef where for<'a> &'a T: Send {} +// SAFETY: Since this represents a reference to T, it is `Sync` if `&T` is +// `Sync`. +unsafe impl Sync for TlsRef where for<'a> &'a T: Sync {} + +impl std::fmt::Debug for TlsRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + (**self).fmt(f) + } +} + +impl RcuData { + /// Used by [`define_rcu_domain!`] to create a new RCU domain. + #[doc(hidden)] + pub const fn new() -> Self { + RcuData { + threads: Mutex::new(Vec::new()), + seq: AtomicU64::new(SEQ_FIRST), + event: Event::new(), + membarriers: AtomicU64::new(0), + } + } +} + +/// The per-thread TLS data. +#[doc(hidden)] +pub struct ThreadData { + /// The current sequence number for the thread. + current_seq: AtomicU64, + /// The RCU domain this thread is registered with. + data: Cell>, +} + +impl std::fmt::Debug for ThreadData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Self { + current_seq: my_seq, + data: _, + } = self; + f.debug_struct("ThreadData") + .field("my_seq", my_seq) + .finish() + } +} + +impl Drop for ThreadData { + fn drop(&mut self) { + if let Some(data) = self.data.get() { + { + let mut threads = data.threads.lock(); + let i = threads + .iter() + .position(|x| x.seq_ptr.0 == &self.current_seq) + .unwrap(); + threads.swap_remove(i); + } + data.event.notify(!0usize); + } + } +} + +impl ThreadData { + /// Used by [`define_rcu_domain!`] to create a new RCU domain. + #[doc(hidden)] + pub const fn new() -> Self { + ThreadData { + current_seq: AtomicU64::new(SEQ_NONE), + data: Cell::new(None), + } + } +} + +/// The thread has not yet registered with the RCU domain. +const SEQ_NONE: u64 = 0; +/// The bit set when the thread in a critical section. +const SEQ_MASK_BUSY: u64 = 1; +/// The value the sequence number is incremented by each synchronize call. +const SEQ_INCREMENT: u64 = 2; +/// The sequence value for a quiesced thread. The thread will issue a full +/// memory barrier when leaving this state. +const SEQ_QUIESCED: u64 = 2; +/// The first actual sequence number. +const SEQ_FIRST: u64 = 4; + +impl RcuDomain { + #[doc(hidden)] + pub const fn new(tls: &'static LocalKey, data: &'static RcuData) -> Self { + RcuDomain { tls, data } + } + + /// Runs `f` in a critical section. Calls to + /// [`synchronize`](Self::synchronize) or + /// [`synchronize_blocking`](Self::synchronize_blocking) for the same RCU root will + /// block until `f` returns. + /// + /// In general, you should avoid blocking the thread in `f`, since that can + /// slow calls to [`synchronize`](Self::synchronize) and can potentially + /// cause deadlocks. + pub fn run(self, f: F) -> R + where + F: FnOnce() -> R, + { + self.tls.with(|x| x.run(self.data, f)) + } + + /// Quiesce the current thread. + /// + /// This can speed up calls to [`synchronize`](Self::synchronize) or + /// [`synchronize_blocking`](Self::synchronize_blocking) by allowing the RCU domain + /// to skip issuing a membarrier if all threads are quiesced. In return, the + /// first call to [`run`](Self::run) after this will be slower, as it will + /// need to issue a memory barrier to leave the quiesced state. + pub fn quiesce(self) { + self.tls.with(|x| { + x.quiesce(self.data); + }); + } + + /// Runs `fut`, calling [`quiesce`](Self::quiesce) on the current thread + /// each time `fut` returns `Poll::Pending`. + pub async fn quiesce_on_pending(self, fut: Fut) -> Fut::Output + where + Fut: Future, + { + let mut fut = pin!(fut); + poll_fn(|cx| { + self.tls.with(|x| { + let r = fut.as_mut().poll(cx); + x.quiesce(self.data); + r + }) + }) + .await + } + + #[track_caller] + fn prepare_to_wait(&self) -> Option { + // Quiesce this thread so we don't wait on ourselves. + { + let this_seq = self.tls.with(|x| x.quiesce(self.data)); + assert!( + this_seq == SEQ_NONE || this_seq == SEQ_QUIESCED, + "called synchronize() inside a critical section, {this_seq:#x}", + ); + } + // Update the domain's sequence number. + let seq = self.data.seq.fetch_add(SEQ_INCREMENT, SeqCst) + SEQ_INCREMENT; + // We need to make sure all threads are quiesced, not busy, or have + // observed the new sequence number. To do this, we must synchronize the + // global sequence number update with changes to each thread's local + // sequence number. To do that, we will issue a membarrier, to broadcast + // a memory barrier to all threads in the process. + // + // First, try to avoid the membarrier if possible--if all threads are quiesced, + // then there is no need to issue a membarrier, because quiesced threads will issue + // a memory barrier when they leave the quiesced state. + if self + .data + .threads + .lock() + .iter_mut() + .all(|t| Self::is_thread_ready(t, seq, false)) + { + return None; + } + // Keep a count for diagnostics purposes. + self.data.membarriers.fetch_add(1, Relaxed); + sys::membarrier(); + Some(seq) + } + + /// Synchronizes the RCU domain, blocking asynchronously until all threads + /// have exited their critical sections and observed the new sequence + /// number. + /// + /// `sleep` should be a function that sleeps for the specified duration. + pub async fn synchronize(self, mut sleep: impl AsyncFnMut(Duration)) { + let Some(seq) = self.prepare_to_wait() else { + return; + }; + let mut wait = pin!(self.wait_threads_ready(seq)); + let mut timeout = Duration::from_millis(100); + loop { + let mut sleep = pin!(sleep(timeout)); + let ready = poll_fn(|cx| { + if let Poll::Ready(()) = wait.as_mut().poll(cx) { + Poll::Ready(true) + } else if let Poll::Ready(()) = sleep.as_mut().poll(cx) { + Poll::Ready(false) + } else { + Poll::Pending + } + }) + .await; + if ready { + break; + } + self.warn_stall(seq); + if timeout < Duration::from_secs(10) { + timeout *= 2; + } + } + } + + /// Like [`synchronize`](Self::synchronize), but blocks the current thread + /// synchronously. + #[track_caller] + pub fn synchronize_blocking(self) { + let Some(seq) = self.prepare_to_wait() else { + return; + }; + let mut timeout = Duration::from_millis(10); + while !self.wait_threads_ready_sync(seq, Instant::now() + timeout) { + self.warn_stall(seq); + if timeout < Duration::from_secs(10) { + timeout *= 2; + } + } + } + + fn warn_stall(&self, target: u64) { + for thread in &mut *self.data.threads.lock() { + if !Self::is_thread_ready(thread, target, true) { + tracelimit::warn_ratelimited!(thread = thread.thread.name(), "rcu stall"); + } + } + } + + async fn wait_threads_ready(&self, target: u64) { + loop { + let event = self.data.event.listen(); + if self.all_threads_ready(target, true) { + break; + } + event.await; + } + } + + #[must_use] + fn wait_threads_ready_sync(&self, target: u64, deadline: Instant) -> bool { + loop { + let event = self.data.event.listen(); + if self.all_threads_ready(target, true) { + break; + } + if event.wait_deadline(deadline).is_none() { + return false; + } + } + true + } + + fn all_threads_ready(&self, target: u64, issued_barrier: bool) -> bool { + self.data + .threads + .lock() + .iter_mut() + .all(|thread| Self::is_thread_ready(thread, target, issued_barrier)) + } + + fn is_thread_ready(thread: &mut ThreadEntry, target: u64, issued_barrier: bool) -> bool { + if thread.observed_seq >= target { + return true; + } + let seq = thread.seq_ptr.load(Relaxed); + assert_ne!(seq, SEQ_NONE); + if seq & !SEQ_MASK_BUSY < target { + if seq & SEQ_MASK_BUSY != 0 { + // The thread is actively running in a critical section. + return false; + } + if seq != SEQ_QUIESCED { + // The thread is not quiesced. If a barrier was issued, then it + // has observed the new sequence number. It may be busy (but + // this CPU has not observed the write yet), but it must be busy + // with a newer sequence number. + // + // If a barrier was not issued, then it is possible that the + // thread is busy with an older sequence number. In this case, + // we will need to issue a membarrier to observe the value of + // the busy bit accurately. + assert!(seq >= SEQ_FIRST, "{seq}"); + if !issued_barrier { + return false; + } + } + } + thread.observed_seq = target; + true + } +} + +impl ThreadData { + fn run(&self, data: &'static RcuData, f: F) -> R + where + F: FnOnce() -> R, + { + // Mark the thread as busy. + let seq = self.current_seq.load(Relaxed); + self.current_seq.store(seq | SEQ_MASK_BUSY, Relaxed); + if seq < SEQ_FIRST { + // The thread was quiesced or not registered. Register it now. + if seq == SEQ_NONE { + self.start(data, seq); + } else { + debug_assert!(seq == SEQ_QUIESCED || seq & SEQ_MASK_BUSY != 0, "{seq:#x}"); + } + // Use a full memory barrier to ensure the write side observes that + // the thread is no longer quiesced before calling `f`. + fence(SeqCst); + } + // Ensure accesses in `f` are bounded by setting the busy bit. Note that + // this and other fences are just compiler fences; the write side must + // call `membarrier` to dynamically turn them into processor memory + // barriers, so to speak. + sys::access_fence(Acquire); + let r = f(); + sys::access_fence(Release); + // Clear the busy bit. + self.current_seq.store(seq, Relaxed); + // Ensure the busy bit clear is visible to the write side, then read the + // new sequence number, to synchronize with the sequence update path. + sys::access_fence(SeqCst); + let new_seq = data.seq.load(Relaxed); + if new_seq != seq { + // The domain's current sequence number has changed. Update it and + // wake up any waiters. + self.update_seq(data, seq, new_seq); + } + r + } + + #[inline(never)] + fn start(&self, data: &'static RcuData, seq: u64) { + if seq == SEQ_NONE { + // Add the thread to the list of known threads in this domain. + assert!(self.data.get().is_none()); + data.threads.lock().push(ThreadEntry { + seq_ptr: TlsRef(&self.current_seq), + observed_seq: SEQ_NONE, + thread: std::thread::current(), + }); + // Remember the domain so that we can remove the thread from the list + // when it exits. + self.data.set(Some(data)); + } + } + + #[inline(never)] + fn update_seq(&self, data: &'static RcuData, seq: u64, new_seq: u64) { + if seq & SEQ_MASK_BUSY != 0 { + // Nested call. Skip. + return; + } + assert!( + new_seq >= SEQ_FIRST && new_seq & SEQ_MASK_BUSY == 0, + "{new_seq}" + ); + self.current_seq.store(new_seq, Relaxed); + // Wake up any waiters. We don't know how many threads are still in a + // critical section, so just wake up the writers every time and let them + // figure it out. + data.event.notify(!0usize); + } + + fn quiesce(&self, data: &'static RcuData) -> u64 { + let seq = self.current_seq.load(Relaxed); + if seq >= SEQ_FIRST && seq & SEQ_MASK_BUSY == 0 { + self.current_seq.store(SEQ_QUIESCED, Relaxed); + data.event.notify(!0usize); + SEQ_QUIESCED + } else { + seq + } + } +} + +#[cfg(test)] +mod tests { + use crate::RcuDomain; + use pal_async::DefaultDriver; + use pal_async::DefaultPool; + use pal_async::async_test; + use pal_async::task::Spawn; + use pal_async::timer::PolledTimer; + use std::sync::atomic::Ordering; + use test_with_tracing::test; + + async fn sync(driver: &DefaultDriver, rcu: RcuDomain) { + let mut timer = PolledTimer::new(driver); + rcu.synchronize(async |timeout| { + timer.sleep(timeout).await; + }) + .await + } + + #[async_test] + async fn test_rcu_single(driver: DefaultDriver) { + define_rcu_domain!(test_rcu); + + test_rcu().run(|| {}); + sync(&driver, test_rcu()).await; + } + + #[async_test] + async fn test_rcu_nested(driver: DefaultDriver) { + define_rcu_domain!(test_rcu); + + test_rcu().run(|| { + test_rcu().run(|| {}); + }); + sync(&driver, test_rcu()).await; + } + + #[async_test] + async fn test_rcu_multi(driver: DefaultDriver) { + define_rcu_domain!(test_rcu); + + let (thread, thread_driver) = DefaultPool::spawn_on_thread("test"); + thread_driver + .spawn("test", async { test_rcu().run(|| {}) }) + .await; + + assert_eq!(test_rcu().data.membarriers.load(Ordering::Relaxed), 0); + sync(&driver, test_rcu()).await; + assert_eq!(test_rcu().data.membarriers.load(Ordering::Relaxed), 1); + + drop(thread_driver); + thread.join().unwrap(); + } + + #[async_test] + async fn test_rcu_multi_quiesce(driver: DefaultDriver) { + define_rcu_domain!(test_rcu); + + let (thread, thread_driver) = DefaultPool::spawn_on_thread("test"); + thread_driver + .spawn( + "test", + test_rcu().quiesce_on_pending(async { test_rcu().run(|| {}) }), + ) + .await; + + assert_eq!(test_rcu().data.membarriers.load(Ordering::Relaxed), 0); + test_rcu().quiesce(); + sync(&driver, test_rcu()).await; + assert_eq!(test_rcu().data.membarriers.load(Ordering::Relaxed), 0); + + drop(thread_driver); + thread.join().unwrap(); + } +} diff --git a/support/minircu/src/linux.rs b/support/minircu/src/linux.rs new file mode 100644 index 0000000000..3119c155e9 --- /dev/null +++ b/support/minircu/src/linux.rs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use libc::SYS_membarrier; +use libc::syscall; + +// Use a compiler fence on the read side since we have a working membarrier +// implementation. +pub use std::sync::atomic::compiler_fence as access_fence; + +pub fn membarrier() { + // Use the membarrier syscall to ensure that all other threads in the + // process have observed the writes made by this thread. + // + // This could be quite expensive with lots of threads, but most of the + // threads in a VMM should be idle most of the time. However, In OpenVMM on + // a host, this could be problematic--KVM and MSHV VP threads will probably + // not be considered idle by the membarrier implementation. + // + // Luckily, in the OpenHCL environment VP threads are usually idle (to + // prevent unnecessary scheduler ticks), so this should be a non-issue. + let r = match membarrier_syscall(libc::MEMBARRIER_CMD_PRIVATE_EXPEDITED) { + Err(err) if err.raw_os_error() == Some(libc::EPERM) => { + membarrier_syscall(libc::MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED) + .expect("failed to register for membarrier use"); + membarrier_syscall(libc::MEMBARRIER_CMD_PRIVATE_EXPEDITED) + } + r => r, + }; + r.expect("failed to issue membarrier syscall"); +} + +fn membarrier_syscall(cmd: libc::c_int) -> std::io::Result<()> { + // SAFETY: no special requirements for the syscall. + let r = unsafe { syscall(SYS_membarrier, cmd, 0, 0) }; + if r < 0 { + return Err(std::io::Error::last_os_error()); + } + Ok(()) +} diff --git a/support/minircu/src/other.rs b/support/minircu/src/other.rs new file mode 100644 index 0000000000..1acc225720 --- /dev/null +++ b/support/minircu/src/other.rs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// Use a memory barrier on the read side since we don't have a working +// membarrier implementation to force a barrier remotely from the write side. +pub use std::sync::atomic::fence as access_fence; + +pub fn membarrier() { + // No suitable implementation on this platform. +} diff --git a/support/minircu/src/windows.rs b/support/minircu/src/windows.rs new file mode 100644 index 0000000000..3865b4d844 --- /dev/null +++ b/support/minircu/src/windows.rs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use windows_sys::Win32::System::Threading::FlushProcessWriteBuffers; + +// Use a compiler fence on the read side since we have a working membarrier +// implementation. +pub use std::sync::atomic::compiler_fence as access_fence; + +pub fn membarrier() { + // Use the FlushProcessWriteBuffers function to ensure that all other threads in the process + // have observed the writes made by this thread. + + // SAFETY: no special requirements for the call. + unsafe { FlushProcessWriteBuffers() } +} diff --git a/vm/vmcore/guestmem/Cargo.toml b/vm/vmcore/guestmem/Cargo.toml index eb542abcea..abc76b1c55 100644 --- a/vm/vmcore/guestmem/Cargo.toml +++ b/vm/vmcore/guestmem/Cargo.toml @@ -6,10 +6,14 @@ name = "guestmem" edition.workspace = true rust-version.workspace = true +[features] +bitmap = ["dep:minircu"] + [dependencies] inspect.workspace = true pal_event.workspace = true sparse_mmap.workspace = true +minircu = { workspace = true, optional = true } thiserror.workspace = true zerocopy.workspace = true diff --git a/vm/vmcore/guestmem/fuzz/Cargo.toml b/vm/vmcore/guestmem/fuzz/Cargo.toml index 626d620ba6..9a339b367e 100644 --- a/vm/vmcore/guestmem/fuzz/Cargo.toml +++ b/vm/vmcore/guestmem/fuzz/Cargo.toml @@ -8,7 +8,7 @@ edition.workspace = true rust-version.workspace = true [dependencies] -guestmem.workspace = true +guestmem = { workspace = true, features = ["bitmap"] } sparse_mmap.workspace = true xtask_fuzz.workspace = true diff --git a/vm/vmcore/guestmem/src/lib.rs b/vm/vmcore/guestmem/src/lib.rs index 3e317e0b9e..d71d5a69d7 100644 --- a/vm/vmcore/guestmem/src/lib.rs +++ b/vm/vmcore/guestmem/src/lib.rs @@ -13,6 +13,7 @@ use self::ranges::PagedRange; use inspect::Inspect; use pal_event::Event; use sparse_mmap::AsMappableRef; +use std::any::Any; use std::fmt::Debug; use std::io; use std::ops::Deref; @@ -21,7 +22,6 @@ use std::ops::Range; use std::ptr::NonNull; use std::sync::Arc; use std::sync::atomic::AtomicU8; -use std::sync::atomic::Ordering; use thiserror::Error; use zerocopy::FromBytes; use zerocopy::FromZeros; @@ -328,7 +328,12 @@ pub unsafe trait GuestMemoryAccess: 'static + Send + Sync { /// fails, then the associated `*_fallback` routine is called to handle the /// error. /// - /// TODO: add a synchronization scheme. + /// Bitmap checks are performed under the [`rcu()`] RCU domain, with relaxed + /// accesses. After a thread updates the bitmap to be more restrictive, it + /// must call [`minircu::global().synchronize()`] to ensure that all threads + /// see the update before taking any action that depends on the bitmap + /// update being visible. + #[cfg(feature = "bitmap")] fn access_bitmap(&self) -> Option { None } @@ -455,6 +460,110 @@ pub unsafe trait GuestMemoryAccess: 'static + Send + Sync { } } +trait DynGuestMemoryAccess: 'static + Send + Sync + Any { + fn subrange( + &self, + offset: u64, + len: u64, + allow_preemptive_locking: bool, + ) -> Result, GuestMemoryBackingError>; + + fn page_fault( + &self, + address: u64, + len: usize, + write: bool, + bitmap_failure: bool, + ) -> PageFaultAction; + + /// # Safety + /// See [`GuestMemoryAccess::read_fallback`]. + unsafe fn read_fallback( + &self, + addr: u64, + dest: *mut u8, + len: usize, + ) -> Result<(), GuestMemoryBackingError>; + + /// # Safety + /// See [`GuestMemoryAccess::write_fallback`]. + unsafe fn write_fallback( + &self, + addr: u64, + src: *const u8, + len: usize, + ) -> Result<(), GuestMemoryBackingError>; + + fn fill_fallback(&self, addr: u64, val: u8, len: usize) -> Result<(), GuestMemoryBackingError>; + + fn compare_exchange_fallback( + &self, + addr: u64, + current: &mut [u8], + new: &[u8], + ) -> Result; + + fn expose_va(&self, address: u64, len: u64) -> Result<(), GuestMemoryBackingError>; +} + +impl DynGuestMemoryAccess for T { + fn subrange( + &self, + offset: u64, + len: u64, + allow_preemptive_locking: bool, + ) -> Result, GuestMemoryBackingError> { + self.subrange(offset, len, allow_preemptive_locking) + } + + fn page_fault( + &self, + address: u64, + len: usize, + write: bool, + bitmap_failure: bool, + ) -> PageFaultAction { + self.page_fault(address, len, write, bitmap_failure) + } + + unsafe fn read_fallback( + &self, + addr: u64, + dest: *mut u8, + len: usize, + ) -> Result<(), GuestMemoryBackingError> { + // SAFETY: guaranteed by caller. + unsafe { self.read_fallback(addr, dest, len) } + } + + unsafe fn write_fallback( + &self, + addr: u64, + src: *const u8, + len: usize, + ) -> Result<(), GuestMemoryBackingError> { + // SAFETY: guaranteed by caller. + unsafe { self.write_fallback(addr, src, len) } + } + + fn fill_fallback(&self, addr: u64, val: u8, len: usize) -> Result<(), GuestMemoryBackingError> { + self.fill_fallback(addr, val, len) + } + + fn compare_exchange_fallback( + &self, + addr: u64, + current: &mut [u8], + new: &[u8], + ) -> Result { + self.compare_exchange_fallback(addr, current, new) + } + + fn expose_va(&self, address: u64, len: u64) -> Result<(), GuestMemoryBackingError> { + self.expose_va(address, len) + } +} + /// The action to take after [`GuestMemoryAccess::page_fault`] returns to /// continue the operation. pub enum PageFaultAction { @@ -467,6 +576,7 @@ pub enum PageFaultAction { } /// Returned by [`GuestMemoryAccess::access_bitmap`]. +#[cfg(feature = "bitmap")] pub struct BitmapInfo { /// A pointer to the bitmap for read access. pub read_bitmap: NonNull, @@ -491,6 +601,7 @@ unsafe impl GuestMemoryAccess for Arc { self.as_ref().max_address() } + #[cfg(feature = "bitmap")] fn access_bitmap(&self) -> Option { self.as_ref().access_bitmap() } @@ -604,6 +715,7 @@ unsafe impl GuestMemoryAccess for GuestMemoryAccessRange { self.len } + #[cfg(feature = "bitmap")] fn access_bitmap(&self) -> Option { let region = &self.base.regions[self.region]; region.bitmaps.map(|bitmaps| { @@ -749,19 +861,7 @@ impl MultiRegionGuestMemoryAccess { } // SAFETY: `mapping()` is unreachable and panics if called. -unsafe impl GuestMemoryAccess for MultiRegionGuestMemoryAccess { - fn mapping(&self) -> Option> { - unreachable!() - } - - fn max_address(&self) -> u64 { - unreachable!() - } - - fn access_bitmap(&self) -> Option { - unreachable!() - } - +impl DynGuestMemoryAccess for MultiRegionGuestMemoryAccess { fn subrange( &self, offset: u64, @@ -814,8 +914,19 @@ unsafe impl GuestMemoryAccess for MultiRegionGuestMemoryAc region.expose_va(offset_in_region, len) } - fn base_iova(&self) -> Option { - unreachable!() + fn page_fault( + &self, + address: u64, + len: usize, + write: bool, + bitmap_failure: bool, + ) -> PageFaultAction { + match self.region(address, len as u64) { + Ok((region, offset_in_region)) => { + region.page_fault(offset_in_region, len, write, bitmap_failure) + } + Err(err) => PageFaultAction::Fail(err.err), + } } } @@ -830,7 +941,7 @@ pub struct GuestMemory { inner: Arc, } -struct GuestMemoryInner { +struct GuestMemoryInner { region_def: RegionDefinition, regions: Vec, debug_name: Arc, @@ -850,7 +961,9 @@ impl Debug for GuestMemoryInner { #[derive(Debug, Copy, Clone, Default)] struct MemoryRegion { mapping: Option, + #[cfg(feature = "bitmap")] bitmaps: Option<[SendPtrU8; 3]>, + #[cfg(feature = "bitmap")] bitmap_start: u8, len: u64, base_iova: Option, @@ -886,18 +999,24 @@ unsafe impl Sync for SendPtrU8 {} impl MemoryRegion { fn new(imp: &impl GuestMemoryAccess) -> Self { - let bitmap_info = imp.access_bitmap(); - let bitmaps = bitmap_info.as_ref().map(|bm| { - [ - SendPtrU8(bm.read_bitmap), - SendPtrU8(bm.write_bitmap), - SendPtrU8(bm.execute_bitmap), - ] - }); - let bitmap_start = bitmap_info.map_or(0, |bi| bi.bit_offset); + #[cfg(feature = "bitmap")] + let (bitmaps, bitmap_start) = { + let bitmap_info = imp.access_bitmap(); + let bitmaps = bitmap_info.as_ref().map(|bm| { + [ + SendPtrU8(bm.read_bitmap), + SendPtrU8(bm.write_bitmap), + SendPtrU8(bm.execute_bitmap), + ] + }); + let bitmap_start = bitmap_info.map_or(0, |bi| bi.bit_offset); + (bitmaps, bitmap_start) + }; Self { mapping: imp.mapping().map(SendPtrU8), + #[cfg(feature = "bitmap")] bitmaps, + #[cfg(feature = "bitmap")] bitmap_start, len: imp.max_address(), base_iova: imp.base_iova(), @@ -916,6 +1035,10 @@ impl MemoryRegion { len: u64, ) -> Result<(), u64> { debug_assert!(self.len >= offset + len); + #[cfg(not(feature = "bitmap"))] + let _ = access_type; + + #[cfg(feature = "bitmap")] if let Some(bitmaps) = &self.bitmaps { let SendPtrU8(bitmap) = bitmaps[access_type as usize]; let start = offset / PAGE_SIZE64; @@ -932,7 +1055,7 @@ impl MemoryRegion { .cast_const() .cast::() .add(bit_offset as usize / 8)) - .load(Ordering::Relaxed) + .load(std::sync::atomic::Ordering::Relaxed) & (1 << (bit_offset % 8)) }; if bit == 0 { @@ -982,6 +1105,19 @@ pub enum MultiRegionError { BackingTooLarge { backing_size: u64, region_size: u64 }, } +/// The RCU domain memory accesses occur under. Updates to any memory access +/// bitmaps must be synchronized under this domain. +/// +/// See [`GuestMemoryAccess::access_bitmap`] for more details. +/// +/// This is currently the global domain, but this is reexported here to make +/// calling code clearer. +#[cfg(feature = "bitmap")] +pub fn rcu() -> minircu::RcuDomain { + // Use the global domain unless we find a reason to do something else. + minircu::global() +} + impl GuestMemory { /// Returns a new instance using `imp` as the backing. /// @@ -1216,6 +1352,7 @@ impl GuestMemory { /// mapped. pub fn full_mapping(&self) -> Option<(*mut u8, usize)> { if let [region] = self.inner.regions.as_slice() { + #[cfg(feature = "bitmap")] if region.bitmaps.is_some() { return None; } @@ -1288,33 +1425,42 @@ impl GuestMemory { mut f: impl FnMut(&mut P, *mut u8) -> Result, fallback: impl FnOnce(&mut P) -> Result, ) -> Result { - let Some(mapping) = self.mapping_range(access_type, gpa, len)? else { - return fallback(&mut param); - }; + let op = || { + let Some(mapping) = self.mapping_range(access_type, gpa, len)? else { + return fallback(&mut param); + }; - // Try until the fault fails to resolve. - loop { - match f(&mut param, mapping) { - Ok(t) => return Ok(t), - Err(fault) => { - match self.inner.imp.page_fault( - gpa + fault.offset() as u64, - len - fault.offset(), - access_type == AccessType::Write, - false, - ) { - PageFaultAction::Fail(err) => { - return Err(GuestMemoryBackingError::new( - gpa + fault.offset() as u64, - err, - )); + // Try until the fault fails to resolve. + loop { + match f(&mut param, mapping) { + Ok(t) => return Ok(t), + Err(fault) => { + match self.inner.imp.page_fault( + gpa + fault.offset() as u64, + len - fault.offset(), + access_type == AccessType::Write, + false, + ) { + PageFaultAction::Fail(err) => { + return Err(GuestMemoryBackingError::new( + gpa + fault.offset() as u64, + err, + )); + } + PageFaultAction::Retry => {} + PageFaultAction::Fallback => return fallback(&mut param), } - PageFaultAction::Retry => {} - PageFaultAction::Fallback => return fallback(&mut param), } } } - } + }; + // If the `bitmap` feature is enabled, run the function in an RCU + // critical section. This will allow callers to flush concurrent + // accesses after bitmap updates. + #[cfg(feature = "bitmap")] + return rcu().run(op); + #[cfg(not(feature = "bitmap"))] + op() } /// # Safety @@ -2219,7 +2365,6 @@ pub trait UnmapRom: Send + Sync { #[cfg(test)] #[expect(clippy::undocumented_unsafe_blocks)] mod tests { - use crate::BitmapInfo; use crate::GuestMemory; use crate::PAGE_SIZE64; use crate::PageFaultAction; @@ -2234,6 +2379,7 @@ mod tests { /// when attempting to access them. pub struct GuestMemoryMapping { mapping: SparseMapping, + #[cfg(feature = "bitmap")] bitmap: Option>, } @@ -2246,8 +2392,9 @@ mod tests { self.mapping.len() as u64 } - fn access_bitmap(&self) -> Option { - self.bitmap.as_ref().map(|bm| BitmapInfo { + #[cfg(feature = "bitmap")] + fn access_bitmap(&self) -> Option { + self.bitmap.as_ref().map(|bm| crate::BitmapInfo { read_bitmap: NonNull::new(bm.as_ptr().cast_mut()).unwrap(), write_bitmap: NonNull::new(bm.as_ptr().cast_mut()).unwrap(), execute_bitmap: NonNull::new(bm.as_ptr().cast_mut()).unwrap(), @@ -2275,6 +2422,7 @@ mod tests { GuestMemoryMapping { mapping, + #[cfg(feature = "bitmap")] bitmap: None, } } @@ -2329,6 +2477,7 @@ mod tests { mapping.alloc(0, len).unwrap(); let mapping = Arc::new(GuestMemoryMapping { mapping, + #[cfg(feature = "bitmap")] bitmap: None, }); let region_len = 1 << 30; @@ -2349,6 +2498,7 @@ mod tests { gm.read_at(3 * region_len, &mut b).unwrap_err(); } + #[cfg(feature = "bitmap")] #[test] fn test_bitmap() { let len = PAGE_SIZE * 4;