diff --git a/.github/workflows/dep_rust.yml b/.github/workflows/dep_rust.yml index 07c9f3615..db765021c 100644 --- a/.github/workflows/dep_rust.yml +++ b/.github/workflows/dep_rust.yml @@ -78,7 +78,7 @@ jobs: if: ${{ inputs.docs_only == 'false' }} timeout-minutes: 60 strategy: - fail-fast: true + fail-fast: false matrix: hypervisor: [hyperv, 'hyperv-ws2025', mshv, mshv3, kvm] # hyperv is windows, mshv and kvm are linux cpu: [amd, intel] diff --git a/Cargo.lock b/Cargo.lock index 449f9b17e..8e05239b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1090,6 +1090,7 @@ dependencies = [ "kvm-ioctls", "lazy_static", "libc", + "lockfree", "log", "metrics", "metrics-exporter-prometheus", @@ -1564,6 +1565,15 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "lockfree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74ee94b5ad113c7cb98c5a040f783d0952ee4fe100993881d1673c2cb002dd23" +dependencies = [ + "owned-alloc", +] + [[package]] name = "log" version = "0.4.27" @@ -1921,6 +1931,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "owned-alloc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30fceb411f9a12ff9222c5f824026be368ff15dc2f13468d850c7d3f502205d6" + [[package]] name = "page_size" version = "0.6.0" diff --git a/docs/signal-handlers-development-notes.md b/docs/signal-handlers-development-notes.md index fca9d31a9..e5c48b57b 100644 --- a/docs/signal-handlers-development-notes.md +++ b/docs/signal-handlers-development-notes.md @@ -1,11 +1,12 @@ # Signal Handling in Hyperlight -Hyperlight registers custom signal handlers to intercept and manage specific signals, primarily `SIGSYS` and `SIGRTMIN`. Here's an overview of the registration process: -- **Preserving Old Handlers**: When registering a new signal handler, Hyperlight first retrieves and stores the existing handler using `OnceCell`. This allows Hyperlight to delegate signals to the original handler if necessary. -- **Custom Handlers**: - - **`SIGSYS` Handler**: Captures disallowed syscalls enforced by seccomp. If the signal originates from a hyperlight thread, Hyperlight logs the syscall details. Otherwise, it delegates the signal to the previously registered handler. - - **`SIGRTMIN` Handler**: Utilized for inter-thread signaling, such as execution cancellation. Similar to SIGSYS, it distinguishes between application and non-hyperlight threads to determine how to handle the signal. -- **Thread Differentiation**: Hyperlight uses thread-local storage (IS_HYPERLIGHT_THREAD) to identify whether the current thread is a hyperlight thread. This distinction ensures that signals are handled appropriately based on the thread's role. +Hyperlight registers custom signal handlers to intercept and manage specific signals, primarily `SIGSYS` , `SIGRTMIN` and `SIGSEGV` Here's an overview of the registration process: + +- **Preserving Old Handlers**: When registering a new signal handler, Hyperlight first retrieves and stores the existing handler using either `OnceCell` or a `static AtomicPtr` This allows Hyperlight to delegate signals to the original handler if necessary. +- **Custom Handlers**: +- **`SIGSYS` Handler**: Captures disallowed syscalls enforced by seccomp. If the signal originates from a hyperlight thread, Hyperlight logs the syscall details. Otherwise, it delegates the signal to the previously registered handler. +- **`SIGRTMIN` Handler**: Utilized for inter-thread signaling, such as execution cancellation. Similar to SIGSYS, it distinguishes between application and non-hyperlight threads to determine how to handle the signal. +- **`SIGSEGV` Handler**: Handles segmentation faults for dirty page tracking of host memory mapped into a VM. If the signal applies to an address that is mapped to a VM, it is processed by Hyperlight; otherwise, it is passed to the original handler. ## Potential Issues and Considerations @@ -15,3 +16,14 @@ Hyperlight registers custom signal handlers to intercept and manage specific sig - **Invalidation of `old_handler`**: The stored old_handler reference may no longer point to a valid handler, causing undefined behavior when Hyperlight attempts to delegate signals. - **Loss of Custom Handling**: Hyperlight's custom handler might not be invoked as expected, disrupting its ability to enforce syscall restrictions or manage inter-thread signals. +### Debugging and Signal Handling + +By default when debugging a host application/test/example with GDB or LLDB the debugger will handle the `SIGSEGV` signal by breaking when it is raised, to prevent this and let hyperlight handle the signal enter the following in the debug console: + +#### LLDB + +```process handle SIGSEGV -n true -p true -s false``` + +#### GDB + +```handle SIGSEGV nostop noprint pass``` diff --git a/src/hyperlight_common/src/mem.rs b/src/hyperlight_common/src/mem.rs index 4e5448ae9..4dffc6c4a 100644 --- a/src/hyperlight_common/src/mem.rs +++ b/src/hyperlight_common/src/mem.rs @@ -17,6 +17,8 @@ limitations under the License. pub const PAGE_SHIFT: u64 = 12; pub const PAGE_SIZE: u64 = 1 << 12; pub const PAGE_SIZE_USIZE: usize = 1 << 12; +// The number of pages in 1 "block". A single u64 can be used as bitmap to keep track of all dirty pages in a block. +pub const PAGES_IN_BLOCK: usize = 64; /// A memory region in the guest address space #[derive(Debug, Clone, Copy)] diff --git a/src/hyperlight_host/Cargo.toml b/src/hyperlight_host/Cargo.toml index bc744c05c..ebf395d43 100644 --- a/src/hyperlight_host/Cargo.toml +++ b/src/hyperlight_host/Cargo.toml @@ -44,6 +44,7 @@ anyhow = "1.0" metrics = "0.24.2" serde_json = "1.0" elfcore = "2.0" +lockfree ="0.5" [target.'cfg(windows)'.dependencies] windows = { version = "0.61", features = [ diff --git a/src/hyperlight_host/benches/benchmarks.rs b/src/hyperlight_host/benches/benchmarks.rs index c9160ff52..a02ad96cc 100644 --- a/src/hyperlight_host/benches/benchmarks.rs +++ b/src/hyperlight_host/benches/benchmarks.rs @@ -79,37 +79,65 @@ fn guest_call_benchmark(c: &mut Criterion) { group.finish(); } -fn guest_call_benchmark_large_param(c: &mut Criterion) { +fn guest_call_benchmark_large_params(c: &mut Criterion) { let mut group = c.benchmark_group("guest_functions_with_large_parameters"); #[cfg(target_os = "windows")] group.sample_size(10); // This benchmark is very slow on Windows, so we reduce the sample size to avoid long test runs. - // This benchmark includes time to first clone a vector and string, so it is not a "pure' benchmark of the guest call, but it's still useful - group.bench_function("guest_call_with_large_parameters", |b| { - const SIZE: usize = 50 * 1024 * 1024; // 50 MB - let large_vec = vec![0u8; SIZE]; - let large_string = unsafe { String::from_utf8_unchecked(large_vec.clone()) }; // Safety: indeed above vec is valid utf8 + // Helper function to create a benchmark for a specific size + let create_benchmark = |group: &mut criterion::BenchmarkGroup<_>, size_mb: usize| { + let benchmark_name = format!("guest_call_with_2_large_parameters_{}mb each", size_mb); + group.bench_function(&benchmark_name, |b| { + let size = size_mb * 1024 * 1024; // Convert MB to bytes + let large_vec = vec![0u8; size]; + let large_string = unsafe { String::from_utf8_unchecked(large_vec.clone()) }; // Safety: indeed above vec is valid utf8 - let mut config = SandboxConfiguration::default(); - config.set_input_data_size(2 * SIZE + (1024 * 1024)); // 2 * SIZE + 1 MB, to allow 1MB for the rest of the serialized function call - config.set_heap_size(SIZE as u64 * 15); + let mut config = SandboxConfiguration::default(); + config.set_input_data_size(2 * size + (1024 * 1024)); - let sandbox = UninitializedSandbox::new( - GuestBinary::FilePath(simple_guest_as_string().unwrap()), - Some(config), - ) - .unwrap(); - let mut sandbox = sandbox.evolve(Noop::default()).unwrap(); + if size < 50 * 1024 * 1024 { + config.set_heap_size(size as u64 * 16); + } else { + config.set_heap_size(size as u64 * 11); // Set to 1GB for larger sizes + } - b.iter(|| { - sandbox - .call_guest_function_by_name::<()>( - "LargeParameters", - (large_vec.clone(), large_string.clone()), - ) - .unwrap() + let sandbox = UninitializedSandbox::new( + GuestBinary::FilePath(simple_guest_as_string().unwrap()), + Some(config), + ) + .unwrap(); + let mut sandbox = sandbox.evolve(Noop::default()).unwrap(); + + b.iter_custom(|iters| { + let mut total_duration = std::time::Duration::new(0, 0); + + for _ in 0..iters { + // Clone the data (not measured) + let vec_clone = large_vec.clone(); + let string_clone = large_string.clone(); + + // Measure only the guest function call + let start = std::time::Instant::now(); + sandbox + .call_guest_function_by_name::<()>( + "LargeParameters", + (vec_clone, string_clone), + ) + .unwrap(); + total_duration += start.elapsed(); + } + + total_duration + }); }); - }); + }; + + // Create benchmarks for different sizes + create_benchmark(&mut group, 5); // 5MB + create_benchmark(&mut group, 10); // 10MB + create_benchmark(&mut group, 20); // 20MB + create_benchmark(&mut group, 40); // 40MB + create_benchmark(&mut group, 60); // 60MB group.finish(); } @@ -153,9 +181,143 @@ fn sandbox_benchmark(c: &mut Criterion) { group.finish(); } +fn sandbox_heap_size_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("sandbox_heap_sizes"); + + // Helper function to create sandbox with specific heap size + let create_sandbox_with_heap_size = |heap_size_mb: Option| { + let path = simple_guest_as_string().unwrap(); + let config = if let Some(size_mb) = heap_size_mb { + let mut config = SandboxConfiguration::default(); + config.set_heap_size(size_mb * 1024 * 1024); // Convert MB to bytes + Some(config) + } else { + None + }; + + let uninit_sandbox = + UninitializedSandbox::new(GuestBinary::FilePath(path), config).unwrap(); + uninit_sandbox.evolve(Noop::default()).unwrap() + }; + + // Benchmark sandbox creation with default heap size + group.bench_function("create_sandbox_default_heap", |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(None)); + }); + + // Benchmark sandbox creation with 50MB heap + group.bench_function("create_sandbox_50mb_heap", |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(Some(50))); + }); + + // Benchmark sandbox creation with 100MB heap + group.bench_function("create_sandbox_100mb_heap", |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(Some(100))); + }); + + // Benchmark sandbox creation with 250MB heap + group.bench_function("create_sandbox_250mb_heap", |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(Some(250))); + }); + + // Benchmark sandbox creation with 500MB heap + group.bench_function("create_sandbox_500mb_heap", |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(Some(500))); + }); + + // Benchmark sandbox creation with 995MB heap (close to the limit of 1GB for a Sandbox ) + group.bench_function("create_sandbox_995mb_heap", |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(Some(995))); + }); + + group.finish(); +} + +fn guest_call_heap_size_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("guest_call_heap_sizes"); + + // Helper function to create sandbox with specific heap size + let create_sandbox_with_heap_size = |heap_size_mb: Option| { + let path = simple_guest_as_string().unwrap(); + let config = if let Some(size_mb) = heap_size_mb { + let mut config = SandboxConfiguration::default(); + config.set_heap_size(size_mb * 1024 * 1024); // Convert MB to bytes + Some(config) + } else { + None + }; + + let uninit_sandbox = + UninitializedSandbox::new(GuestBinary::FilePath(path), config).unwrap(); + uninit_sandbox.evolve(Noop::default()).unwrap() + }; + + // Benchmark guest function call with default heap size + group.bench_function("guest_call_default_heap", |b| { + let mut sandbox = create_sandbox_with_heap_size(None); + b.iter(|| { + sandbox + .call_guest_function_by_name::("Echo", "hello\n".to_string()) + .unwrap() + }); + }); + + // Benchmark guest function call with 50MB heap + group.bench_function("guest_call_50mb_heap", |b| { + let mut sandbox = create_sandbox_with_heap_size(Some(50)); + b.iter(|| { + sandbox + .call_guest_function_by_name::("Echo", "hello\n".to_string()) + .unwrap() + }); + }); + + // Benchmark guest function call with 100MB heap + group.bench_function("guest_call_100mb_heap", |b| { + let mut sandbox = create_sandbox_with_heap_size(Some(100)); + b.iter(|| { + sandbox + .call_guest_function_by_name::("Echo", "hello\n".to_string()) + .unwrap() + }); + }); + + // Benchmark guest function call with 250MB heap + group.bench_function("guest_call_250mb_heap", |b| { + let mut sandbox = create_sandbox_with_heap_size(Some(250)); + b.iter(|| { + sandbox + .call_guest_function_by_name::("Echo", "hello\n".to_string()) + .unwrap() + }); + }); + + // Benchmark guest function call with 500MB heap + group.bench_function("guest_call_500mb_heap", |b| { + let mut sandbox = create_sandbox_with_heap_size(Some(500)); + b.iter(|| { + sandbox + .call_guest_function_by_name::("Echo", "hello\n".to_string()) + .unwrap() + }); + }); + + // Benchmark guest function call with 995MB heap + group.bench_function("guest_call_995mb_heap", |b| { + let mut sandbox = create_sandbox_with_heap_size(Some(995)); + b.iter(|| { + sandbox + .call_guest_function_by_name::("Echo", "hello\n".to_string()) + .unwrap() + }); + }); + + group.finish(); +} + criterion_group! { name = benches; config = Criterion::default(); - targets = guest_call_benchmark, sandbox_benchmark, guest_call_benchmark_large_param + targets = guest_call_benchmark, sandbox_benchmark, sandbox_heap_size_benchmark, guest_call_benchmark_large_params, guest_call_heap_size_benchmark } criterion_main!(benches); diff --git a/src/hyperlight_host/src/hypervisor/hyperv_linux.rs b/src/hyperlight_host/src/hypervisor/hyperv_linux.rs index 90e91f496..be34d71d6 100644 --- a/src/hyperlight_host/src/hypervisor/hyperv_linux.rs +++ b/src/hyperlight_host/src/hypervisor/hyperv_linux.rs @@ -29,6 +29,8 @@ use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use log::{LevelFilter, error}; +#[cfg(mshv3)] +use mshv_bindings::MSHV_GPAP_ACCESS_OP_CLEAR; #[cfg(mshv2)] use mshv_bindings::hv_message; use mshv_bindings::{ @@ -76,6 +78,9 @@ use crate::sandbox::SandboxConfiguration; use crate::sandbox::uninitialized::SandboxRuntimeConfig; use crate::{Result, log_then_return, new_error}; +#[cfg(mshv2)] +const CLEAR_DIRTY_BIT_FLAG: u64 = 0b100; + #[cfg(gdb)] mod debug { use std::sync::{Arc, Mutex}; @@ -302,6 +307,7 @@ pub(crate) struct HypervLinuxDriver { vcpu_fd: VcpuFd, entrypoint: u64, mem_regions: Vec, + n_initial_regions: usize, orig_rsp: GuestPtr, interrupt_handle: Arc, @@ -351,6 +357,7 @@ impl HypervLinuxDriver { vm_fd.initialize()?; vm_fd }; + vm_fd.enable_dirty_page_tracking()?; let mut vcpu_fd = vm_fd.create_vcpu(0)?; @@ -391,13 +398,41 @@ impl HypervLinuxDriver { (None, None) }; + let mut base_pfn = u64::MAX; + let mut total_size: usize = 0; + mem_regions.iter().try_for_each(|region| { - let mshv_region = region.to_owned().into(); + let mshv_region: mshv_user_mem_region = region.to_owned().into(); + if base_pfn == u64::MAX { + base_pfn = mshv_region.guest_pfn; + } + total_size += mshv_region.size as usize; vm_fd.map_user_memory(mshv_region) })?; Self::setup_initial_sregs(&mut vcpu_fd, pml4_ptr.absolute()?)?; + // get/clear the dirty page bitmap, mshv sets all the bit dirty at initialization + // if we dont clear them then we end up taking a complete snapsot of memory page by page which gets + // progressively slower as the sandbox size increases + // the downside of doing this here is that the call to get_dirty_log will takes longer as the number of pages increase + // but for larger sandboxes its easily cheaper than copying all the pages + + // Clear dirty bits for each memory region separately since they may not be contiguous + for region in &mem_regions { + let mshv_region: mshv_user_mem_region = region.to_owned().into(); + let region_size = region.guest_region.len(); + + #[cfg(mshv2)] + vm_fd.get_dirty_log(mshv_region.guest_pfn, region_size, CLEAR_DIRTY_BIT_FLAG)?; + #[cfg(mshv3)] + vm_fd.get_dirty_log( + mshv_region.guest_pfn, + region_size, + MSHV_GPAP_ACCESS_OP_CLEAR as u8, + )?; + } + let interrupt_handle = Arc::new(LinuxInterruptHandle { running: AtomicU64::new(0), cancel_requested: AtomicBool::new(false), @@ -428,6 +463,7 @@ impl HypervLinuxDriver { page_size: 0, vm_fd, vcpu_fd, + n_initial_regions: mem_regions.len(), mem_regions, entrypoint: entrypoint_ptr.absolute()?, orig_rsp: rsp_ptr, @@ -863,6 +899,50 @@ impl Hypervisor for HypervLinuxDriver { self.interrupt_handle.clone() } + // TODO: Implement getting additional host-mapped dirty pages. + fn get_and_clear_dirty_pages(&mut self) -> Result<(Vec, Option>>)> { + let first_mshv_region: mshv_user_mem_region = self + .mem_regions + .first() + .ok_or(new_error!( + "tried to get dirty page bitmap of 0-sized region" + ))? + .to_owned() + .into(); + + let n_contiguous = self + .mem_regions + .windows(2) + .take_while(|window| window[0].guest_region.end == window[1].guest_region.start) + .count() + + 1; // +1 because windows(2) gives us n-1 pairs for n regions + + if n_contiguous != self.n_initial_regions { + return Err(new_error!( + "get_and_clear_dirty_pages: not all regions are contiguous, expected {} but got {}", + self.n_initial_regions, + n_contiguous + )); + } + + let sandbox_total_size = self + .mem_regions + .iter() + .take(n_contiguous) + .map(|r| r.guest_region.len()) + .sum(); + + let sandbox_dirty_pages = self.vm_fd.get_dirty_log( + first_mshv_region.guest_pfn, + sandbox_total_size, + #[cfg(mshv2)] + CLEAR_DIRTY_BIT_FLAG, + #[cfg(mshv3)] + (MSHV_GPAP_ACCESS_OP_CLEAR as u8), + )?; + Ok((sandbox_dirty_pages, None)) + } + #[cfg(crashdump)] fn crashdump_context(&self) -> Result> { if self.rt_cfg.guest_core_dump { @@ -1113,7 +1193,8 @@ mod tests { return; } const MEM_SIZE: usize = 0x3000; - let gm = shared_mem_with_code(CODE.as_slice(), MEM_SIZE, 0).unwrap(); + let mut gm = shared_mem_with_code(CODE.as_slice(), MEM_SIZE, 0).unwrap(); + gm.stop_tracking_dirty_pages().unwrap(); let rsp_ptr = GuestPtr::try_from(0).unwrap(); let pml4_ptr = GuestPtr::try_from(0).unwrap(); let entrypoint_ptr = GuestPtr::try_from(0).unwrap(); diff --git a/src/hyperlight_host/src/hypervisor/hyperv_windows.rs b/src/hyperlight_host/src/hypervisor/hyperv_windows.rs index cd0398854..1af477d57 100644 --- a/src/hyperlight_host/src/hypervisor/hyperv_windows.rs +++ b/src/hyperlight_host/src/hypervisor/hyperv_windows.rs @@ -55,6 +55,7 @@ use super::{ use super::{HyperlightExit, Hypervisor, InterruptHandle, VirtualCPU}; use crate::hypervisor::fpu::FP_CONTROL_WORD_DEFAULT; use crate::hypervisor::wrappers::WHvGeneralRegisters; +use crate::mem::bitmap::new_page_bitmap; use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags}; use crate::mem::ptr::{GuestPtr, RawPtr}; #[cfg(crashdump)] @@ -606,13 +607,21 @@ impl Hypervisor for HypervWindowsDriver { Ok(()) } + fn get_and_clear_dirty_pages(&mut self) -> Result<(Vec, Option>>)> { + // For now we just mark all pages dirty which is the equivalent of taking a full snapshot + let total_size = self.mem_regions.iter().map(|r| r.guest_region.len()).sum(); + Ok((new_page_bitmap(total_size, true)?, None)) + } + #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] unsafe fn map_region(&mut self, _rgn: &MemoryRegion) -> Result<()> { - log_then_return!("Mapping host memory into the guest not yet supported on this platform"); + // TODO: when adding support, also update `get_and_clear_dirty_pages`, see kvm/mshv for details + log_then_return!("Mapping host memory into the guest not yet supported on this platform."); } #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] unsafe fn unmap_regions(&mut self, n: u64) -> Result<()> { + // TODO: when adding support, also update `get_and_clear_dirty_pages`, see kvm/mshv for details if n > 0 { log_then_return!( "Mapping host memory into the guest not yet supported on this platform" diff --git a/src/hyperlight_host/src/hypervisor/kvm.rs b/src/hyperlight_host/src/hypervisor/kvm.rs index 0802ecb6b..39ef8dbe4 100644 --- a/src/hyperlight_host/src/hypervisor/kvm.rs +++ b/src/hyperlight_host/src/hypervisor/kvm.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use std::sync::Mutex; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use hyperlight_common::mem::{PAGE_SIZE_USIZE, PAGES_IN_BLOCK}; use kvm_bindings::{kvm_fpu, kvm_regs, kvm_userspace_memory_region}; use kvm_ioctls::Cap::UserMemory; use kvm_ioctls::{Kvm, VcpuExit, VcpuFd, VmFd}; @@ -43,7 +44,8 @@ use super::{ use super::{HyperlightExit, Hypervisor, InterruptHandle, LinuxInterruptHandle, VirtualCPU}; #[cfg(gdb)] use crate::HyperlightError; -use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags}; +use crate::mem::bitmap::{bit_index_iterator, new_page_bitmap}; +use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags, MemoryRegionType}; use crate::mem::ptr::{GuestPtr, RawPtr}; use crate::sandbox::SandboxConfiguration; #[cfg(crashdump)] @@ -290,6 +292,7 @@ pub(crate) struct KVMDriver { entrypoint: u64, orig_rsp: GuestPtr, mem_regions: Vec, + n_initial_regions: usize, interrupt_handle: Arc, #[cfg(gdb)] @@ -372,6 +375,7 @@ impl KVMDriver { vcpu_fd, entrypoint, orig_rsp: rsp_gp, + n_initial_regions: mem_regions.len(), mem_regions, interrupt_handle: interrupt_handle.clone(), #[cfg(gdb)] @@ -750,6 +754,61 @@ impl Hypervisor for KVMDriver { self.interrupt_handle.clone() } + // TODO: Implement getting additional host-mapped dirty pages. + fn get_and_clear_dirty_pages(&mut self) -> Result<(Vec, Option>>)> { + let n_contiguous = self + .mem_regions + .windows(2) + .take_while(|window| window[0].guest_region.end == window[1].guest_region.start) + .count() + + 1; // +1 because windows(2) gives us n-1 pairs for n regions + + if n_contiguous != self.n_initial_regions { + return Err(new_error!( + "get_and_clear_dirty_pages: not all regions are contiguous, expected {} but got {}", + self.n_initial_regions, + n_contiguous + )); + } + let mut page_indices = vec![]; + let mut current_page = 0; + + // Iterate over all memory regions and get the dirty pages for each region ignoring guard pages which cannot be dirty + for (i, mem_region) in self.mem_regions.iter().take(n_contiguous).enumerate() { + let num_pages = mem_region.guest_region.len() / PAGE_SIZE_USIZE; + let bitmap = match mem_region.flags { + MemoryRegionFlags::READ => { + // read-only page. It can never be dirty so return zero dirty pages. + new_page_bitmap(mem_region.guest_region.len(), false)? + } + _ => { + if mem_region.region_type == MemoryRegionType::GuardPage { + // Trying to get dirty pages for a guard page region results in a VMMSysError(2) + new_page_bitmap(mem_region.guest_region.len(), false)? + } else { + // Get the dirty bitmap for the memory region + self.vm_fd + .get_dirty_log(i as u32, mem_region.guest_region.len())? + } + } + }; + for page_idx in bit_index_iterator(&bitmap) { + page_indices.push(current_page + page_idx); + } + current_page += num_pages; + } + + // convert vec of page indices to vec of blocks + let mut sandbox_dirty_pages = new_page_bitmap(current_page * PAGE_SIZE_USIZE, false)?; + for page_idx in page_indices { + let block_idx = page_idx / PAGES_IN_BLOCK; + let bit_idx = page_idx % PAGES_IN_BLOCK; + sandbox_dirty_pages[block_idx] |= 1 << bit_idx; + } + + Ok((sandbox_dirty_pages, None)) + } + #[cfg(crashdump)] fn crashdump_context(&self) -> Result> { if self.rt_cfg.guest_core_dump { diff --git a/src/hyperlight_host/src/hypervisor/mod.rs b/src/hyperlight_host/src/hypervisor/mod.rs index ecf6acbc5..b8f348878 100644 --- a/src/hyperlight_host/src/hypervisor/mod.rs +++ b/src/hyperlight_host/src/hypervisor/mod.rs @@ -196,6 +196,14 @@ pub(crate) trait Hypervisor: Debug + Sync + Send { None } + /// Get dirty pages as a bitmap (Vec). + /// Each bit in a u64 represents a page. + /// This also clears the bitflags, marking the pages as non-dirty. + /// The Vec in the tuple is the bitmap of the first contiguous memory regions, which represents the sandbox itself. + /// The Vec> in the tuple are the host-mapped regions, which aren't necessarily contiguous, and not yet implemented + #[allow(clippy::type_complexity)] + fn get_and_clear_dirty_pages(&mut self) -> Result<(Vec, Option>>)>; + /// Get InterruptHandle to underlying VM fn interrupt_handle(&self) -> Arc; @@ -507,6 +515,7 @@ pub(crate) mod tests { #[cfg(gdb)] use crate::hypervisor::DbgMemAccessHandlerCaller; use crate::mem::ptr::RawPtr; + use crate::mem::shared_mem::SharedMemory; use crate::sandbox::uninitialized::GuestBinary; #[cfg(any(crashdump, gdb))] use crate::sandbox::uninitialized::SandboxRuntimeConfig; @@ -558,9 +567,33 @@ pub(crate) mod tests { let sandbox = UninitializedSandbox::new(GuestBinary::FilePath(filename.clone()), Some(config))?; let (_hshm, mut gshm) = sandbox.mgr.build(); + + let regions = gshm.layout.get_memory_regions(&gshm.shared_mem)?; + + // Set up shared memory to calculate rsp_ptr + #[cfg(feature = "init-paging")] + let rsp_ptr = { + use crate::mem::ptr::GuestPtr; + let rsp_u64 = gshm.set_up_page_tables(®ions)?; + let rsp_raw = RawPtr::from(rsp_u64); + GuestPtr::try_from(rsp_raw) + }?; + #[cfg(not(feature = "init-paging"))] + let rsp_ptr = { + use crate::mem::ptr::GuestPtr; + use crate::mem::ptr_offset::Offset; + GuestPtr::try_from(Offset::from(0)) + }?; + + // We need to stop tracking dirty pages from the host side before we start the guest + gshm.shared_mem + .with_exclusivity(|e| e.stop_tracking_dirty_pages())??; + let mut vm = set_up_hypervisor_partition( &mut gshm, &config, + rsp_ptr, + regions, #[cfg(any(crashdump, gdb))] &rt_cfg, )?; diff --git a/src/hyperlight_host/src/mem/bitmap.rs b/src/hyperlight_host/src/mem/bitmap.rs new file mode 100644 index 000000000..37f33fb07 --- /dev/null +++ b/src/hyperlight_host/src/mem/bitmap.rs @@ -0,0 +1,296 @@ +/* +Copyright 2025 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +use std::cmp::Ordering; + +use hyperlight_common::mem::{PAGE_SIZE_USIZE, PAGES_IN_BLOCK}; +use termcolor::{Color, ColorChoice, ColorSpec, StandardStream, WriteColor}; + +use super::layout::SandboxMemoryLayout; +use crate::{Result, log_then_return}; + +// Contains various helper functions for dealing with bitmaps. + +/// Returns a new bitmap of pages. If `init_dirty` is true, all pages are marked as dirty, otherwise all pages are clean. +/// Will return an error if given size is 0. +pub fn new_page_bitmap(size_in_bytes: usize, init_dirty: bool) -> Result> { + if size_in_bytes == 0 { + log_then_return!("Tried to create a bitmap with size 0."); + } + let num_pages = size_in_bytes.div_ceil(PAGE_SIZE_USIZE); + let num_blocks = num_pages.div_ceil(PAGES_IN_BLOCK); + match init_dirty { + false => Ok(vec![0; num_blocks]), + true => { + let mut bitmap = vec![!0u64; num_blocks]; // all pages are dirty + let num_unused_bits = num_blocks * PAGES_IN_BLOCK - num_pages; + // set the unused bits to 0, could cause problems otherwise + #[allow(clippy::unwrap_used)] + let last_block = bitmap.last_mut().unwrap(); // unwrap is safe since size_in_bytes>0 + *last_block >>= num_unused_bits; + Ok(bitmap) + } + } +} + +/// Returns the union (bitwise OR) of two bitmaps. The resulting bitmap will have the same length +/// as the longer of the two input bitmaps. +pub(crate) fn bitmap_union(bitmap: &[u64], other_bitmap: &[u64]) -> Vec { + let min_len = bitmap.len().min(other_bitmap.len()); + let max_len = bitmap.len().max(other_bitmap.len()); + + let mut result = vec![0; max_len]; + + for i in 0..min_len { + result[i] = bitmap[i] | other_bitmap[i]; + } + + match bitmap.len().cmp(&other_bitmap.len()) { + Ordering::Greater => { + result[min_len..].copy_from_slice(&bitmap[min_len..]); + } + Ordering::Less => { + result[min_len..].copy_from_slice(&other_bitmap[min_len..]); + } + Ordering::Equal => {} + } + + result +} + +// Used as a helper struct to implement an iterator on. +struct SetBitIndices<'a> { + bitmap: &'a [u64], + block_index: usize, // one block is 1 u64, which is 64 pages + current: u64, // the current block we are iterating over, or 0 if first iteration +} + +/// Iterates over the zero-based indices of the set bits in the given bitmap. +pub(crate) fn bit_index_iterator(bitmap: &[u64]) -> impl Iterator + '_ { + SetBitIndices { + bitmap, + block_index: 0, + current: 0, + } +} + +impl Iterator for SetBitIndices<'_> { + type Item = usize; + + fn next(&mut self) -> Option { + while self.current == 0 { + // will always enter this on first iteration because current is initialized to 0 + if self.block_index >= self.bitmap.len() { + // no more blocks to iterate over + return None; + } + self.current = self.bitmap[self.block_index]; + self.block_index += 1; + } + let trailing_zeros = self.current.trailing_zeros(); + self.current &= self.current - 1; // Clear the least significant set bit + Some((self.block_index - 1) * 64 + trailing_zeros as usize) // block_index guaranteed to be > 0 at this point + } +} + +// Unused but useful for debugging +// Prints the dirty bitmap in a human-readable format, coloring each page according to its region +// NOTE: Might need to be updated if the memory layout changes +#[allow(dead_code)] +pub(crate) fn print_dirty_bitmap(bitmap: &[u64], layout: &SandboxMemoryLayout) { + let mut stdout = StandardStream::stdout(ColorChoice::Auto); + + // Helper function to determine which memory region a page belongs to + fn get_region_info(page_index: usize, layout: &SandboxMemoryLayout) -> (&'static str, Color) { + let page_offset = page_index * PAGE_SIZE_USIZE; + + // Check each memory region in order, using available methods and approximations + if page_offset >= layout.init_data_offset { + ("INIT_DATA", Color::Ansi256(129)) // Purple + } else if page_offset >= layout.get_top_of_user_stack_offset() { + ("STACK", Color::Ansi256(208)) // Orange + } else if page_offset >= layout.get_guard_page_offset() { + ("GUARD_PAGE", Color::White) + } else if page_offset >= layout.guest_heap_buffer_offset { + ("HEAP", Color::Red) + } else if page_offset >= layout.output_data_buffer_offset { + ("OUTPUT_DATA", Color::Green) + } else if page_offset >= layout.input_data_buffer_offset { + ("INPUT_DATA", Color::Blue) + } else if page_offset >= layout.host_function_definitions_buffer_offset { + ("HOST_FUNC_DEF", Color::Cyan) + } else if page_offset >= layout.peb_address { + ("PEB", Color::Magenta) + } else if page_offset >= layout.get_guest_code_offset() { + ("CODE", Color::Yellow) + } else { + // Everything up to and including guest code should be PAGE_TABLES + ("PAGE_TABLES", Color::Ansi256(14)) // Bright cyan + } + } + + let mut num_dirty_pages = 0; + for &block in bitmap.iter() { + num_dirty_pages += block.count_ones() as usize; + } + + for (i, &block) in bitmap.iter().enumerate() { + if block != 0 { + print!("Block {:3}: ", i); + + // Print each bit in the block with appropriate color + for bit_pos in 0..64 { + let bit_mask = 1u64 << bit_pos; + let page_index = i * 64 + bit_pos; + let (_region_name, color) = get_region_info(page_index, layout); + + let mut color_spec = ColorSpec::new(); + color_spec.set_fg(Some(color)); + + if block & bit_mask != 0 { + // Make 1s bold with dark background to stand out from 0s + color_spec.set_bold(true).set_bg(Some(Color::Black)); + let _ = stdout.set_color(&color_spec); + print!("1"); + } else { + // 0s are colored but not bold, no background + let _ = stdout.set_color(&color_spec); + print!("0"); + } + let _ = stdout.reset(); + } + + // Print a legend for this block showing which regions are represented + let mut regions_in_block = std::collections::HashMap::new(); + for bit_pos in 0..64 { + let bit_mask = 1u64 << bit_pos; + if block & bit_mask != 0 { + let page_index = i * 64 + bit_pos; + let (region_name, color) = get_region_info(page_index, layout); + regions_in_block.insert(region_name, color); + } + } + + if !regions_in_block.is_empty() { + print!(" ["); + let mut sorted_regions: Vec<_> = regions_in_block.iter().collect(); + sorted_regions.sort_by_key(|(name, _)| *name); + for (i, (region_name, color)) in sorted_regions.iter().enumerate() { + if i > 0 { + print!(", "); + } + let mut color_spec = ColorSpec::new(); + color_spec.set_fg(Some(**color)).set_bold(true); + let _ = stdout.set_color(&color_spec); + print!("{}", region_name); + let _ = stdout.reset(); + } + print!("]"); + } + println!(); + } + } + // Print the total number of dirty pages + println!("Total dirty pages: {}", num_dirty_pages); +} + +#[cfg(test)] +mod tests { + use hyperlight_common::mem::PAGE_SIZE_USIZE; + + use crate::Result; + use crate::mem::bitmap::{bit_index_iterator, bitmap_union, new_page_bitmap}; + + #[test] + fn new_page_bitmap_test() -> Result<()> { + let bitmap = new_page_bitmap(1, false)?; + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap[0], 0); + + let bitmap = new_page_bitmap(1, true)?; + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap[0], 1); + + let bitmap = new_page_bitmap(32 * PAGE_SIZE_USIZE, false)?; + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap[0], 0); + + let bitmap = new_page_bitmap(32 * PAGE_SIZE_USIZE, true)?; + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap[0], 0x0000_0000_FFFF_FFFF); + Ok(()) + } + + #[test] + fn page_iterator() { + let data = vec![0b1000010100, 0b01, 0b100000000000000011]; + let mut iter = bit_index_iterator(&data); + assert_eq!(iter.next(), Some(2)); + assert_eq!(iter.next(), Some(4)); + assert_eq!(iter.next(), Some(9)); + assert_eq!(iter.next(), Some(64)); + assert_eq!(iter.next(), Some(128)); + assert_eq!(iter.next(), Some(129)); + assert_eq!(iter.next(), Some(145)); + assert_eq!(iter.next(), None); + + let data_2 = vec![0, 0, 0]; + let mut iter_2 = bit_index_iterator(&data_2); + assert_eq!(iter_2.next(), None); + + let data_3 = vec![0, 0, 0b1, 1 << 63]; + let mut iter_3 = bit_index_iterator(&data_3); + assert_eq!(iter_3.next(), Some(128)); + assert_eq!(iter_3.next(), Some(255)); + assert_eq!(iter_3.next(), None); + + let data_4 = vec![]; + let mut iter_4 = bit_index_iterator(&data_4); + assert_eq!(iter_4.next(), None); + } + + #[test] + fn union() -> Result<()> { + let a = 0b1000010100; + let b = 0b01; + let c = 0b100000000000000011; + let d = 0b101010100000011000000011; + let e = 0b000000000000001000000000000000000000; + let f = 0b100000000000000001010000000001010100000000000; + let bitmap = vec![a, b, c]; + let other_bitmap = vec![d, e, f]; + let union = bitmap_union(&bitmap, &other_bitmap); + assert_eq!(union, vec![a | d, b | e, c | f]); + + // different length + let union = bitmap_union(&[a], &[d, e, f]); + assert_eq!(union, vec![a | d, e, f]); + + let union = bitmap_union(&[a, b, c], &[d]); + assert_eq!(union, vec![a | d, b, c]); + + let union = bitmap_union(&[], &[d, e]); + assert_eq!(union, vec![d, e]); + + let union = bitmap_union(&[a, b, c], &[]); + assert_eq!(union, vec![a, b, c]); + + let union = bitmap_union(&[], &[]); + let empty: Vec = vec![]; + assert_eq!(union, empty); + + Ok(()) + } +} diff --git a/src/hyperlight_host/src/mem/dirty_page_tracking.rs b/src/hyperlight_host/src/mem/dirty_page_tracking.rs new file mode 100644 index 000000000..6d95ad232 --- /dev/null +++ b/src/hyperlight_host/src/mem/dirty_page_tracking.rs @@ -0,0 +1,59 @@ +/* +Copyright 2025 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use std::sync::Arc; + +use tracing::{Span, instrument}; + +#[cfg(target_os = "linux")] +pub use super::linux_dirty_page_tracker::LinuxDirtyPageTracker as PlatformDirtyPageTracker; +use super::shared_mem::HostMapping; +#[cfg(target_os = "windows")] +pub use super::windows_dirty_page_tracker::WindowsDirtyPageTracker as PlatformDirtyPageTracker; +use crate::Result; + +/// Trait defining the interface for dirty page tracking implementations +pub trait DirtyPageTracking { + #[cfg(test)] + fn get_dirty_pages(&self) -> Result>; + fn uninstall(self) -> Result>; +} + +/// Cross-platform dirty page tracker that delegates to platform-specific implementations +#[derive(Debug)] +pub struct DirtyPageTracker { + inner: PlatformDirtyPageTracker, +} + +impl DirtyPageTracker { + /// Create a new dirty page tracker for the given shared memory + #[instrument(skip_all, parent = Span::current(), level = "Trace")] + pub fn new(mapping: Arc) -> Result { + let inner = PlatformDirtyPageTracker::new(mapping)?; + Ok(Self { inner }) + } +} + +impl DirtyPageTracking for DirtyPageTracker { + fn uninstall(self) -> Result> { + self.inner.stop_tracking_and_get_dirty_pages() + } + + #[cfg(test)] + fn get_dirty_pages(&self) -> Result> { + self.inner.get_dirty_pages() + } +} diff --git a/src/hyperlight_host/src/mem/layout.rs b/src/hyperlight_host/src/mem/layout.rs index 04edc9bcc..b5be468fa 100644 --- a/src/hyperlight_host/src/mem/layout.rs +++ b/src/hyperlight_host/src/mem/layout.rs @@ -111,10 +111,10 @@ pub(crate) struct SandboxMemoryLayout { pub(crate) host_function_definitions_buffer_offset: usize, pub(super) input_data_buffer_offset: usize, pub(super) output_data_buffer_offset: usize, - guest_heap_buffer_offset: usize, + pub(super) guest_heap_buffer_offset: usize, guard_page_offset: usize, guest_user_stack_buffer_offset: usize, // the lowest address of the user stack - init_data_offset: usize, + pub(super) init_data_offset: usize, // other pub(crate) peb_address: usize, @@ -246,7 +246,7 @@ impl SandboxMemoryLayout { pub(crate) const BASE_ADDRESS: usize = 0x0; // the offset into a sandbox's input/output buffer where the stack starts - const STACK_POINTER_SIZE_BYTES: u64 = 8; + pub(crate) const STACK_POINTER_SIZE_BYTES: u64 = 8; /// Create a new `SandboxMemoryLayout` with the given /// `SandboxConfiguration`, code size and stack/heap size. @@ -397,7 +397,7 @@ impl SandboxMemoryLayout { /// Get the offset in guest memory to the output data pointer. #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_output_data_pointer_offset(&self) -> usize { + pub(super) fn get_output_data_pointer_offset(&self) -> usize { // This field is immediately after the output data size field, // which is a `u64`. self.get_output_data_size_offset() + size_of::() @@ -429,7 +429,7 @@ impl SandboxMemoryLayout { /// Get the offset in guest memory to the input data pointer. #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_input_data_pointer_offset(&self) -> usize { + pub(super) fn get_input_data_pointer_offset(&self) -> usize { // The input data pointer is immediately after the input // data size field in the input data `GuestMemoryRegion` struct which is a `u64`. self.get_input_data_size_offset() + size_of::() diff --git a/src/hyperlight_host/src/mem/linux_dirty_page_tracker.rs b/src/hyperlight_host/src/mem/linux_dirty_page_tracker.rs new file mode 100644 index 000000000..531449777 --- /dev/null +++ b/src/hyperlight_host/src/mem/linux_dirty_page_tracker.rs @@ -0,0 +1,841 @@ +/* +Copyright 2025 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use std::ptr; +use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering}; +use std::sync::{Arc, OnceLock}; + +use hyperlight_common::mem::PAGE_SIZE_USIZE; +use libc::{PROT_READ, PROT_WRITE, mprotect}; +use lockfree::map::Map; +use log::error; + +use crate::mem::shared_mem::HostMapping; +use crate::{Result, new_error}; + +// Tracker metadata stored in global lock-free storage +struct TrackerData { + pid: u32, + base_addr: usize, + size: usize, + num_pages: usize, + dirty_pages: Vec, +} + +// Global lock-free collection to store tracker data for signal handler to access + +static TRACKERS: OnceLock> = OnceLock::new(); + +// Helper function to get or initialize the global trackers map +// lockfree::Map is truly lock-free and safe for signal handlers +fn get_trackers() -> &'static Map { + TRACKERS.get_or_init(Map::new) +} + +/// Global tracker ID counter +static NEXT_TRACKER_ID: AtomicUsize = AtomicUsize::new(1); + +/// Original SIGSEGV handler to chain to (stored atomically for async signal safety) +static ORIGINAL_SIGSEGV_HANDLER: AtomicPtr = AtomicPtr::new(ptr::null_mut()); + +/// Whether our SIGSEGV handler is installed +static HANDLER_INSTALLED: AtomicBool = AtomicBool::new(false); + +/// Dirty page tracker for Linux +/// This tracks which pages have been written for a memory region once new has been called +/// It marks pages as RO and then uses SIGSEGV to detect writes to pages, then updates the page to RW and notes the page index as dirty by writing details to global lock-free storage +/// +/// A user calls get_dirty_pages to get a list of dirty pages to get details of the pages that were written to since the tracker was created +/// +/// Once a user has called get_dirty_pages, this tracker is destroyed and will not track changes any longer +#[derive(Debug)] +pub struct LinuxDirtyPageTracker { + /// Unique ID for this tracker + id: usize, + /// Base address of the memory region being tracked + base_addr: usize, + /// Size of the memory region in bytes + size: usize, + /// Keep a reference to the HostMapping to ensure memory lifetime + _mapping: Arc, +} + +// DirtyPageTracker should be Send because: +// 1. The Arc ensures the memory stays valid +// 2. The tracker handles synchronization properly +// 3. This is needed for threaded sandbox initialization +unsafe impl Send for LinuxDirtyPageTracker {} + +impl LinuxDirtyPageTracker { + /// Create a new dirty page tracker for the given shared memory + pub(super) fn new(mapping: Arc) -> Result { + if mapping.size == 0 { + return Err(new_error!("Cannot track empty memory region")); + } + + if mapping.ptr as usize % PAGE_SIZE_USIZE != 0 { + return Err(new_error!("Base address must be page-aligned")); + } + let base_addr = mapping.ptr as usize + PAGE_SIZE_USIZE; // Start after the first page to avoid tracking guard page + let size = mapping.size - 2 * PAGE_SIZE_USIZE; // Exclude guard pages at start and end + + // Get the current process ID + let current_pid = std::process::id(); + + // Check that there is not already a tracker that includes this address range + // within the same process (virtual addresses are only unique per process) + for guard in get_trackers().iter() { + let tracker_data = guard.val(); + + // Only check for overlaps within the same process + if tracker_data.pid == current_pid { + let existing_start = tracker_data.base_addr; + let existing_end = tracker_data.base_addr + tracker_data.size; + let new_start = base_addr; + let new_end = base_addr.wrapping_add(size); + + // Check for overlap: two ranges [a,b) and [c,d) overlap if max(a,c) < min(b,d) + // Equivalently: they DON'T overlap if b <= c || d <= a + // So they DO overlap if !(b <= c || d <= a) which is (b > c && d > a) + if new_end > existing_start && existing_end > new_start { + return Err(new_error!( + "Address range [{:#x}, {:#x}) overlaps with existing tracker [{:#x}, {:#x}) in process {}", + new_start, + new_end, + existing_start, + existing_end, + current_pid + )); + } + } + } + + let num_pages = size.div_ceil(PAGE_SIZE_USIZE); + let id = NEXT_TRACKER_ID.fetch_add(1, Ordering::Relaxed); + + // Create atomic array for dirty page tracking + let dirty_pages: Vec = (0..num_pages).map(|_| AtomicBool::new(false)).collect(); + + // Create tracker data + let tracker_data = TrackerData { + pid: current_pid, + base_addr, + size, + num_pages, + dirty_pages, + }; + + // Install global SIGSEGV handler if not already installed + Self::ensure_sigsegv_handler_installed()?; + + // Write protect the memory region to make it read-only so we get SIGSEGV on writes + let result = unsafe { mprotect(base_addr as *mut libc::c_void, size, PROT_READ) }; + + if result != 0 { + return Err(new_error!( + "Failed to write-protect memory for dirty tracking: {}", + std::io::Error::last_os_error() + )); + } + + get_trackers().insert(id, tracker_data); + + Ok(Self { + id, + base_addr, + size, + _mapping: mapping, + }) + } + + /// Get all dirty page indices for this tracker. + /// NOTE: This is not a bitmap, but a vector of indices where each index corresponds to a page that has been written to. + #[cfg(test)] + pub(super) fn get_dirty_pages(&self) -> Result> { + let res: Vec = if let Some(tracker_data) = get_trackers().get(&self.id) { + let mut dirty_pages = Vec::new(); + let tracker_data = tracker_data.val(); + for (idx, dirty) in tracker_data.dirty_pages.iter().enumerate() { + if dirty.load(Ordering::Acquire) { + dirty_pages.push(idx); + } + } + dirty_pages + } else { + return Err(new_error!( + "Tried to get dirty pages from tracker, but no tracker data found" + )); + }; + + Ok(res) + } + + /// Get all dirty page indices for this tracker. + /// NOTE: This is not a bitmap, but a vector of indices where each index corresponds to a page that has been written to. + pub(super) fn stop_tracking_and_get_dirty_pages(self) -> Result> { + let res: Vec = if let Some(tracker_data) = get_trackers().get(&self.id) { + let mut dirty_pages = Vec::new(); + let tracker_data = tracker_data.val(); + for (idx, dirty) in tracker_data.dirty_pages.iter().enumerate() { + if dirty.load(Ordering::Acquire) { + dirty_pages.push(idx); + } + } + dirty_pages + } else { + return Err(new_error!( + "Tried to get dirty pages from tracker, but no tracker data found" + )); + }; + + // explicit to document intent + drop(self); + + Ok(res) + } + + /// Install global SIGSEGV handler if not already installed + fn ensure_sigsegv_handler_installed() -> Result<()> { + // Use compare_exchange to ensure only one thread does the installation + match HANDLER_INSTALLED.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) { + Ok(_) => { + // We won the race - we're responsible for installation + + // Get the current handler before installing ours + let mut original = Box::new(unsafe { std::mem::zeroed::() }); + + unsafe { + let result = libc::sigaction( + libc::SIGSEGV, + std::ptr::null(), + original.as_mut() as *mut libc::sigaction, + ); + + if result != 0 { + // Reset the flag on error + HANDLER_INSTALLED.store(false, Ordering::Release); + return Err(new_error!( + "Failed to get original SIGSEGV handler: {}", + std::io::Error::last_os_error() + )); + } + } + + // Install our handler + if let Err(e) = vmm_sys_util::signal::register_signal_handler( + libc::SIGSEGV, + Self::sigsegv_handler, + ) { + // Reset the flag on error + HANDLER_INSTALLED.store(false, Ordering::Release); + return Err(new_error!("Failed to register SIGSEGV handler: {}", e)); + } + + // Store original handler pointer atomically + let original_ptr = Box::into_raw(original); + ORIGINAL_SIGSEGV_HANDLER.store(original_ptr, Ordering::Release); + + Ok(()) + } + Err(_) => { + // Another thread already installed it, we're done + Ok(()) + } + } + } + + /// MINIMAL async signal safe SIGSEGV handler for dirty page tracking + /// This handler uses only async signal safe operations: + /// - Atomic loads/stores + /// - mprotect (async signal safe) + /// - Simple pointer arithmetic + /// - global lock-free storage (lockfree::Map) + /// - `getpid()` to check process ownership + extern "C" fn sigsegv_handler( + signal: libc::c_int, + info: *mut libc::siginfo_t, + context: *mut libc::c_void, + ) { + unsafe { + if signal != libc::SIGSEGV || info.is_null() { + Self::call_original_handler(signal, info, context); + return; + } + + let fault_addr = (*info).si_addr() as usize; + + // Check all trackers in global lock-free storage + // lockfree::Map::iter() is guaranteed to be async-signal-safe + let mut handled = false; + for guard in get_trackers().iter() { + let tracker_data = guard.val(); + + // Only handle faults for trackers in the current process + // We compare the stored PID with the current process PID + // getpid() is async-signal-safe, but we can avoid the call by checking + // if the fault address is within this tracker's range first + if fault_addr < tracker_data.base_addr + || fault_addr >= tracker_data.base_addr + tracker_data.size + { + continue; // Fault not in this tracker's range + } + + // Now verify this tracker belongs to the current process + let current_pid = libc::getpid() as u32; + if tracker_data.pid != current_pid { + continue; + } + + // We know the fault is in this tracker's range and it's our process + // Calculate page index + let page_offset = fault_addr - tracker_data.base_addr; + let page_idx = page_offset / PAGE_SIZE_USIZE; + + if page_idx < tracker_data.num_pages { + // Mark page dirty atomically (async signal safe) + tracker_data.dirty_pages[page_idx].store(true, Ordering::Relaxed); + + // Make page writable (mprotect is async signal safe) + let page_addr = tracker_data.base_addr + (page_idx * PAGE_SIZE_USIZE); + let result = mprotect( + page_addr as *mut libc::c_void, + PAGE_SIZE_USIZE, + PROT_READ | PROT_WRITE, + ); + + handled = result == 0; + break; // Found the tracker, stop searching + } + } + + // If not handled by any of our trackers, chain to original handler + if !handled { + Self::call_original_handler(signal, info, context); + } + } + } + + /// Call the original SIGSEGV handler if available (async signal safe) + fn call_original_handler( + signal: libc::c_int, + info: *mut libc::siginfo_t, + context: *mut libc::c_void, + ) { + unsafe { + let handler_ptr = ORIGINAL_SIGSEGV_HANDLER.load(Ordering::Acquire); + if !handler_ptr.is_null() { + let original = &*handler_ptr; + if original.sa_sigaction != 0 { + let handler_fn: extern "C" fn( + libc::c_int, + *mut libc::siginfo_t, + *mut libc::c_void, + ) = std::mem::transmute(original.sa_sigaction); + handler_fn(signal, info, context); + } + } + } + } + + #[cfg(test)] + /// Check if a memory address falls within this tracker's region + fn contains_address(&self, addr: usize) -> bool { + addr >= self.base_addr && addr < self.base_addr + self.size + } +} + +impl Drop for LinuxDirtyPageTracker { + fn drop(&mut self) { + // Remove this tracker's metadata from global lock-free storage + if get_trackers().remove(&self.id).is_none() { + error!("Tracker {} not found in global storage", self.id); + } + + // Restore memory protection + unsafe { + mprotect( + self.base_addr as *mut libc::c_void, + self.size, + PROT_READ | PROT_WRITE, + ); + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + use std::ptr::null_mut; + use std::sync::{Arc, Barrier}; + use std::thread; + + use libc::{MAP_ANONYMOUS, MAP_FAILED, MAP_PRIVATE, PROT_READ, PROT_WRITE, mmap}; + use rand::{Rng, rng}; + + use super::*; + use crate::mem::shared_mem::{ExclusiveSharedMemory, HostMapping, SharedMemory}; + + const PAGE_SIZE: usize = 4096; + + /// Helper to create page-aligned memory for testing + /// Returns (pointer, size) tuple + fn create_aligned_memory(size: usize) -> Arc { + let addr = unsafe { + mmap( + null_mut(), + size, + PROT_READ | PROT_WRITE, + MAP_ANONYMOUS | MAP_PRIVATE, + -1, + 0, + ) + }; + + if addr == MAP_FAILED { + panic!("Failed to allocate aligned memory with mmap"); + } + + // HostMapping is only non-Send/Sync because raw pointers + // are not ("as a lint", as the Rust docs say). We don't + // want to mark HostMapping Send/Sync immediately, because + // that could socially imply that it's "safe" to use + // unsafe accesses from multiple threads at once. Instead, we + // directly impl Send and Sync on this type. Since this + // type does have Send and Sync manually impl'd, the Arc + // is not pointless as the lint suggests. + #[allow(clippy::arc_with_non_send_sync)] + Arc::new(HostMapping { + ptr: addr as *mut u8, + size, + }) + } + + #[test] + fn test_tracker_creation() { + let mut memory = ExclusiveSharedMemory::new(5 * 4096).unwrap(); + memory.stop_tracking_dirty_pages().unwrap(); + } + + #[test] + fn test_get_dirty_pages_initially_empty() { + let mut memory = ExclusiveSharedMemory::new(5 * 4096).unwrap(); + + let bitmap = memory + .stop_tracking_dirty_pages() + .expect("Failed to stop tracking dirty pages"); + + assert!(bitmap.is_empty(), "Dirty pages should be empty initially"); + } + + #[test] + fn test_random_page_dirtying() { + const MEMORY_SIZE: usize = 4096; + let mut memory = ExclusiveSharedMemory::new(MEMORY_SIZE).unwrap(); + + let bitmap = memory.get_dirty_pages().expect("Failed to get dirty pages"); + + assert!(bitmap.is_empty(), "Dirty pages should be empty initially"); + + let mem = memory.as_mut_slice(); + let five_random_idx = rand::rng() + .sample_iter(rand::distr::Uniform::new(0, MEMORY_SIZE).unwrap()) + .take(5) + .collect::>(); + + println!("Random indices: {:?}", &five_random_idx); + + for idx in &five_random_idx { + mem[*idx] = 1; // Write to random indices + } + let dirty_pages = memory + .stop_tracking_dirty_pages() + .expect("Failed to stop tracking dirty pages"); + assert!( + !dirty_pages.is_empty(), + "Dirty pages should not be empty after writes" + ); + for idx in five_random_idx { + let page_idx = idx / PAGE_SIZE; + assert!( + dirty_pages.contains(&page_idx), + "Page {} should be dirty after writing to index {}", + page_idx, + idx + ); + } + } + + #[test] + fn test_multiple_trackers_different_regions() { + const MEMORY_SIZE: usize = PAGE_SIZE * 4; + + let mut memory1 = ExclusiveSharedMemory::new(MEMORY_SIZE).unwrap(); + let mut memory2 = ExclusiveSharedMemory::new(MEMORY_SIZE).unwrap(); + + // Verify initial state is clean + let bitmap1 = memory1 + .get_dirty_pages() + .expect("Failed to get dirty pages"); + let bitmap2 = memory2 + .get_dirty_pages() + .expect("Failed to get dirty pages"); + assert!(bitmap1.is_empty(), "Dirty pages should be empty initially"); + assert!(bitmap2.is_empty(), "Dirty pages should be empty initially"); + + // Write to different memory regions + let mem1 = memory1.as_mut_slice(); + let mem2 = memory2.as_mut_slice(); + + mem1[100] = 1; // Write to offset 100 in first memory region (page 0) + mem2[PAGE_SIZE + 200] = 2; // Write to offset 200 in second memory region (page 1) + + let dirty1 = memory1.stop_tracking_dirty_pages().unwrap(); + let dirty2 = memory2.stop_tracking_dirty_pages().unwrap(); + + // Verify each tracker only reports pages that were actually written to + // Memory1: wrote to offset 100, which is in page 0 + assert!(dirty1.contains(&0), "Memory 1 should have page 0 dirty"); + assert_eq!(dirty1.len(), 1, "Memory 1 should only have 1 dirty page"); + + // Memory2: wrote to offset 200, which is in page 1 + assert!(dirty2.contains(&1), "Memory 2 should have page 1 dirty"); + assert_eq!(dirty2.len(), 1, "Memory 2 should only have 1 dirty page"); + } + + #[test] + fn test_cleanup_on_drop() { + const MEMORY_SIZE: usize = PAGE_SIZE * 2; + let mut memory = ExclusiveSharedMemory::new(MEMORY_SIZE).unwrap(); + + // Verify initial state is clean + let bitmap = memory.get_dirty_pages().expect("Failed to get dirty pages"); + assert!(bitmap.is_empty(), "Dirty pages should be empty initially"); + + // Get memory slice - this should work initially + let mem = memory.as_mut_slice(); + + // Memory should be read-only during tracking (writes will trigger SIGSEGV but get handled) + // Write to memory to verify tracking works - this should succeed due to signal handler + mem[100] = 42; + + let raw_addr = memory.raw_ptr(); + let raw_size = memory.raw_mem_size(); + // Verify the write was tracked + let dirty_pages_before_stop = memory.get_dirty_pages().expect("Failed to get dirty pages"); + assert!( + !dirty_pages_before_stop.is_empty(), + "Should have dirty pages after write" + ); + assert!( + dirty_pages_before_stop.contains(&0), + "Page 0 should be dirty" + ); + + drop(memory); // Explicitly drop the memory + + // now try mmap the memory again, it should work + let res = unsafe { + libc::mmap( + raw_addr as *mut libc::c_void, + raw_size, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, + -1, + 0, + ) + }; + + assert!( + res != MAP_FAILED, + "Failed to remap memory after tracker drop: {}", + std::io::Error::last_os_error() + ); + } + + #[test] + fn test_page_boundaries() { + const MEMORY_SIZE: usize = PAGE_SIZE * 3; + let mut memory = ExclusiveSharedMemory::new(MEMORY_SIZE).unwrap(); + + // Verify initial state is clean + let bitmap = memory.get_dirty_pages().expect("Failed to get dirty pages"); + assert!(bitmap.is_empty(), "Dirty pages should be empty initially"); + + let mem = memory.as_mut_slice(); + + // Write to different offsets within the first tracked page + // Remember: tracker excludes the first page (guard page), so we need to offset by PAGE_SIZE + let offsets = [0, 1, 100, 1000, PAGE_SIZE - 1]; + + for &offset in &offsets { + // Write to the first tracked page (which is the second page in the memory region) + mem[offset] = offset as u8; + } + + let dirty_pages = memory.stop_tracking_dirty_pages().unwrap(); + + // All writes to the same page should result in the same page being dirty + assert!( + !dirty_pages.is_empty(), + "Should have dirty pages after writes" + ); + assert!( + dirty_pages.contains(&0), + "Page 0 should be dirty after writes to first tracked page" + ); + + // Since all writes were to the same page, we should only have one dirty page + assert_eq!( + dirty_pages.len(), + 1, + "Should only have one dirty page since all writes were to the same page" + ); + + // Now test writing to different pages + let mut memory2 = ExclusiveSharedMemory::new(MEMORY_SIZE).unwrap(); + let mem2 = memory2.as_mut_slice(); + + // Write to first tracked page (page 0 in tracker terms) + mem2[100] = 1; + // Write to second tracked page (page 1 in tracker terms) - this is the third page in memory + mem2[PAGE_SIZE] = 2; + + let dirty_pages2 = memory2.stop_tracking_dirty_pages().unwrap(); + + assert_eq!(dirty_pages2.len(), 2, "Should have two dirty pages"); + assert!(dirty_pages2.contains(&0), "Page 0 should be dirty"); + assert!(dirty_pages2.contains(&1), "Page 1 should be dirty"); + } + + #[test] + fn test_concurrent_trackers() { + const NUM_THREADS: usize = 50; + const UPDATES_PER_THREAD: usize = 500; + const MIN_MEMORY_SIZE: usize = 1024 * 1024; // 1MB + const MAX_MEMORY_SIZE: usize = 10 * 1024 * 1024; // 10MB + + // Create barrier for synchronization + let start_writing_barrier = Arc::new(Barrier::new(NUM_THREADS)); + + let mut handles = Vec::new(); + + for thread_id in 0..NUM_THREADS { + let start_writing_barrier = Arc::clone(&start_writing_barrier); + + let handle = thread::spawn(move || { + let mut rng = rng(); + + let memory_size = rng.random_range(MIN_MEMORY_SIZE..=MAX_MEMORY_SIZE); + + // Ensure memory size is page-aligned + let memory_size = (memory_size + PAGE_SIZE - 1) & !(PAGE_SIZE - 1); + + let mut memory = ExclusiveSharedMemory::new(memory_size) + .expect("Failed to create shared memory"); + + // Wait for all threads to finish allocating before starting writes + start_writing_barrier.wait(); + + // Track which pages we write to (in tracker page indices) + let mut pages_written = HashSet::new(); + let mut total_writes = 0; + + // Perform random memory updates in a scope to ensure slice is dropped + { + let mem = memory.as_mut_slice(); + + for _ in 0..UPDATES_PER_THREAD { + // Generate random offset within the entire slice + let write_offset = rng.random_range(0..mem.len()); + + // Calculate which tracker page this corresponds to + let tracker_page_idx = write_offset / PAGE_SIZE; + + // Generate random value to write + let value = rng.random::(); + + // Write to memory to trigger dirty tracking + mem[write_offset] = value; + + // Track this page as written to (HashSet handles duplicates) + pages_written.insert(tracker_page_idx); + total_writes += 1; + } + } // mem goes out of scope here + + // Final verification: check that ALL pages we wrote to are marked as dirty + let final_dirty_pages = memory.stop_tracking_dirty_pages().unwrap(); + + // Check that every page we wrote to is marked as dirty + for &page_idx in &pages_written { + assert!( + final_dirty_pages.contains(&page_idx), + "Thread {}: Page {} was written but not marked dirty. Pages written: {:?}, Pages dirty: {:?}", + thread_id, + page_idx, + pages_written, + final_dirty_pages + ); + } + + // Verify that dirty pages don't contain extra pages we didn't write to + for &dirty_page in &final_dirty_pages { + assert!( + pages_written.contains(&dirty_page), + "Thread {}: Found dirty page {} that was not written to. Pages written: {:?}", + thread_id, + dirty_page, + pages_written + ); + } + + // Additional check: verify that pages we didn't write to are not dirty + for page_idx in 0..(memory_size / PAGE_SIZE) { + if !pages_written.contains(&page_idx) { + assert!( + !final_dirty_pages.contains(&page_idx), + "Thread {}: Page {} was not written but is marked dirty. Pages written: {:?}, Pages dirty: {:?}", + thread_id, + page_idx, + pages_written, + final_dirty_pages + ); + } + } + + // Verify that the number of unique dirty pages matches unique pages written + let dirty_pages_set: HashSet = final_dirty_pages.into_iter().collect(); + assert_eq!( + pages_written.len(), + dirty_pages_set.len(), + "Thread {}: Mismatch between unique pages written ({}) and unique dirty pages ({}). \ + Total writes: {}, Pages written: {:?}, Dirty pages: {:?}", + thread_id, + pages_written.len(), + dirty_pages_set.len(), + total_writes, + pages_written, + dirty_pages_set + ); + + // Verify that dirty pages don't contain extra pages we didn't write to + for &dirty_page in &dirty_pages_set { + assert!( + pages_written.contains(&dirty_page), + "Thread {}: Found dirty page {} that was not written to. Pages written: {:?}", + thread_id, + dirty_page, + pages_written + ); + } + + (pages_written.len(), dirty_pages_set.len(), total_writes) + }); + + handles.push(handle); + } + + // Wait for all threads to complete and collect results + let mut total_unique_pages_written = 0; + let mut total_unique_dirty_pages = 0; + let mut total_write_operations = 0; + + for (thread_id, handle) in handles.into_iter().enumerate() { + let (unique_pages_written, unique_dirty_pages, write_operations) = handle + .join() + .unwrap_or_else(|_| panic!("Thread {} panicked", thread_id)); + + total_unique_pages_written += unique_pages_written; + total_unique_dirty_pages += unique_dirty_pages; + total_write_operations += write_operations; + } + + println!("Concurrent test completed:"); + println!(" {} threads", NUM_THREADS); + println!(" {} updates per thread", UPDATES_PER_THREAD); + println!(" {} total write operations", total_write_operations); + println!( + " {} total unique pages written", + total_unique_pages_written + ); + println!( + " {} total unique dirty pages detected", + total_unique_dirty_pages + ); + + // Verify that we detected the expected number of dirty pages + assert!( + total_unique_dirty_pages > 0, + "No dirty pages detected across all threads" + ); + assert_eq!( + total_unique_pages_written, total_unique_dirty_pages, + "Mismatch between unique pages written and unique dirty pages detected" + ); + + // The total write operations should normally be much higher than unique pages (due to multiple writes to same pages) + assert!( + total_write_operations >= total_unique_pages_written, + "Total write operations ({}) should be >= unique pages written ({})", + total_write_operations, + total_unique_pages_written + ); + } + + #[test] + fn test_tracker_contains_address() { + const MEMORY_SIZE: usize = PAGE_SIZE * 10; + let mapping = create_aligned_memory(MEMORY_SIZE); + let tracker = LinuxDirtyPageTracker::new(mapping.clone()).unwrap(); + + let base = mapping.ptr as usize; + + // Test all addresses in the memory region + for offset in 0..MEMORY_SIZE { + let address = base + offset; + + // First page (guard page) and last page (guard page) should not be contained + let is_first_page = offset < PAGE_SIZE; + let is_last_page = offset >= MEMORY_SIZE - PAGE_SIZE; + + if is_first_page || is_last_page { + assert!( + !tracker.contains_address(address), + "Address at offset {} (page {}) should not be contained (guard page)", + offset, + offset / PAGE_SIZE + ); + } else { + assert!( + tracker.contains_address(address), + "Address at offset {} (page {}) should be contained", + offset, + offset / PAGE_SIZE + ); + } + } + + // try some random addresses far from the base address + assert!( + !tracker.contains_address(base - 213217), + "Address far from base should not be contained" + ); + assert!( + !tracker.contains_address(base + MEMORY_SIZE + 12345), + "Address far from end should not be contained" + ); + } +} diff --git a/src/hyperlight_host/src/mem/memory_region.rs b/src/hyperlight_host/src/mem/memory_region.rs index b46426c3b..6b2c04f42 100644 --- a/src/hyperlight_host/src/mem/memory_region.rs +++ b/src/hyperlight_host/src/mem/memory_region.rs @@ -31,7 +31,7 @@ use bitflags::bitflags; use hyperlight_common::mem::PAGE_SHIFT; use hyperlight_common::mem::PAGE_SIZE_USIZE; #[cfg(kvm)] -use kvm_bindings::{KVM_MEM_READONLY, kvm_userspace_memory_region}; +use kvm_bindings::{KVM_MEM_LOG_DIRTY_PAGES, KVM_MEM_READONLY, kvm_userspace_memory_region}; #[cfg(mshv2)] use mshv_bindings::{ HV_MAP_GPA_EXECUTABLE, HV_MAP_GPA_PERMISSIONS_NONE, HV_MAP_GPA_READABLE, HV_MAP_GPA_WRITABLE, @@ -326,7 +326,7 @@ impl From for kvm_bindings::kvm_userspace_memory_region { userspace_addr: region.host_region.start as u64, flags: match perm_flags { MemoryRegionFlags::READ => KVM_MEM_READONLY, - _ => 0, // normal, RWX + _ => KVM_MEM_LOG_DIRTY_PAGES, // normal, RWX }, } } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 90cb76573..b6eb77719 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -24,6 +24,7 @@ use hyperlight_common::flatbuffer_wrappers::function_types::ReturnValue; use hyperlight_common::flatbuffer_wrappers::guest_error::GuestError; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; use hyperlight_common::flatbuffer_wrappers::host_function_details::HostFunctionDetails; +use hyperlight_common::mem::PAGES_IN_BLOCK; use tracing::{Span, instrument}; use super::exe::ExeInfo; @@ -33,8 +34,8 @@ use super::memory_region::{DEFAULT_GUEST_BLOB_MEM_FLAGS, MemoryRegion, MemoryReg use super::ptr::{GuestPtr, RawPtr}; use super::ptr_offset::Offset; use super::shared_mem::{ExclusiveSharedMemory, GuestSharedMemory, HostSharedMemory, SharedMemory}; -use super::shared_mem_snapshot::SharedMemorySnapshot; -use crate::HyperlightError::NoMemorySnapshot; +use super::shared_memory_snapshot_manager::SharedMemorySnapshotManager; +use crate::mem::bitmap::{bitmap_union, new_page_bitmap}; use crate::sandbox::SandboxConfiguration; use crate::sandbox::uninitialized::GuestBlob; use crate::{Result, log_then_return, new_error}; @@ -75,9 +76,8 @@ pub(crate) struct SandboxMemoryManager { pub(crate) entrypoint_offset: Offset, /// How many memory regions were mapped after sandbox creation pub(crate) mapped_rgns: u64, - /// A vector of memory snapshots that can be used to save and restore the state of the memory - /// This is used by the Rust Sandbox implementation (rather than the mem_snapshot field above which only exists to support current C API) - snapshots: Arc>>, + /// Shared memory snapshots that can be used to save and restore the state of the memory + snapshot_manager: Arc>>, } impl SandboxMemoryManager @@ -98,7 +98,7 @@ where load_addr, entrypoint_offset, mapped_rgns: 0, - snapshots: Arc::new(Mutex::new(Vec::new())), + snapshot_manager: Arc::new(Mutex::new(None)), } } @@ -107,17 +107,12 @@ where &mut self.shared_mem } - /// Set up the hypervisor partition in the given `SharedMemory` parameter - /// `shared_mem`, with the given memory size `mem_size` + /// Set up the page tables in the shared memory // TODO: This should perhaps happen earlier and use an // ExclusiveSharedMemory from the beginning. #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] #[cfg(feature = "init-paging")] - pub(crate) fn set_up_shared_memory( - &mut self, - mem_size: u64, - regions: &mut [MemoryRegion], - ) -> Result { + pub(crate) fn set_up_page_tables(&mut self, regions: &[MemoryRegion]) -> Result { let rsp: u64 = self.layout.get_top_of_user_stack_offset() as u64 + SandboxMemoryLayout::BASE_ADDRESS as u64 + self.layout.stack_size as u64 @@ -126,6 +121,7 @@ where // test from `sandbox_host_tests` fails. We should investigate this further. // See issue #498 for more details. - 0x28; + let mem_size = self.shared_mem.mem_size(); self.shared_mem.with_exclusivity(|shared_mem| { // Create PDL4 table with only 1 PML4E @@ -154,8 +150,6 @@ where // We can use the memory size to calculate the number of PTs we need // We round up mem_size/2MB - let mem_size = usize::try_from(mem_size)?; - let num_pages: usize = mem_size.div_ceil(AMOUNT_OF_MEMORY_PER_PT); // Create num_pages PT with 512 PTEs @@ -265,14 +259,36 @@ where } } - /// this function will create a memory snapshot and push it onto the stack of snapshots - /// It should be used when you want to save the state of the memory, for example, when evolving a sandbox to a new state - pub(crate) fn push_state(&mut self) -> Result<()> { - let snapshot = SharedMemorySnapshot::new(&mut self.shared_mem, self.mapped_rgns)?; - self.snapshots + /// this function will create an initial snapshot and then create the SnapshotManager + pub(crate) fn create_initial_snapshot( + &mut self, + vm_dirty_bitmap: &[u64], + host_dirty_page_idx: &[usize], + layout: &SandboxMemoryLayout, + ) -> Result<()> { + let mut existing_snapshot_manager = self + .snapshot_manager .try_lock() - .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))? - .push(snapshot); + .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?; + + if existing_snapshot_manager.is_some() { + log_then_return!("Snapshot manager already initialized, not creating a new one"); + } + + // covert vec of page indices to bitmap + let mut res = new_page_bitmap(self.shared_mem.mem_size(), false)?; + for page_idx in host_dirty_page_idx { + let block_idx = page_idx / PAGES_IN_BLOCK; + let bit_idx = page_idx % PAGES_IN_BLOCK; + res[block_idx] |= 1 << bit_idx; + } + + // merge the host dirty page map into the dirty bitmap + let merged = bitmap_union(&res, vm_dirty_bitmap); + + let mut snapshot_manager = SharedMemorySnapshotManager::new(&mut self.shared_mem, layout)?; + snapshot_manager.create_new_snapshot(&mut self.shared_mem, &merged, self.mapped_rgns)?; + existing_snapshot_manager.replace(snapshot_manager); Ok(()) } @@ -284,35 +300,62 @@ where /// Returns the number of memory regions mapped into the sandbox /// that need to be unmapped in order for the restore to be /// completed. - pub(crate) fn restore_state_from_last_snapshot(&mut self) -> Result { - let mut snapshots = self - .snapshots + pub(crate) fn restore_state_from_last_snapshot(&mut self, dirty_bitmap: &[u64]) -> Result { + let mut snapshot_manager = self + .snapshot_manager .try_lock() .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?; - let last = snapshots.last_mut(); - if last.is_none() { - log_then_return!(NoMemorySnapshot); + + match snapshot_manager.as_mut() { + None => { + log_then_return!("Snapshot manager not initialized"); + } + Some(snapshot_manager) => { + let old_rgns = self.mapped_rgns; + self.mapped_rgns = + snapshot_manager.restore_from_snapshot(&mut self.shared_mem, dirty_bitmap)?; + Ok(old_rgns - self.mapped_rgns) + } } - #[allow(clippy::unwrap_used)] // We know that last is not None because we checked it above - let snapshot = last.unwrap(); - let old_rgns = self.mapped_rgns; - self.mapped_rgns = snapshot.restore_from_snapshot(&mut self.shared_mem)?; - Ok(old_rgns - self.mapped_rgns) } /// this function pops the last snapshot off the stack and restores the memory to the previous state /// It should be used when you want to restore the state of the memory to a previous state and do not need to retain that state /// for example when devolving a sandbox to a previous state. - pub(crate) fn pop_and_restore_state_from_snapshot(&mut self) -> Result { - let last = self - .snapshots + pub(crate) fn pop_and_restore_state_from_snapshot( + &mut self, + dirty_bitmap: &[u64], + ) -> Result { + let mut snapshot_manager = self + .snapshot_manager .try_lock() - .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))? - .pop(); - if last.is_none() { - log_then_return!(NoMemorySnapshot); + .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?; + + match snapshot_manager.as_mut() { + None => { + log_then_return!("Snapshot manager not initialized"); + } + Some(snapshot_manager) => snapshot_manager + .pop_and_restore_state_from_snapshot(&mut self.shared_mem, dirty_bitmap), + } + } + + pub(crate) fn push_state(&mut self, dirty_bitmap: &[u64]) -> Result<()> { + let mut snapshot_manager = self + .snapshot_manager + .try_lock() + .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?; + + match snapshot_manager.as_mut() { + None => { + log_then_return!("Snapshot manager not initialized"); + } + Some(snapshot_manager) => snapshot_manager.create_new_snapshot( + &mut self.shared_mem, + dirty_bitmap, + self.mapped_rgns, + ), } - self.restore_state_from_last_snapshot() } /// Sets `addr` to the correct offset in the memory referenced by @@ -440,7 +483,7 @@ impl SandboxMemoryManager { load_addr: self.load_addr.clone(), entrypoint_offset: self.entrypoint_offset, mapped_rgns: 0, - snapshots: Arc::new(Mutex::new(Vec::new())), + snapshot_manager: Arc::new(Mutex::new(None)), }, SandboxMemoryManager { shared_mem: gshm, @@ -448,7 +491,7 @@ impl SandboxMemoryManager { load_addr: self.load_addr.clone(), entrypoint_offset: self.entrypoint_offset, mapped_rgns: 0, - snapshots: Arc::new(Mutex::new(Vec::new())), + snapshot_manager: Arc::new(Mutex::new(None)), }, ) } diff --git a/src/hyperlight_host/src/mem/mod.rs b/src/hyperlight_host/src/mem/mod.rs index 1bcc03eae..a7c36e1ea 100644 --- a/src/hyperlight_host/src/mem/mod.rs +++ b/src/hyperlight_host/src/mem/mod.rs @@ -14,17 +14,25 @@ See the License for the specific language governing permissions and limitations under the License. */ +/// Various helper functions for working with bitmaps +pub(crate) mod bitmap; +/// a module for tracking dirty pages in the host. +pub(crate) mod dirty_page_tracking; /// A simple ELF loader pub(crate) mod elf; /// A generic wrapper for executable files (PE, ELF, etc) pub(crate) mod exe; /// Functionality to establish a sandbox's memory layout. pub mod layout; +#[cfg(target_os = "linux")] +mod linux_dirty_page_tracker; /// memory regions to be mapped inside a vm pub mod memory_region; /// Functionality that wraps a `SandboxMemoryLayout` and a /// `SandboxMemoryConfig` to mutate a sandbox's memory as necessary. pub mod mgr; +/// A compact snapshot representation for memory pages +pub(crate) mod page_snapshot; /// Structures to represent pointers into guest and host memory pub mod ptr; /// Structures to represent memory address spaces into which pointers @@ -35,9 +43,10 @@ pub mod ptr_offset; /// A wrapper around unsafe functionality to create and initialize /// a memory region for a guest running in a sandbox. pub mod shared_mem; -/// A wrapper around a `SharedMemory` and a snapshot in time -/// of the memory therein -pub mod shared_mem_snapshot; /// Utilities for writing shared memory tests #[cfg(test)] pub(crate) mod shared_mem_tests; +/// A wrapper around a `SharedMemory` to manage snapshots of the memory +pub mod shared_memory_snapshot_manager; +#[cfg(target_os = "windows")] +mod windows_dirty_page_tracker; diff --git a/src/hyperlight_host/src/mem/page_snapshot.rs b/src/hyperlight_host/src/mem/page_snapshot.rs new file mode 100644 index 000000000..6b2383c8e --- /dev/null +++ b/src/hyperlight_host/src/mem/page_snapshot.rs @@ -0,0 +1,88 @@ +/* +Copyright 2025 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use std::collections::HashMap; + +use hyperlight_common::mem::PAGE_SIZE_USIZE; + +/// A compact snapshot representation that stores pages in a contiguous buffer +/// with an index for efficient lookup. +/// +/// This struct is designed to efficiently store and retrieve memory snapshots +/// by using a contiguous buffer for all page data combined with a HashMap index +/// for page lookups. This approach reduces memory overhead +/// compared to storing pages individually. +/// +/// # Clone Derivation +/// +/// This struct derives `Clone` because it's stored in `Vec` within +/// `SharedMemorySnapshotManager`, which itself derives `Clone`. +#[derive(Clone)] +pub(super) struct PageSnapshot { + /// Maps page numbers to their offset within the buffer (in page units) + page_index: HashMap, // page_number -> buffer_offset_in_pages + /// Contiguous buffer containing all the page data + buffer: Vec, + /// How many non-main-RAM regions were mapped when this snapshot was taken? + mapped_rgns: u64, +} + +impl PageSnapshot { + /// Create a snapshot from a list of page numbers with pre-allocated buffer + pub(super) fn with_pages_and_buffer( + page_numbers: Vec, + buffer: Vec, + mapped_rgns: u64, + ) -> Self { + let page_count = page_numbers.len(); + let mut page_index = HashMap::with_capacity(page_count); + + // Map each page number to its offset in the buffer + for (buffer_offset, page_num) in page_numbers.into_iter().enumerate() { + page_index.insert(page_num, buffer_offset); + } + + Self { + page_index, + buffer, + mapped_rgns, + } + } + + /// Get page data by page number, returns None if page is not in snapshot + pub(super) fn get_page(&self, page_num: usize) -> Option<&[u8]> { + self.page_index.get(&page_num).map(|&buffer_offset| { + let start = buffer_offset * PAGE_SIZE_USIZE; + let end = start + PAGE_SIZE_USIZE; + &self.buffer[start..end] + }) + } + + /// Get an iterator over all page numbers in this snapshot + pub(super) fn page_numbers(&self) -> impl Iterator + '_ { + self.page_index.keys().copied() + } + + /// Get the maximum page number in this snapshot, or None if empty + pub(super) fn max_page(&self) -> Option { + self.page_index.keys().max().copied() + } + + /// Get the number of mapped regions when this snapshot was taken + pub(super) fn mapped_rgns(&self) -> u64 { + self.mapped_rgns + } +} diff --git a/src/hyperlight_host/src/mem/shared_mem.rs b/src/hyperlight_host/src/mem/shared_mem.rs index 50c809f44..afb7a49fb 100644 --- a/src/hyperlight_host/src/mem/shared_mem.rs +++ b/src/hyperlight_host/src/mem/shared_mem.rs @@ -19,7 +19,7 @@ use std::ffi::c_void; use std::io::Error; #[cfg(target_os = "linux")] use std::ptr::null_mut; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use hyperlight_common::mem::PAGE_SIZE_USIZE; use tracing::{Span, instrument}; @@ -39,6 +39,7 @@ use windows::core::PCSTR; use crate::HyperlightError::MemoryAllocationFailed; #[cfg(target_os = "windows")] use crate::HyperlightError::{MemoryRequestTooBig, WindowsAPIError}; +use crate::mem::dirty_page_tracking::{DirtyPageTracker, DirtyPageTracking}; use crate::{Result, log_then_return, new_error}; /// Makes sure that the given `offset` and `size` are within the bounds of the memory with size `mem_size`. @@ -91,8 +92,8 @@ macro_rules! generate_writer { /// Send or Sync, since it doesn't ensure any particular synchronization. #[derive(Debug)] pub struct HostMapping { - ptr: *mut u8, - size: usize, + pub(crate) ptr: *mut u8, + pub(crate) size: usize, #[cfg(target_os = "windows")] handle: HANDLE, } @@ -133,6 +134,7 @@ impl Drop for HostMapping { #[derive(Debug)] pub struct ExclusiveSharedMemory { region: Arc, + signal_dirty_bitmap_tracker: Arc>>, } unsafe impl Send for ExclusiveSharedMemory {} @@ -147,6 +149,7 @@ unsafe impl Send for ExclusiveSharedMemory {} #[derive(Debug)] pub struct GuestSharedMemory { region: Arc, + signal_dirty_bitmap_tracker: Arc>>, /// The lock that indicates this shared memory is being used by non-Rust code /// /// This lock _must_ be held whenever the guest is executing, @@ -298,6 +301,8 @@ unsafe impl Send for GuestSharedMemory {} #[derive(Clone, Debug)] pub struct HostSharedMemory { region: Arc, + signal_dirty_bitmap_tracker: Arc>>, + lock: Arc>, } unsafe impl Send for HostSharedMemory {} @@ -370,23 +375,58 @@ impl ExclusiveSharedMemory { return Err(MprotectFailed(Error::last_os_error().raw_os_error())); } + // HostMapping is only non-Send/Sync because raw pointers + // are not ("as a lint", as the Rust docs say). We don't + // want to mark HostMapping Send/Sync immediately, because + // that could socially imply that it's "safe" to use + // unsafe accesses from multiple threads at once. Instead, we + // directly impl Send and Sync on this type. Since this + // type does have Send and Sync manually impl'd, the Arc + // is not pointless as the lint suggests. + #[allow(clippy::arc_with_non_send_sync)] + let host_mapping = Arc::new(HostMapping { + ptr: addr as *mut u8, + size: total_size, + }); + + let dirty_page_tracker = Arc::new(Mutex::new(Some(DirtyPageTracker::new(Arc::clone( + &host_mapping, + ))?))); + Ok(Self { - // HostMapping is only non-Send/Sync because raw pointers - // are not ("as a lint", as the Rust docs say). We don't - // want to mark HostMapping Send/Sync immediately, because - // that could socially imply that it's "safe" to use - // unsafe accesses from multiple threads at once. Instead, we - // directly impl Send and Sync on this type. Since this - // type does have Send and Sync manually impl'd, the Arc - // is not pointless as the lint suggests. - #[allow(clippy::arc_with_non_send_sync)] - region: Arc::new(HostMapping { - ptr: addr as *mut u8, - size: total_size, - }), + region: host_mapping, + signal_dirty_bitmap_tracker: dirty_page_tracker, }) } + #[cfg(test)] + pub(crate) fn get_dirty_pages(&self) -> Result> { + self.signal_dirty_bitmap_tracker + .try_lock() + .map_err(|_| new_error!("Failed to acquire lock on dirty page tracker"))? + .as_ref() + .ok_or_else(|| { + new_error!("Dirty page tracker was not initialized, cannot get dirty pages") + })? + .get_dirty_pages() + .map_err(|e| new_error!("Failed to get dirty pages: {}", e)) + } + + /// Stop tracking dirty pages in the shared memory region. + pub(crate) fn stop_tracking_dirty_pages(&mut self) -> Result> { + self.signal_dirty_bitmap_tracker + .try_lock() + .map_err(|_| new_error!("Failed to acquire lock on dirty page tracker"))? + .take() + .ok_or_else(|| { + new_error!( + "Dirty page tracker was not initialized, cannot stop tracking dirty pages" + ) + })? + .uninstall() + .map_err(|e| new_error!("Failed to stop tracking dirty pages: {}", e)) + } + /// Create a new region of shared memory with the given minimum /// size in bytes. The region will be surrounded by guard pages. /// @@ -484,21 +524,28 @@ impl ExclusiveSharedMemory { log_then_return!(WindowsAPIError(e.clone())); } + // HostMapping is only non-Send/Sync because raw pointers + // are not ("as a lint", as the Rust docs say). We don't + // want to mark HostMapping Send/Sync immediately, because + // that could socially imply that it's "safe" to use + // unsafe accesses from multiple threads at once. Instead, we + // directly impl Send and Sync on this type. Since this + // type does have Send and Sync manually impl'd, the Arc + // is not pointless as the lint suggests. + #[allow(clippy::arc_with_non_send_sync)] + let host_mapping = Arc::new(HostMapping { + ptr: addr.Value as *mut u8, + size: total_size, + handle, + }); + + let dirty_page_tracker = Arc::new(Mutex::new(Some(DirtyPageTracker::new(Arc::clone( + &host_mapping, + ))?))); + Ok(Self { - // HostMapping is only non-Send/Sync because raw pointers - // are not ("as a lint", as the Rust docs say). We don't - // want to mark HostMapping Send/Sync immediately, because - // that could socially imply that it's "safe" to use - // unsafe accesses from multiple threads at once. Instead, we - // directly impl Send and Sync on this type. Since this - // type does have Send and Sync manually impl'd, the Arc - // is not pointless as the lint suggests. - #[allow(clippy::arc_with_non_send_sync)] - region: Arc::new(HostMapping { - ptr: addr.Value as *mut u8, - size: total_size, - handle, - }), + region: host_mapping, + signal_dirty_bitmap_tracker: dirty_page_tracker, }) } @@ -613,6 +660,15 @@ impl ExclusiveSharedMemory { Ok(()) } + /// Copies bytes from `self` to `dst` starting at offset + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub fn copy_to_slice(&self, dst: &mut [u8], offset: usize) -> Result<()> { + let data = self.as_slice(); + bounds_check!(offset, dst.len(), data.len()); + dst.copy_from_slice(&data[offset..offset + dst.len()]); + Ok(()) + } + /// Return the address of memory at an offset to this `SharedMemory` checking /// that the memory is within the bounds of the `SharedMemory`. #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] @@ -621,6 +677,15 @@ impl ExclusiveSharedMemory { Ok(self.base_addr() + offset) } + /// Fill the memory in the range `[offset, offset + len)` with `value` + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub fn zero_fill(&mut self, offset: usize, len: usize) -> Result<()> { + bounds_check!(offset, len, self.mem_size()); + let data = self.as_mut_slice(); + data[offset..offset + len].fill(0); + Ok(()) + } + generate_reader!(read_u8, u8); generate_reader!(read_i8, i8); generate_reader!(read_u16, u16); @@ -654,10 +719,12 @@ impl ExclusiveSharedMemory { ( HostSharedMemory { region: self.region.clone(), + signal_dirty_bitmap_tracker: self.signal_dirty_bitmap_tracker.clone(), lock: lock.clone(), }, GuestSharedMemory { region: self.region.clone(), + signal_dirty_bitmap_tracker: self.signal_dirty_bitmap_tracker.clone(), lock: lock.clone(), }, ) @@ -740,6 +807,7 @@ impl SharedMemory for GuestSharedMemory { fn region(&self) -> &HostMapping { &self.region } + fn with_exclusivity T>( &mut self, f: F, @@ -750,6 +818,7 @@ impl SharedMemory for GuestSharedMemory { .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?; let mut excl = ExclusiveSharedMemory { region: self.region.clone(), + signal_dirty_bitmap_tracker: self.signal_dirty_bitmap_tracker.clone(), }; let ret = f(&mut excl); drop(excl); @@ -982,6 +1051,7 @@ impl SharedMemory for HostSharedMemory { fn region(&self) -> &HostMapping { &self.region } + fn with_exclusivity T>( &mut self, f: F, @@ -992,6 +1062,7 @@ impl SharedMemory for HostSharedMemory { .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?; let mut excl = ExclusiveSharedMemory { region: self.region.clone(), + signal_dirty_bitmap_tracker: self.signal_dirty_bitmap_tracker.clone(), }; let ret = f(&mut excl); drop(excl); diff --git a/src/hyperlight_host/src/mem/shared_mem_snapshot.rs b/src/hyperlight_host/src/mem/shared_mem_snapshot.rs deleted file mode 100644 index ac2bdc6b5..000000000 --- a/src/hyperlight_host/src/mem/shared_mem_snapshot.rs +++ /dev/null @@ -1,104 +0,0 @@ -/* -Copyright 2025 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -use tracing::{Span, instrument}; - -use super::shared_mem::SharedMemory; -use crate::Result; - -/// A wrapper around a `SharedMemory` reference and a snapshot -/// of the memory therein -#[derive(Clone)] -pub(super) struct SharedMemorySnapshot { - snapshot: Vec, - /// How many non-main-RAM regions were mapped when this snapshot was taken? - mapped_rgns: u64, -} - -impl SharedMemorySnapshot { - /// Take a snapshot of the memory in `shared_mem`, then create a new - /// instance of `Self` with the snapshot stored therein. - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn new(shared_mem: &mut S, mapped_rgns: u64) -> Result { - // TODO: Track dirty pages instead of copying entire memory - let snapshot = shared_mem.with_exclusivity(|e| e.copy_all_to_vec())??; - Ok(Self { - snapshot, - mapped_rgns, - }) - } - - /// Take another snapshot of the internally-stored `SharedMemory`, - /// then store it internally. - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - - pub(super) fn replace_snapshot(&mut self, shared_mem: &mut S) -> Result<()> { - self.snapshot = shared_mem.with_exclusivity(|e| e.copy_all_to_vec())??; - Ok(()) - } - - /// Copy the memory from the internally-stored memory snapshot - /// into the internally-stored `SharedMemory` - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn restore_from_snapshot( - &mut self, - shared_mem: &mut S, - ) -> Result { - shared_mem.with_exclusivity(|e| e.copy_from_slice(self.snapshot.as_slice(), 0))??; - Ok(self.mapped_rgns) - } -} - -#[cfg(test)] -mod tests { - use hyperlight_common::mem::PAGE_SIZE_USIZE; - - use crate::mem::shared_mem::ExclusiveSharedMemory; - - #[test] - fn restore_replace() { - let mut data1 = vec![b'a', b'b', b'c']; - data1.resize_with(PAGE_SIZE_USIZE, || 0); - let data2 = data1.iter().map(|b| b + 1).collect::>(); - let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE).unwrap(); - gm.copy_from_slice(data1.as_slice(), 0).unwrap(); - let mut snap = super::SharedMemorySnapshot::new(&mut gm, 0).unwrap(); - { - // after the first snapshot is taken, make sure gm has the equivalent - // of data1 - assert_eq!(data1, gm.copy_all_to_vec().unwrap()); - } - - { - // modify gm with data2 rather than data1 and restore from - // snapshot. we should have the equivalent of data1 again - gm.copy_from_slice(data2.as_slice(), 0).unwrap(); - assert_eq!(data2, gm.copy_all_to_vec().unwrap()); - snap.restore_from_snapshot(&mut gm).unwrap(); - assert_eq!(data1, gm.copy_all_to_vec().unwrap()); - } - { - // modify gm with data2, then retake the snapshot and restore - // from the new snapshot. we should have the equivalent of data2 - gm.copy_from_slice(data2.as_slice(), 0).unwrap(); - assert_eq!(data2, gm.copy_all_to_vec().unwrap()); - snap.replace_snapshot(&mut gm).unwrap(); - assert_eq!(data2, gm.copy_all_to_vec().unwrap()); - snap.restore_from_snapshot(&mut gm).unwrap(); - assert_eq!(data2, gm.copy_all_to_vec().unwrap()); - } - } -} diff --git a/src/hyperlight_host/src/mem/shared_memory_snapshot_manager.rs b/src/hyperlight_host/src/mem/shared_memory_snapshot_manager.rs new file mode 100644 index 000000000..49d588594 --- /dev/null +++ b/src/hyperlight_host/src/mem/shared_memory_snapshot_manager.rs @@ -0,0 +1,1286 @@ +/* +Copyright 2025 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use hyperlight_common::mem::{PAGE_SIZE_USIZE, PAGES_IN_BLOCK}; +use tracing::{Span, instrument}; + +use super::page_snapshot::PageSnapshot; +use super::shared_mem::SharedMemory; +use crate::mem::bitmap::{bit_index_iterator, bitmap_union}; +use crate::mem::layout::SandboxMemoryLayout; +use crate::{Result, new_error}; + +/// A wrapper around a `SharedMemory` reference and a snapshot +/// of the memory therein +pub(super) struct SharedMemorySnapshotManager { + /// A vector of snapshots, each snapshot contains only the dirty pages in a compact format. + /// The initial snapshot is a delta from zeroing the memory on allocation + /// Subsequent snapshots are deltas from the previous state (i.e. only the dirty pages are stored) + snapshots: Vec, + /// The offsets of the input and output data buffers in the memory layout are stored + /// this allows us to reset the input and output buffers to their initial state (i.e. zeroed) + /// each time we restore from a snapshot, EVEN if the input/output buffers are not explicitly marked dirty by the host. + input_data_size: usize, + output_data_size: usize, + output_data_buffer_offset: usize, + input_data_buffer_offset: usize, +} + +impl SharedMemorySnapshotManager { + /// Take a snapshot of the memory in `shared_mem`, then create a new + /// instance of `Self` with the snapshot stored therein. + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub(super) fn new( + shared_mem: &mut S, + layout: &SandboxMemoryLayout, + ) -> Result { + // Get the input output buffer details from the layout so that they can be reset to their initial state + let input_data_size_offset = layout.get_input_data_size_offset(); + let output_data_size_offset = layout.get_output_data_size_offset(); + let output_data_buffer_offset = layout.get_output_data_pointer_offset(); + let input_data_buffer_offset = layout.get_input_data_pointer_offset(); + + // Read the input and output data sizes and pointers from memory + let ( + input_data_size, + output_data_size, + output_data_buffer_offset, + input_data_buffer_offset, + ) = shared_mem.with_exclusivity(|e| -> Result<(usize, usize, usize, usize)> { + Ok(( + e.read_usize(input_data_size_offset)?, + e.read_usize(output_data_size_offset)?, + e.read_usize(output_data_buffer_offset)?, + e.read_usize(input_data_buffer_offset)?, + )) + })??; + + Ok(Self { + snapshots: vec![], + input_data_size, + output_data_size, + output_data_buffer_offset, + input_data_buffer_offset, + }) + } + + pub(super) fn create_new_snapshot( + &mut self, + shared_mem: &mut S, + dirty_page_bitmap: &[u64], + mapped_rgns: u64, + ) -> Result<()> { + if dirty_page_bitmap.is_empty() { + return Err(new_error!( + "Tried to build snapshot from empty dirty page bitmap" + )); + } + + let mut dirty_pages: Vec = bit_index_iterator(dirty_page_bitmap).collect(); + + // Pre-allocate buffer for all pages + let page_count = dirty_pages.len(); + let total_size = page_count * PAGE_SIZE_USIZE; + let mut buffer = vec![0u8; total_size]; + + // if the total size is equal to the shared memory size, we can optimize the copy + if total_size == shared_mem.mem_size() { + // Copy the entire memory region in one go + shared_mem.with_exclusivity(|e| e.copy_to_slice(&mut buffer, 0))??; + } else { + // Sort pages for deterministic ordering and to enable consecutive page optimization + dirty_pages.sort_unstable(); + + let mut buffer_offset = 0; + let mut i = 0; + + while i < dirty_pages.len() { + let start_page = dirty_pages[i]; + let mut consecutive_count = 1; + + // Find consecutive pages + while i + consecutive_count < dirty_pages.len() + && dirty_pages[i + consecutive_count] == start_page + consecutive_count + { + consecutive_count += 1; + } + + // Calculate memory positions + let memory_offset = start_page * PAGE_SIZE_USIZE; + let copy_size = consecutive_count * PAGE_SIZE_USIZE; + let buffer_end = buffer_offset + copy_size; + + // Single copy operation for consecutive pages directly into final buffer + shared_mem.with_exclusivity(|e| { + e.copy_to_slice(&mut buffer[buffer_offset..buffer_end], memory_offset) + })??; + // copy_operations += 1; + + buffer_offset += copy_size; + i += consecutive_count; + } + } + + // Create the snapshot with the pre-allocated buffer + let snapshot = PageSnapshot::with_pages_and_buffer(dirty_pages, buffer, mapped_rgns); + self.snapshots.push(snapshot); + Ok(()) + } + + /// Copy the memory from the internally-stored memory snapshot + /// into the internally-stored `SharedMemory` + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub(super) fn restore_from_snapshot( + &mut self, + shared_mem: &mut S, + dirty_bitmap: &[u64], + ) -> Result { + // check the each index in the dirty bitmap and restore only the corresponding pages from the snapshots vector + // starting at the last snapshot look for the page in each snapshot if it exists and restore it + // if it does not exist set the page to zero + if self.snapshots.is_empty() { + return Err(crate::HyperlightError::NoMemorySnapshot); + } + + // Collect dirty pages and sort them for consecutive page optimization + let mut dirty_pages: Vec = bit_index_iterator(dirty_bitmap).collect(); + dirty_pages.sort_unstable(); + + let mut i = 0; + while i < dirty_pages.len() { + let start_page = dirty_pages[i]; + let mut consecutive_count = 1; + + // Find consecutive pages + while i + consecutive_count < dirty_pages.len() + && dirty_pages[i + consecutive_count] == start_page + consecutive_count + { + consecutive_count += 1; + } + + // Build buffer for consecutive pages + let mut buffer = vec![0u8; consecutive_count * PAGE_SIZE_USIZE]; + let mut buffer_offset = 0; + + for page_idx in 0..consecutive_count { + let page = start_page + page_idx; + + // Check for the page in every snapshot starting from the last one + for snapshot in self.snapshots.iter().rev() { + if let Some(data) = snapshot.get_page(page) { + buffer[buffer_offset..buffer_offset + PAGE_SIZE_USIZE] + .copy_from_slice(data); + break; + } + } + + buffer_offset += PAGE_SIZE_USIZE; + + // If the page was not found in any snapshot, it will be now be zero in the buffer as we skip over it above and didnt write any data + // This is the correct state as the page was not dirty in any snapshot which means it should be zeroed (the initial state) + } + + // Single copy operation for all consecutive pages + let memory_offset = start_page * PAGE_SIZE_USIZE; + shared_mem.with_exclusivity(|e| e.copy_from_slice(&buffer, memory_offset))??; + + i += consecutive_count; + } + // Reset input/output buffers these need to set to their initial state each time a snapshot is restored to clear any previous io/data that may be in the buffers + shared_mem.with_exclusivity(|e| { + e.zero_fill(self.input_data_buffer_offset, self.input_data_size)?; + e.zero_fill(self.output_data_buffer_offset, self.output_data_size)?; + e.write_u64( + self.input_data_buffer_offset, + SandboxMemoryLayout::STACK_POINTER_SIZE_BYTES, + )?; + e.write_u64( + self.output_data_buffer_offset, + SandboxMemoryLayout::STACK_POINTER_SIZE_BYTES, + ) + })??; + + #[allow(clippy::unwrap_used)] + Ok(self.snapshots.last().unwrap().mapped_rgns()) + } + + pub(super) fn pop_and_restore_state_from_snapshot( + &mut self, + shared_mem: &mut S, + dirty_bitmap: &[u64], + ) -> Result { + // Check that there is a snapshot to restore from + if self.snapshots.is_empty() { + return Err(crate::HyperlightError::NoMemorySnapshot); + } + // Get the last snapshot index + let last_snapshot_index = self.snapshots.len() - 1; + let last_snapshot_bitmap = self.get_bitmap_from_snapshot(last_snapshot_index); + // merge the last snapshot bitmap with the dirty bitmap + let merged_bitmap = bitmap_union(&last_snapshot_bitmap, dirty_bitmap); + + // drop the last snapshot then restore the state from the merged bitmap + if self.snapshots.pop().is_none() { + return Err(crate::HyperlightError::NoMemorySnapshot); + } + + // restore the state from the last snapshot + self.restore_from_snapshot(shared_mem, &merged_bitmap) + } + + fn get_bitmap_from_snapshot(&self, snapshot_index: usize) -> Vec { + // Get the snapshot at the given index + if snapshot_index < self.snapshots.len() { + let snapshot = &self.snapshots[snapshot_index]; + // Create a bitmap from the snapshot + let max_page = snapshot.max_page().unwrap_or_default(); + let num_blocks = max_page.div_ceil(PAGES_IN_BLOCK); + let mut bitmap = vec![0u64; num_blocks]; + for page in snapshot.page_numbers() { + let block = page / PAGES_IN_BLOCK; + let offset = page % PAGES_IN_BLOCK; + if block < bitmap.len() { + bitmap[block] |= 1 << offset; + } + } + bitmap + } else { + vec![] + } + } +} + +#[cfg(test)] +mod tests { + use hyperlight_common::mem::PAGE_SIZE_USIZE; + + use super::super::layout::SandboxMemoryLayout; + use crate::mem::bitmap::new_page_bitmap; + use crate::mem::shared_mem::{ExclusiveSharedMemory, SharedMemory}; + use crate::sandbox::SandboxConfiguration; + + fn create_test_layout() -> SandboxMemoryLayout { + let cfg = SandboxConfiguration::default(); + // Create a layout with large init_data area for testing (64KB for plenty of test pages) + let init_data_size = 64 * 1024; // 64KB = 16 pages of 4KB each + SandboxMemoryLayout::new(cfg, 4096, 16384, 16384, init_data_size, None).unwrap() + } + + fn create_test_shared_memory_with_layout( + layout: &SandboxMemoryLayout, + ) -> ExclusiveSharedMemory { + let memory_size = layout.get_memory_size().unwrap(); + let mut shared_mem = ExclusiveSharedMemory::new(memory_size).unwrap(); + + // Initialize the memory with the full layout to ensure it's properly set up + layout + .write( + &mut shared_mem, + SandboxMemoryLayout::BASE_ADDRESS, + memory_size, + ) + .unwrap(); + + shared_mem + } + + /// Get safe memory area for testing - uses init_data area which is safe to modify + fn get_safe_test_area( + layout: &SandboxMemoryLayout, + shared_mem: &mut ExclusiveSharedMemory, + ) -> (usize, usize) { + // The init_data area is positioned after the guest stack in the memory layout + // We can safely use this area for testing as it's designed for initialization data + // Read the actual init_data buffer offset and size from memory + let init_data_size_offset = layout.get_init_data_size_offset(); + let init_data_pointer_offset = layout.get_init_data_pointer_offset(); + + let (init_data_size, init_data_buffer_offset) = shared_mem + .with_exclusivity(|e| -> crate::Result<(usize, usize)> { + Ok(( + e.read_usize(init_data_size_offset)?, + e.read_usize(init_data_pointer_offset)?, + )) + }) + .unwrap() + .unwrap(); + + (init_data_buffer_offset, init_data_size) + } + + #[test] + fn test_single_snapshot_restore() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Use a safe page well within the init_data area + let safe_offset = init_data_offset + PAGE_SIZE_USIZE; // Skip first page for extra safety + + // Ensure we have enough space for testing + assert!( + init_data_size >= 2 * PAGE_SIZE_USIZE, + "Init data area too small for testing: {} bytes", + init_data_size + ); + assert!( + safe_offset + PAGE_SIZE_USIZE <= init_data_offset + init_data_size, + "Safe offset exceeds init_data bounds" + ); + + // Initial data - only initialize safe page, leave other pages as zero + let initial_data = vec![0xAA; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&initial_data, safe_offset) + .unwrap(); + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec = shared_mem.get_dirty_pages().unwrap(); + + // Convert to bitmap format + let mut dirty_pages = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages.len() { + dirty_pages[block] |= 1 << bit; + } + } + + // Create snapshot + let mut snapshot_manager = + super::SharedMemorySnapshotManager::new(&mut shared_mem, &layout).unwrap(); + + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages, 0) + .unwrap(); + + // Modify memory + let modified_data = vec![0xBB; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&modified_data, safe_offset) + .unwrap(); + + // Verify modification + let mut current_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_data, safe_offset) + .unwrap(); + assert_eq!(current_data, modified_data); + + // Restore from snapshot + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages) + .unwrap(); + + // Verify restoration + let mut restored_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_data, safe_offset) + .unwrap(); + assert_eq!(restored_data, initial_data); + } + + #[test] + fn test_multiple_snapshots_and_restores() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Use a safe page well within the init_data area + let safe_offset = init_data_offset + PAGE_SIZE_USIZE; // Skip first page for extra safety + + // Ensure we have enough space for testing + assert!( + init_data_size >= 2 * PAGE_SIZE_USIZE, + "Init data area too small for testing" + ); + assert!( + safe_offset + PAGE_SIZE_USIZE <= init_data_offset + init_data_size, + "Safe offset exceeds init_data bounds" + ); + + // State 1: Initial state + let state1_data = vec![0x11; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&state1_data, safe_offset) + .unwrap(); + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec = shared_mem.get_dirty_pages().unwrap(); + + // Convert to bitmap format + let mut dirty_pages = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages.len() { + dirty_pages[block] |= 1 << bit; + } + } + + // Create initial snapshot (State 1) + let mut snapshot_manager = + super::SharedMemorySnapshotManager::new(&mut shared_mem, &layout).unwrap(); + + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages, 0) + .unwrap(); + + // State 2: Modify and create second snapshot + let state2_data = vec![0x22; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&state2_data, safe_offset) + .unwrap(); + let dirty_pages_vec2 = shared_mem.get_dirty_pages().unwrap(); + + let mut dirty_pages2 = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec2 { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages2.len() { + dirty_pages2[block] |= 1 << bit; + } + } + + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages2, 0) + .unwrap(); + + // State 3: Modify again + let state3_data = vec![0x33; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&state3_data, safe_offset) + .unwrap(); + + // Verify we're in state 3 + let mut current_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_data, safe_offset) + .unwrap(); + assert_eq!(current_data, state3_data); + + // Restore to state 2 (most recent snapshot) + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages2) + .unwrap(); + let mut restored_data_state2 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_data_state2, safe_offset) + .unwrap(); + assert_eq!(restored_data_state2, state2_data); + + // Pop state 2 and restore to state 1 + snapshot_manager + .pop_and_restore_state_from_snapshot(&mut shared_mem, &dirty_pages) + .unwrap(); + let mut restored_data_state1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_data_state1, safe_offset) + .unwrap(); + assert_eq!(restored_data_state1, state1_data); + } + + #[test] + fn test_multiple_pages_snapshot_restore() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Ensure we have enough space for 4 test pages + assert!( + init_data_size >= 6 * PAGE_SIZE_USIZE, + "Init data area too small for testing multiple pages" + ); + + // Use page offsets within the init_data area, skipping first page for safety + let base_page = (init_data_offset + PAGE_SIZE_USIZE) / PAGE_SIZE_USIZE; + let page_offsets = [base_page, base_page + 1, base_page + 2, base_page + 3]; + + let page_data = [ + vec![0xAA; PAGE_SIZE_USIZE], + vec![0xBB; PAGE_SIZE_USIZE], + vec![0xCC; PAGE_SIZE_USIZE], + vec![0xDD; PAGE_SIZE_USIZE], + ]; + + // Start tracking dirty pages + + // Initialize data in init_data pages + for (i, &page_offset) in page_offsets.iter().enumerate() { + let offset = page_offset * PAGE_SIZE_USIZE; + assert!( + offset + PAGE_SIZE_USIZE <= shared_mem.mem_size(), + "Page offset {} exceeds memory bounds", + page_offset + ); + assert!( + offset >= init_data_offset + && offset + PAGE_SIZE_USIZE <= init_data_offset + init_data_size, + "Page offset {} is outside init_data bounds", + page_offset + ); + shared_mem.copy_from_slice(&page_data[i], offset).unwrap(); + } + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec = shared_mem.get_dirty_pages().unwrap(); + + // Convert to bitmap format + let mut dirty_pages = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages.len() { + dirty_pages[block] |= 1 << bit; + } + } + + // Create snapshot + let mut snapshot_manager = + super::SharedMemorySnapshotManager::new(&mut shared_mem, &layout).unwrap(); + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages, 0) + .unwrap(); + + // Modify first and third pages + let modified_data = [vec![0x11; PAGE_SIZE_USIZE], vec![0x22; PAGE_SIZE_USIZE]]; + shared_mem + .copy_from_slice(&modified_data[0], page_offsets[0] * PAGE_SIZE_USIZE) + .unwrap(); + shared_mem + .copy_from_slice(&modified_data[1], page_offsets[2] * PAGE_SIZE_USIZE) + .unwrap(); + + // Restore from snapshot + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages) + .unwrap(); + + // Verify restoration + for (i, &page_offset) in page_offsets.iter().enumerate() { + let mut restored_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_data, page_offset * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!( + restored_data, page_data[i], + "Page {} should be restored to original data", + i + ); + } + } + + #[test] + fn test_sequential_modifications_with_snapshots() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Use safe page offsets within init_data area + let safe_offset1 = init_data_offset + PAGE_SIZE_USIZE; // Skip first page for safety + let safe_offset2 = init_data_offset + 2 * PAGE_SIZE_USIZE; + + // Ensure we have enough space for testing + assert!( + init_data_size >= 3 * PAGE_SIZE_USIZE, + "Init data area too small for testing" + ); + assert!( + safe_offset2 + PAGE_SIZE_USIZE <= init_data_offset + init_data_size, + "Safe offsets exceed init_data bounds" + ); + + // Start tracking dirty pages + + // Cycle 1: Set initial data + let cycle1_page0 = (0..PAGE_SIZE_USIZE) + .map(|i| (i % 256) as u8) + .collect::>(); + let cycle1_page1 = vec![0x01; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&cycle1_page0, safe_offset1) + .unwrap(); + shared_mem + .copy_from_slice(&cycle1_page1, safe_offset2) + .unwrap(); + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec1 = shared_mem.get_dirty_pages().unwrap(); + + // Convert to bitmap format + let mut dirty_pages = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec1 { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages.len() { + dirty_pages[block] |= 1 << bit; + } + } + + let mut snapshot_manager = + super::SharedMemorySnapshotManager::new(&mut shared_mem, &layout).unwrap(); + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages, 0) + .unwrap(); + + // Cycle 2: Modify and snapshot + let cycle2_page0 = vec![0x02; PAGE_SIZE_USIZE]; + let cycle2_page1 = (0..PAGE_SIZE_USIZE) + .map(|i| ((i + 100) % 256) as u8) + .collect::>(); + shared_mem + .copy_from_slice(&cycle2_page0, safe_offset1) + .unwrap(); + shared_mem + .copy_from_slice(&cycle2_page1, safe_offset2) + .unwrap(); + + let dirty_pages_vec2 = shared_mem.get_dirty_pages().unwrap(); + + let mut dirty_pages2 = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec2 { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages2.len() { + dirty_pages2[block] |= 1 << bit; + } + } + + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages2, 0) + .unwrap(); + + // Cycle 3: Modify again + let cycle3_page0 = vec![0x03; PAGE_SIZE_USIZE]; + let cycle3_page1 = vec![0x33; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&cycle3_page0, safe_offset1) + .unwrap(); + shared_mem + .copy_from_slice(&cycle3_page1, safe_offset2) + .unwrap(); + + // Verify current state (cycle 3) + let mut current_page0 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_page0, safe_offset1) + .unwrap(); + let mut current_page1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_page1, safe_offset2) + .unwrap(); + assert_eq!(current_page0, cycle3_page0); + assert_eq!(current_page1, cycle3_page1); + + // Restore to cycle 2 + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages2) + .unwrap(); + let mut restored_page0 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page0, safe_offset1) + .unwrap(); + let mut restored_page1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page1, safe_offset2) + .unwrap(); + assert_eq!(restored_page0, cycle2_page0); + assert_eq!(restored_page1, cycle2_page1); + + // Pop cycle 2 and restore to cycle 1 + snapshot_manager + .pop_and_restore_state_from_snapshot(&mut shared_mem, &dirty_pages) + .unwrap(); + let mut restored_page0 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page0, safe_offset1) + .unwrap(); + let mut restored_page1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page1, safe_offset2) + .unwrap(); + assert_eq!(restored_page0, cycle1_page0); + assert_eq!(restored_page1, cycle1_page1); + } + + #[test] + fn test_restore_with_zero_pages() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Ensure we have enough space for testing + assert!( + init_data_size >= 3 * PAGE_SIZE_USIZE, + "Init data area too small for testing" + ); + + // Only initialize one page in the init_data area + let page1_offset = init_data_offset + PAGE_SIZE_USIZE; // Skip first page for safety + let page1_data = vec![0xFF; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&page1_data, page1_offset) + .unwrap(); + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec = shared_mem.stop_tracking_dirty_pages().unwrap(); + + // Convert to bitmap format + let mut dirty_pages_snapshot = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages_snapshot.len() { + dirty_pages_snapshot[block] |= 1 << bit; + } + } + + let mut snapshot_manager = + super::SharedMemorySnapshotManager::new(&mut shared_mem, &layout).unwrap(); + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages_snapshot, 0) + .unwrap(); + + // Modify pages in init_data area + let page0_offset = init_data_offset; + let page2_offset = init_data_offset + 2 * PAGE_SIZE_USIZE; + + let modified_page0 = vec![0xAA; PAGE_SIZE_USIZE]; + let modified_page1 = vec![0xBB; PAGE_SIZE_USIZE]; + let modified_page2 = vec![0xCC; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&modified_page0, page0_offset) + .unwrap(); + shared_mem + .copy_from_slice(&modified_page1, page1_offset) + .unwrap(); + shared_mem + .copy_from_slice(&modified_page2, page2_offset) + .unwrap(); + + // Create dirty page map for all test pages + let mut dirty_pages_restore = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + let page0_idx = page0_offset / PAGE_SIZE_USIZE; + let page1_idx = page1_offset / PAGE_SIZE_USIZE; + let page2_idx = page2_offset / PAGE_SIZE_USIZE; + + // Mark all test pages as dirty for restore + for &page_idx in &[page0_idx, page1_idx, page2_idx] { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages_restore.len() { + dirty_pages_restore[block] |= 1 << bit; + } + } + + // Restore from snapshot + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages_restore) + .unwrap(); + + // Verify restoration + let mut restored_page0 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page0, page0_offset) + .unwrap(); + let mut restored_page1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page1, page1_offset) + .unwrap(); + let mut restored_page2 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page2, page2_offset) + .unwrap(); + + // Page 0 and 2 should be zeroed (not in snapshot), page 1 should be restored + assert_eq!(restored_page0, vec![0u8; PAGE_SIZE_USIZE]); + assert_eq!(restored_page1, page1_data); + assert_eq!(restored_page2, vec![0u8; PAGE_SIZE_USIZE]); + } + + #[test] + fn test_empty_snapshot_error() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + let memory_size = shared_mem.mem_size(); + + // Create snapshot manager with no snapshots + let mut snapshot_manager = super::SharedMemorySnapshotManager { + snapshots: vec![], + input_data_size: 0, + output_data_size: 0, + output_data_buffer_offset: 0, + input_data_buffer_offset: 0, + }; + + let dirty_pages = new_page_bitmap(memory_size, true).unwrap(); + + // Should return error when trying to restore from empty snapshots + let result = snapshot_manager.restore_from_snapshot(&mut shared_mem, &dirty_pages); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + crate::HyperlightError::NoMemorySnapshot + )); + + // Should return error when trying to pop from empty snapshots + let result = + snapshot_manager.pop_and_restore_state_from_snapshot(&mut shared_mem, &dirty_pages); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + crate::HyperlightError::NoMemorySnapshot + )); + } + + #[test] + fn test_complex_workflow_simulation() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Ensure we have enough space for 4 test pages + assert!( + init_data_size >= 6 * PAGE_SIZE_USIZE, + "Init data area too small for testing" + ); + + // Start tracking dirty pages + + // Use the init_data area - this is safe and won't interfere with other layout structures + let base_page = (init_data_offset + PAGE_SIZE_USIZE) / PAGE_SIZE_USIZE; // Skip first page for safety + let page_offsets = [base_page, base_page + 1, base_page + 2, base_page + 3]; + + // Initialize memory with pattern in init_data area + for (i, &page_offset) in page_offsets.iter().enumerate() { + let data = vec![i as u8; PAGE_SIZE_USIZE]; + let offset = page_offset * PAGE_SIZE_USIZE; + assert!( + offset + PAGE_SIZE_USIZE <= shared_mem.mem_size(), + "Page offset {} exceeds memory bounds", + page_offset + ); + assert!( + offset >= init_data_offset + && offset + PAGE_SIZE_USIZE <= init_data_offset + init_data_size, + "Page offset {} is outside init_data bounds", + page_offset + ); + shared_mem.copy_from_slice(&data, offset).unwrap(); + } + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec = shared_mem.get_dirty_pages().unwrap(); + + // Convert to bitmap format + let mut dirty_pages = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages.len() { + dirty_pages[block] |= 1 << bit; + } + } + + // Create initial checkpoint + let mut snapshot_manager = + super::SharedMemorySnapshotManager::new(&mut shared_mem, &layout).unwrap(); + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages, 0) + .unwrap(); + + // Simulate function call 1: modify pages 0 and 2 + let func1_page0 = vec![0x10; PAGE_SIZE_USIZE]; + let func1_page2 = vec![0x12; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&func1_page0, page_offsets[0] * PAGE_SIZE_USIZE) + .unwrap(); + shared_mem + .copy_from_slice(&func1_page2, page_offsets[2] * PAGE_SIZE_USIZE) + .unwrap(); + + let dirty_pages_vec1 = shared_mem.get_dirty_pages().unwrap(); + + let mut dirty_pages1 = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec1 { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages1.len() { + dirty_pages1[block] |= 1 << bit; + } + } + + // Checkpoint after function 1 + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages1, 0) + .unwrap(); + + // Simulate function call 2: modify pages 1 and 3 + let func2_page1 = vec![0x21; PAGE_SIZE_USIZE]; + let func2_page3 = vec![0x23; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&func2_page1, page_offsets[1] * PAGE_SIZE_USIZE) + .unwrap(); + shared_mem + .copy_from_slice(&func2_page3, page_offsets[3] * PAGE_SIZE_USIZE) + .unwrap(); + + let dirty_pages_vec2 = shared_mem.get_dirty_pages().unwrap(); + + let mut dirty_pages2 = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec2 { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages2.len() { + dirty_pages2[block] |= 1 << bit; + } + } + + // Checkpoint after function 2 + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages2, 0) + .unwrap(); + + // Simulate function call 3: modify all pages + for (i, &page_offset) in page_offsets.iter().enumerate() { + let data = vec![0x30 + i as u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&data, page_offset * PAGE_SIZE_USIZE) + .unwrap(); + } + + // Verify current state (after function 3) + for (i, &page_offset) in page_offsets.iter().enumerate() { + let mut current = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current, page_offset * PAGE_SIZE_USIZE) + .unwrap(); + let expected = vec![0x30 + i as u8; PAGE_SIZE_USIZE]; + assert_eq!(current, expected); + } + + // Create a bitmap that includes all pages that were modified in function 3 + let mut dirty_pages_all_func3 = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for &page_offset in &page_offsets { + let block = page_offset / 64; + let bit = page_offset % 64; + if block < dirty_pages_all_func3.len() { + dirty_pages_all_func3[block] |= 1 << bit; + } + } + + // Rollback to after function 2 + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages_all_func3) + .unwrap(); + + // Verify state after function 2 + let mut page0_after_func2 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page0_after_func2, page_offsets[0] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page0_after_func2, func1_page0); + + let mut page1_after_func2 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page1_after_func2, page_offsets[1] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page1_after_func2, func2_page1); + + let mut page2_after_func2 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page2_after_func2, page_offsets[2] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page2_after_func2, func1_page2); + + let mut page3_after_func2 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page3_after_func2, page_offsets[3] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page3_after_func2, func2_page3); + + // Rollback to after function 1 + // Need to create a bitmap that includes all pages that could have been modified + let mut combined_dirty_pages1 = dirty_pages.clone(); + for i in 0..combined_dirty_pages1.len().min(dirty_pages1.len()) { + combined_dirty_pages1[i] |= dirty_pages1[i]; + } + for i in 0..combined_dirty_pages1.len().min(dirty_pages2.len()) { + combined_dirty_pages1[i] |= dirty_pages2[i]; + } + + snapshot_manager + .pop_and_restore_state_from_snapshot(&mut shared_mem, &combined_dirty_pages1) + .unwrap(); + + // Verify state after function 1 + let mut page0_after_func1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page0_after_func1, page_offsets[0] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page0_after_func1, func1_page0); + + let mut page1_after_func1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page1_after_func1, page_offsets[1] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page1_after_func1, vec![1u8; PAGE_SIZE_USIZE]); // Original + + let mut page2_after_func1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page2_after_func1, page_offsets[2] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page2_after_func1, func1_page2); + + let mut page3_after_func1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page3_after_func1, page_offsets[3] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page3_after_func1, vec![3u8; PAGE_SIZE_USIZE]); // Original + + // Rollback to initial state + // Need to create a bitmap that includes all pages that could have been modified + let mut combined_dirty_pages_all = dirty_pages.clone(); + for i in 0..combined_dirty_pages_all.len().min(dirty_pages1.len()) { + combined_dirty_pages_all[i] |= dirty_pages1[i]; + } + for i in 0..combined_dirty_pages_all.len().min(dirty_pages2.len()) { + combined_dirty_pages_all[i] |= dirty_pages2[i]; + } + + snapshot_manager + .pop_and_restore_state_from_snapshot(&mut shared_mem, &combined_dirty_pages_all) + .unwrap(); + + // Verify initial state + for (i, &page_offset) in page_offsets.iter().enumerate() { + let mut current = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current, page_offset * PAGE_SIZE_USIZE) + .unwrap(); + let expected = vec![i as u8; PAGE_SIZE_USIZE]; + assert_eq!(current, expected); + } + } + + #[test] + fn test_unchanged_data_verification() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Ensure we have enough space for 6 test pages + assert!( + init_data_size >= 8 * PAGE_SIZE_USIZE, + "Init data area too small for testing" + ); + + // Start tracking dirty pages + + // Initialize all pages with different patterns - use safe offsets within init_data area + let base_page = (init_data_offset + PAGE_SIZE_USIZE) / PAGE_SIZE_USIZE; // Skip first page for safety + let page_offsets = [ + base_page, + base_page + 1, + base_page + 2, + base_page + 3, + base_page + 4, + base_page + 5, + ]; + let initial_patterns = [ + vec![0xAA; PAGE_SIZE_USIZE], // Page 0 + vec![0xBB; PAGE_SIZE_USIZE], // Page 1 + vec![0xCC; PAGE_SIZE_USIZE], // Page 2 + vec![0xDD; PAGE_SIZE_USIZE], // Page 3 + vec![0xEE; PAGE_SIZE_USIZE], // Page 4 + vec![0xFF; PAGE_SIZE_USIZE], // Page 5 + ]; + + for (i, pattern) in initial_patterns.iter().enumerate() { + let offset = page_offsets[i] * PAGE_SIZE_USIZE; + assert!( + offset + PAGE_SIZE_USIZE <= shared_mem.mem_size(), + "Page offset {} exceeds memory bounds", + page_offsets[i] + ); + assert!( + offset >= init_data_offset + && offset + PAGE_SIZE_USIZE <= init_data_offset + init_data_size, + "Page offset {} is outside init_data bounds", + page_offsets[i] + ); + shared_mem.copy_from_slice(pattern, offset).unwrap(); + } + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec = shared_mem.get_dirty_pages().unwrap(); + + // Convert to bitmap format - only track specific pages (1, 3, 5) + let mut dirty_pages = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + let tracked_pages = [1, 3, 5]; // Only track these pages for snapshot + for &tracked_page_idx in &tracked_pages { + let actual_page = page_offsets[tracked_page_idx]; + if dirty_pages_vec.contains(&actual_page) { + let block = actual_page / 64; + let bit = actual_page % 64; + if block < dirty_pages.len() { + dirty_pages[block] |= 1 << bit; + } + } + } + + // Create snapshot + let mut snapshot_manager = + super::SharedMemorySnapshotManager::new(&mut shared_mem, &layout).unwrap(); + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages, 0) + .unwrap(); + + // Modify only the dirty pages + let modified_patterns = [ + vec![0x11; PAGE_SIZE_USIZE], // Page 1 modified + vec![0x33; PAGE_SIZE_USIZE], // Page 3 modified + vec![0x55; PAGE_SIZE_USIZE], // Page 5 modified + ]; + + shared_mem + .copy_from_slice(&modified_patterns[0], page_offsets[1] * PAGE_SIZE_USIZE) + .unwrap(); + shared_mem + .copy_from_slice(&modified_patterns[1], page_offsets[3] * PAGE_SIZE_USIZE) + .unwrap(); + shared_mem + .copy_from_slice(&modified_patterns[2], page_offsets[5] * PAGE_SIZE_USIZE) + .unwrap(); + + // Verify that untracked pages (0, 2, 4) remain unchanged + let unchanged_pages = [0, 2, 4]; + for &page_idx in &unchanged_pages { + let mut current_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_data, page_offsets[page_idx] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!( + current_data, initial_patterns[page_idx], + "Page {} should remain unchanged after modification", + page_idx + ); + } + + // Verify that tracked pages were modified + let changed_pages = [ + (1, &modified_patterns[0]), + (3, &modified_patterns[1]), + (5, &modified_patterns[2]), + ]; + for &(page_idx, expected) in &changed_pages { + let mut current_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_data, page_offsets[page_idx] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!( + current_data, *expected, + "Page {} should be modified", + page_idx + ); + } + + // Restore from snapshot + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages) + .unwrap(); + + // Verify tracked pages are restored to their original state + for &page_idx in &tracked_pages { + let mut restored_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_data, page_offsets[page_idx] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!( + restored_data, initial_patterns[page_idx], + "Page {} should be restored to initial pattern after snapshot restore", + page_idx + ); + } + + // Test partial dirty bitmap restoration + let mut partial_dirty = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + // Only mark page 1 as dirty for restoration + let page1_actual = page_offsets[1]; + let block = page1_actual / 64; + let bit = page1_actual % 64; + if block < partial_dirty.len() { + partial_dirty[block] |= 1 << bit; + } + + // Modify multiple pages again + shared_mem + .copy_from_slice(&modified_patterns[0], page_offsets[1] * PAGE_SIZE_USIZE) + .unwrap(); + shared_mem + .copy_from_slice(&modified_patterns[1], page_offsets[3] * PAGE_SIZE_USIZE) + .unwrap(); + + // Restore with partial dirty bitmap (only page 1) + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &partial_dirty) + .unwrap(); + + // Verify page 1 is restored but page 3 remains modified + let mut page1_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page1_data, page_offsets[1] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page1_data, initial_patterns[1], "Page 1 should be restored"); + + let mut page3_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page3_data, page_offsets[3] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!( + page3_data, modified_patterns[1], + "Page 3 should remain modified since it wasn't in restoration dirty bitmap" + ); + + // Verify all other pages remain in their expected state + for page_idx in [0, 2, 4, 5] { + let mut current_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_data, page_offsets[page_idx] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!( + current_data, initial_patterns[page_idx], + "Page {} should remain in initial state", + page_idx + ); + } + } +} diff --git a/src/hyperlight_host/src/mem/windows_dirty_page_tracker.rs b/src/hyperlight_host/src/mem/windows_dirty_page_tracker.rs new file mode 100644 index 000000000..5177f7f92 --- /dev/null +++ b/src/hyperlight_host/src/mem/windows_dirty_page_tracker.rs @@ -0,0 +1,68 @@ +/* +Copyright 2025 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use std::sync::Arc; + +use hyperlight_common::mem::PAGE_SIZE_USIZE; +use tracing::{Span, instrument}; + +use super::bitmap::{bit_index_iterator, new_page_bitmap}; +use super::dirty_page_tracking::DirtyPageTracking; +use super::shared_mem::HostMapping; +use crate::Result; + +/// Windows implementation of dirty page tracking +#[derive(Debug)] +pub struct WindowsDirtyPageTracker { + size: usize, +} + +// DirtyPageTracker should be Send because: +// 1. The Arc ensures the memory stays valid +// 2. The tracker handles synchronization properly +// 3. This is needed for threaded sandbox initialization +unsafe impl Send for WindowsDirtyPageTracker {} + +impl WindowsDirtyPageTracker { + /// Create a new Windows dirty page tracker + #[instrument(skip_all, parent = Span::current(), level = "Trace")] + pub fn new(mapping: Arc) -> Result { + let size = mapping.size - 2 * PAGE_SIZE_USIZE; // Exclude guard pages at start and end + + Ok(Self { size }) + } +} + +impl DirtyPageTracking for WindowsDirtyPageTracker { + #[cfg(test)] + fn get_dirty_pages(&self) -> Result> { + let bitmap = new_page_bitmap(self.size, true)?; + Ok(bit_index_iterator(&bitmap).collect()) + } + + fn uninstall(self) -> Result> { + let bitmap = new_page_bitmap(self.size, true)?; + Ok(bit_index_iterator(&bitmap).collect()) + } +} + +impl WindowsDirtyPageTracker { + /// Stop tracking dirty pages and return the list of dirty pages + pub fn stop_tracking_and_get_dirty_pages(self) -> Result> { + let bitmap = new_page_bitmap(self.size, true)?; + Ok(bit_index_iterator(&bitmap).collect()) + } +} diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index c2959262b..fc6ffe56a 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -270,7 +270,8 @@ impl MultiUseSandbox { #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] pub(crate) fn restore_state(&mut self) -> Result<()> { let mem_mgr = self.mem_mgr.unwrap_mgr_mut(); - let rgns_to_unmap = mem_mgr.restore_state_from_last_snapshot()?; + let (dirty_sandbox_pages, _) = self.vm.get_and_clear_dirty_pages()?; + let rgns_to_unmap = mem_mgr.restore_state_from_last_snapshot(&dirty_sandbox_pages)?; unsafe { self.vm.unmap_regions(rgns_to_unmap)? }; Ok(()) } @@ -354,10 +355,11 @@ impl DevolvableSandbox) -> Result { + let (dirty_sandbox_pages, _) = self.vm.get_and_clear_dirty_pages()?; let rgns_to_unmap = self .mem_mgr .unwrap_mgr_mut() - .pop_and_restore_state_from_snapshot()?; + .pop_and_restore_state_from_snapshot(&dirty_sandbox_pages)?; unsafe { self.vm.unmap_regions(rgns_to_unmap)? }; Ok(self) } @@ -389,7 +391,10 @@ where let mut ctx = self.new_call_context(); transition_func.call(&mut ctx)?; let mut sbox = ctx.finish_no_reset(); - sbox.mem_mgr.unwrap_mgr_mut().push_state()?; + let (dirty_sandbox_pages, _) = sbox.vm.get_and_clear_dirty_pages()?; + sbox.mem_mgr + .unwrap_mgr_mut() + .push_state(&dirty_sandbox_pages)?; Ok(sbox) } } @@ -736,6 +741,7 @@ mod tests { let len = src.len().div_ceil(PAGE_SIZE_USIZE) * PAGE_SIZE_USIZE; let mut mem = ExclusiveSharedMemory::new(len).unwrap(); + mem.stop_tracking_dirty_pages().unwrap(); mem.copy_from_slice(src, 0).unwrap(); let (_, guest_mem) = mem.build(); diff --git a/src/hyperlight_host/src/sandbox/uninitialized.rs b/src/hyperlight_host/src/sandbox/uninitialized.rs index e27f91ff2..fd5dc1341 100644 --- a/src/hyperlight_host/src/sandbox/uninitialized.rs +++ b/src/hyperlight_host/src/sandbox/uninitialized.rs @@ -250,17 +250,15 @@ impl UninitializedSandbox { } }; - let mut mem_mgr_wrapper = { - let mut mgr = UninitializedSandbox::load_guest_binary( - sandbox_cfg, - &guest_binary, - guest_blob.as_ref(), - )?; - - let stack_guard = Self::create_stack_guard(); - mgr.set_stack_guard(&stack_guard)?; - MemMgrWrapper::new(mgr, stack_guard) - }; + let mut mgr = UninitializedSandbox::load_guest_binary( + sandbox_cfg, + &guest_binary, + guest_blob.as_ref(), + )?; + + let stack_guard = Self::create_stack_guard(); + mgr.set_stack_guard(&stack_guard)?; + let mut mem_mgr_wrapper = MemMgrWrapper::new(mgr, stack_guard); mem_mgr_wrapper.write_memory_layout()?; diff --git a/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs b/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs index a37f747e2..e5fffca07 100644 --- a/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs +++ b/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs @@ -29,12 +29,11 @@ use crate::HyperlightError::NoHypervisorFound; use crate::hypervisor::Hypervisor; use crate::hypervisor::handlers::{MemAccessHandlerCaller, OutBHandlerCaller}; use crate::mem::layout::SandboxMemoryLayout; +use crate::mem::memory_region::MemoryRegion; use crate::mem::mgr::SandboxMemoryManager; use crate::mem::ptr::{GuestPtr, RawPtr}; use crate::mem::ptr_offset::Offset; -use crate::mem::shared_mem::GuestSharedMemory; -#[cfg(feature = "init-paging")] -use crate::mem::shared_mem::SharedMemory; +use crate::mem::shared_mem::{GuestSharedMemory, SharedMemory}; #[cfg(gdb)] use crate::sandbox::config::DebugInfo; use crate::sandbox::host_funcs::FunctionRegistry; @@ -70,15 +69,11 @@ where Arc>, Arc>, RawPtr, + &[usize], // dirty host pages (indices, not bitmap) ) -> Result, { let (hshm, mut gshm) = u_sbox.mgr.build(); - let mut vm = set_up_hypervisor_partition( - &mut gshm, - &u_sbox.config, - #[cfg(any(crashdump, gdb))] - &u_sbox.rt_cfg, - )?; + let outb_hdl = outb_handler_wrapper(hshm.clone(), u_sbox.host_funcs.clone()); let seed = { @@ -99,6 +94,32 @@ where #[cfg(target_os = "linux")] setup_signal_handlers(&u_sbox.config)?; + let regions = gshm.layout.get_memory_regions(&gshm.shared_mem)?; + + // Set up shared memory before stopping dirty page tracking to ensure page table setup is tracked + #[cfg(feature = "init-paging")] + let rsp_ptr = { + let rsp_u64 = gshm.set_up_page_tables(®ions)?; + let rsp_raw = RawPtr::from(rsp_u64); + GuestPtr::try_from(rsp_raw) + }?; + #[cfg(not(feature = "init-paging"))] + let rsp_ptr = GuestPtr::try_from(Offset::from(0))?; + + // before entering VM (and before mapping memory into VM), stop tracking dirty pages from the host side + let dirty_host_pages_idx = gshm + .get_shared_mem_mut() + .with_exclusivity(|e| e.stop_tracking_dirty_pages())??; + + let mut vm = set_up_hypervisor_partition( + &mut gshm, + &u_sbox.config, + rsp_ptr, + regions, + #[cfg(any(crashdump, gdb))] + &u_sbox.rt_cfg, + )?; + vm.initialise( peb_addr, seed, @@ -122,6 +143,7 @@ where outb_hdl, mem_access_hdl, RawPtr::from(dispatch_function_addr), + &dirty_host_pages_idx, ) } @@ -129,9 +151,15 @@ where pub(super) fn evolve_impl_multi_use(u_sbox: UninitializedSandbox) -> Result { evolve_impl( u_sbox, - |hf, mut hshm, vm, out_hdl, mem_hdl, dispatch_ptr| { + |hf, mut hshm, mut vm, out_hdl, mem_hdl, dispatch_ptr, host_dirty_pages_idx| { { - hshm.as_mut().push_state()?; + let (sandbox_dirty_pages_bitmap, _) = vm.get_and_clear_dirty_pages()?; + let layout = hshm.unwrap_mgr().layout; + hshm.as_mut().create_initial_snapshot( + &sandbox_dirty_pages_bitmap, + host_dirty_pages_idx, + &layout, + )?; } Ok(MultiUseSandbox::from_uninit( hf, @@ -150,19 +178,10 @@ pub(super) fn evolve_impl_multi_use(u_sbox: UninitializedSandbox) -> Result, #[cfg_attr(target_os = "windows", allow(unused_variables))] config: &SandboxConfiguration, + rsp_ptr: GuestPtr, + regions: Vec, #[cfg(any(crashdump, gdb))] rt_cfg: &SandboxRuntimeConfig, ) -> Result> { - #[cfg(feature = "init-paging")] - let rsp_ptr = { - let mut regions = mgr.layout.get_memory_regions(&mgr.shared_mem)?; - let mem_size = u64::try_from(mgr.shared_mem.mem_size())?; - let rsp_u64 = mgr.set_up_shared_memory(mem_size, &mut regions)?; - let rsp_raw = RawPtr::from(rsp_u64); - GuestPtr::try_from(rsp_raw) - }?; - #[cfg(not(feature = "init-paging"))] - let rsp_ptr = GuestPtr::try_from(Offset::from(0))?; - let regions = mgr.layout.get_memory_regions(&mgr.shared_mem)?; let base_ptr = GuestPtr::try_from(Offset::from(0))?; let pml4_ptr = { let pml4_offset_u64 = u64::try_from(SandboxMemoryLayout::PML4_OFFSET)?; diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index 2db32b4a0..200c53793 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -786,3 +786,32 @@ fn log_test_messages(levelfilter: Option) { .unwrap(); } } + +#[test] +// Test to ensure that the state of a sandbox is reset after each function call +// This uses the simpleguest and calls the "echo" function 1000 times with a 64-character string +// The fact that we can successfully call the function 1000 times and get consistent +// results indicates that the sandbox state is being properly reset between calls. +// If there were state leaks, we would expect to see failures or inconsistent behavior +// as the calls accumulate, specifically the input buffer would fill up and cause an error +// if the default size of the input buffer is changed this test should be updated accordingly +fn sandbox_state_reset_between_calls() { + let mut sbox = new_uninit().unwrap().evolve(Noop::default()).unwrap(); + + // Create a 64-character test string + let test_string = "A".repeat(64); + + // Call the echo function 1000 times + for i in 0..1000 { + let result = sbox + .call_guest_function_by_name::("Echo", test_string.clone()) + .unwrap(); + + // Verify that the echo function returns the same string we sent + assert_eq!( + result, test_string, + "Echo function returned unexpected result on iteration {}: expected '{}', got '{}'", + i, test_string, result + ); + } +}