Skip to content

Generate bools as bools instead of u8 #809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 67 additions & 98 deletions crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ pub(crate) fn provide(providers: &mut Providers) {
let TyAndLayout { ty, mut layout } =
(rustc_interface::DEFAULT_QUERY_PROVIDERS.layout_of)(tcx, key)?;

// FIXME(eddyb) make use of this - at this point, it's just a placeholder.
#[allow(clippy::match_single_binding)]
#[allow(clippy::match_like_matches_macro)]
let hide_niche = match ty.kind() {
ty::Bool => true,
_ => false,
};

Expand Down Expand Up @@ -284,13 +284,6 @@ enum PointeeDefState {
/// provides a uniform way of translating them.
pub trait ConvSpirvType<'tcx> {
fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word;
/// spirv (and llvm) do not allow storing booleans in memory, they are abstract unsized values.
/// So, if we're dealing with a "memory type", convert bool to u8. The opposite is an
/// "immediate type", which keeps bools as bools. See also the functions `from_immediate` and
/// `to_immediate`, which convert between the two.
fn spirv_type_immediate(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
self.spirv_type(span, cx)
}
}

impl<'tcx> ConvSpirvType<'tcx> for PointeeTy<'tcx> {
Expand All @@ -302,14 +295,6 @@ impl<'tcx> ConvSpirvType<'tcx> for PointeeTy<'tcx> {
.spirv_type(span, cx),
}
}
fn spirv_type_immediate(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
match *self {
PointeeTy::Ty(ty) => ty.spirv_type_immediate(span, cx),
PointeeTy::Fn(ty) => cx
.fn_abi_of_fn_ptr(ty, ty::List::empty())
.spirv_type_immediate(span, cx),
}
}
}

impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
Expand All @@ -318,9 +303,7 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {

let return_type = match self.ret.mode {
PassMode::Ignore => SpirvType::Void.def(span, cx),
PassMode::Direct(_) | PassMode::Pair(..) => {
self.ret.layout.spirv_type_immediate(span, cx)
}
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.spirv_type(span, cx),
PassMode::Cast(_) | PassMode::Indirect { .. } => span_bug!(
span,
"query hooks should've made this `PassMode` impossible: {:#?}",
Expand All @@ -331,14 +314,10 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
for arg in &self.args {
let arg_type = match arg.mode {
PassMode::Ignore => continue,
PassMode::Direct(_) => arg.layout.spirv_type_immediate(span, cx),
PassMode::Direct(_) => arg.layout.spirv_type(span, cx),
PassMode::Pair(_, _) => {
argument_types.push(scalar_pair_element_backend_type(
cx, span, arg.layout, 0, true,
));
argument_types.push(scalar_pair_element_backend_type(
cx, span, arg.layout, 1, true,
));
argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 0));
argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 1));
continue;
}
PassMode::Cast(_) | PassMode::Indirect { .. } => span_bug!(
Expand All @@ -359,77 +338,69 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
}

impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
trans_type_impl(cx, span, *self, false)
}
fn spirv_type_immediate(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
trans_type_impl(cx, span, *self, true)
}
}

fn trans_type_impl<'tcx>(
cx: &CodegenCx<'tcx>,
mut span: Span,
ty: TyAndLayout<'tcx>,
is_immediate: bool,
) -> Word {
if let TyKind::Adt(adt, substs) = *ty.ty.kind() {
if span == DUMMY_SP {
span = cx.tcx.def_span(adt.did);
}
fn spirv_type(&self, mut span: Span, cx: &CodegenCx<'tcx>) -> Word {
if let TyKind::Adt(adt, substs) = *self.ty.kind() {
if span == DUMMY_SP {
span = cx.tcx.def_span(adt.did);
}

let attrs = AggregatedSpirvAttributes::parse(cx, cx.tcx.get_attrs(adt.did));
let attrs = AggregatedSpirvAttributes::parse(cx, cx.tcx.get_attrs(adt.did));

if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value) {
if let Ok(spirv_type) = trans_intrinsic_type(cx, span, ty, substs, intrinsic_type_attr)
{
return spirv_type;
if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value) {
if let Ok(spirv_type) =
trans_intrinsic_type(cx, span, *self, substs, intrinsic_type_attr)
{
return spirv_type;
}
}
}
}

// Note: ty.layout is orthogonal to ty.ty, e.g. `ManuallyDrop<Result<isize, isize>>` has abi
// `ScalarPair`.
// There's a few layers that we go through here. First we inspect layout.abi, then if relevant, layout.fields, etc.
match ty.abi {
Abi::Uninhabited => SpirvType::Adt {
def_id: def_id_for_spirv_type_adt(ty),
size: Some(Size::ZERO),
align: Align::from_bytes(0).unwrap(),
field_types: Vec::new(),
field_offsets: Vec::new(),
field_names: None,
}
.def_with_name(cx, span, TyLayoutNameKey::from(ty)),
Abi::Scalar(ref scalar) => trans_scalar(cx, span, ty, scalar, Size::ZERO, is_immediate),
Abi::ScalarPair(ref a, ref b) => {
// Note: We can't use auto_struct_layout here because the spirv types here might be undefined due to
// recursive pointer types.
let a_offset = Size::ZERO;
let b_offset = a.value.size(cx).align_to(b.value.align(cx).abi);
// Note! Do not pass through is_immediate here - they're wrapped in a struct, hence, not immediate.
let a = trans_scalar(cx, span, ty, a, a_offset, false);
let b = trans_scalar(cx, span, ty, b, b_offset, false);
let size = if ty.is_unsized() { None } else { Some(ty.size) };
SpirvType::Adt {
def_id: def_id_for_spirv_type_adt(ty),
size,
align: ty.align.abi,
field_types: vec![a, b],
field_offsets: vec![a_offset, b_offset],
// Note: ty.layout is orthogonal to ty.ty, e.g. `ManuallyDrop<Result<isize, isize>>` has abi
// `ScalarPair`.
// There's a few layers that we go through here. First we inspect layout.abi, then if relevant, layout.fields, etc.
match self.abi {
Abi::Uninhabited => SpirvType::Adt {
def_id: def_id_for_spirv_type_adt(*self),
size: Some(Size::ZERO),
align: Align::from_bytes(0).unwrap(),
field_types: Vec::new(),
field_offsets: Vec::new(),
field_names: None,
}
.def_with_name(cx, span, TyLayoutNameKey::from(ty))
}
Abi::Vector { ref element, count } => {
let elem_spirv = trans_scalar(cx, span, ty, element, Size::ZERO, false);
SpirvType::Vector {
element: elem_spirv,
count: count as u32,
.def_with_name(cx, span, TyLayoutNameKey::from(*self)),
Abi::Scalar(ref scalar) => trans_scalar(cx, span, *self, scalar, Size::ZERO),
Abi::ScalarPair(ref a, ref b) => {
// Note: We can't use auto_struct_layout here because the spirv types here might be undefined due to
// recursive pointer types.
let a_offset = Size::ZERO;
let b_offset = a.value.size(cx).align_to(b.value.align(cx).abi);
let a = trans_scalar(cx, span, *self, a, a_offset);
let b = trans_scalar(cx, span, *self, b, b_offset);
let size = if self.is_unsized() {
None
} else {
Some(self.size)
};
SpirvType::Adt {
def_id: def_id_for_spirv_type_adt(*self),
size,
align: self.align.abi,
field_types: vec![a, b],
field_offsets: vec![a_offset, b_offset],
field_names: None,
}
.def_with_name(cx, span, TyLayoutNameKey::from(*self))
}
Abi::Vector { ref element, count } => {
let elem_spirv = trans_scalar(cx, span, *self, element, Size::ZERO);
SpirvType::Vector {
element: elem_spirv,
count: count as u32,
}
.def(span, cx)
}
.def(span, cx)
Abi::Aggregate { sized: _ } => trans_aggregate(cx, span, *self),
}
Abi::Aggregate { sized: _ } => trans_aggregate(cx, span, ty),
}
}

Expand All @@ -440,7 +411,6 @@ pub fn scalar_pair_element_backend_type<'tcx>(
span: Span,
ty: TyAndLayout<'tcx>,
index: usize,
is_immediate: bool,
) -> Word {
let [a, b] = match &ty.layout.abi {
Abi::ScalarPair(a, b) => [a, b],
Expand All @@ -455,7 +425,7 @@ pub fn scalar_pair_element_backend_type<'tcx>(
1 => a.value.size(cx).align_to(b.value.align(cx).abi),
_ => unreachable!(),
};
trans_scalar(cx, span, ty, [a, b][index], offset, is_immediate)
trans_scalar(cx, span, ty, [a, b][index], offset)
}

/// A "scalar" is a basic building block: bools, ints, floats, pointers. (i.e. not something complex like a struct)
Expand All @@ -471,9 +441,8 @@ fn trans_scalar<'tcx>(
ty: TyAndLayout<'tcx>,
scalar: &Scalar,
offset: Size,
is_immediate: bool,
) -> Word {
if is_immediate && scalar.is_bool() {
if scalar.is_bool() {
return SpirvType::Bool.def(span, cx);
}

Expand Down Expand Up @@ -608,7 +577,7 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
}
}
FieldsShape::Array { stride, count } => {
let element_type = trans_type_impl(cx, span, ty.field(cx, 0), false);
let element_type = ty.field(cx, 0).spirv_type(span, cx);
if ty.is_unsized() {
// There's a potential for this array to be sized, but the element to be unsized, e.g. `[[u8]; 5]`.
// However, I think rust disallows all these cases, so assert this here.
Expand Down Expand Up @@ -676,7 +645,7 @@ fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -
let mut field_names = Vec::new();
for i in ty.fields.index_by_increasing_offset() {
let field_ty = ty.field(cx, i);
field_types.push(trans_type_impl(cx, span, field_ty, false));
field_types.push(field_ty.spirv_type(span, cx));
let offset = ty.fields.offset(i);
field_offsets.push(offset);
if let Variants::Single { index } = ty.variants {
Expand Down Expand Up @@ -887,7 +856,7 @@ fn trans_intrinsic_type<'tcx>(
// The spirv type of it will be generated by querying the type of the first generic.
if let Some(image_ty) = substs.types().next() {
// TODO: enforce that the generic param is an image type?
let image_type = trans_type_impl(cx, span, cx.layout_of(image_ty), false);
let image_type = cx.layout_of(image_ty).spirv_type(span, cx);
Ok(SpirvType::SampledImage { image_type }.def(span, cx))
} else {
cx.tcx
Expand All @@ -907,7 +876,7 @@ fn trans_intrinsic_type<'tcx>(
// We use a generic to indicate the underlying element type.
// The spirv type of it will be generated by querying the type of the first generic.
if let Some(elem_ty) = substs.types().next() {
let element = trans_type_impl(cx, span, cx.layout_of(elem_ty), false);
let element = cx.layout_of(elem_ty).spirv_type(span, cx);
Ok(SpirvType::RuntimeArray { element }.def(span, cx))
} else {
cx.tcx
Expand All @@ -922,7 +891,7 @@ fn trans_intrinsic_type<'tcx>(
.expect("#[spirv(matrix)] must be added to a type which has DefId");

let field_types = (0..ty.fields.count())
.map(|i| trans_type_impl(cx, span, ty.field(cx, i), false))
.map(|i| ty.field(cx, i).spirv_type(span, cx))
.collect::<Vec<_>>();
if field_types.len() < 2 {
cx.tcx
Expand Down
24 changes: 13 additions & 11 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,22 +782,24 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
result
}

// rustc has the concept of an immediate vs. memory type - bools are compiled to LLVM bools as
// immediates, but if they're behind a pointer, they're compiled to u8. The reason for this is
// because LLVM is bad at bools behind pointers (something something u1 bitmasking on load).
//
// SPIR-V allows bools behind *some* pointers, and disallows others - specifically, it allows
// bools behind the storage classes Workgroup, CrossWorkgroup, Private, Function, Input, and
// Output. In other words, "For stuff the CPU can't see, bools are OK. For stuff the CPU *can*
// see, no bools allowed". So, we always compile rust bools to SPIR-V bools instead of u8 as
// rustc does, even if they're behind a pointer, and error if bools are in an interface (the
// user should choose u8, u32, or something else instead). That means that immediate types and
// memory types are the same, and no conversion needs to happen here.
fn from_immediate(&mut self, val: Self::Value) -> Self::Value {
if self.lookup_type(val.ty) == SpirvType::Bool {
let i8 = SpirvType::Integer(8, false).def(self.span(), self);
self.zext(val, i8)
} else {
val
}
val
}

// silly clippy, we can't rename this!
#[allow(clippy::wrong_self_convention)]
fn to_immediate_scalar(&mut self, val: Self::Value, scalar: Scalar) -> Self::Value {
if scalar.is_bool() {
let bool = SpirvType::Bool.def(self.span(), self);
return self.trunc(val, bool);
}
fn to_immediate_scalar(&mut self, val: Self::Value, _scalar: Scalar) -> Self::Value {
val
}

Expand Down
16 changes: 3 additions & 13 deletions crates/rustc_codegen_spirv/src/builder_spirv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,25 +407,15 @@ impl BuilderSpirv {
}

pub fn dump_module_str(&self) -> String {
let mut module = self.builder.borrow().module_ref().clone();
let mut header = rspirv::dr::ModuleHeader::new(0);
header.set_version(0, 0);
module.header = Some(header);
module.disassemble()
self.builder.borrow().module_ref().disassemble()
}

/// Helper function useful to place right before a crash, to debug the module state.
pub fn dump_module(&self, path: impl AsRef<Path>) {
let mut module = self.builder.borrow().module_ref().clone();
let mut header = rspirv::dr::ModuleHeader::new(0);
header.set_version(0, 0);
module.header = Some(header);
let disas = module.disassemble();
println!("{}", disas);
let spirv_module = module.assemble();
let module = self.builder.borrow().module_ref().assemble();
File::create(path)
.unwrap()
.write_all(spirv_tools::binary::from_binary(&spirv_module))
.write_all(spirv_tools::binary::from_binary(&module))
.unwrap();
}

Expand Down
Loading