Skip to content

Commit 299b5e2

Browse files
Do not emit undefined lshr/ashr for Neon shifts (#1238)
1 parent 05e32d4 commit 299b5e2

File tree

4 files changed

+66
-21
lines changed

4 files changed

+66
-21
lines changed

crates/core_arch/src/aarch64/neon/mod.rs

+10-4
Original file line numberDiff line numberDiff line change
@@ -2772,7 +2772,8 @@ pub unsafe fn vshld_n_u64<const N: i32>(a: u64) -> u64 {
27722772
#[rustc_legacy_const_generics(1)]
27732773
pub unsafe fn vshrd_n_s64<const N: i32>(a: i64) -> i64 {
27742774
static_assert!(N : i32 where N >= 1 && N <= 64);
2775-
a >> N
2775+
let n: i32 = if N == 64 { 63 } else { N };
2776+
a >> n
27762777
}
27772778

27782779
/// Unsigned shift right
@@ -2782,7 +2783,12 @@ pub unsafe fn vshrd_n_s64<const N: i32>(a: i64) -> i64 {
27822783
#[rustc_legacy_const_generics(1)]
27832784
pub unsafe fn vshrd_n_u64<const N: i32>(a: u64) -> u64 {
27842785
static_assert!(N : i32 where N >= 1 && N <= 64);
2785-
a >> N
2786+
let n: i32 = if N == 64 {
2787+
return 0;
2788+
} else {
2789+
N
2790+
};
2791+
a >> n
27862792
}
27872793

27882794
/// Signed shift right and accumulate
@@ -2792,7 +2798,7 @@ pub unsafe fn vshrd_n_u64<const N: i32>(a: u64) -> u64 {
27922798
#[rustc_legacy_const_generics(2)]
27932799
pub unsafe fn vsrad_n_s64<const N: i32>(a: i64, b: i64) -> i64 {
27942800
static_assert!(N : i32 where N >= 1 && N <= 64);
2795-
a + (b >> N)
2801+
a + vshrd_n_s64::<N>(b)
27962802
}
27972803

27982804
/// Unsigned shift right and accumulate
@@ -2802,7 +2808,7 @@ pub unsafe fn vsrad_n_s64<const N: i32>(a: i64, b: i64) -> i64 {
28022808
#[rustc_legacy_const_generics(2)]
28032809
pub unsafe fn vsrad_n_u64<const N: i32>(a: u64, b: u64) -> u64 {
28042810
static_assert!(N : i32 where N >= 1 && N <= 64);
2805-
a + (b >> N)
2811+
a + vshrd_n_u64::<N>(b)
28062812
}
28072813

28082814
/// Shift Left and Insert (immediate)

crates/core_arch/src/arm_shared/neon/generated.rs

+32-16
Original file line numberDiff line numberDiff line change
@@ -21987,7 +21987,8 @@ pub unsafe fn vshll_n_u32<const N: i32>(a: uint32x2_t) -> uint64x2_t {
2198721987
#[rustc_legacy_const_generics(1)]
2198821988
pub unsafe fn vshr_n_s8<const N: i32>(a: int8x8_t) -> int8x8_t {
2198921989
static_assert!(N : i32 where N >= 1 && N <= 8);
21990-
simd_shr(a, vdup_n_s8(N.try_into().unwrap()))
21990+
let n: i32 = if N == 8 { 7 } else { N };
21991+
simd_shr(a, vdup_n_s8(n.try_into().unwrap()))
2199121992
}
2199221993

2199321994
/// Shift right
@@ -21999,7 +22000,8 @@ pub unsafe fn vshr_n_s8<const N: i32>(a: int8x8_t) -> int8x8_t {
2199922000
#[rustc_legacy_const_generics(1)]
2200022001
pub unsafe fn vshrq_n_s8<const N: i32>(a: int8x16_t) -> int8x16_t {
2200122002
static_assert!(N : i32 where N >= 1 && N <= 8);
22002-
simd_shr(a, vdupq_n_s8(N.try_into().unwrap()))
22003+
let n: i32 = if N == 8 { 7 } else { N };
22004+
simd_shr(a, vdupq_n_s8(n.try_into().unwrap()))
2200322005
}
2200422006

2200522007
/// Shift right
@@ -22011,7 +22013,8 @@ pub unsafe fn vshrq_n_s8<const N: i32>(a: int8x16_t) -> int8x16_t {
2201122013
#[rustc_legacy_const_generics(1)]
2201222014
pub unsafe fn vshr_n_s16<const N: i32>(a: int16x4_t) -> int16x4_t {
2201322015
static_assert!(N : i32 where N >= 1 && N <= 16);
22014-
simd_shr(a, vdup_n_s16(N.try_into().unwrap()))
22016+
let n: i32 = if N == 16 { 15 } else { N };
22017+
simd_shr(a, vdup_n_s16(n.try_into().unwrap()))
2201522018
}
2201622019

2201722020
/// Shift right
@@ -22023,7 +22026,8 @@ pub unsafe fn vshr_n_s16<const N: i32>(a: int16x4_t) -> int16x4_t {
2202322026
#[rustc_legacy_const_generics(1)]
2202422027
pub unsafe fn vshrq_n_s16<const N: i32>(a: int16x8_t) -> int16x8_t {
2202522028
static_assert!(N : i32 where N >= 1 && N <= 16);
22026-
simd_shr(a, vdupq_n_s16(N.try_into().unwrap()))
22029+
let n: i32 = if N == 16 { 15 } else { N };
22030+
simd_shr(a, vdupq_n_s16(n.try_into().unwrap()))
2202722031
}
2202822032

2202922033
/// Shift right
@@ -22035,7 +22039,8 @@ pub unsafe fn vshrq_n_s16<const N: i32>(a: int16x8_t) -> int16x8_t {
2203522039
#[rustc_legacy_const_generics(1)]
2203622040
pub unsafe fn vshr_n_s32<const N: i32>(a: int32x2_t) -> int32x2_t {
2203722041
static_assert!(N : i32 where N >= 1 && N <= 32);
22038-
simd_shr(a, vdup_n_s32(N.try_into().unwrap()))
22042+
let n: i32 = if N == 32 { 31 } else { N };
22043+
simd_shr(a, vdup_n_s32(n.try_into().unwrap()))
2203922044
}
2204022045

2204122046
/// Shift right
@@ -22047,7 +22052,8 @@ pub unsafe fn vshr_n_s32<const N: i32>(a: int32x2_t) -> int32x2_t {
2204722052
#[rustc_legacy_const_generics(1)]
2204822053
pub unsafe fn vshrq_n_s32<const N: i32>(a: int32x4_t) -> int32x4_t {
2204922054
static_assert!(N : i32 where N >= 1 && N <= 32);
22050-
simd_shr(a, vdupq_n_s32(N.try_into().unwrap()))
22055+
let n: i32 = if N == 32 { 31 } else { N };
22056+
simd_shr(a, vdupq_n_s32(n.try_into().unwrap()))
2205122057
}
2205222058

2205322059
/// Shift right
@@ -22059,7 +22065,8 @@ pub unsafe fn vshrq_n_s32<const N: i32>(a: int32x4_t) -> int32x4_t {
2205922065
#[rustc_legacy_const_generics(1)]
2206022066
pub unsafe fn vshr_n_s64<const N: i32>(a: int64x1_t) -> int64x1_t {
2206122067
static_assert!(N : i32 where N >= 1 && N <= 64);
22062-
simd_shr(a, vdup_n_s64(N.try_into().unwrap()))
22068+
let n: i32 = if N == 64 { 63 } else { N };
22069+
simd_shr(a, vdup_n_s64(n.try_into().unwrap()))
2206322070
}
2206422071

2206522072
/// Shift right
@@ -22071,7 +22078,8 @@ pub unsafe fn vshr_n_s64<const N: i32>(a: int64x1_t) -> int64x1_t {
2207122078
#[rustc_legacy_const_generics(1)]
2207222079
pub unsafe fn vshrq_n_s64<const N: i32>(a: int64x2_t) -> int64x2_t {
2207322080
static_assert!(N : i32 where N >= 1 && N <= 64);
22074-
simd_shr(a, vdupq_n_s64(N.try_into().unwrap()))
22081+
let n: i32 = if N == 64 { 63 } else { N };
22082+
simd_shr(a, vdupq_n_s64(n.try_into().unwrap()))
2207522083
}
2207622084

2207722085
/// Shift right
@@ -22083,7 +22091,8 @@ pub unsafe fn vshrq_n_s64<const N: i32>(a: int64x2_t) -> int64x2_t {
2208322091
#[rustc_legacy_const_generics(1)]
2208422092
pub unsafe fn vshr_n_u8<const N: i32>(a: uint8x8_t) -> uint8x8_t {
2208522093
static_assert!(N : i32 where N >= 1 && N <= 8);
22086-
simd_shr(a, vdup_n_u8(N.try_into().unwrap()))
22094+
let n: i32 = if N == 8 { return vdup_n_u8(0); } else { N };
22095+
simd_shr(a, vdup_n_u8(n.try_into().unwrap()))
2208722096
}
2208822097

2208922098
/// Shift right
@@ -22095,7 +22104,8 @@ pub unsafe fn vshr_n_u8<const N: i32>(a: uint8x8_t) -> uint8x8_t {
2209522104
#[rustc_legacy_const_generics(1)]
2209622105
pub unsafe fn vshrq_n_u8<const N: i32>(a: uint8x16_t) -> uint8x16_t {
2209722106
static_assert!(N : i32 where N >= 1 && N <= 8);
22098-
simd_shr(a, vdupq_n_u8(N.try_into().unwrap()))
22107+
let n: i32 = if N == 8 { return vdupq_n_u8(0); } else { N };
22108+
simd_shr(a, vdupq_n_u8(n.try_into().unwrap()))
2209922109
}
2210022110

2210122111
/// Shift right
@@ -22107,7 +22117,8 @@ pub unsafe fn vshrq_n_u8<const N: i32>(a: uint8x16_t) -> uint8x16_t {
2210722117
#[rustc_legacy_const_generics(1)]
2210822118
pub unsafe fn vshr_n_u16<const N: i32>(a: uint16x4_t) -> uint16x4_t {
2210922119
static_assert!(N : i32 where N >= 1 && N <= 16);
22110-
simd_shr(a, vdup_n_u16(N.try_into().unwrap()))
22120+
let n: i32 = if N == 16 { return vdup_n_u16(0); } else { N };
22121+
simd_shr(a, vdup_n_u16(n.try_into().unwrap()))
2211122122
}
2211222123

2211322124
/// Shift right
@@ -22119,7 +22130,8 @@ pub unsafe fn vshr_n_u16<const N: i32>(a: uint16x4_t) -> uint16x4_t {
2211922130
#[rustc_legacy_const_generics(1)]
2212022131
pub unsafe fn vshrq_n_u16<const N: i32>(a: uint16x8_t) -> uint16x8_t {
2212122132
static_assert!(N : i32 where N >= 1 && N <= 16);
22122-
simd_shr(a, vdupq_n_u16(N.try_into().unwrap()))
22133+
let n: i32 = if N == 16 { return vdupq_n_u16(0); } else { N };
22134+
simd_shr(a, vdupq_n_u16(n.try_into().unwrap()))
2212322135
}
2212422136

2212522137
/// Shift right
@@ -22131,7 +22143,8 @@ pub unsafe fn vshrq_n_u16<const N: i32>(a: uint16x8_t) -> uint16x8_t {
2213122143
#[rustc_legacy_const_generics(1)]
2213222144
pub unsafe fn vshr_n_u32<const N: i32>(a: uint32x2_t) -> uint32x2_t {
2213322145
static_assert!(N : i32 where N >= 1 && N <= 32);
22134-
simd_shr(a, vdup_n_u32(N.try_into().unwrap()))
22146+
let n: i32 = if N == 32 { return vdup_n_u32(0); } else { N };
22147+
simd_shr(a, vdup_n_u32(n.try_into().unwrap()))
2213522148
}
2213622149

2213722150
/// Shift right
@@ -22143,7 +22156,8 @@ pub unsafe fn vshr_n_u32<const N: i32>(a: uint32x2_t) -> uint32x2_t {
2214322156
#[rustc_legacy_const_generics(1)]
2214422157
pub unsafe fn vshrq_n_u32<const N: i32>(a: uint32x4_t) -> uint32x4_t {
2214522158
static_assert!(N : i32 where N >= 1 && N <= 32);
22146-
simd_shr(a, vdupq_n_u32(N.try_into().unwrap()))
22159+
let n: i32 = if N == 32 { return vdupq_n_u32(0); } else { N };
22160+
simd_shr(a, vdupq_n_u32(n.try_into().unwrap()))
2214722161
}
2214822162

2214922163
/// Shift right
@@ -22155,7 +22169,8 @@ pub unsafe fn vshrq_n_u32<const N: i32>(a: uint32x4_t) -> uint32x4_t {
2215522169
#[rustc_legacy_const_generics(1)]
2215622170
pub unsafe fn vshr_n_u64<const N: i32>(a: uint64x1_t) -> uint64x1_t {
2215722171
static_assert!(N : i32 where N >= 1 && N <= 64);
22158-
simd_shr(a, vdup_n_u64(N.try_into().unwrap()))
22172+
let n: i32 = if N == 64 { return vdup_n_u64(0); } else { N };
22173+
simd_shr(a, vdup_n_u64(n.try_into().unwrap()))
2215922174
}
2216022175

2216122176
/// Shift right
@@ -22167,7 +22182,8 @@ pub unsafe fn vshr_n_u64<const N: i32>(a: uint64x1_t) -> uint64x1_t {
2216722182
#[rustc_legacy_const_generics(1)]
2216822183
pub unsafe fn vshrq_n_u64<const N: i32>(a: uint64x2_t) -> uint64x2_t {
2216922184
static_assert!(N : i32 where N >= 1 && N <= 64);
22170-
simd_shr(a, vdupq_n_u64(N.try_into().unwrap()))
22185+
let n: i32 = if N == 64 { return vdupq_n_u64(0); } else { N };
22186+
simd_shr(a, vdupq_n_u64(n.try_into().unwrap()))
2217122187
}
2217222188

2217322189
/// Shift right narrow

crates/stdarch-gen/neon.spec

+2-1
Original file line numberDiff line numberDiff line change
@@ -6785,7 +6785,8 @@ name = vshr
67856785
n-suffix
67866786
constn = N
67876787
multi_fn = static_assert-N-1-bits
6788-
multi_fn = simd_shr, a, {vdup-nself-noext, N.try_into().unwrap()}
6788+
multi_fn = fix_right_shift_imm-N-bits
6789+
multi_fn = simd_shr, a, {vdup-nself-noext, n.try_into().unwrap()}
67896790
a = 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64
67906791
n = 2
67916792
validate 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16

crates/stdarch-gen/src/main.rs

+22
Original file line numberDiff line numberDiff line change
@@ -2664,6 +2664,28 @@ fn get_call(
26642664
);
26652665
}
26662666
}
2667+
if fn_name.starts_with("fix_right_shift_imm") {
2668+
let fn_format: Vec<_> = fn_name.split('-').map(|v| v.to_string()).collect();
2669+
let lim = if fn_format[2] == "bits" {
2670+
type_bits(in_t[1]).to_string()
2671+
} else {
2672+
fn_format[2].clone()
2673+
};
2674+
let fixed = if in_t[1].starts_with('u') {
2675+
format!("return vdup{nself}(0);", nself = type_to_n_suffix(in_t[1]))
2676+
} else {
2677+
(lim.parse::<i32>().unwrap() - 1).to_string()
2678+
};
2679+
2680+
return format!(
2681+
r#"let {name}: i32 = if {const_name} == {upper} {{ {fixed} }} else {{ N }};"#,
2682+
name = fn_format[1].to_lowercase(),
2683+
const_name = fn_format[1],
2684+
upper = lim,
2685+
fixed = fixed,
2686+
);
2687+
}
2688+
26672689
if fn_name.starts_with("matchn") {
26682690
let fn_format: Vec<_> = fn_name.split('-').map(|v| v.to_string()).collect();
26692691
let len = match &*fn_format[1] {

0 commit comments

Comments
 (0)