Skip to content

Commit be9668d

Browse files
committed
Use an interpreter in jump threading.
1 parent 25f8d01 commit be9668d

File tree

4 files changed

+197
-27
lines changed

4 files changed

+197
-27
lines changed

compiler/rustc_mir_transform/src/jump_threading.rs

+75-27
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,21 @@
3636
//! cost by `MAX_COST`.
3737
3838
use rustc_arena::DroplessArena;
39+
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
3940
use rustc_data_structures::fx::FxHashSet;
4041
use rustc_index::bit_set::BitSet;
4142
use rustc_index::IndexVec;
43+
use rustc_middle::mir::interpret::Scalar;
4244
use rustc_middle::mir::visit::Visitor;
4345
use rustc_middle::mir::*;
44-
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
46+
use rustc_middle::ty::layout::LayoutOf;
47+
use rustc_middle::ty::{self, ScalarInt, TyCtxt};
4548
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
49+
use rustc_span::DUMMY_SP;
4650
use rustc_target::abi::{TagEncoding, Variants};
4751

4852
use crate::cost_checker::CostChecker;
53+
use crate::dataflow_const_prop::DummyMachine;
4954

5055
pub struct JumpThreading;
5156

@@ -71,6 +76,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
7176
let mut finder = TOFinder {
7277
tcx,
7378
param_env,
79+
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
7480
body,
7581
arena: &arena,
7682
map: &map,
@@ -88,7 +94,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
8894
debug!(?discr, ?bb);
8995

9096
let discr_ty = discr.ty(body, tcx).ty;
91-
let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue };
97+
let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue };
9298

9399
let Some(discr) = finder.map.find(discr.as_ref()) else { continue };
94100
debug!(?discr);
@@ -142,6 +148,7 @@ struct ThreadingOpportunity {
142148
struct TOFinder<'tcx, 'a> {
143149
tcx: TyCtxt<'tcx>,
144150
param_env: ty::ParamEnv<'tcx>,
151+
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
145152
body: &'a Body<'tcx>,
146153
map: &'a Map,
147154
loop_headers: &'a BitSet<BasicBlock>,
@@ -329,25 +336,72 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
329336
}
330337

331338
#[instrument(level = "trace", skip(self))]
332-
fn process_operand(
339+
fn process_immediate(
333340
&mut self,
334341
bb: BasicBlock,
335342
lhs: PlaceIndex,
336-
rhs: &Operand<'tcx>,
343+
rhs: ImmTy<'tcx>,
337344
state: &mut State<ConditionSet<'a>>,
338345
) -> Option<!> {
339346
let register_opportunity = |c: Condition| {
340347
debug!(?bb, ?c.target, "register");
341348
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
342349
};
343350

351+
let conditions = state.try_get_idx(lhs, self.map)?;
352+
if let Immediate::Scalar(Scalar::Int(int)) = *rhs {
353+
conditions.iter_matches(int).for_each(register_opportunity);
354+
}
355+
356+
None
357+
}
358+
359+
#[instrument(level = "trace", skip(self))]
360+
fn process_operand(
361+
&mut self,
362+
bb: BasicBlock,
363+
lhs: PlaceIndex,
364+
rhs: &Operand<'tcx>,
365+
state: &mut State<ConditionSet<'a>>,
366+
) -> Option<!> {
344367
match rhs {
345368
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
346369
Operand::Constant(constant) => {
347-
let conditions = state.try_get_idx(lhs, self.map)?;
348-
let constant =
349-
constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
350-
conditions.iter_matches(constant).for_each(register_opportunity);
370+
let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?;
371+
self.map.for_each_projection_value(
372+
lhs,
373+
constant,
374+
&mut |elem, op| match elem {
375+
TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
376+
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
377+
TrackElem::Discriminant => {
378+
let variant = self.ecx.read_discriminant(op).ok()?;
379+
let discr_value =
380+
self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
381+
Some(discr_value.into())
382+
}
383+
TrackElem::DerefLen => {
384+
let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
385+
let len_usize = op.len(&self.ecx).ok()?;
386+
let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
387+
Some(ImmTy::from_uint(len_usize, layout).into())
388+
}
389+
},
390+
&mut |place, op| {
391+
if let Some(conditions) = state.try_get_idx(place, self.map)
392+
&& let Ok(imm) = self.ecx.read_immediate_raw(op)
393+
&& let Some(imm) = imm.right()
394+
&& let Immediate::Scalar(Scalar::Int(int)) = *imm
395+
{
396+
conditions.iter_matches(int).for_each(|c: Condition| {
397+
self.opportunities.push(ThreadingOpportunity {
398+
chain: vec![bb],
399+
target: c.target,
400+
})
401+
})
402+
}
403+
},
404+
);
351405
}
352406
// Transfer the conditions on the copied rhs.
353407
Operand::Move(rhs) | Operand::Copy(rhs) => {
@@ -374,18 +428,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
374428
// Below, `lhs` is the return value of `mutated_statement`,
375429
// the place to which `conditions` apply.
376430

377-
let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| {
378-
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
379-
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
380-
let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?;
381-
Some(Operand::const_from_scalar(
382-
self.tcx,
383-
discr.ty,
384-
scalar.into(),
385-
rustc_span::DUMMY_SP,
386-
))
387-
};
388-
389431
match &stmt.kind {
390432
// If we expect `discriminant(place) ?= A`,
391433
// we have an opportunity if `variant_index ?= A`.
@@ -395,7 +437,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
395437
// `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
396438
// of a niche encoding. If we cannot ensure that we write to the discriminant, do
397439
// nothing.
398-
let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?;
440+
let enum_layout = self.ecx.layout_of(enum_ty).ok()?;
399441
let writes_discriminant = match enum_layout.variants {
400442
Variants::Single { index } => {
401443
assert_eq!(index, *variant_index);
@@ -408,8 +450,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
408450
} => *variant_index != untagged_variant,
409451
};
410452
if writes_discriminant {
411-
let discr = discriminant_for_variant(enum_ty, *variant_index)?;
412-
self.process_operand(bb, discr_target, &discr, state)?;
453+
let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?;
454+
self.process_immediate(bb, discr_target, discr, state)?;
413455
}
414456
}
415457
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
@@ -440,10 +482,16 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
440482
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
441483
if let Some(discr_target) =
442484
self.map.apply(lhs, TrackElem::Discriminant)
443-
&& let Some(discr_value) =
444-
discriminant_for_variant(agg_ty, *variant_index)
485+
&& let Ok(discr_value) = self
486+
.ecx
487+
.discriminant_for_variant(agg_ty, *variant_index)
445488
{
446-
self.process_operand(bb, discr_target, &discr_value, state);
489+
self.process_immediate(
490+
bb,
491+
discr_target,
492+
discr_value,
493+
state,
494+
);
447495
}
448496
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
449497
}
@@ -577,7 +625,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
577625

578626
let discr = discr.place()?;
579627
let discr_ty = discr.ty(self.body, self.tcx).ty;
580-
let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?;
628+
let discr_layout = self.ecx.layout_of(discr_ty).ok()?;
581629
let conditions = state.try_get(discr.as_ref(), self.map)?;
582630

583631
if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
- // MIR for `aggregate` before JumpThreading
2+
+ // MIR for `aggregate` after JumpThreading
3+
4+
fn aggregate(_1: u8) -> u8 {
5+
debug x => _1;
6+
let mut _0: u8;
7+
let _2: u8;
8+
let _3: u8;
9+
let mut _4: (u8, u8);
10+
let mut _5: bool;
11+
let mut _6: u8;
12+
scope 1 {
13+
debug a => _2;
14+
debug b => _3;
15+
}
16+
17+
bb0: {
18+
StorageLive(_4);
19+
_4 = const _;
20+
StorageLive(_2);
21+
_2 = (_4.0: u8);
22+
StorageLive(_3);
23+
_3 = (_4.1: u8);
24+
StorageDead(_4);
25+
StorageLive(_5);
26+
StorageLive(_6);
27+
_6 = _2;
28+
_5 = Eq(move _6, const 7_u8);
29+
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
30+
+ goto -> bb2;
31+
}
32+
33+
bb1: {
34+
StorageDead(_6);
35+
_0 = _3;
36+
goto -> bb3;
37+
}
38+
39+
bb2: {
40+
StorageDead(_6);
41+
_0 = _2;
42+
goto -> bb3;
43+
}
44+
45+
bb3: {
46+
StorageDead(_5);
47+
StorageDead(_3);
48+
StorageDead(_2);
49+
return;
50+
}
51+
}
52+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
- // MIR for `aggregate` before JumpThreading
2+
+ // MIR for `aggregate` after JumpThreading
3+
4+
fn aggregate(_1: u8) -> u8 {
5+
debug x => _1;
6+
let mut _0: u8;
7+
let _2: u8;
8+
let _3: u8;
9+
let mut _4: (u8, u8);
10+
let mut _5: bool;
11+
let mut _6: u8;
12+
scope 1 {
13+
debug a => _2;
14+
debug b => _3;
15+
}
16+
17+
bb0: {
18+
StorageLive(_4);
19+
_4 = const _;
20+
StorageLive(_2);
21+
_2 = (_4.0: u8);
22+
StorageLive(_3);
23+
_3 = (_4.1: u8);
24+
StorageDead(_4);
25+
StorageLive(_5);
26+
StorageLive(_6);
27+
_6 = _2;
28+
_5 = Eq(move _6, const 7_u8);
29+
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
30+
+ goto -> bb2;
31+
}
32+
33+
bb1: {
34+
StorageDead(_6);
35+
_0 = _3;
36+
goto -> bb3;
37+
}
38+
39+
bb2: {
40+
StorageDead(_6);
41+
_0 = _2;
42+
goto -> bb3;
43+
}
44+
45+
bb3: {
46+
StorageDead(_5);
47+
StorageDead(_3);
48+
StorageDead(_2);
49+
return;
50+
}
51+
}
52+

tests/mir-opt/jump_threading.rs

+18
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,23 @@ fn disappearing_bb(x: u8) -> u8 {
453453
)
454454
}
455455

456+
/// Verify that we can thread jumps when we assign from an aggregate constant.
457+
fn aggregate(x: u8) -> u8 {
458+
// CHECK-LABEL: fn aggregate(
459+
// CHECK-NOT: switchInt(
460+
461+
const FOO: (u8, u8) = (5, 13);
462+
463+
let (a, b) = FOO;
464+
if a == 7 {
465+
b
466+
} else {
467+
a
468+
}
469+
}
470+
456471
fn main() {
472+
// CHECK-LABEL: fn main(
457473
too_complex(Ok(0));
458474
identity(Ok(0));
459475
custom_discr(false);
@@ -464,6 +480,7 @@ fn main() {
464480
mutable_ref();
465481
renumbered_bb(true);
466482
disappearing_bb(7);
483+
aggregate(7);
467484
}
468485

469486
// EMIT_MIR jump_threading.too_complex.JumpThreading.diff
@@ -476,3 +493,4 @@ fn main() {
476493
// EMIT_MIR jump_threading.mutable_ref.JumpThreading.diff
477494
// EMIT_MIR jump_threading.renumbered_bb.JumpThreading.diff
478495
// EMIT_MIR jump_threading.disappearing_bb.JumpThreading.diff
496+
// EMIT_MIR jump_threading.aggregate.JumpThreading.diff

0 commit comments

Comments
 (0)