Skip to content

Commit a40adb3

Browse files
committed
Local copy propagation
Based on ideas from 76723.
1 parent 19857d9 commit a40adb3

File tree

3 files changed

+271
-0
lines changed

3 files changed

+271
-0
lines changed
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
//! A intra-block copy propagation pass.
2+
//!
3+
//! Given an assignment `_a = _b` replaces subsequent uses of destination `_a` with source `_b`, as
4+
//! long as neither `a` nor `_b` had been modified in the intervening statements.
5+
//!
6+
//! The implementation processes block statements & terminator in the execution order. For each
7+
//! local it keeps track of a source that defined its current value. When it encounters a copy use
8+
//! of a local, it verifies that source had not been modified since the assignment and replaces the
9+
//! local with the source.
10+
//!
11+
//! To detect modifications, each local has a generation number that is increased after each direct
12+
//! modification. The local generation number is recorded at the time of the assignment and
13+
//! verified before the propagation to ensure that the local remains unchanged since the
14+
//! assignment.
15+
//!
16+
//! Instead of detecting indirect modifications, locals that have their address taken never
17+
//! participate in copy propagation.
18+
//!
19+
//! When moving in-between the blocks, all recorded values are invalidated. To do that in O(1)
20+
//! time, generation numbers have a global component that is increased after each block.
21+
22+
use crate::transform::MirPass;
23+
use crate::util::ever_borrowed_locals;
24+
use rustc_index::bit_set::BitSet;
25+
use rustc_index::vec::IndexVec;
26+
use rustc_middle::mir::visit::*;
27+
use rustc_middle::mir::*;
28+
use rustc_middle::ty::TyCtxt;
29+
30+
pub struct CopyPropagation;
31+
32+
impl<'tcx> MirPass<'tcx> for CopyPropagation {
33+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
34+
copy_move_operands(tcx, body);
35+
propagate_copies(tcx, body);
36+
}
37+
}
38+
39+
fn propagate_copies(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
40+
let mut values = LocalValues {
41+
borrowed_locals: ever_borrowed_locals(body),
42+
values: IndexVec::from_elem_n(LocalValue::default(), body.local_decls.len()),
43+
block_generation: 0,
44+
};
45+
for (block, data) in body.basic_blocks_mut().iter_enumerated_mut() {
46+
for (statement_index, statement) in data.statements.iter_mut().enumerate() {
47+
let location = Location { block, statement_index };
48+
InvalidateModifiedLocals { values: &mut values }.visit_statement(statement, location);
49+
CopyPropagate { tcx, values: &mut values }.visit_statement(statement, location);
50+
values.record_assignment(statement);
51+
}
52+
53+
let location = Location { block, statement_index: data.statements.len() };
54+
InvalidateModifiedLocals { values: &mut values }
55+
.visit_terminator(data.terminator_mut(), location);
56+
CopyPropagate { tcx, values: &mut values }
57+
.visit_terminator(data.terminator_mut(), location);
58+
values.invalidate_all();
59+
}
60+
}
61+
62+
struct LocalValues {
63+
/// Locals that have their address taken. They do not participate in copy propagation.
64+
borrowed_locals: BitSet<Local>,
65+
/// A symbolic value of each local.
66+
values: IndexVec<Local, LocalValue>,
67+
/// Block generation number. Used to invalidate locals' values in-between the blocks in O(1) time.
68+
block_generation: u32,
69+
}
70+
71+
/// A symbolic value of a local variable.
72+
#[derive(Copy, Clone, Default)]
73+
struct LocalValue {
74+
/// Generation of the current value.
75+
generation: Generation,
76+
/// Generation of the source value at the time of the assignment.
77+
src_generation: Generation,
78+
/// If present the current value of this local is a result of assignment `this = src`.
79+
src: Option<Local>,
80+
}
81+
82+
#[derive(Copy, Clone, Default, PartialEq, Eq)]
83+
struct Generation {
84+
/// Local generation number. Increased after each mutation.
85+
local: u32,
86+
/// Block generation number. Increased in-between the blocks.
87+
block: u32,
88+
}
89+
90+
impl LocalValues {
91+
/// Invalidates all locals' values.
92+
fn invalidate_all(&mut self) {
93+
assert!(self.block_generation != u32::MAX);
94+
self.block_generation += 1;
95+
}
96+
97+
/// Invalidates the local's value.
98+
fn invalidate_local(&mut self, local: Local) {
99+
let value = &mut self.values[local];
100+
assert!(value.generation.local != u32::MAX);
101+
value.generation.local += 1;
102+
value.src_generation = Generation::default();
103+
value.src = None;
104+
}
105+
106+
fn record_assignment(&mut self, statement: &Statement<'tcx>) {
107+
let (place, rvalue) = match statement.kind {
108+
StatementKind::Assign(box (ref place, ref rvalue)) => (place, rvalue),
109+
_ => return,
110+
};
111+
112+
// Record only complete definitions of local variables.
113+
let dst = match place.as_local() {
114+
Some(dst) => dst,
115+
None => return,
116+
};
117+
// Reject borrowed destinations.
118+
if self.borrowed_locals.contains(dst) {
119+
return;
120+
}
121+
122+
let src = match rvalue {
123+
Rvalue::Use(Operand::Copy(src)) => src,
124+
_ => return,
125+
};
126+
let src = match src.as_local() {
127+
Some(src) => src,
128+
None => return,
129+
};
130+
// Reject borrowed sources.
131+
if self.borrowed_locals.contains(src) {
132+
return;
133+
}
134+
135+
// Record `dst = src` assignment.
136+
let src_generation = self.values[src].generation;
137+
let value = &mut self.values[dst];
138+
value.generation.local += 1;
139+
value.generation.block = self.block_generation;
140+
value.src = Some(src);
141+
value.src_generation = src_generation;
142+
}
143+
144+
/// Replaces a use of dst with its current value.
145+
fn propagate_local(&mut self, dst: &mut Local) {
146+
let dst_value = &self.values[*dst];
147+
148+
let src = match dst_value.src {
149+
Some(src) => src,
150+
None => return,
151+
};
152+
// Last definition of dst was of the form `dst = src`.
153+
154+
// Check that dst was defined in this block.
155+
if dst_value.generation.block != self.block_generation {
156+
return;
157+
}
158+
// Check that src still has the same value.
159+
if dst_value.src_generation != self.values[src].generation {
160+
return;
161+
}
162+
163+
// Propagate
164+
*dst = src;
165+
}
166+
}
167+
168+
/// Invalidates locals that could be modified during execution of visited MIR.
169+
struct InvalidateModifiedLocals<'a> {
170+
values: &'a mut LocalValues,
171+
}
172+
173+
impl<'tcx, 'a> Visitor<'tcx> for InvalidateModifiedLocals<'a> {
174+
fn visit_local(&mut self, local: &Local, context: PlaceContext, _location: Location) {
175+
match context {
176+
PlaceContext::MutatingUse(_)
177+
| PlaceContext::NonMutatingUse(NonMutatingUseContext::Move)
178+
| PlaceContext::NonUse(NonUseContext::StorageLive | NonUseContext::StorageDead) => {
179+
self.values.invalidate_local(*local)
180+
}
181+
182+
PlaceContext::NonMutatingUse(_)
183+
| PlaceContext::NonUse(NonUseContext::AscribeUserTy | NonUseContext::VarDebugInfo) => {}
184+
}
185+
}
186+
}
187+
188+
/// Replaces copy uses of locals with their current value.
189+
struct CopyPropagate<'tcx, 'a> {
190+
tcx: TyCtxt<'tcx>,
191+
values: &'a mut LocalValues,
192+
}
193+
194+
impl<'tcx, 'a> MutVisitor<'tcx> for CopyPropagate<'tcx, 'a> {
195+
fn tcx(&self) -> TyCtxt<'tcx> {
196+
self.tcx
197+
}
198+
199+
fn visit_local(&mut self, local: &mut Local, context: PlaceContext, _location: Location) {
200+
match context {
201+
PlaceContext::NonMutatingUse(
202+
NonMutatingUseContext::Copy | NonMutatingUseContext::Inspect,
203+
) => self.values.propagate_local(local),
204+
_ => {}
205+
}
206+
}
207+
}
208+
209+
/// Transforms move operands into copy operands.
210+
fn copy_move_operands<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
211+
let mut visitor = CopyMoveOperands { tcx };
212+
for (block, data) in body.basic_blocks_mut().iter_enumerated_mut() {
213+
visitor.visit_basic_block_data(block, data);
214+
}
215+
}
216+
217+
struct CopyMoveOperands<'tcx> {
218+
tcx: TyCtxt<'tcx>,
219+
}
220+
221+
impl<'tcx> MutVisitor<'tcx> for CopyMoveOperands<'tcx> {
222+
fn tcx(&self) -> TyCtxt<'tcx> {
223+
self.tcx
224+
}
225+
226+
fn visit_operand(&mut self, operand: &mut Operand<'tcx>, _location: Location) {
227+
if let Operand::Move(place) = operand {
228+
*operand = Operand::Copy(*place);
229+
}
230+
}
231+
232+
fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
233+
if let TerminatorKind::Call { .. } = terminator.kind {
234+
// When a move operand is used in a call terminator and ABI passes value by a
235+
// reference, the code generation uses provided operand in place instead of making a
236+
// copy. To avoid introducing extra copies, we retain move operands in call
237+
// terminators.
238+
} else {
239+
self.super_terminator(terminator, location)
240+
}
241+
}
242+
}

compiler/rustc_mir/src/transform/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pub mod cleanup_post_borrowck;
2424
pub mod const_debuginfo;
2525
pub mod const_goto;
2626
pub mod const_prop;
27+
pub mod copy_propagation;
2728
pub mod coverage;
2829
pub mod deaggregator;
2930
pub mod deduplicate_blocks;
@@ -517,6 +518,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
517518
&simplify::SimplifyCfg::new("final"),
518519
&nrvo::RenameReturnPlace,
519520
&const_debuginfo::ConstDebugInfo,
521+
&copy_propagation::CopyPropagation,
520522
&simplify::SimplifyLocals,
521523
&multiple_return_terminators::MultipleReturnTerminators,
522524
&deduplicate_blocks::DeduplicateBlocks,

src/test/mir-opt/copy_propagation.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// compile-flags: --crate-type=lib
2+
3+
// EMIT_MIR copy_propagation.write.CopyPropagation.diff
4+
pub fn write<T: Copy>(dst: &mut T, value: T) {
5+
*dst = value;
6+
}
7+
8+
// EMIT_MIR copy_propagation.id.CopyPropagation.diff
9+
pub fn id<T: Copy>(mut a: T) -> T {
10+
// Not optimized.
11+
a = a;
12+
a
13+
}
14+
15+
// EMIT_MIR copy_propagation.chain.CopyPropagation.diff
16+
pub fn chain<T: Copy>(mut a: T) -> T {
17+
let mut b;
18+
let mut c;
19+
b = a;
20+
c = b;
21+
a = c;
22+
b = a;
23+
c = b;
24+
25+
let d = c;
26+
d
27+
}

0 commit comments

Comments
 (0)