@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
14
14
limitations under the License.
15
15
*/
16
16
17
+ use std:: collections:: HashSet ;
17
18
#[ cfg( unix) ]
18
19
use std:: os:: fd:: AsRawFd ;
19
20
#[ cfg( unix) ]
@@ -95,18 +96,35 @@ impl MultiUseSandbox {
95
96
/// Create a snapshot of the current state of the sandbox's memory.
96
97
#[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) ) ]
97
98
pub fn snapshot ( & mut self ) -> Result < Snapshot > {
98
- let snapshot = self . mem_mgr . unwrap_mgr_mut ( ) . snapshot ( ) ?;
99
- Ok ( Snapshot { inner : snapshot } )
99
+ let mapped_regions_iter = self . vm . get_mapped_regions ( ) ;
100
+ let mapped_regions_vec: Vec < MemoryRegion > = mapped_regions_iter. cloned ( ) . collect ( ) ;
101
+ let memory_snapshot = self . mem_mgr . unwrap_mgr_mut ( ) . snapshot ( mapped_regions_vec) ?;
102
+ Ok ( Snapshot {
103
+ inner : memory_snapshot,
104
+ } )
100
105
}
101
106
102
107
/// Restore the sandbox's memory to the state captured in the given snapshot.
103
108
#[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) ) ]
104
109
pub fn restore ( & mut self , snapshot : & Snapshot ) -> Result < ( ) > {
105
- let rgns_to_unmap = self
106
- . mem_mgr
110
+ self . mem_mgr
107
111
. unwrap_mgr_mut ( )
108
112
. restore_snapshot ( & snapshot. inner ) ?;
109
- unsafe { self . vm . unmap_regions ( rgns_to_unmap) ? } ;
113
+
114
+ let current_regions: HashSet < _ > = self . vm . get_mapped_regions ( ) . cloned ( ) . collect ( ) ;
115
+ let snapshot_regions: HashSet < _ > = snapshot. inner . regions ( ) . iter ( ) . cloned ( ) . collect ( ) ;
116
+
117
+ let regions_to_unmap = current_regions. difference ( & snapshot_regions) ;
118
+ let regions_to_map = snapshot_regions. difference ( & current_regions) ;
119
+
120
+ for region in regions_to_unmap {
121
+ unsafe { self . vm . unmap_region ( region) ? } ;
122
+ }
123
+
124
+ for region in regions_to_map {
125
+ unsafe { self . vm . map_region ( region) ? } ;
126
+ }
127
+
110
128
Ok ( ( ) )
111
129
}
112
130
@@ -645,4 +663,57 @@ mod tests {
645
663
region_type : MemoryRegionType :: Heap ,
646
664
}
647
665
}
666
+
667
+ #[ cfg( target_os = "linux" ) ]
668
+ fn allocate_guest_memory ( ) -> GuestSharedMemory {
669
+ page_aligned_memory ( b"test data for snapshot" )
670
+ }
671
+
672
+ #[ test]
673
+ #[ cfg( target_os = "linux" ) ]
674
+ fn snapshot_restore_handles_remapping_correctly ( ) {
675
+ let mut sbox: MultiUseSandbox = {
676
+ let path = simple_guest_as_string ( ) . unwrap ( ) ;
677
+ let u_sbox = UninitializedSandbox :: new ( GuestBinary :: FilePath ( path) , None ) . unwrap ( ) ;
678
+ u_sbox. evolve ( ) . unwrap ( )
679
+ } ;
680
+
681
+ // 1. Take snapshot 1 with no additional regions mapped
682
+ let snapshot1 = sbox. snapshot ( ) . unwrap ( ) ;
683
+ assert_eq ! ( sbox. vm. get_mapped_regions( ) . len( ) , 0 ) ;
684
+
685
+ // 2. Map a memory region
686
+ let map_mem = allocate_guest_memory ( ) ;
687
+ let guest_base = 0x200000000_usize ;
688
+ let region = region_for_memory ( & map_mem, guest_base) ;
689
+
690
+ unsafe { sbox. map_region ( & region) . unwrap ( ) } ;
691
+ assert_eq ! ( sbox. vm. get_mapped_regions( ) . len( ) , 1 ) ;
692
+
693
+ // 3. Take snapshot 2 with 1 region mapped
694
+ let snapshot2 = sbox. snapshot ( ) . unwrap ( ) ;
695
+ assert_eq ! ( sbox. vm. get_mapped_regions( ) . len( ) , 1 ) ;
696
+
697
+ // 4. Restore to snapshot 1 (should unmap the region)
698
+ sbox. restore ( & snapshot1) . unwrap ( ) ;
699
+ assert_eq ! ( sbox. vm. get_mapped_regions( ) . len( ) , 0 ) ;
700
+
701
+ // 5. Restore forward to snapshot 2 (should remap the region)
702
+ sbox. restore ( & snapshot2) . unwrap ( ) ;
703
+ assert_eq ! ( sbox. vm. get_mapped_regions( ) . len( ) , 1 ) ;
704
+
705
+ // Verify the region is the same
706
+ let mut restored_regions = sbox. vm . get_mapped_regions ( ) ;
707
+ assert_eq ! ( * restored_regions. next( ) . unwrap( ) , region) ;
708
+ assert ! ( restored_regions. next( ) . is_none( ) ) ;
709
+ drop ( restored_regions) ;
710
+
711
+ // 6. Try map the region again (should fail since already mapped)
712
+ let err = unsafe { sbox. map_region ( & region) } ;
713
+ assert ! (
714
+ err. is_err( ) ,
715
+ "Expected error when remapping existing region: {:?}" ,
716
+ err
717
+ ) ;
718
+ }
648
719
}
0 commit comments