Skip to content

simd intrinsics with mask: accept unsigned integer masks, and fix some of the errors #137953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions compiler/rustc_codegen_gcc/src/intrinsic/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,14 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
m_len == v_len,
InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len }
);
// TODO: also support unsigned integers.
match *m_elem_ty.kind() {
ty::Int(_) => {}
_ => return_error!(InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }),
_ => return_error!(InvalidMonomorphization::MaskWrongElementType {
span,
name,
ty: m_elem_ty
}),
}
return Ok(bx.vector_select(args[0].immediate(), args[1].immediate(), args[2].immediate()));
}
Expand Down Expand Up @@ -987,19 +992,15 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
assert_eq!(pointer_count - 1, ptr_count(element_ty0));
assert_eq!(underlying_ty, non_ptr(element_ty0));

// The element type of the third argument must be a signed integer type of any width:
// The element type of the third argument must be an integer type of any width:
// TODO: also support unsigned integers.
let (_, element_ty2) = arg_tys[2].simd_size_and_type(bx.tcx());
match *element_ty2.kind() {
ty::Int(_) => (),
_ => {
require!(
false,
InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: element_ty2,
third_arg: arg_tys[2]
}
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
);
}
}
Expand Down Expand Up @@ -1105,17 +1106,13 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
assert_eq!(underlying_ty, non_ptr(element_ty0));

// The element type of the third argument must be a signed integer type of any width:
// TODO: also support unsigned integers.
match *element_ty2.kind() {
ty::Int(_) => (),
_ => {
require!(
false,
InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: element_ty2,
third_arg: arg_tys[2]
}
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
);
}
}
Expand Down
56 changes: 12 additions & 44 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1184,18 +1184,6 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}};
}

/// Returns the bitwidth of the `$ty` argument if it is an `Int` type.
macro_rules! require_int_ty {
($ty: expr, $diag: expr) => {
match $ty {
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
_ => {
return_error!($diag);
}
}
};
}

/// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type.
macro_rules! require_int_or_uint_ty {
($ty: expr, $diag: expr) => {
Expand Down Expand Up @@ -1476,9 +1464,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
m_len == v_len,
InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len }
);
let in_elem_bitwidth = require_int_ty!(
let in_elem_bitwidth = require_int_or_uint_ty!(
m_elem_ty.kind(),
InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }
InvalidMonomorphization::MaskWrongElementType { span, name, ty: m_elem_ty }
);
let m_i1s = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, m_len);
return Ok(bx.select(m_i1s, args[1].immediate(), args[2].immediate()));
Expand All @@ -1499,7 +1487,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
// Integer vector <i{in_bitwidth} x in_len>:
let in_elem_bitwidth = require_int_or_uint_ty!(
in_elem.kind(),
InvalidMonomorphization::VectorArgument { span, name, in_ty, in_elem }
InvalidMonomorphization::MaskWrongElementType { span, name, ty: in_elem }
);

let i1xn = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, in_len);
Expand Down Expand Up @@ -1723,14 +1711,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

let mask_elem_bitwidth = require_int_ty!(
let mask_elem_bitwidth = require_int_or_uint_ty!(
element_ty2.kind(),
InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: element_ty2,
third_arg: arg_tys[2]
}
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
);

// Alignment of T, must be a constant integer value:
Expand Down Expand Up @@ -1825,14 +1808,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

let m_elem_bitwidth = require_int_ty!(
let m_elem_bitwidth = require_int_or_uint_ty!(
mask_elem.kind(),
InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: values_elem,
third_arg: mask_ty,
}
InvalidMonomorphization::MaskWrongElementType { span, name, ty: mask_elem }
);

let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
Expand Down Expand Up @@ -1915,14 +1893,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

let m_elem_bitwidth = require_int_ty!(
let m_elem_bitwidth = require_int_or_uint_ty!(
mask_elem.kind(),
InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: values_elem,
third_arg: mask_ty,
}
InvalidMonomorphization::MaskWrongElementType { span, name, ty: mask_elem }
);

let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
Expand Down Expand Up @@ -2010,15 +1983,10 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

// The element type of the third argument must be a signed integer type of any width:
let mask_elem_bitwidth = require_int_ty!(
// The element type of the third argument must be an integer type of any width:
let mask_elem_bitwidth = require_int_or_uint_ty!(
element_ty2.kind(),
InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: element_ty2,
third_arg: arg_tys[2]
}
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
);

// Alignment of T, must be a constant integer value:
Expand Down
7 changes: 1 addition & 6 deletions compiler/rustc_codegen_ssa/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ codegen_ssa_invalid_monomorphization_inserted_type = invalid monomorphization of

codegen_ssa_invalid_monomorphization_invalid_bitmask = invalid monomorphization of `{$name}` intrinsic: invalid bitmask `{$mask_ty}`, expected `u{$expected_int_bits}` or `[u8; {$expected_bytes}]`

codegen_ssa_invalid_monomorphization_mask_type = invalid monomorphization of `{$name}` intrinsic: found mask element type is `{$ty}`, expected a signed integer type
.note = the mask may be widened, which only has the correct behavior for signed integers
codegen_ssa_invalid_monomorphization_mask_wrong_element_type = invalid monomorphization of `{$name}` intrinsic: expected mask element type to be an integer, found `{$ty}`

codegen_ssa_invalid_monomorphization_mismatched_lengths = invalid monomorphization of `{$name}` intrinsic: mismatched lengths: mask length `{$m_len}` != other vector length `{$v_len}`

Expand Down Expand Up @@ -166,8 +165,6 @@ codegen_ssa_invalid_monomorphization_simd_shuffle = invalid monomorphization of

codegen_ssa_invalid_monomorphization_simd_third = invalid monomorphization of `{$name}` intrinsic: expected SIMD third type, found non-SIMD `{$ty}`

codegen_ssa_invalid_monomorphization_third_arg_element_type = invalid monomorphization of `{$name}` intrinsic: expected element type `{$expected_element}` of third argument `{$third_arg}` to be a signed integer type

codegen_ssa_invalid_monomorphization_third_argument_length = invalid monomorphization of `{$name}` intrinsic: expected third argument with length {$in_len} (same as input type `{$in_ty}`), found `{$arg_ty}` with length {$out_len}

codegen_ssa_invalid_monomorphization_unrecognized_intrinsic = invalid monomorphization of `{$name}` intrinsic: unrecognized intrinsic `{$name}`
Expand All @@ -180,8 +177,6 @@ codegen_ssa_invalid_monomorphization_unsupported_symbol = invalid monomorphizati

codegen_ssa_invalid_monomorphization_unsupported_symbol_of_size = invalid monomorphization of `{$name}` intrinsic: unsupported {$symbol} from `{$in_ty}` with element `{$in_elem}` of size `{$size}` to `{$ret_ty}`

codegen_ssa_invalid_monomorphization_vector_argument = invalid monomorphization of `{$name}` intrinsic: vector argument `{$in_ty}`'s element type `{$in_elem}`, expected integer element type

codegen_ssa_invalid_no_sanitize = invalid argument for `no_sanitize`
.note = expected one of: `address`, `cfi`, `hwaddress`, `kcfi`, `memory`, `memtag`, `shadow-call-stack`, or `thread`

Expand Down
23 changes: 2 additions & 21 deletions compiler/rustc_codegen_ssa/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1059,24 +1059,14 @@ pub enum InvalidMonomorphization<'tcx> {
v_len: u64,
},

#[diag(codegen_ssa_invalid_monomorphization_mask_type, code = E0511)]
#[note]
MaskType {
#[diag(codegen_ssa_invalid_monomorphization_mask_wrong_element_type, code = E0511)]
MaskWrongElementType {
#[primary_span]
span: Span,
name: Symbol,
ty: Ty<'tcx>,
},

#[diag(codegen_ssa_invalid_monomorphization_vector_argument, code = E0511)]
VectorArgument {
#[primary_span]
span: Span,
name: Symbol,
in_ty: Ty<'tcx>,
in_elem: Ty<'tcx>,
},

#[diag(codegen_ssa_invalid_monomorphization_cannot_return, code = E0511)]
CannotReturn {
#[primary_span]
Expand All @@ -1099,15 +1089,6 @@ pub enum InvalidMonomorphization<'tcx> {
mutability: ExpectedPointerMutability,
},

#[diag(codegen_ssa_invalid_monomorphization_third_arg_element_type, code = E0511)]
ThirdArgElementType {
#[primary_span]
span: Span,
name: Symbol,
expected_element: Ty<'tcx>,
third_arg: Ty<'tcx>,
},

#[diag(codegen_ssa_invalid_monomorphization_unsupported_symbol_of_size, code = E0511)]
UnsupportedSymbolOfSize {
#[primary_span]
Expand Down
10 changes: 5 additions & 5 deletions library/core/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ pub unsafe fn simd_shuffle<T, U, V>(x: T, y: T, idx: U) -> V;
///
/// `U` must be a vector of pointers to the element type of `T`, with the same length as `T`.
///
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
/// `V` must be a vector of integers with the same length as `T` (but any element size).
///
/// For each pointer in `ptr`, if the corresponding value in `mask` is `!0`, read the pointer.
/// Otherwise if the corresponding value in `mask` is `0`, return the corresponding value from
Expand All @@ -292,7 +292,7 @@ pub unsafe fn simd_gather<T, U, V>(val: T, ptr: U, mask: V) -> T;
///
/// `U` must be a vector of pointers to the element type of `T`, with the same length as `T`.
///
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
/// `V` must be a vector of integers with the same length as `T` (but any element size).
///
/// For each pointer in `ptr`, if the corresponding value in `mask` is `!0`, write the
/// corresponding value in `val` to the pointer.
Expand All @@ -316,7 +316,7 @@ pub unsafe fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);
///
/// `U` must be a pointer to the element type of `T`
///
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
/// `V` must be a vector of integers with the same length as `T` (but any element size).
///
/// For each element, if the corresponding value in `mask` is `!0`, read the corresponding
/// pointer offset from `ptr`.
Expand All @@ -339,7 +339,7 @@ pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T) -> T;
///
/// `U` must be a pointer to the element type of `T`
///
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
/// `V` must be a vector of integers with the same length as `T` (but any element size).
///
/// For each element, if the corresponding value in `mask` is `!0`, write the corresponding
/// value in `val` to the pointer offset from `ptr`.
Expand Down Expand Up @@ -523,7 +523,7 @@ pub unsafe fn simd_bitmask<T, U>(x: T) -> U;
///
/// `T` must be a vector.
///
/// `M` must be a signed integer vector with the same length as `T` (but any element size).
/// `M` must be an integer vector with the same length as `T` (but any element size).
///
/// For each element, if the corresponding value in `mask` is `!0`, select the element from
/// `if_true`. If the corresponding value in `mask` is `0`, select the element from
Expand Down
5 changes: 5 additions & 0 deletions src/tools/miri/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,11 @@ pub(crate) fn bool_to_simd_element(b: bool, size: Size) -> Scalar {
}

pub(crate) fn simd_element_to_bool(elem: ImmTy<'_>) -> InterpResult<'_, bool> {
assert!(
matches!(elem.layout.ty.kind(), ty::Int(_) | ty::Uint(_)),
"SIMD mask element type must be an integer, but this is `{}`",
elem.layout.ty
);
let val = elem.to_scalar().to_int(elem.layout.size)?;
interp_ok(match val {
0 => false,
Expand Down
13 changes: 13 additions & 0 deletions tests/codegen/simd-intrinsic/simd-intrinsic-generic-gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ pub unsafe fn gather_f32x2(
simd_gather(values, pointers, mask)
}

// CHECK-LABEL: @gather_f32x2_unsigned
#[no_mangle]
pub unsafe fn gather_f32x2_unsigned(
pointers: Vec2<*const f32>,
mask: Vec2<u32>,
values: Vec2<f32>,
) -> Vec2<f32> {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call <2 x float> @llvm.masked.gather.v2f32.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]], <2 x float> {{.*}})
simd_gather(values, pointers, mask)
}

// CHECK-LABEL: @gather_pf32x2
#[no_mangle]
pub unsafe fn gather_pf32x2(
Expand Down
13 changes: 13 additions & 0 deletions tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ pub unsafe fn load_f32x2(mask: Vec2<i32>, pointer: *const f32, values: Vec2<f32>
simd_masked_load(mask, pointer, values)
}

// CHECK-LABEL: @load_f32x2_unsigned
#[no_mangle]
pub unsafe fn load_f32x2_unsigned(
mask: Vec2<u32>,
pointer: *const f32,
values: Vec2<f32>,
) -> Vec2<f32> {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> [[B]], <2 x float> {{.*}})
simd_masked_load(mask, pointer, values)
}

// CHECK-LABEL: @load_pf32x4
#[no_mangle]
pub unsafe fn load_pf32x4(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ pub unsafe fn store_f32x2(mask: Vec2<i32>, pointer: *mut f32, values: Vec2<f32>)
simd_masked_store(mask, pointer, values)
}

// CHECK-LABEL: @store_f32x2_unsigned
#[no_mangle]
pub unsafe fn store_f32x2_unsigned(mask: Vec2<u32>, pointer: *mut f32, values: Vec2<f32>) {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> [[B]])
simd_masked_store(mask, pointer, values)
}

// CHECK-LABEL: @store_pf32x4
#[no_mangle]
pub unsafe fn store_pf32x4(mask: Vec4<i32>, pointer: *mut *const f32, values: Vec4<*const f32>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ pub unsafe fn scatter_f32x2(pointers: Vec2<*mut f32>, mask: Vec2<i32>, values: V
simd_scatter(values, pointers, mask)
}

// CHECK-LABEL: @scatter_f32x2_unsigned
#[no_mangle]
pub unsafe fn scatter_f32x2_unsigned(pointers: Vec2<*mut f32>, mask: Vec2<u32>, values: Vec2<f32>) {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call void @llvm.masked.scatter.v2f32.v2p0(<2 x float> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]]
simd_scatter(values, pointers, mask)
}

// CHECK-LABEL: @scatter_pf32x2
#[no_mangle]
pub unsafe fn scatter_pf32x2(
Expand Down
13 changes: 13 additions & 0 deletions tests/codegen/simd-intrinsic/simd-intrinsic-generic-select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ pub struct b8x4(pub [i8; 4]);
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct i32x4([i32; 4]);

#[repr(simd)]
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct u32x4([u32; 4]);

// CHECK-LABEL: @select_m8
#[no_mangle]
pub unsafe fn select_m8(m: b8x4, a: f32x4, b: f32x4) -> f32x4 {
Expand All @@ -40,6 +44,15 @@ pub unsafe fn select_m32(m: i32x4, a: f32x4, b: f32x4) -> f32x4 {
simd_select(m, a, b)
}

// CHECK-LABEL: @select_m32_unsigned
#[no_mangle]
pub unsafe fn select_m32_unsigned(m: u32x4, a: f32x4, b: f32x4) -> f32x4 {
// CHECK: [[A:%[0-9]+]] = lshr <4 x i32> %{{.*}}, {{<i32 31, i32 31, i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <4 x i32> [[A]] to <4 x i1>
// CHECK: select <4 x i1> [[B]]
simd_select(m, a, b)
}

// CHECK-LABEL: @select_bitmask
#[no_mangle]
pub unsafe fn select_bitmask(m: i8, a: f32x8, b: f32x8) -> f32x8 {
Expand Down
Loading
Loading