Skip to content

Commit 26cb001

Browse files
committed
simd intrinsics with mask: accept unsigned integer masks, and fix some of the errors
1 parent 81d8edc commit 26cb001

18 files changed

+118
-150
lines changed

compiler/rustc_codegen_gcc/src/intrinsic/simd.rs

+10-16
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,11 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
445445
);
446446
match *m_elem_ty.kind() {
447447
ty::Int(_) => {}
448-
_ => return_error!(InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }),
448+
_ => return_error!(InvalidMonomorphization::MaskWrongElementType {
449+
span,
450+
name,
451+
ty: m_elem_ty
452+
}),
449453
}
450454
return Ok(bx.vector_select(args[0].immediate(), args[1].immediate(), args[2].immediate()));
451455
}
@@ -987,19 +991,14 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
987991
assert_eq!(pointer_count - 1, ptr_count(element_ty0));
988992
assert_eq!(underlying_ty, non_ptr(element_ty0));
989993

990-
// The element type of the third argument must be a signed integer type of any width:
994+
// The element type of the third argument must be an integer type of any width:
991995
let (_, element_ty2) = arg_tys[2].simd_size_and_type(bx.tcx());
992996
match *element_ty2.kind() {
993-
ty::Int(_) => (),
997+
ty::Int(_) | ty::Uint(_) => (),
994998
_ => {
995999
require!(
9961000
false,
997-
InvalidMonomorphization::ThirdArgElementType {
998-
span,
999-
name,
1000-
expected_element: element_ty2,
1001-
third_arg: arg_tys[2]
1002-
}
1001+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
10031002
);
10041003
}
10051004
}
@@ -1106,16 +1105,11 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
11061105

11071106
// The element type of the third argument must be a signed integer type of any width:
11081107
match *element_ty2.kind() {
1109-
ty::Int(_) => (),
1108+
ty::Int(_) | ty::Uint(_) => (),
11101109
_ => {
11111110
require!(
11121111
false,
1113-
InvalidMonomorphization::ThirdArgElementType {
1114-
span,
1115-
name,
1116-
expected_element: element_ty2,
1117-
third_arg: arg_tys[2]
1118-
}
1112+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
11191113
);
11201114
}
11211115
}

compiler/rustc_codegen_llvm/src/intrinsic.rs

+12-44
Original file line numberDiff line numberDiff line change
@@ -1180,18 +1180,6 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
11801180
}};
11811181
}
11821182

1183-
/// Returns the bitwidth of the `$ty` argument if it is an `Int` type.
1184-
macro_rules! require_int_ty {
1185-
($ty: expr, $diag: expr) => {
1186-
match $ty {
1187-
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1188-
_ => {
1189-
return_error!($diag);
1190-
}
1191-
}
1192-
};
1193-
}
1194-
11951183
/// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type.
11961184
macro_rules! require_int_or_uint_ty {
11971185
($ty: expr, $diag: expr) => {
@@ -1472,9 +1460,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14721460
m_len == v_len,
14731461
InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len }
14741462
);
1475-
let in_elem_bitwidth = require_int_ty!(
1463+
let in_elem_bitwidth = require_int_or_uint_ty!(
14761464
m_elem_ty.kind(),
1477-
InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }
1465+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: m_elem_ty }
14781466
);
14791467
let m_i1s = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, m_len);
14801468
return Ok(bx.select(m_i1s, args[1].immediate(), args[2].immediate()));
@@ -1495,7 +1483,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14951483
// Integer vector <i{in_bitwidth} x in_len>:
14961484
let in_elem_bitwidth = require_int_or_uint_ty!(
14971485
in_elem.kind(),
1498-
InvalidMonomorphization::VectorArgument { span, name, in_ty, in_elem }
1486+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: in_elem }
14991487
);
15001488

15011489
let i1xn = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, in_len);
@@ -1719,14 +1707,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
17191707
}
17201708
);
17211709

1722-
let mask_elem_bitwidth = require_int_ty!(
1710+
let mask_elem_bitwidth = require_int_or_uint_ty!(
17231711
element_ty2.kind(),
1724-
InvalidMonomorphization::ThirdArgElementType {
1725-
span,
1726-
name,
1727-
expected_element: element_ty2,
1728-
third_arg: arg_tys[2]
1729-
}
1712+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
17301713
);
17311714

17321715
// Alignment of T, must be a constant integer value:
@@ -1821,14 +1804,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18211804
}
18221805
);
18231806

1824-
let m_elem_bitwidth = require_int_ty!(
1807+
let m_elem_bitwidth = require_int_or_uint_ty!(
18251808
mask_elem.kind(),
1826-
InvalidMonomorphization::ThirdArgElementType {
1827-
span,
1828-
name,
1829-
expected_element: values_elem,
1830-
third_arg: mask_ty,
1831-
}
1809+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: mask_elem }
18321810
);
18331811

18341812
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
@@ -1911,14 +1889,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19111889
}
19121890
);
19131891

1914-
let m_elem_bitwidth = require_int_ty!(
1892+
let m_elem_bitwidth = require_int_or_uint_ty!(
19151893
mask_elem.kind(),
1916-
InvalidMonomorphization::ThirdArgElementType {
1917-
span,
1918-
name,
1919-
expected_element: values_elem,
1920-
third_arg: mask_ty,
1921-
}
1894+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: mask_elem }
19221895
);
19231896

19241897
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
@@ -2006,15 +1979,10 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
20061979
}
20071980
);
20081981

2009-
// The element type of the third argument must be a signed integer type of any width:
2010-
let mask_elem_bitwidth = require_int_ty!(
1982+
// The element type of the third argument must be an integer type of any width:
1983+
let mask_elem_bitwidth = require_int_or_uint_ty!(
20111984
element_ty2.kind(),
2012-
InvalidMonomorphization::ThirdArgElementType {
2013-
span,
2014-
name,
2015-
expected_element: element_ty2,
2016-
third_arg: arg_tys[2]
2017-
}
1985+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
20181986
);
20191987

20201988
// Alignment of T, must be a constant integer value:

compiler/rustc_codegen_ssa/messages.ftl

+1-6
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ codegen_ssa_invalid_monomorphization_inserted_type = invalid monomorphization of
119119
120120
codegen_ssa_invalid_monomorphization_invalid_bitmask = invalid monomorphization of `{$name}` intrinsic: invalid bitmask `{$mask_ty}`, expected `u{$expected_int_bits}` or `[u8; {$expected_bytes}]`
121121
122-
codegen_ssa_invalid_monomorphization_mask_type = invalid monomorphization of `{$name}` intrinsic: found mask element type is `{$ty}`, expected a signed integer type
123-
.note = the mask may be widened, which only has the correct behavior for signed integers
122+
codegen_ssa_invalid_monomorphization_mask_wrong_element_type = invalid monomorphization of `{$name}` intrinsic: found mask element type is `{$ty}`, expected an integer type
124123
125124
codegen_ssa_invalid_monomorphization_mismatched_lengths = invalid monomorphization of `{$name}` intrinsic: mismatched lengths: mask length `{$m_len}` != other vector length `{$v_len}`
126125
@@ -152,8 +151,6 @@ codegen_ssa_invalid_monomorphization_simd_shuffle = invalid monomorphization of
152151
153152
codegen_ssa_invalid_monomorphization_simd_third = invalid monomorphization of `{$name}` intrinsic: expected SIMD third type, found non-SIMD `{$ty}`
154153
155-
codegen_ssa_invalid_monomorphization_third_arg_element_type = invalid monomorphization of `{$name}` intrinsic: expected element type `{$expected_element}` of third argument `{$third_arg}` to be a signed integer type
156-
157154
codegen_ssa_invalid_monomorphization_third_argument_length = invalid monomorphization of `{$name}` intrinsic: expected third argument with length {$in_len} (same as input type `{$in_ty}`), found `{$arg_ty}` with length {$out_len}
158155
159156
codegen_ssa_invalid_monomorphization_unrecognized_intrinsic = invalid monomorphization of `{$name}` intrinsic: unrecognized intrinsic `{$name}`
@@ -166,8 +163,6 @@ codegen_ssa_invalid_monomorphization_unsupported_symbol = invalid monomorphizati
166163
167164
codegen_ssa_invalid_monomorphization_unsupported_symbol_of_size = invalid monomorphization of `{$name}` intrinsic: unsupported {$symbol} from `{$in_ty}` with element `{$in_elem}` of size `{$size}` to `{$ret_ty}`
168165
169-
codegen_ssa_invalid_monomorphization_vector_argument = invalid monomorphization of `{$name}` intrinsic: vector argument `{$in_ty}`'s element type `{$in_elem}`, expected integer element type
170-
171166
codegen_ssa_invalid_no_sanitize = invalid argument for `no_sanitize`
172167
.note = expected one of: `address`, `cfi`, `hwaddress`, `kcfi`, `memory`, `memtag`, `shadow-call-stack`, or `thread`
173168

compiler/rustc_codegen_ssa/src/errors.rs

+2-21
Original file line numberDiff line numberDiff line change
@@ -956,24 +956,14 @@ pub enum InvalidMonomorphization<'tcx> {
956956
v_len: u64,
957957
},
958958

959-
#[diag(codegen_ssa_invalid_monomorphization_mask_type, code = E0511)]
960-
#[note]
961-
MaskType {
959+
#[diag(codegen_ssa_invalid_monomorphization_mask_wrong_element_type, code = E0511)]
960+
MaskWrongElementType {
962961
#[primary_span]
963962
span: Span,
964963
name: Symbol,
965964
ty: Ty<'tcx>,
966965
},
967966

968-
#[diag(codegen_ssa_invalid_monomorphization_vector_argument, code = E0511)]
969-
VectorArgument {
970-
#[primary_span]
971-
span: Span,
972-
name: Symbol,
973-
in_ty: Ty<'tcx>,
974-
in_elem: Ty<'tcx>,
975-
},
976-
977967
#[diag(codegen_ssa_invalid_monomorphization_cannot_return, code = E0511)]
978968
CannotReturn {
979969
#[primary_span]
@@ -996,15 +986,6 @@ pub enum InvalidMonomorphization<'tcx> {
996986
mutability: ExpectedPointerMutability,
997987
},
998988

999-
#[diag(codegen_ssa_invalid_monomorphization_third_arg_element_type, code = E0511)]
1000-
ThirdArgElementType {
1001-
#[primary_span]
1002-
span: Span,
1003-
name: Symbol,
1004-
expected_element: Ty<'tcx>,
1005-
third_arg: Ty<'tcx>,
1006-
},
1007-
1008989
#[diag(codegen_ssa_invalid_monomorphization_unsupported_symbol_of_size, code = E0511)]
1009990
UnsupportedSymbolOfSize {
1010991
#[primary_span]

library/core/src/intrinsics/simd.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ pub unsafe fn simd_shuffle<T, U, V>(_x: T, _y: T, _idx: U) -> V;
271271
///
272272
/// `U` must be a vector of pointers to the element type of `T`, with the same length as `T`.
273273
///
274-
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
274+
/// `V` must be a vector of integers with the same length as `T` (but any element size).
275275
///
276276
/// For each pointer in `ptr`, if the corresponding value in `mask` is `!0`, read the pointer.
277277
/// Otherwise if the corresponding value in `mask` is `0`, return the corresponding value from
@@ -292,7 +292,7 @@ pub unsafe fn simd_gather<T, U, V>(_val: T, _ptr: U, _mask: V) -> T;
292292
///
293293
/// `U` must be a vector of pointers to the element type of `T`, with the same length as `T`.
294294
///
295-
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
295+
/// `V` must be a vector of integers with the same length as `T` (but any element size).
296296
///
297297
/// For each pointer in `ptr`, if the corresponding value in `mask` is `!0`, write the
298298
/// corresponding value in `val` to the pointer.
@@ -316,7 +316,7 @@ pub unsafe fn simd_scatter<T, U, V>(_val: T, _ptr: U, _mask: V);
316316
///
317317
/// `U` must be a pointer to the element type of `T`
318318
///
319-
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
319+
/// `V` must be a vector of integers with the same length as `T` (but any element size).
320320
///
321321
/// For each element, if the corresponding value in `mask` is `!0`, read the corresponding
322322
/// pointer offset from `ptr`.
@@ -339,7 +339,7 @@ pub unsafe fn simd_masked_load<V, U, T>(_mask: V, _ptr: U, _val: T) -> T;
339339
///
340340
/// `U` must be a pointer to the element type of `T`
341341
///
342-
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
342+
/// `V` must be a vector of integers with the same length as `T` (but any element size).
343343
///
344344
/// For each element, if the corresponding value in `mask` is `!0`, write the corresponding
345345
/// value in `val` to the pointer offset from `ptr`.
@@ -523,7 +523,7 @@ pub unsafe fn simd_bitmask<T, U>(_x: T) -> U;
523523
///
524524
/// `T` must be a vector.
525525
///
526-
/// `M` must be a signed integer vector with the same length as `T` (but any element size).
526+
/// `M` must be an integer vector with the same length as `T` (but any element size).
527527
///
528528
/// For each element, if the corresponding value in `mask` is `!0`, select the element from
529529
/// `if_true`. If the corresponding value in `mask` is `0`, select the element from

src/tools/miri/src/helpers.rs

+5
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,11 @@ pub(crate) fn bool_to_simd_element(b: bool, size: Size) -> Scalar {
12951295
}
12961296

12971297
pub(crate) fn simd_element_to_bool(elem: ImmTy<'_>) -> InterpResult<'_, bool> {
1298+
assert!(
1299+
matches!(elem.layout.ty.kind(), ty::Int(_) | ty::Uint(_)),
1300+
"SIMD mask element type must be an integer, but this is `{}`",
1301+
elem.layout.ty
1302+
);
12981303
let val = elem.to_scalar().to_int(elem.layout.size)?;
12991304
interp_ok(match val {
13001305
0 => false,

tests/codegen/simd-intrinsic/simd-intrinsic-generic-gather.rs

+13
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ pub unsafe fn gather_f32x2(
2929
simd_gather(values, pointers, mask)
3030
}
3131

32+
// CHECK-LABEL: @gather_f32x2_unsigned
33+
#[no_mangle]
34+
pub unsafe fn gather_f32x2_unsigned(
35+
pointers: Vec2<*const f32>,
36+
mask: Vec2<u32>,
37+
values: Vec2<f32>,
38+
) -> Vec2<f32> {
39+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
40+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
41+
// CHECK: call <2 x float> @llvm.masked.gather.v2f32.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]], <2 x float> {{.*}})
42+
simd_gather(values, pointers, mask)
43+
}
44+
3245
// CHECK-LABEL: @gather_pf32x2
3346
#[no_mangle]
3447
pub unsafe fn gather_pf32x2(

tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-load.rs

+13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ pub unsafe fn load_f32x2(mask: Vec2<i32>, pointer: *const f32, values: Vec2<f32>
2323
simd_masked_load(mask, pointer, values)
2424
}
2525

26+
// CHECK-LABEL: @load_f32x2_unsigned
27+
#[no_mangle]
28+
pub unsafe fn load_f32x2_unsigned(
29+
mask: Vec2<u32>,
30+
pointer: *const f32,
31+
values: Vec2<f32>,
32+
) -> Vec2<f32> {
33+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
34+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
35+
// CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> [[B]], <2 x float> {{.*}})
36+
simd_masked_load(mask, pointer, values)
37+
}
38+
2639
// CHECK-LABEL: @load_pf32x4
2740
#[no_mangle]
2841
pub unsafe fn load_pf32x4(

tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-store.rs

+9
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ pub unsafe fn store_f32x2(mask: Vec2<i32>, pointer: *mut f32, values: Vec2<f32>)
2323
simd_masked_store(mask, pointer, values)
2424
}
2525

26+
// CHECK-LABEL: @store_f32x2_unsigned
27+
#[no_mangle]
28+
pub unsafe fn store_f32x2_unsigned(mask: Vec2<u32>, pointer: *mut f32, values: Vec2<f32>) {
29+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
30+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
31+
// CHECK: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> [[B]])
32+
simd_masked_store(mask, pointer, values)
33+
}
34+
2635
// CHECK-LABEL: @store_pf32x4
2736
#[no_mangle]
2837
pub unsafe fn store_pf32x4(mask: Vec4<i32>, pointer: *mut *const f32, values: Vec4<*const f32>) {

tests/codegen/simd-intrinsic/simd-intrinsic-generic-scatter.rs

+9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ pub unsafe fn scatter_f32x2(pointers: Vec2<*mut f32>, mask: Vec2<i32>, values: V
2525
simd_scatter(values, pointers, mask)
2626
}
2727

28+
// CHECK-LABEL: @scatter_f32x2_unsigned
29+
#[no_mangle]
30+
pub unsafe fn scatter_f32x2_unsigned(pointers: Vec2<*mut f32>, mask: Vec2<u32>, values: Vec2<f32>) {
31+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
32+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
33+
// CHECK: call void @llvm.masked.scatter.v2f32.v2p0(<2 x float> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]]
34+
simd_scatter(values, pointers, mask)
35+
}
36+
2837
// CHECK-LABEL: @scatter_pf32x2
2938
#[no_mangle]
3039
pub unsafe fn scatter_pf32x2(

tests/codegen/simd-intrinsic/simd-intrinsic-generic-select.rs

+13
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ pub struct b8x4(pub [i8; 4]);
2222
#[derive(Copy, Clone, PartialEq, Debug)]
2323
pub struct i32x4([i32; 4]);
2424

25+
#[repr(simd)]
26+
#[derive(Copy, Clone, PartialEq, Debug)]
27+
pub struct u32x4([u32; 4]);
28+
2529
// CHECK-LABEL: @select_m8
2630
#[no_mangle]
2731
pub unsafe fn select_m8(m: b8x4, a: f32x4, b: f32x4) -> f32x4 {
@@ -40,6 +44,15 @@ pub unsafe fn select_m32(m: i32x4, a: f32x4, b: f32x4) -> f32x4 {
4044
simd_select(m, a, b)
4145
}
4246

47+
// CHECK-LABEL: @select_m32_unsigned
48+
#[no_mangle]
49+
pub unsafe fn select_m32_unsigned(m: u32x4, a: f32x4, b: f32x4) -> f32x4 {
50+
// CHECK: [[A:%[0-9]+]] = lshr <4 x i32> %{{.*}}, {{<i32 31, i32 31, i32 31, i32 31>|splat \(i32 31\)}}
51+
// CHECK: [[B:%[0-9]+]] = trunc <4 x i32> [[A]] to <4 x i1>
52+
// CHECK: select <4 x i1> [[B]]
53+
simd_select(m, a, b)
54+
}
55+
4356
// CHECK-LABEL: @select_bitmask
4457
#[no_mangle]
4558
pub unsafe fn select_bitmask(m: i8, a: f32x8, b: f32x8) -> f32x8 {

0 commit comments

Comments
 (0)