diff --git a/compiler/rustc_middle/src/mir/interpret/allocation.rs b/compiler/rustc_middle/src/mir/interpret/allocation.rs index 17ff52cefcfa..616256b116e6 100644 --- a/compiler/rustc_middle/src/mir/interpret/allocation.rs +++ b/compiler/rustc_middle/src/mir/interpret/allocation.rs @@ -147,7 +147,7 @@ impl Allocation { Self { bytes, relocations: Relocations::new(), - init_mask: InitMask::new(size, true), + init_mask: InitMask::new_init(size), align, mutability, extra: (), @@ -180,7 +180,7 @@ impl Allocation { Ok(Allocation { bytes, relocations: Relocations::new(), - init_mask: InitMask::new(size, false), + init_mask: InitMask::new_uninit(size), align, mutability: Mutability::Mut, extra: (), @@ -629,15 +629,19 @@ impl InitMask { Size::from_bytes(block * InitMask::BLOCK_SIZE + bit) } - pub fn new(size: Size, state: bool) -> Self { + pub fn new_init(size: Size) -> Self { let mut m = InitMask { blocks: vec![], len: Size::ZERO }; - m.grow(size, state); + m.grow(size, true); m } + pub fn new_uninit(size: Size) -> Self { + InitMask { blocks: vec![], len: size } + } + pub fn set_range(&mut self, start: Size, end: Size, new_state: bool) { let len = self.len; - if end > len { + if end > len && new_state { self.grow(end - len, new_state); } self.set_range_inbounds(start, end, new_state); @@ -655,14 +659,16 @@ impl InitMask { (u64::MAX << bita) & (u64::MAX >> (64 - bitb)) }; if new_state { + self.ensure_blocks(blocka); self.blocks[blocka] |= range; - } else { - self.blocks[blocka] &= !range; + } else if let Some(block) = self.blocks.get_mut(blocka) { + *block &= !range; } return; } // across block boundaries if new_state { + self.ensure_blocks(blockb); // Set `bita..64` to `1`. self.blocks[blocka] |= u64::MAX << bita; // Set `0..bitb` to `1`. @@ -673,15 +679,17 @@ impl InitMask { for block in (blocka + 1)..blockb { self.blocks[block] = u64::MAX; } - } else { + } else if let Some(blocka_val) = self.blocks.get_mut(blocka) { // Set `bita..64` to `0`. - self.blocks[blocka] &= !(u64::MAX << bita); + *blocka_val &= !(u64::MAX << bita); // Set `0..bitb` to `0`. if bitb != 0 { - self.blocks[blockb] &= !(u64::MAX >> (64 - bitb)); + if let Some(blockb_val) = self.blocks.get_mut(blockb) { + *blockb_val &= !(u64::MAX >> (64 - bitb)); + } } // Fill in all the other blocks (much faster than one bit at a time). - for block in (blocka + 1)..blockb { + for block in (blocka + 1)..std::cmp::min(blockb, self.blocks.len()) { self.blocks[block] = 0; } } @@ -690,7 +698,10 @@ impl InitMask { #[inline] pub fn get(&self, i: Size) -> bool { let (block, bit) = Self::bit_index(i); - (self.blocks[block] & (1 << bit)) != 0 + match self.blocks.get(block) { + Some(block) => (*block & (1 << bit)) != 0, + None => false, + } } #[inline] @@ -702,10 +713,22 @@ impl InitMask { #[inline] fn set_bit(&mut self, block: usize, bit: usize, new_state: bool) { if new_state { + self.ensure_blocks(block); self.blocks[block] |= 1 << bit; - } else { - self.blocks[block] &= !(1 << bit); + } else if let Some(block) = self.blocks.get_mut(block) { + *block &= !(1 << bit); + } + } + + fn ensure_blocks(&mut self, block: usize) { + if block < self.blocks.len() { + return; } + let additional_blocks = block - self.blocks.len() + 1; + self.blocks.extend( + // FIXME(oli-obk): optimize this by repeating `new_state as Block`. + iter::repeat(0).take(usize::try_from(additional_blocks).unwrap()), + ); } pub fn grow(&mut self, amount: Size, new_state: bool) { @@ -716,10 +739,7 @@ impl InitMask { u64::try_from(self.blocks.len()).unwrap() * Self::BLOCK_SIZE - self.len.bytes(); if amount.bytes() > unused_trailing_bits { let additional_blocks = amount.bytes() / Self::BLOCK_SIZE + 1; - self.blocks.extend( - // FIXME(oli-obk): optimize this by repeating `new_state as Block`. - iter::repeat(0).take(usize::try_from(additional_blocks).unwrap()), - ); + self.ensure_blocks(self.blocks.len() + additional_blocks as usize - 1); } let start = self.len; self.len += amount; @@ -821,25 +841,31 @@ impl InitMask { // (c) 01000000|00000000|00000001 // ^~~~~~~~~~~~~~~~~~^ // start end - if let Some(i) = - search_block(init_mask.blocks[start_block], start_block, start_bit, is_init) - { - // If the range is less than a block, we may find a matching bit after `end`. - // - // For example, we shouldn't successfully find bit (2), because it's after `end`: - // - // (2) - // -------| - // (d) 00000001|00000000|00000001 - // ^~~~~^ - // start end - // - // An alternative would be to mask off end bits in the same way as we do for start bits, - // but performing this check afterwards is faster and simpler to implement. - if i < end { - return Some(i); - } else { + if let Some(&bits) = init_mask.blocks.get(start_block) { + if let Some(i) = search_block(bits, start_block, start_bit, is_init) { + // If the range is less than a block, we may find a matching bit after `end`. + // + // For example, we shouldn't successfully find bit (2), because it's after `end`: + // + // (2) + // -------| + // (d) 00000001|00000000|00000001 + // ^~~~~^ + // start end + // + // An alternative would be to mask off end bits in the same way as we do for start bits, + // but performing this check afterwards is faster and simpler to implement. + if i < end { + return Some(i); + } else { + return None; + } + } + } else { + if is_init { return None; + } else { + return Some(start); } } @@ -861,7 +887,8 @@ impl InitMask { // because both alternatives result in significantly worse codegen. // `end_block_inclusive + 1` is guaranteed not to wrap, because `end_block_inclusive <= end / BLOCK_SIZE`, // and `BLOCK_SIZE` (the number of bits per block) will always be at least 8 (1 byte). - for (&bits, block) in init_mask.blocks[start_block + 1..end_block_inclusive + 1] + for (&bits, block) in init_mask.blocks[start_block + 1 + ..std::cmp::min(end_block_inclusive + 1, init_mask.blocks.len())] .iter() .zip(start_block + 1..) { @@ -886,6 +913,9 @@ impl InitMask { } } } + if !is_init && end_block_inclusive >= init_mask.blocks.len() { + return Some(InitMask::size_from_bit_index(init_mask.blocks.len(), 0)); + } } None diff --git a/src/test/ui-fulldeps/uninit_mask.rs b/src/test/ui-fulldeps/uninit_mask.rs index 84ce291016aa..cf020c1a26dd 100644 --- a/src/test/ui-fulldeps/uninit_mask.rs +++ b/src/test/ui-fulldeps/uninit_mask.rs @@ -11,7 +11,7 @@ use rustc_middle::mir::interpret::InitMask; use rustc_target::abi::Size; fn main() { - let mut mask = InitMask::new(Size::from_bytes(500), false); + let mut mask = InitMask::new_uninit(Size::from_bytes(500)); assert!(!mask.get(Size::from_bytes(499))); mask.set(Size::from_bytes(499), true); assert!(mask.get(Size::from_bytes(499)));