diff --git a/src/hyperlight_host/src/hypervisor/hyperv_linux.rs b/src/hyperlight_host/src/hypervisor/hyperv_linux.rs index 62bce6425..0f8a68521 100644 --- a/src/hyperlight_host/src/hypervisor/hyperv_linux.rs +++ b/src/hyperlight_host/src/hypervisor/hyperv_linux.rs @@ -75,6 +75,7 @@ use super::{ use super::{HyperlightExit, Hypervisor, InterruptHandle, LinuxInterruptHandle, VirtualCPU}; #[cfg(gdb)] use crate::HyperlightError; +use crate::hypervisor::get_memory_access_violation; use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags}; use crate::mem::ptr::{GuestPtr, RawPtr}; use crate::mem::shared_mem::HostSharedMemory; @@ -312,12 +313,15 @@ pub(crate) struct HypervLinuxDriver { page_size: usize, vm_fd: VmFd, vcpu_fd: VcpuFd, - entrypoint: u64, - mem_regions: Vec, orig_rsp: GuestPtr, + entrypoint: u64, interrupt_handle: Arc, mem_mgr: Option>, host_funcs: Option>>, + + sandbox_regions: Vec, // Initially mapped regions when sandbox is created + mmap_regions: Vec, // Later mapped regions + #[cfg(gdb)] debug: Option, #[cfg(gdb)] @@ -447,7 +451,8 @@ impl HypervLinuxDriver { page_size: 0, vm_fd, vcpu_fd, - mem_regions, + sandbox_regions: mem_regions, + mmap_regions: Vec::new(), entrypoint: entrypoint_ptr.absolute()?, orig_rsp: rsp_ptr, interrupt_handle: interrupt_handle.clone(), @@ -540,8 +545,11 @@ impl Debug for HypervLinuxDriver { f.field("Entrypoint", &self.entrypoint) .field("Original RSP", &self.orig_rsp); - for region in &self.mem_regions { - f.field("Memory Region", ®ion); + for region in &self.sandbox_regions { + f.field("Sandbox Memory Region", ®ion); + } + for region in &self.mmap_regions { + f.field("Mapped Memory Region", ®ion); } let regs = self.vcpu_fd.get_regs(); @@ -631,20 +639,24 @@ impl Hypervisor for HypervLinuxDriver { } let mshv_region: mshv_user_mem_region = rgn.to_owned().into(); self.vm_fd.map_user_memory(mshv_region)?; - self.mem_regions.push(rgn.to_owned()); + self.mmap_regions.push(rgn.to_owned()); Ok(()) } #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] - unsafe fn unmap_regions(&mut self, n: u64) -> Result<()> { - for rgn in self - .mem_regions - .split_off(self.mem_regions.len() - n as usize) - { - let mshv_region: mshv_user_mem_region = rgn.to_owned().into(); + unsafe fn unmap_region(&mut self, region: &MemoryRegion) -> Result<()> { + if let Some(pos) = self.mmap_regions.iter().position(|r| r == region) { + let removed_region = self.mmap_regions.remove(pos); + let mshv_region: mshv_user_mem_region = removed_region.into(); self.vm_fd.unmap_user_memory(mshv_region)?; + Ok(()) + } else { + Err(new_error!("Tried to unmap region that is not mapped")) } - Ok(()) + } + + fn get_mapped_regions(&self) -> Box + '_> { + Box::new(self.mmap_regions.iter()) } #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] @@ -867,9 +879,9 @@ impl Hypervisor for HypervLinuxDriver { gpa, &self ); - match self.get_memory_access_violation( + match get_memory_access_violation( gpa as usize, - &self.mem_regions, + self.sandbox_regions.iter().chain(self.mmap_regions.iter()), access_info, ) { Some(access_info_violation) => access_info_violation, @@ -999,7 +1011,7 @@ impl Hypervisor for HypervLinuxDriver { }); Ok(Some(crashdump::CrashDumpContext::new( - &self.mem_regions, + &self.sandbox_regions, regs, xsave.buffer.to_vec(), self.entrypoint, @@ -1180,7 +1192,7 @@ impl Drop for HypervLinuxDriver { #[instrument(skip_all, parent = Span::current(), level = "Trace")] fn drop(&mut self) { self.interrupt_handle.dropped.store(true, Ordering::Relaxed); - for region in &self.mem_regions { + for region in self.sandbox_regions.iter().chain(self.mmap_regions.iter()) { let mshv_region: mshv_user_mem_region = region.to_owned().into(); match self.vm_fd.unmap_user_memory(mshv_region) { Ok(_) => (), diff --git a/src/hyperlight_host/src/hypervisor/hyperv_windows.rs b/src/hyperlight_host/src/hypervisor/hyperv_windows.rs index d4a0c06cd..0fc701a13 100644 --- a/src/hyperlight_host/src/hypervisor/hyperv_windows.rs +++ b/src/hyperlight_host/src/hypervisor/hyperv_windows.rs @@ -57,6 +57,7 @@ use super::{ }; use super::{HyperlightExit, Hypervisor, InterruptHandle, VirtualCPU}; use crate::hypervisor::fpu::FP_CONTROL_WORD_DEFAULT; +use crate::hypervisor::get_memory_access_violation; use crate::hypervisor::wrappers::WHvGeneralRegisters; use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags}; use crate::mem::ptr::{GuestPtr, RawPtr}; @@ -281,10 +282,13 @@ pub(crate) struct HypervWindowsDriver { _surrogate_process: SurrogateProcess, // we need to keep a reference to the SurrogateProcess for the duration of the driver since otherwise it will dropped and the memory mapping will be unmapped and the surrogate process will be returned to the pool entrypoint: u64, orig_rsp: GuestPtr, - mem_regions: Vec, interrupt_handle: Arc, mem_mgr: Option>, host_funcs: Option>>, + + sandbox_regions: Vec, // Initially mapped regions when sandbox is created + mmap_regions: Vec, // Later mapped regions + #[cfg(gdb)] debug: Option, #[cfg(gdb)] @@ -358,7 +362,8 @@ impl HypervWindowsDriver { _surrogate_process: surrogate_process, entrypoint, orig_rsp: GuestPtr::try_from(RawPtr::from(rsp))?, - mem_regions, + sandbox_regions: mem_regions, + mmap_regions: Vec::new(), interrupt_handle: interrupt_handle.clone(), mem_mgr: None, host_funcs: None, @@ -457,8 +462,11 @@ impl Debug for HypervWindowsDriver { fs.field("Entrypoint", &self.entrypoint) .field("Original RSP", &self.orig_rsp); - for region in &self.mem_regions { - fs.field("Memory Region", ®ion); + for region in &self.sandbox_regions { + fs.field("Sandbox Memory Region", ®ion); + } + for region in &self.mmap_regions { + fs.field("Mapped Memory Region", ®ion); } // Get the registers @@ -631,18 +639,17 @@ impl Hypervisor for HypervWindowsDriver { } #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] - unsafe fn map_region(&mut self, _rgn: &MemoryRegion) -> Result<()> { + unsafe fn map_region(&mut self, _region: &MemoryRegion) -> Result<()> { log_then_return!("Mapping host memory into the guest not yet supported on this platform"); } #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] - unsafe fn unmap_regions(&mut self, n: u64) -> Result<()> { - if n > 0 { - log_then_return!( - "Mapping host memory into the guest not yet supported on this platform" - ); - } - Ok(()) + unsafe fn unmap_region(&mut self, _region: &MemoryRegion) -> Result<()> { + log_then_return!("Mapping host memory into the guest not yet supported on this platform"); + } + + fn get_mapped_regions(&self) -> Box + '_> { + Box::new(self.mmap_regions.iter()) } #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] @@ -824,8 +831,11 @@ impl Hypervisor for HypervWindowsDriver { gpa, access_info, &self ); - match self.get_memory_access_violation(gpa as usize, &self.mem_regions, access_info) - { + match get_memory_access_violation( + gpa as usize, + self.sandbox_regions.iter().chain(self.mmap_regions.iter()), + access_info, + ) { Some(access_info) => access_info, None => HyperlightExit::Mmio(gpa), } @@ -934,7 +944,7 @@ impl Hypervisor for HypervWindowsDriver { }); Ok(Some(crashdump::CrashDumpContext::new( - &self.mem_regions, + &self.sandbox_regions, regs, xsave, self.entrypoint, diff --git a/src/hyperlight_host/src/hypervisor/kvm.rs b/src/hyperlight_host/src/hypervisor/kvm.rs index 8a2de0d40..26eae3e25 100644 --- a/src/hyperlight_host/src/hypervisor/kvm.rs +++ b/src/hyperlight_host/src/hypervisor/kvm.rs @@ -42,6 +42,7 @@ use super::{ use super::{HyperlightExit, Hypervisor, InterruptHandle, LinuxInterruptHandle, VirtualCPU}; #[cfg(gdb)] use crate::HyperlightError; +use crate::hypervisor::get_memory_access_violation; use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags}; use crate::mem::ptr::{GuestPtr, RawPtr}; use crate::mem::shared_mem::HostSharedMemory; @@ -294,10 +295,15 @@ pub(crate) struct KVMDriver { vcpu_fd: VcpuFd, entrypoint: u64, orig_rsp: GuestPtr, - mem_regions: Vec, interrupt_handle: Arc, mem_mgr: Option>, host_funcs: Option>>, + + sandbox_regions: Vec, // Initially mapped regions when sandbox is created + mmap_regions: Vec<(MemoryRegion, u32)>, // Later mapped regions (region, slot number) + next_slot: u32, // Monotonically increasing slot number + freed_slots: Vec, // Reusable slots from unmapped regions + #[cfg(gdb)] debug: Option, #[cfg(gdb)] @@ -384,7 +390,10 @@ impl KVMDriver { vcpu_fd, entrypoint, orig_rsp: rsp_gp, - mem_regions, + next_slot: mem_regions.len() as u32, + sandbox_regions: mem_regions, + mmap_regions: Vec::new(), + freed_slots: Vec::new(), interrupt_handle: interrupt_handle.clone(), mem_mgr: None, host_funcs: None, @@ -434,8 +443,11 @@ impl Debug for KVMDriver { let mut f = f.debug_struct("KVM Driver"); // Output each memory region - for region in &self.mem_regions { - f.field("Memory Region", ®ion); + for region in &self.sandbox_regions { + f.field("Sandbox Memory Region", ®ion); + } + for region in &self.mmap_regions { + f.field("Mapped Memory Region", ®ion); } let regs = self.vcpu_fd.get_regs(); // check that regs is OK and then set field in debug struct @@ -517,25 +529,45 @@ impl Hypervisor for KVMDriver { } let mut kvm_region: kvm_userspace_memory_region = region.clone().into(); - kvm_region.slot = self.mem_regions.len() as u32; + + // Try to reuse a freed slot first, otherwise use next_slot + let slot = if let Some(freed_slot) = self.freed_slots.pop() { + freed_slot + } else { + let slot = self.next_slot; + self.next_slot += 1; + slot + }; + + kvm_region.slot = slot; unsafe { self.vm_fd.set_user_memory_region(kvm_region) }?; - self.mem_regions.push(region.to_owned()); + self.mmap_regions.push((region.to_owned(), slot)); Ok(()) } #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] - unsafe fn unmap_regions(&mut self, n: u64) -> Result<()> { - let n_keep = self.mem_regions.len() - n as usize; - for (k, region) in self.mem_regions.split_off(n_keep).iter().enumerate() { - let mut kvm_region: kvm_userspace_memory_region = region.clone().into(); - kvm_region.slot = (n_keep + k) as u32; + unsafe fn unmap_region(&mut self, region: &MemoryRegion) -> Result<()> { + if let Some(idx) = self.mmap_regions.iter().position(|(r, _)| r == region) { + let (region, slot) = self.mmap_regions.remove(idx); + let mut kvm_region: kvm_userspace_memory_region = region.into(); + kvm_region.slot = slot; // Setting memory_size to 0 unmaps the slot's region // From https://docs.kernel.org/virt/kvm/api.html // > Deleting a slot is done by passing zero for memory_size. kvm_region.memory_size = 0; unsafe { self.vm_fd.set_user_memory_region(kvm_region) }?; + + // Add the freed slot to the reuse list + self.freed_slots.push(slot); + + Ok(()) + } else { + Err(new_error!("Tried to unmap region that is not mapped")) } - Ok(()) + } + + fn get_mapped_regions(&self) -> Box + '_> { + Box::new(self.mmap_regions.iter().map(|(region, _)| region)) } #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] @@ -717,9 +749,11 @@ impl Hypervisor for KVMDriver { Ok(VcpuExit::MmioRead(addr, _)) => { crate::debug!("KVM MMIO Read -Details: Address: {} \n {:#?}", addr, &self); - match self.get_memory_access_violation( + match get_memory_access_violation( addr as usize, - &self.mem_regions, + self.sandbox_regions + .iter() + .chain(self.mmap_regions.iter().map(|(r, _)| r)), MemoryRegionFlags::READ, ) { Some(access_violation_exit) => access_violation_exit, @@ -729,9 +763,11 @@ impl Hypervisor for KVMDriver { Ok(VcpuExit::MmioWrite(addr, _)) => { crate::debug!("KVM MMIO Write -Details: Address: {} \n {:#?}", addr, &self); - match self.get_memory_access_violation( + match get_memory_access_violation( addr as usize, - &self.mem_regions, + self.sandbox_regions + .iter() + .chain(self.mmap_regions.iter().map(|(r, _)| r)), MemoryRegionFlags::WRITE, ) { Some(access_violation_exit) => access_violation_exit, @@ -847,7 +883,7 @@ impl Hypervisor for KVMDriver { // The [`CrashDumpContext`] accepts xsave as a vector of u8, so we need to convert the // xsave region to a vector of u8 Ok(Some(crashdump::CrashDumpContext::new( - &self.mem_regions, + &self.sandbox_regions, regs, xsave .region diff --git a/src/hyperlight_host/src/hypervisor/mod.rs b/src/hyperlight_host/src/hypervisor/mod.rs index 0fe452b6e..53b721c7c 100644 --- a/src/hyperlight_host/src/hypervisor/mod.rs +++ b/src/hyperlight_host/src/hypervisor/mod.rs @@ -157,8 +157,13 @@ pub(crate) trait Hypervisor: Debug + Send { /// requirements of at least one page for base and len. unsafe fn map_region(&mut self, rgn: &MemoryRegion) -> Result<()>; - /// Unmap the most recent `n` regions mapped by `map_region` - unsafe fn unmap_regions(&mut self, n: u64) -> Result<()>; + /// Unmap a memory region from the sandbox + unsafe fn unmap_region(&mut self, rgn: &MemoryRegion) -> Result<()>; + + /// Get the currently mapped dynamic memory regions (not including sandbox regions) + /// + /// Note: Box needed for trait to be object-safe :( + fn get_mapped_regions(&self) -> Box + '_>; /// Dispatch a call from the host to the guest using the given pointer /// to the dispatch function _in the guest's address space_. @@ -185,33 +190,6 @@ pub(crate) trait Hypervisor: Debug + Send { /// Run the vCPU fn run(&mut self) -> Result; - /// Returns a Some(HyperlightExit::AccessViolation(..)) if the given gpa doesn't have - /// access its corresponding region. Returns None otherwise, or if the region is not found. - fn get_memory_access_violation( - &self, - gpa: usize, - mem_regions: &[MemoryRegion], - access_info: MemoryRegionFlags, - ) -> Option { - // find the region containing the given gpa - let region = mem_regions - .iter() - .find(|region| region.guest_region.contains(&gpa)); - - if let Some(region) = region { - if !region.flags.contains(access_info) - || region.flags.contains(MemoryRegionFlags::STACK_GUARD) - { - return Some(HyperlightExit::AccessViolation( - gpa as u64, - access_info, - region.flags, - )); - } - } - None - } - /// Get InterruptHandle to underlying VM fn interrupt_handle(&self) -> Arc; @@ -283,6 +261,30 @@ pub(crate) trait Hypervisor: Debug + Send { fn trace_info_as_mut(&mut self) -> &mut TraceInfo; } +/// Returns a Some(HyperlightExit::AccessViolation(..)) if the given gpa doesn't have +/// access its corresponding region. Returns None otherwise, or if the region is not found. +pub(crate) fn get_memory_access_violation<'a>( + gpa: usize, + mut mem_regions: impl Iterator, + access_info: MemoryRegionFlags, +) -> Option { + // find the region containing the given gpa + let region = mem_regions.find(|region| region.guest_region.contains(&gpa)); + + if let Some(region) = region { + if !region.flags.contains(access_info) + || region.flags.contains(MemoryRegionFlags::STACK_GUARD) + { + return Some(HyperlightExit::AccessViolation( + gpa as u64, + access_info, + region.flags, + )); + } + } + None +} + /// A virtual CPU that can be run until an exit occurs pub struct VirtualCPU {} diff --git a/src/hyperlight_host/src/mem/memory_region.rs b/src/hyperlight_host/src/mem/memory_region.rs index f4c38b168..22f71d65b 100644 --- a/src/hyperlight_host/src/mem/memory_region.rs +++ b/src/hyperlight_host/src/mem/memory_region.rs @@ -52,7 +52,7 @@ pub(crate) const DEFAULT_GUEST_BLOB_MEM_FLAGS: MemoryRegionFlags = MemoryRegionF bitflags! { /// flags representing memory permission for a memory region - #[derive(Copy, Clone, Debug, PartialEq, Eq)] + #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct MemoryRegionFlags: u32 { /// no permissions const NONE = 0; @@ -154,7 +154,7 @@ impl TryFrom for MemoryRegionFlags { } // only used for debugging -#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash)] /// The type of memory region pub enum MemoryRegionType { /// The region contains the guest's page tables @@ -181,7 +181,7 @@ pub enum MemoryRegionType { /// represents a single memory region inside the guest. All memory within a region has /// the same memory permissions -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct MemoryRegion { /// the range of guest memory addresses pub guest_region: Range, diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 4e4aa0112..526190d4c 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -27,8 +27,9 @@ use tracing::{Span, instrument}; use super::exe::ExeInfo; use super::layout::SandboxMemoryLayout; +use super::memory_region::MemoryRegion; #[cfg(feature = "init-paging")] -use super::memory_region::{DEFAULT_GUEST_BLOB_MEM_FLAGS, MemoryRegion, MemoryRegionType}; +use super::memory_region::{DEFAULT_GUEST_BLOB_MEM_FLAGS, MemoryRegionType}; use super::ptr::{GuestPtr, RawPtr}; use super::ptr_offset::Offset; use super::shared_mem::{ExclusiveSharedMemory, GuestSharedMemory, HostSharedMemory, SharedMemory}; @@ -259,16 +260,16 @@ where } } - pub(crate) fn snapshot(&mut self) -> Result { - SharedMemorySnapshot::new(&mut self.shared_mem, self.mapped_rgns) + /// Create a snapshot with the given mapped regions + pub(crate) fn snapshot( + &mut self, + mapped_regions: Vec, + ) -> Result { + SharedMemorySnapshot::new(&mut self.shared_mem, mapped_regions) } /// This function restores a memory snapshot from a given snapshot. - /// - /// Returns the number of memory regions mapped into the sandbox - /// that need to be unmapped in order for the restore to be - /// completed. - pub(crate) fn restore_snapshot(&mut self, snapshot: &SharedMemorySnapshot) -> Result { + pub(crate) fn restore_snapshot(&mut self, snapshot: &SharedMemorySnapshot) -> Result<()> { if self.shared_mem.mem_size() != snapshot.mem_size() { return Err(new_error!( "Snapshot size does not match current memory size: {} != {}", @@ -276,9 +277,8 @@ where snapshot.mem_size() )); } - let old_rgns = self.mapped_rgns; - self.mapped_rgns = snapshot.restore_from_snapshot(&mut self.shared_mem)?; - Ok(old_rgns - self.mapped_rgns) + snapshot.restore_from_snapshot(&mut self.shared_mem)?; + Ok(()) } /// Sets `addr` to the correct offset in the memory referenced by diff --git a/src/hyperlight_host/src/mem/shared_mem_snapshot.rs b/src/hyperlight_host/src/mem/shared_mem_snapshot.rs index dfa54430c..8f55a8b25 100644 --- a/src/hyperlight_host/src/mem/shared_mem_snapshot.rs +++ b/src/hyperlight_host/src/mem/shared_mem_snapshot.rs @@ -16,6 +16,7 @@ limitations under the License. use tracing::{Span, instrument}; +use super::memory_region::MemoryRegion; use super::shared_mem::SharedMemory; use crate::Result; @@ -24,21 +25,21 @@ use crate::Result; #[derive(Clone)] pub(crate) struct SharedMemorySnapshot { snapshot: Vec, - /// How many non-main-RAM regions were mapped when this snapshot was taken? - mapped_rgns: u64, + /// The memory regions that were mapped when this snapshot was taken (excluding initial sandbox regions) + regions: Vec, } impl SharedMemorySnapshot { /// Take a snapshot of the memory in `shared_mem`, then create a new /// instance of `Self` with the snapshot stored therein. #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn new(shared_mem: &mut S, mapped_rgns: u64) -> Result { + pub(super) fn new( + shared_mem: &mut S, + regions: Vec, + ) -> Result { // TODO: Track dirty pages instead of copying entire memory let snapshot = shared_mem.with_exclusivity(|e| e.copy_all_to_vec())??; - Ok(Self { - snapshot, - mapped_rgns, - }) + Ok(Self { snapshot, regions }) } /// Take another snapshot of the internally-stored `SharedMemory`, @@ -51,11 +52,16 @@ impl SharedMemorySnapshot { } /// Copy the memory from the internally-stored memory snapshot - /// into the internally-stored `SharedMemory` + /// into the internally-stored `SharedMemory`. #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn restore_from_snapshot(&self, shared_mem: &mut S) -> Result { + pub(super) fn restore_from_snapshot(&self, shared_mem: &mut S) -> Result<()> { shared_mem.with_exclusivity(|e| e.copy_from_slice(self.snapshot.as_slice(), 0))??; - Ok(self.mapped_rgns) + Ok(()) + } + + /// Get the mapped regions from this snapshot + pub(crate) fn regions(&self) -> &[MemoryRegion] { + &self.regions } /// Return the size of the snapshot in bytes. @@ -78,7 +84,7 @@ mod tests { let data2 = data1.iter().map(|b| b + 1).collect::>(); let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE).unwrap(); gm.copy_from_slice(data1.as_slice(), 0).unwrap(); - let mut snap = super::SharedMemorySnapshot::new(&mut gm, 0).unwrap(); + let mut snap = super::SharedMemorySnapshot::new(&mut gm, Vec::new()).unwrap(); { // after the first snapshot is taken, make sure gm has the equivalent // of data1 diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index f9514dcd2..b21ab6361 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +use std::collections::HashSet; #[cfg(unix)] use std::os::fd::AsRawFd; #[cfg(unix)] @@ -88,18 +89,35 @@ impl MultiUseSandbox { /// Create a snapshot of the current state of the sandbox's memory. #[instrument(err(Debug), skip_all, parent = Span::current())] pub fn snapshot(&mut self) -> Result { - let snapshot = self.mem_mgr.unwrap_mgr_mut().snapshot()?; - Ok(Snapshot { inner: snapshot }) + let mapped_regions_iter = self.vm.get_mapped_regions(); + let mapped_regions_vec: Vec = mapped_regions_iter.cloned().collect(); + let memory_snapshot = self.mem_mgr.unwrap_mgr_mut().snapshot(mapped_regions_vec)?; + Ok(Snapshot { + inner: memory_snapshot, + }) } /// Restore the sandbox's memory to the state captured in the given snapshot. #[instrument(err(Debug), skip_all, parent = Span::current())] pub fn restore(&mut self, snapshot: &Snapshot) -> Result<()> { - let rgns_to_unmap = self - .mem_mgr + self.mem_mgr .unwrap_mgr_mut() .restore_snapshot(&snapshot.inner)?; - unsafe { self.vm.unmap_regions(rgns_to_unmap)? }; + + let current_regions: HashSet<_> = self.vm.get_mapped_regions().cloned().collect(); + let snapshot_regions: HashSet<_> = snapshot.inner.regions().iter().cloned().collect(); + + let regions_to_unmap = current_regions.difference(&snapshot_regions); + let regions_to_map = snapshot_regions.difference(¤t_regions); + + for region in regions_to_unmap { + unsafe { self.vm.unmap_region(region)? }; + } + + for region in regions_to_map { + unsafe { self.vm.map_region(region)? }; + } + Ok(()) } @@ -694,4 +712,57 @@ mod tests { region_type: MemoryRegionType::Heap, } } + + #[cfg(target_os = "linux")] + fn allocate_guest_memory() -> GuestSharedMemory { + page_aligned_memory(b"test data for snapshot") + } + + #[test] + #[cfg(target_os = "linux")] + fn snapshot_restore_handles_remapping_correctly() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve().unwrap() + }; + + // 1. Take snapshot 1 with no additional regions mapped + let snapshot1 = sbox.snapshot().unwrap(); + assert_eq!(sbox.vm.get_mapped_regions().len(), 0); + + // 2. Map a memory region + let map_mem = allocate_guest_memory(); + let guest_base = 0x200000000_usize; + let region = region_for_memory(&map_mem, guest_base, MemoryRegionFlags::READ); + + unsafe { sbox.map_region(®ion).unwrap() }; + assert_eq!(sbox.vm.get_mapped_regions().len(), 1); + + // 3. Take snapshot 2 with 1 region mapped + let snapshot2 = sbox.snapshot().unwrap(); + assert_eq!(sbox.vm.get_mapped_regions().len(), 1); + + // 4. Restore to snapshot 1 (should unmap the region) + sbox.restore(&snapshot1).unwrap(); + assert_eq!(sbox.vm.get_mapped_regions().len(), 0); + + // 5. Restore forward to snapshot 2 (should remap the region) + sbox.restore(&snapshot2).unwrap(); + assert_eq!(sbox.vm.get_mapped_regions().len(), 1); + + // Verify the region is the same + let mut restored_regions = sbox.vm.get_mapped_regions(); + assert_eq!(*restored_regions.next().unwrap(), region); + assert!(restored_regions.next().is_none()); + drop(restored_regions); + + // 6. Try map the region again (should fail since already mapped) + let err = unsafe { sbox.map_region(®ion) }; + assert!( + err.is_err(), + "Expected error when remapping existing region: {:?}", + err + ); + } }