Skip to content

Commit 4f0f4c7

Browse files
committed
Auto merge of #143784 - scottmcm:enums-again-new-ex2, r=<try>
Simplify codegen for niche-encoded variant tests Inspired by #139729, this attempts to be a much-simpler and more-localized change while still making a difference. (Specifically, this does not try to solve the problem with select-sinking, leaving that to be fixed by llvm/llvm-project#134024 -- once it gets released -- instead of in rustc's codegen.) What this *does* improve is checking for the variant in a 3+ variant enum when that variant is the type providing the niche. Something like `if let Foo::WithBool(_) = ...` previously compiled to `ugt(add(x, -2), 2)`, which is non-trivial to think about because it's depending on the unsigned wrapping to shift the 0/1 up above 2. With this PR it compiles to just `ult(x, 2)`, which is probably what you'd have written yourself if you were doing it by hand to look for "is this byte a bool?". That's done by leaving most of the codegen alone, but adding a couple new special cases to the `is_niche` check. The default looks at the relative discriminant, but in the common cases where there's no wraparound involved, we can just check the original value, rather than the offsetted one. The first commit just adds some tests, so the best way to see the effect of this change is to look at the second commit and [how it updates the test expectations](da52d97#diff-14bab05dc3e3448a531a97fafed38bf775095cc68f7997af1721a4c3dc58eb47R218-R223).
2 parents bfc046a + d5bcfb3 commit 4f0f4c7

File tree

4 files changed

+623
-38
lines changed

4 files changed

+623
-38
lines changed

compiler/rustc_abi/src/lib.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use std::fmt;
4343
#[cfg(feature = "nightly")]
4444
use std::iter::Step;
4545
use std::num::{NonZeroUsize, ParseIntError};
46-
use std::ops::{Add, AddAssign, Deref, Mul, RangeInclusive, Sub};
46+
use std::ops::{Add, AddAssign, Deref, Mul, RangeFull, RangeInclusive, Sub};
4747
use std::str::FromStr;
4848

4949
use bitflags::bitflags;
@@ -1391,12 +1391,45 @@ impl WrappingRange {
13911391
}
13921392

13931393
/// Returns `true` if `size` completely fills the range.
1394+
///
1395+
/// Note that this is *not* the same as `self == WrappingRange::full(size)`.
1396+
/// Niche calculations can produce full ranges which are not the canonical one;
1397+
/// for example `Option<NonZero<u16>>` gets `valid_range: (..=0) | (1..)`.
13941398
#[inline]
13951399
fn is_full_for(&self, size: Size) -> bool {
13961400
let max_value = size.unsigned_int_max();
13971401
debug_assert!(self.start <= max_value && self.end <= max_value);
13981402
self.start == (self.end.wrapping_add(1) & max_value)
13991403
}
1404+
1405+
/// Checks whether this range is considered non-wrapping when the values are
1406+
/// interpreted as *unsigned* numbers of width `size`.
1407+
///
1408+
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
1409+
/// and `Err(..)` if the range is full so it depends how you think about it.
1410+
#[inline]
1411+
pub fn no_unsigned_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
1412+
if self.is_full_for(size) { Err(..) } else { Ok(self.start <= self.end) }
1413+
}
1414+
1415+
/// Checks whether this range is considered non-wrapping when the values are
1416+
/// interpreted as *signed* numbers of width `size`.
1417+
///
1418+
/// This is heavily dependent on the `size`, as `100..=200` does wrap when
1419+
/// interpreted as `i8`, but doesn't when interpreted as `i16`.
1420+
///
1421+
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
1422+
/// and `Err(..)` if the range is full so it depends how you think about it.
1423+
#[inline]
1424+
pub fn no_signed_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
1425+
if self.is_full_for(size) {
1426+
Err(..)
1427+
} else {
1428+
let start: i128 = size.sign_extend(self.start);
1429+
let end: i128 = size.sign_extend(self.end);
1430+
Ok(start <= end)
1431+
}
1432+
}
14001433
}
14011434

14021435
impl fmt::Debug for WrappingRange {

compiler/rustc_codegen_ssa/src/mir/operand.rs

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
486486
// value and the variant index match, since that's all `Niche` can encode.
487487

488488
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
489+
let niche_start_const = bx.cx().const_uint_big(tag_llty, niche_start);
489490

490491
// We have a subrange `niche_start..=niche_end` inside `range`.
491492
// If the value of the tag is inside this subrange, it's a
@@ -511,35 +512,44 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
511512
// } else {
512513
// untagged_variant
513514
// }
514-
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
515-
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
515+
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start_const);
516516
let tagged_discr =
517517
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
518518
(is_niche, tagged_discr, 0)
519519
} else {
520520
// The special cases don't apply, so we'll have to go with
521521
// the general algorithm.
522-
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
522+
523+
let tag_range = tag_scalar.valid_range(&dl);
524+
let tag_size = tag_scalar.size(&dl);
525+
let niche_end = u128::from(relative_max).wrapping_add(niche_start);
526+
let niche_end = tag_size.truncate(niche_end);
527+
528+
let relative_discr = bx.sub(tag, niche_start_const);
523529
let cast_tag = bx.intcast(relative_discr, cast_to, false);
524-
let is_niche = bx.icmp(
525-
IntPredicate::IntULE,
526-
relative_discr,
527-
bx.cx().const_uint(tag_llty, relative_max as u64),
528-
);
529-
530-
// Thanks to parameter attributes and load metadata, LLVM already knows
531-
// the general valid range of the tag. It's possible, though, for there
532-
// to be an impossible value *in the middle*, which those ranges don't
533-
// communicate, so it's worth an `assume` to let the optimizer know.
534-
if niche_variants.contains(&untagged_variant)
535-
&& bx.cx().sess().opts.optimize != OptLevel::No
536-
{
537-
let impossible =
538-
u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32());
539-
let impossible = bx.cx().const_uint(tag_llty, impossible);
540-
let ne = bx.icmp(IntPredicate::IntNE, relative_discr, impossible);
541-
bx.assume(ne);
542-
}
530+
let is_niche = if tag_range.no_unsigned_wraparound(tag_size) == Ok(true) {
531+
if niche_start == tag_range.start {
532+
let niche_end_const = bx.cx().const_uint_big(tag_llty, niche_end);
533+
bx.icmp(IntPredicate::IntULE, tag, niche_end_const)
534+
} else {
535+
assert_eq!(niche_end, tag_range.end);
536+
bx.icmp(IntPredicate::IntUGE, tag, niche_start_const)
537+
}
538+
} else if tag_range.no_signed_wraparound(tag_size) == Ok(true) {
539+
if niche_start == tag_range.start {
540+
let niche_end_const = bx.cx().const_uint_big(tag_llty, niche_end);
541+
bx.icmp(IntPredicate::IntSLE, tag, niche_end_const)
542+
} else {
543+
assert_eq!(niche_end, tag_range.end);
544+
bx.icmp(IntPredicate::IntSGE, tag, niche_start_const)
545+
}
546+
} else {
547+
bx.icmp(
548+
IntPredicate::IntULE,
549+
relative_discr,
550+
bx.cx().const_uint(tag_llty, relative_max as u64),
551+
)
552+
};
543553

544554
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
545555
};
@@ -550,11 +560,24 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
550560
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
551561
};
552562

553-
let discr = bx.select(
554-
is_niche,
555-
tagged_discr,
556-
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
557-
);
563+
let untagged_variant_const =
564+
bx.cx().const_uint(cast_to, u64::from(untagged_variant.as_u32()));
565+
566+
// Thanks to parameter attributes and load metadata, LLVM already knows
567+
// the general valid range of the tag. It's possible, though, for there
568+
// to be an impossible value *in the middle*, which those ranges don't
569+
// communicate, so it's worth an `assume` to let the optimizer know.
570+
// Most importantly, this means when optimizing a variant test like
571+
// `SELECT(is_niche, complex, CONST) == CONST` it's ok to simplify that
572+
// to `!is_niche` because the `complex` part can't possibly match.
573+
if niche_variants.contains(&untagged_variant)
574+
&& bx.cx().sess().opts.optimize != OptLevel::No
575+
{
576+
let ne = bx.icmp(IntPredicate::IntNE, tagged_discr, untagged_variant_const);
577+
bx.assume(ne);
578+
}
579+
580+
let discr = bx.select(is_niche, tagged_discr, untagged_variant_const);
558581

559582
// In principle we could insert assumes on the possible range of `discr`, but
560583
// currently in LLVM this isn't worth it because the original `tag` will
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
//@ compile-flags: -Copt-level=3 -Zmerge-functions=disabled
2+
//@ min-llvm-version: 20
3+
//@ only-64bit
4+
5+
// The `derive(PartialEq)` on enums with field-less variants compares discriminants,
6+
// so make sure we emit that in some reasonable way.
7+
8+
#![crate_type = "lib"]
9+
#![feature(ascii_char)]
10+
#![feature(core_intrinsics)]
11+
#![feature(repr128)]
12+
13+
use std::ascii::Char as AC;
14+
use std::cmp::Ordering;
15+
use std::intrinsics::discriminant_value;
16+
use std::num::NonZero;
17+
18+
// A type that's bigger than `isize`, unlike the usual cases that have small tags.
19+
#[repr(u128)]
20+
pub enum Giant {
21+
Two = 2,
22+
Three = 3,
23+
Four = 4,
24+
}
25+
26+
#[unsafe(no_mangle)]
27+
pub fn opt_bool_eq_discr(a: Option<bool>, b: Option<bool>) -> bool {
28+
// CHECK-LABEL: @opt_bool_eq_discr(
29+
// CHECK: %[[A:.+]] = icmp ne i8 %a, 2
30+
// CHECK: %[[B:.+]] = icmp eq i8 %b, 2
31+
// CHECK: %[[R:.+]] = xor i1 %[[A]], %[[B]]
32+
// CHECK: ret i1 %[[R]]
33+
34+
discriminant_value(&a) == discriminant_value(&b)
35+
}
36+
37+
#[unsafe(no_mangle)]
38+
pub fn opt_ord_eq_discr(a: Option<Ordering>, b: Option<Ordering>) -> bool {
39+
// CHECK-LABEL: @opt_ord_eq_discr(
40+
// CHECK: %[[A:.+]] = icmp ne i8 %a, 2
41+
// CHECK: %[[B:.+]] = icmp eq i8 %b, 2
42+
// CHECK: %[[R:.+]] = xor i1 %[[A]], %[[B]]
43+
// CHECK: ret i1 %[[R]]
44+
45+
discriminant_value(&a) == discriminant_value(&b)
46+
}
47+
48+
#[unsafe(no_mangle)]
49+
pub fn opt_nz32_eq_discr(a: Option<NonZero<u32>>, b: Option<NonZero<u32>>) -> bool {
50+
// CHECK-LABEL: @opt_nz32_eq_discr(
51+
// CHECK: %[[A:.+]] = icmp ne i32 %a, 0
52+
// CHECK: %[[B:.+]] = icmp eq i32 %b, 0
53+
// CHECK: %[[R:.+]] = xor i1 %[[A]], %[[B]]
54+
// CHECK: ret i1 %[[R]]
55+
56+
discriminant_value(&a) == discriminant_value(&b)
57+
}
58+
59+
#[unsafe(no_mangle)]
60+
pub fn opt_ac_eq_discr(a: Option<AC>, b: Option<AC>) -> bool {
61+
// CHECK-LABEL: @opt_ac_eq_discr(
62+
// CHECK: %[[A:.+]] = icmp ne i8 %a, -128
63+
// CHECK: %[[B:.+]] = icmp eq i8 %b, -128
64+
// CHECK: %[[R:.+]] = xor i1 %[[A]], %[[B]]
65+
// CHECK: ret i1 %[[R]]
66+
67+
discriminant_value(&a) == discriminant_value(&b)
68+
}
69+
70+
#[unsafe(no_mangle)]
71+
pub fn opt_giant_eq_discr(a: Option<Giant>, b: Option<Giant>) -> bool {
72+
// CHECK-LABEL: @opt_giant_eq_discr(
73+
// CHECK: %[[A:.+]] = icmp ne i128 %a, 1
74+
// CHECK: %[[B:.+]] = icmp eq i128 %b, 1
75+
// CHECK: %[[R:.+]] = xor i1 %[[A]], %[[B]]
76+
// CHECK: ret i1 %[[R]]
77+
78+
discriminant_value(&a) == discriminant_value(&b)
79+
}
80+
81+
pub enum Mid<T> {
82+
Before,
83+
Thing(T),
84+
After,
85+
}
86+
87+
#[unsafe(no_mangle)]
88+
pub fn mid_bool_eq_discr(a: Mid<bool>, b: Mid<bool>) -> bool {
89+
// CHECK-LABEL: @mid_bool_eq_discr(
90+
91+
// CHECK: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2
92+
// CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i8 %a, 1
93+
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %[[A_REL_DISCR]], 1
94+
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
95+
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1
96+
97+
// CHECK: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2
98+
// CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i8 %b, 1
99+
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %[[B_REL_DISCR]], 1
100+
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
101+
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1
102+
103+
// CHECK: ret i1 %[[R]]
104+
discriminant_value(&a) == discriminant_value(&b)
105+
}
106+
107+
#[unsafe(no_mangle)]
108+
pub fn mid_ord_eq_discr(a: Mid<Ordering>, b: Mid<Ordering>) -> bool {
109+
// CHECK-LABEL: @mid_ord_eq_discr(
110+
111+
// CHECK: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2
112+
// CHECK: %[[A_IS_NICHE:.+]] = icmp sgt i8 %a, 1
113+
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %[[A_REL_DISCR]], 1
114+
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
115+
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1
116+
117+
// CHECK: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2
118+
// CHECK: %[[B_IS_NICHE:.+]] = icmp sgt i8 %b, 1
119+
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %[[B_REL_DISCR]], 1
120+
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
121+
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1
122+
123+
// CHECK: %[[R:.+]] = icmp eq i8 %[[A_DISCR]], %[[B_DISCR]]
124+
// CHECK: ret i1 %[[R]]
125+
discriminant_value(&a) == discriminant_value(&b)
126+
}
127+
128+
#[unsafe(no_mangle)]
129+
pub fn mid_nz32_eq_discr(a: Mid<NonZero<u32>>, b: Mid<NonZero<u32>>) -> bool {
130+
// CHECK-LABEL: @mid_nz32_eq_discr(
131+
// CHECK: %[[R:.+]] = icmp eq i32 %a.0, %b.0
132+
// CHECK: ret i1 %[[R]]
133+
discriminant_value(&a) == discriminant_value(&b)
134+
}
135+
136+
#[unsafe(no_mangle)]
137+
pub fn mid_ac_eq_discr(a: Mid<AC>, b: Mid<AC>) -> bool {
138+
// CHECK-LABEL: @mid_ac_eq_discr(
139+
140+
// CHECK: %[[A_REL_DISCR:.+]] = xor i8 %a, -128
141+
// CHECK: %[[A_IS_NICHE:.+]] = icmp slt i8 %a, 0
142+
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %a, -127
143+
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
144+
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1
145+
146+
// CHECK: %[[B_REL_DISCR:.+]] = xor i8 %b, -128
147+
// CHECK: %[[B_IS_NICHE:.+]] = icmp slt i8 %b, 0
148+
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %b, -127
149+
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
150+
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1
151+
152+
// CHECK: %[[R:.+]] = icmp eq i8 %[[A_DISCR]], %[[B_DISCR]]
153+
// CHECK: ret i1 %[[R]]
154+
discriminant_value(&a) == discriminant_value(&b)
155+
}
156+
157+
// FIXME: This should be improved once our LLVM fork picks up the fix for
158+
// <https://github.com/llvm/llvm-project/issues/134024>
159+
#[unsafe(no_mangle)]
160+
pub fn mid_giant_eq_discr(a: Mid<Giant>, b: Mid<Giant>) -> bool {
161+
// CHECK-LABEL: @mid_giant_eq_discr(
162+
163+
// CHECK: %[[A_TRUNC:.+]] = trunc nuw nsw i128 %a to i64
164+
// CHECK: %[[A_REL_DISCR:.+]] = add nsw i64 %[[A_TRUNC]], -5
165+
// CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i128 %a, 4
166+
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i64 %[[A_REL_DISCR]], 1
167+
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
168+
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i64 %[[A_REL_DISCR]], i64 1
169+
170+
// CHECK: %[[B_TRUNC:.+]] = trunc nuw nsw i128 %b to i64
171+
// CHECK: %[[B_REL_DISCR:.+]] = add nsw i64 %[[B_TRUNC]], -5
172+
// CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i128 %b, 4
173+
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i64 %[[B_REL_DISCR]], 1
174+
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
175+
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i64 %[[B_REL_DISCR]], i64 1
176+
177+
// CHECK: %[[R:.+]] = icmp eq i64 %[[A_DISCR]], %[[B_DISCR]]
178+
// CHECK: ret i1 %[[R]]
179+
discriminant_value(&a) == discriminant_value(&b)
180+
}
181+
182+
// In niche-encoded enums, testing for the untagged variant should optimize to a
183+
// straight-forward comparison looking for the natural range of the payload value.
184+
185+
#[unsafe(no_mangle)]
186+
pub fn mid_bool_is_thing(a: Mid<bool>) -> bool {
187+
// CHECK-LABEL: @mid_bool_is_thing(
188+
// CHECK: %[[R:.+]] = icmp samesign ult i8 %a, 2
189+
// CHECK: ret i1 %[[R]]
190+
discriminant_value(&a) == 1
191+
}
192+
193+
#[unsafe(no_mangle)]
194+
pub fn mid_ord_is_thing(a: Mid<Ordering>) -> bool {
195+
// CHECK-LABEL: @mid_ord_is_thing(
196+
// CHECK: %[[R:.+]] = icmp slt i8 %a, 2
197+
// CHECK: ret i1 %[[R]]
198+
discriminant_value(&a) == 1
199+
}
200+
201+
#[unsafe(no_mangle)]
202+
pub fn mid_nz32_is_thing(a: Mid<NonZero<u32>>) -> bool {
203+
// CHECK-LABEL: @mid_nz32_is_thing(
204+
// CHECK: %[[R:.+]] = icmp eq i32 %a.0, 1
205+
// CHECK: ret i1 %[[R]]
206+
discriminant_value(&a) == 1
207+
}
208+
209+
#[unsafe(no_mangle)]
210+
pub fn mid_ac_is_thing(a: Mid<AC>) -> bool {
211+
// CHECK-LABEL: @mid_ac_is_thing(
212+
// CHECK: %[[R:.+]] = icmp sgt i8 %a, -1
213+
// CHECK: ret i1 %[[R]]
214+
discriminant_value(&a) == 1
215+
}
216+
217+
#[unsafe(no_mangle)]
218+
pub fn mid_giant_is_thing(a: Mid<Giant>) -> bool {
219+
// CHECK-LABEL: @mid_giant_is_thing(
220+
// CHECK: %[[R:.+]] = icmp samesign ult i128 %a, 5
221+
// CHECK: ret i1 %[[R]]
222+
discriminant_value(&a) == 1
223+
}

0 commit comments

Comments
 (0)