Skip to content

Commit 2e4de94

Browse files
authored
Generate bools as bools instead of u8 (#809)
* Generate bools as bools instead of u8 * convert bool->int select to cast
1 parent c6b7560 commit 2e4de94

23 files changed

+236
-214
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 67 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,9 @@ pub(crate) fn provide(providers: &mut Providers) {
150150
let TyAndLayout { ty, mut layout } =
151151
(rustc_interface::DEFAULT_QUERY_PROVIDERS.layout_of)(tcx, key)?;
152152

153-
// FIXME(eddyb) make use of this - at this point, it's just a placeholder.
154-
#[allow(clippy::match_single_binding)]
153+
#[allow(clippy::match_like_matches_macro)]
155154
let hide_niche = match ty.kind() {
155+
ty::Bool => true,
156156
_ => false,
157157
};
158158

@@ -284,13 +284,6 @@ enum PointeeDefState {
284284
/// provides a uniform way of translating them.
285285
pub trait ConvSpirvType<'tcx> {
286286
fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word;
287-
/// spirv (and llvm) do not allow storing booleans in memory, they are abstract unsized values.
288-
/// So, if we're dealing with a "memory type", convert bool to u8. The opposite is an
289-
/// "immediate type", which keeps bools as bools. See also the functions `from_immediate` and
290-
/// `to_immediate`, which convert between the two.
291-
fn spirv_type_immediate(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
292-
self.spirv_type(span, cx)
293-
}
294287
}
295288

296289
impl<'tcx> ConvSpirvType<'tcx> for PointeeTy<'tcx> {
@@ -302,14 +295,6 @@ impl<'tcx> ConvSpirvType<'tcx> for PointeeTy<'tcx> {
302295
.spirv_type(span, cx),
303296
}
304297
}
305-
fn spirv_type_immediate(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
306-
match *self {
307-
PointeeTy::Ty(ty) => ty.spirv_type_immediate(span, cx),
308-
PointeeTy::Fn(ty) => cx
309-
.fn_abi_of_fn_ptr(ty, ty::List::empty())
310-
.spirv_type_immediate(span, cx),
311-
}
312-
}
313298
}
314299

315300
impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
@@ -318,9 +303,7 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
318303

319304
let return_type = match self.ret.mode {
320305
PassMode::Ignore => SpirvType::Void.def(span, cx),
321-
PassMode::Direct(_) | PassMode::Pair(..) => {
322-
self.ret.layout.spirv_type_immediate(span, cx)
323-
}
306+
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.spirv_type(span, cx),
324307
PassMode::Cast(_) | PassMode::Indirect { .. } => span_bug!(
325308
span,
326309
"query hooks should've made this `PassMode` impossible: {:#?}",
@@ -331,14 +314,10 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
331314
for arg in &self.args {
332315
let arg_type = match arg.mode {
333316
PassMode::Ignore => continue,
334-
PassMode::Direct(_) => arg.layout.spirv_type_immediate(span, cx),
317+
PassMode::Direct(_) => arg.layout.spirv_type(span, cx),
335318
PassMode::Pair(_, _) => {
336-
argument_types.push(scalar_pair_element_backend_type(
337-
cx, span, arg.layout, 0, true,
338-
));
339-
argument_types.push(scalar_pair_element_backend_type(
340-
cx, span, arg.layout, 1, true,
341-
));
319+
argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 0));
320+
argument_types.push(scalar_pair_element_backend_type(cx, span, arg.layout, 1));
342321
continue;
343322
}
344323
PassMode::Cast(_) | PassMode::Indirect { .. } => span_bug!(
@@ -359,77 +338,69 @@ impl<'tcx> ConvSpirvType<'tcx> for FnAbi<'tcx, Ty<'tcx>> {
359338
}
360339

361340
impl<'tcx> ConvSpirvType<'tcx> for TyAndLayout<'tcx> {
362-
fn spirv_type(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
363-
trans_type_impl(cx, span, *self, false)
364-
}
365-
fn spirv_type_immediate(&self, span: Span, cx: &CodegenCx<'tcx>) -> Word {
366-
trans_type_impl(cx, span, *self, true)
367-
}
368-
}
369-
370-
fn trans_type_impl<'tcx>(
371-
cx: &CodegenCx<'tcx>,
372-
mut span: Span,
373-
ty: TyAndLayout<'tcx>,
374-
is_immediate: bool,
375-
) -> Word {
376-
if let TyKind::Adt(adt, substs) = *ty.ty.kind() {
377-
if span == DUMMY_SP {
378-
span = cx.tcx.def_span(adt.did);
379-
}
341+
fn spirv_type(&self, mut span: Span, cx: &CodegenCx<'tcx>) -> Word {
342+
if let TyKind::Adt(adt, substs) = *self.ty.kind() {
343+
if span == DUMMY_SP {
344+
span = cx.tcx.def_span(adt.did);
345+
}
380346

381-
let attrs = AggregatedSpirvAttributes::parse(cx, cx.tcx.get_attrs(adt.did));
347+
let attrs = AggregatedSpirvAttributes::parse(cx, cx.tcx.get_attrs(adt.did));
382348

383-
if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value) {
384-
if let Ok(spirv_type) = trans_intrinsic_type(cx, span, ty, substs, intrinsic_type_attr)
385-
{
386-
return spirv_type;
349+
if let Some(intrinsic_type_attr) = attrs.intrinsic_type.map(|attr| attr.value) {
350+
if let Ok(spirv_type) =
351+
trans_intrinsic_type(cx, span, *self, substs, intrinsic_type_attr)
352+
{
353+
return spirv_type;
354+
}
387355
}
388356
}
389-
}
390357

391-
// Note: ty.layout is orthogonal to ty.ty, e.g. `ManuallyDrop<Result<isize, isize>>` has abi
392-
// `ScalarPair`.
393-
// There's a few layers that we go through here. First we inspect layout.abi, then if relevant, layout.fields, etc.
394-
match ty.abi {
395-
Abi::Uninhabited => SpirvType::Adt {
396-
def_id: def_id_for_spirv_type_adt(ty),
397-
size: Some(Size::ZERO),
398-
align: Align::from_bytes(0).unwrap(),
399-
field_types: Vec::new(),
400-
field_offsets: Vec::new(),
401-
field_names: None,
402-
}
403-
.def_with_name(cx, span, TyLayoutNameKey::from(ty)),
404-
Abi::Scalar(ref scalar) => trans_scalar(cx, span, ty, scalar, Size::ZERO, is_immediate),
405-
Abi::ScalarPair(ref a, ref b) => {
406-
// Note: We can't use auto_struct_layout here because the spirv types here might be undefined due to
407-
// recursive pointer types.
408-
let a_offset = Size::ZERO;
409-
let b_offset = a.value.size(cx).align_to(b.value.align(cx).abi);
410-
// Note! Do not pass through is_immediate here - they're wrapped in a struct, hence, not immediate.
411-
let a = trans_scalar(cx, span, ty, a, a_offset, false);
412-
let b = trans_scalar(cx, span, ty, b, b_offset, false);
413-
let size = if ty.is_unsized() { None } else { Some(ty.size) };
414-
SpirvType::Adt {
415-
def_id: def_id_for_spirv_type_adt(ty),
416-
size,
417-
align: ty.align.abi,
418-
field_types: vec![a, b],
419-
field_offsets: vec![a_offset, b_offset],
358+
// Note: ty.layout is orthogonal to ty.ty, e.g. `ManuallyDrop<Result<isize, isize>>` has abi
359+
// `ScalarPair`.
360+
// There's a few layers that we go through here. First we inspect layout.abi, then if relevant, layout.fields, etc.
361+
match self.abi {
362+
Abi::Uninhabited => SpirvType::Adt {
363+
def_id: def_id_for_spirv_type_adt(*self),
364+
size: Some(Size::ZERO),
365+
align: Align::from_bytes(0).unwrap(),
366+
field_types: Vec::new(),
367+
field_offsets: Vec::new(),
420368
field_names: None,
421369
}
422-
.def_with_name(cx, span, TyLayoutNameKey::from(ty))
423-
}
424-
Abi::Vector { ref element, count } => {
425-
let elem_spirv = trans_scalar(cx, span, ty, element, Size::ZERO, false);
426-
SpirvType::Vector {
427-
element: elem_spirv,
428-
count: count as u32,
370+
.def_with_name(cx, span, TyLayoutNameKey::from(*self)),
371+
Abi::Scalar(ref scalar) => trans_scalar(cx, span, *self, scalar, Size::ZERO),
372+
Abi::ScalarPair(ref a, ref b) => {
373+
// Note: We can't use auto_struct_layout here because the spirv types here might be undefined due to
374+
// recursive pointer types.
375+
let a_offset = Size::ZERO;
376+
let b_offset = a.value.size(cx).align_to(b.value.align(cx).abi);
377+
let a = trans_scalar(cx, span, *self, a, a_offset);
378+
let b = trans_scalar(cx, span, *self, b, b_offset);
379+
let size = if self.is_unsized() {
380+
None
381+
} else {
382+
Some(self.size)
383+
};
384+
SpirvType::Adt {
385+
def_id: def_id_for_spirv_type_adt(*self),
386+
size,
387+
align: self.align.abi,
388+
field_types: vec![a, b],
389+
field_offsets: vec![a_offset, b_offset],
390+
field_names: None,
391+
}
392+
.def_with_name(cx, span, TyLayoutNameKey::from(*self))
393+
}
394+
Abi::Vector { ref element, count } => {
395+
let elem_spirv = trans_scalar(cx, span, *self, element, Size::ZERO);
396+
SpirvType::Vector {
397+
element: elem_spirv,
398+
count: count as u32,
399+
}
400+
.def(span, cx)
429401
}
430-
.def(span, cx)
402+
Abi::Aggregate { sized: _ } => trans_aggregate(cx, span, *self),
431403
}
432-
Abi::Aggregate { sized: _ } => trans_aggregate(cx, span, ty),
433404
}
434405
}
435406

@@ -440,7 +411,6 @@ pub fn scalar_pair_element_backend_type<'tcx>(
440411
span: Span,
441412
ty: TyAndLayout<'tcx>,
442413
index: usize,
443-
is_immediate: bool,
444414
) -> Word {
445415
let [a, b] = match &ty.layout.abi {
446416
Abi::ScalarPair(a, b) => [a, b],
@@ -455,7 +425,7 @@ pub fn scalar_pair_element_backend_type<'tcx>(
455425
1 => a.value.size(cx).align_to(b.value.align(cx).abi),
456426
_ => unreachable!(),
457427
};
458-
trans_scalar(cx, span, ty, [a, b][index], offset, is_immediate)
428+
trans_scalar(cx, span, ty, [a, b][index], offset)
459429
}
460430

461431
/// A "scalar" is a basic building block: bools, ints, floats, pointers. (i.e. not something complex like a struct)
@@ -471,9 +441,8 @@ fn trans_scalar<'tcx>(
471441
ty: TyAndLayout<'tcx>,
472442
scalar: &Scalar,
473443
offset: Size,
474-
is_immediate: bool,
475444
) -> Word {
476-
if is_immediate && scalar.is_bool() {
445+
if scalar.is_bool() {
477446
return SpirvType::Bool.def(span, cx);
478447
}
479448

@@ -608,7 +577,7 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
608577
}
609578
}
610579
FieldsShape::Array { stride, count } => {
611-
let element_type = trans_type_impl(cx, span, ty.field(cx, 0), false);
580+
let element_type = ty.field(cx, 0).spirv_type(span, cx);
612581
if ty.is_unsized() {
613582
// There's a potential for this array to be sized, but the element to be unsized, e.g. `[[u8]; 5]`.
614583
// However, I think rust disallows all these cases, so assert this here.
@@ -676,7 +645,7 @@ fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -
676645
let mut field_names = Vec::new();
677646
for i in ty.fields.index_by_increasing_offset() {
678647
let field_ty = ty.field(cx, i);
679-
field_types.push(trans_type_impl(cx, span, field_ty, false));
648+
field_types.push(field_ty.spirv_type(span, cx));
680649
let offset = ty.fields.offset(i);
681650
field_offsets.push(offset);
682651
if let Variants::Single { index } = ty.variants {
@@ -887,7 +856,7 @@ fn trans_intrinsic_type<'tcx>(
887856
// The spirv type of it will be generated by querying the type of the first generic.
888857
if let Some(image_ty) = substs.types().next() {
889858
// TODO: enforce that the generic param is an image type?
890-
let image_type = trans_type_impl(cx, span, cx.layout_of(image_ty), false);
859+
let image_type = cx.layout_of(image_ty).spirv_type(span, cx);
891860
Ok(SpirvType::SampledImage { image_type }.def(span, cx))
892861
} else {
893862
cx.tcx
@@ -907,7 +876,7 @@ fn trans_intrinsic_type<'tcx>(
907876
// We use a generic to indicate the underlying element type.
908877
// The spirv type of it will be generated by querying the type of the first generic.
909878
if let Some(elem_ty) = substs.types().next() {
910-
let element = trans_type_impl(cx, span, cx.layout_of(elem_ty), false);
879+
let element = cx.layout_of(elem_ty).spirv_type(span, cx);
911880
Ok(SpirvType::RuntimeArray { element }.def(span, cx))
912881
} else {
913882
cx.tcx
@@ -922,7 +891,7 @@ fn trans_intrinsic_type<'tcx>(
922891
.expect("#[spirv(matrix)] must be added to a type which has DefId");
923892

924893
let field_types = (0..ty.fields.count())
925-
.map(|i| trans_type_impl(cx, span, ty.field(cx, i), false))
894+
.map(|i| ty.field(cx, i).spirv_type(span, cx))
926895
.collect::<Vec<_>>();
927896
if field_types.len() < 2 {
928897
cx.tcx

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -782,22 +782,24 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
782782
result
783783
}
784784

785+
// rustc has the concept of an immediate vs. memory type - bools are compiled to LLVM bools as
786+
// immediates, but if they're behind a pointer, they're compiled to u8. The reason for this is
787+
// because LLVM is bad at bools behind pointers (something something u1 bitmasking on load).
788+
//
789+
// SPIR-V allows bools behind *some* pointers, and disallows others - specifically, it allows
790+
// bools behind the storage classes Workgroup, CrossWorkgroup, Private, Function, Input, and
791+
// Output. In other words, "For stuff the CPU can't see, bools are OK. For stuff the CPU *can*
792+
// see, no bools allowed". So, we always compile rust bools to SPIR-V bools instead of u8 as
793+
// rustc does, even if they're behind a pointer, and error if bools are in an interface (the
794+
// user should choose u8, u32, or something else instead). That means that immediate types and
795+
// memory types are the same, and no conversion needs to happen here.
785796
fn from_immediate(&mut self, val: Self::Value) -> Self::Value {
786-
if self.lookup_type(val.ty) == SpirvType::Bool {
787-
let i8 = SpirvType::Integer(8, false).def(self.span(), self);
788-
self.zext(val, i8)
789-
} else {
790-
val
791-
}
797+
val
792798
}
793799

794800
// silly clippy, we can't rename this!
795801
#[allow(clippy::wrong_self_convention)]
796-
fn to_immediate_scalar(&mut self, val: Self::Value, scalar: Scalar) -> Self::Value {
797-
if scalar.is_bool() {
798-
let bool = SpirvType::Bool.def(self.span(), self);
799-
return self.trunc(val, bool);
800-
}
802+
fn to_immediate_scalar(&mut self, val: Self::Value, _scalar: Scalar) -> Self::Value {
801803
val
802804
}
803805

crates/rustc_codegen_spirv/src/builder_spirv.rs

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -407,25 +407,15 @@ impl BuilderSpirv {
407407
}
408408

409409
pub fn dump_module_str(&self) -> String {
410-
let mut module = self.builder.borrow().module_ref().clone();
411-
let mut header = rspirv::dr::ModuleHeader::new(0);
412-
header.set_version(0, 0);
413-
module.header = Some(header);
414-
module.disassemble()
410+
self.builder.borrow().module_ref().disassemble()
415411
}
416412

417413
/// Helper function useful to place right before a crash, to debug the module state.
418414
pub fn dump_module(&self, path: impl AsRef<Path>) {
419-
let mut module = self.builder.borrow().module_ref().clone();
420-
let mut header = rspirv::dr::ModuleHeader::new(0);
421-
header.set_version(0, 0);
422-
module.header = Some(header);
423-
let disas = module.disassemble();
424-
println!("{}", disas);
425-
let spirv_module = module.assemble();
415+
let module = self.builder.borrow().module_ref().assemble();
426416
File::create(path)
427417
.unwrap()
428-
.write_all(spirv_tools::binary::from_binary(&spirv_module))
418+
.write_all(spirv_tools::binary::from_binary(&module))
429419
.unwrap();
430420
}
431421

0 commit comments

Comments
 (0)