diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 3a8b4c1801..84d6f17f13 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -2033,6 +2033,31 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { if val.ty == dest_ty { val } else { + // If casting a constant, directly create a constant of the target type. + // This avoids creating intermediate types that might require additional + // capabilities. For example, casting a f16 constant to f32 will directly + // create a f32 constant, avoiding the need for Float16 capability if it is + // not used elsewhere. + if let Some(const_val) = self.builder.lookup_const_scalar(val) { + if let (SpirvType::Float(src_width), SpirvType::Float(dst_width)) = + (self.lookup_type(val.ty), self.lookup_type(dest_ty)) + { + if src_width < dst_width { + // Convert the bit representation to the actual float value + let float_val = match src_width { + 32 => Some(f32::from_bits(const_val as u32) as f64), + 64 => Some(f64::from_bits(const_val as u64)), + _ => None, + }; + + if let Some(val) = float_val { + return self.constant_float(dest_ty, val); + } + } + } + } + + // Regular conversion self.emit() .f_convert(dest_ty, None, val.def(self)) .unwrap() @@ -2198,6 +2223,46 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // I guess? return val; } + + // If casting a constant, directly create a constant of the target type. This + // avoids creating intermediate types that might require additional + // capabilities. For example, casting a u8 constant to u32 will directly create + // a u32 constant, avoiding the need for Int8 capability if it is not used + // elsewhere. + if let Some(const_val) = self.builder.lookup_const_scalar(val) { + let src_ty = self.lookup_type(val.ty); + let dst_ty_spv = self.lookup_type(dest_ty); + + // Try to optimize the constant cast + let optimized_result = match (src_ty, dst_ty_spv) { + // Integer to integer cast + (SpirvType::Integer(src_width, _), SpirvType::Integer(dst_width, _)) => { + // Only optimize if we're widening. This avoids creating the source + // type when it's safe to do so. For narrowing casts (e.g., u32 as + // u8), we need the proper truncation behavior that the regular cast + // provides. + if src_width < dst_width { + Some(self.constant_int(dest_ty, const_val)) + } else { + None + } + } + // Bool to integer cast - const_val will be 0 or 1 + (SpirvType::Bool, SpirvType::Integer(_, _)) => { + Some(self.constant_int(dest_ty, const_val)) + } + // Integer to bool cast - compare with zero + (SpirvType::Integer(_, _), SpirvType::Bool) => { + Some(self.constant_bool(self.span(), const_val != 0)) + } + _ => None, + }; + + if let Some(result) = optimized_result { + return result; + } + } + match (self.lookup_type(val.ty), self.lookup_type(dest_ty)) { // sign change ( @@ -3128,6 +3193,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { .and_then(|def_id| self.buffer_store_intrinsics.borrow().get(&def_id).copied()); let is_panic_entry_point = instance_def_id .is_some_and(|def_id| self.panic_entry_points.borrow().contains(&def_id)); + let from_trait_impl = + instance_def_id.and_then(|def_id| self.from_trait_impls.borrow().get(&def_id).copied()); if let Some(libm_intrinsic) = libm_intrinsic { let result = self.call_libm_intrinsic(libm_intrinsic, result_type, args); @@ -3139,8 +3206,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { self.debug_type(result.ty), ); } - result - } else if is_panic_entry_point { + return result; + } + + if is_panic_entry_point { // HACK(eddyb) Rust 2021 `panic!` always uses `format_args!`, even // in the simple case that used to pass a `&str` constant, which // would not remain reachable in the SPIR-V - but `format_args!` is @@ -3613,24 +3682,59 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // HACK(eddyb) redirect any possible panic call to an abort, to avoid // needing to materialize `&core::panic::Location` or `format_args!`. self.abort_with_kind_and_message_debug_printf("panic", message, debug_printf_args); - self.undef(result_type) - } else if let Some(mode) = buffer_load_intrinsic { - self.codegen_buffer_load_intrinsic(result_type, args, mode) - } else if let Some(mode) = buffer_store_intrinsic { + return self.undef(result_type); + } + + if let Some(mode) = buffer_load_intrinsic { + return self.codegen_buffer_load_intrinsic(result_type, args, mode); + } + + if let Some(mode) = buffer_store_intrinsic { self.codegen_buffer_store_intrinsic(args, mode); let void_ty = SpirvType::Void.def(rustc_span::DUMMY_SP, self); - SpirvValue { + return SpirvValue { kind: SpirvValueKind::IllegalTypeUsed(void_ty), ty: void_ty, + }; + } + + if let Some((source_ty, target_ty)) = from_trait_impl { + // Optimize From::from calls with constant arguments to avoid creating intermediate types. + // Since From is only implemented for safe conversions (widening conversions that preserve + // the numeric value), we can directly create a constant of the target type for primitive + // numeric types. + if let [arg] = args { + if let Some(const_val) = self.builder.lookup_const_scalar(*arg) { + use rustc_middle::ty::FloatTy; + let optimized_result = match (source_ty.kind(), target_ty.kind()) { + // Integer widening conversions + (ty::Uint(_), ty::Uint(_)) | (ty::Int(_), ty::Int(_)) => { + Some(self.constant_int(result_type, const_val)) + } + // Float widening conversions + // TODO(@LegNeato): Handle more float types + (ty::Float(FloatTy::F32), ty::Float(FloatTy::F64)) => { + let float_val = f32::from_bits(const_val as u32) as f64; + Some(self.constant_float(result_type, float_val)) + } + // No optimization for narrowing conversions or unsupported types + _ => None, + }; + + if let Some(result) = optimized_result { + return result; + } + } } - } else { - let args = args.iter().map(|arg| arg.def(self)).collect::>(); - self.emit() - .function_call(result_type, None, callee_val, args) - .unwrap() - .with_type(result_type) } + + // Default: emit a regular function call + let args = args.iter().map(|arg| arg.def(self)).collect::>(); + self.emit() + .function_call(result_type, None, callee_val, args) + .unwrap() + .with_type(result_type) } fn zext(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value { diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs index 93479d077f..9ce6f8fc78 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -172,6 +172,30 @@ impl<'tcx> CodegenCx<'tcx> { } } + // Check if this is a From trait implementation + if let Some(impl_def_id) = self.tcx.impl_of_method(def_id) { + if let Some(trait_ref) = self.tcx.impl_trait_ref(impl_def_id) { + let trait_def_id = trait_ref.skip_binder().def_id; + + // Check if this is the From trait. + let trait_path = self.tcx.def_path_str(trait_def_id); + if matches!( + trait_path.as_str(), + "core::convert::From" | "std::convert::From" + ) { + // Extract the source and target types from the trait substitutions + let trait_args = trait_ref.skip_binder().args; + if let (Some(target_ty), Some(source_ty)) = + (trait_args.types().nth(0), trait_args.types().nth(1)) + { + self.from_trait_impls + .borrow_mut() + .insert(def_id, (source_ty, target_ty)); + } + } + } + } + if [ self.tcx.lang_items().panic_fn(), self.tcx.lang_items().panic_fmt(), diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index f3a3828018..9f575fba72 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -84,6 +84,10 @@ pub struct CodegenCx<'tcx> { /// Intrinsic for storing a `` into a `&[u32]`. The `PassMode` is the mode of the ``. pub buffer_store_intrinsics: RefCell>, + /// Maps `DefId`s of `From::from` method implementations to their source and target types. + /// Used to optimize constant conversions like `u32::from(42u8)` to avoid creating the source type. + pub from_trait_impls: RefCell, Ty<'tcx>)>>, + /// Some runtimes (e.g. intel-compute-runtime) disallow atomics on i8 and i16, even though it's allowed by the spec. /// This enables/disables them. pub i8_i16_atomics_allowed: bool, @@ -203,6 +207,7 @@ impl<'tcx> CodegenCx<'tcx> { fmt_rt_arg_new_fn_ids_to_ty_and_spec: Default::default(), buffer_load_intrinsics: Default::default(), buffer_store_intrinsics: Default::default(), + from_trait_impls: Default::default(), i8_i16_atomics_allowed: false, codegen_args, } diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 6b5d7919b1..b4e91e1dcb 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -477,6 +477,16 @@ pub fn link( simple_passes::remove_non_uniform_decorations(sess, &mut output)?; } + { + let _timer = sess.timer("link_remove_unused_type_capabilities"); + simple_passes::remove_unused_type_capabilities(&mut output); + } + + { + let _timer = sess.timer("link_type_capability_check"); + simple_passes::check_type_capabilities(sess, &output)?; + } + // NOTE(eddyb) SPIR-T pipeline is entirely limited to this block. { let (spv_words, module_or_err, lower_from_spv_timer) = diff --git a/crates/rustc_codegen_spirv/src/linker/simple_passes.rs b/crates/rustc_codegen_spirv/src/linker/simple_passes.rs index 9dfeadad35..68e820ae2a 100644 --- a/crates/rustc_codegen_spirv/src/linker/simple_passes.rs +++ b/crates/rustc_codegen_spirv/src/linker/simple_passes.rs @@ -7,6 +7,25 @@ use rustc_session::Session; use std::iter::once; use std::mem::take; +/// Returns the capability required for an integer type of the given width, if any. +fn capability_for_int_width(width: u32) -> Option { + match width { + 8 => Some(rspirv::spirv::Capability::Int8), + 16 => Some(rspirv::spirv::Capability::Int16), + 64 => Some(rspirv::spirv::Capability::Int64), + _ => None, + } +} + +/// Returns the capability required for a float type of the given width, if any. +fn capability_for_float_width(width: u32) -> Option { + match width { + 16 => Some(rspirv::spirv::Capability::Float16), + 64 => Some(rspirv::spirv::Capability::Float64), + _ => None, + } +} + pub fn shift_ids(module: &mut Module, add: u32) { module.all_inst_iter_mut().for_each(|inst| { if let Some(ref mut result_id) = &mut inst.result_id { @@ -266,6 +285,111 @@ pub fn check_fragment_insts(sess: &Session, module: &Module) -> Result<()> { } } +/// Check that types requiring specific capabilities have those capabilities declared. +/// +/// This function validates that if a module uses types like u8/i8 (requiring Int8), +/// u16/i16 (requiring Int16), etc., the corresponding capabilities are declared. +pub fn check_type_capabilities(sess: &Session, module: &Module) -> Result<()> { + use rspirv::spirv::Capability; + + // Collect declared capabilities + let declared_capabilities: FxHashSet = module + .capabilities + .iter() + .map(|inst| inst.operands[0].unwrap_capability()) + .collect(); + + let mut errors = Vec::new(); + + for inst in &module.types_global_values { + match inst.class.opcode { + Op::TypeInt => { + let width = inst.operands[0].unwrap_literal_bit32(); + let signedness = inst.operands[1].unwrap_literal_bit32() != 0; + let type_name = if signedness { "i" } else { "u" }; + + if let Some(required_cap) = capability_for_int_width(width) { + if !declared_capabilities.contains(&required_cap) { + errors.push(format!( + "`{type_name}{width}` type used without `OpCapability {required_cap:?}`" + )); + } + } + } + Op::TypeFloat => { + let width = inst.operands[0].unwrap_literal_bit32(); + + if let Some(required_cap) = capability_for_float_width(width) { + if !declared_capabilities.contains(&required_cap) { + errors.push(format!( + "`f{width}` type used without `OpCapability {required_cap:?}`" + )); + } + } + } + _ => {} + } + } + + if !errors.is_empty() { + let mut err = sess + .dcx() + .struct_err("Missing required capabilities for types"); + for error in errors { + err = err.with_note(error); + } + Err(err.emit()) + } else { + Ok(()) + } +} + +/// Remove type-related capabilities that are not required by any types in the module. +/// +/// This function specifically targets Int8, Int16, Int64, Float16, and Float64 capabilities, +/// removing them if no types in the module require them. All other capabilities are preserved. +/// This is part of the fix for issue #300 where constant casts were creating unnecessary types. +pub fn remove_unused_type_capabilities(module: &mut Module) { + use rspirv::spirv::Capability; + + // Collect type-related capabilities that are actually needed + let mut needed_type_capabilities = FxHashSet::default(); + + // Scan all types to determine which type-related capabilities are needed + for inst in &module.types_global_values { + match inst.class.opcode { + Op::TypeInt => { + let width = inst.operands[0].unwrap_literal_bit32(); + if let Some(cap) = capability_for_int_width(width) { + needed_type_capabilities.insert(cap); + } + } + Op::TypeFloat => { + let width = inst.operands[0].unwrap_literal_bit32(); + if let Some(cap) = capability_for_float_width(width) { + needed_type_capabilities.insert(cap); + } + } + _ => {} + } + } + + // Remove only type-related capabilities that aren't needed + module.capabilities.retain(|inst| { + let cap = inst.operands[0].unwrap_capability(); + match cap { + // Only remove these type-related capabilities if they're not used + Capability::Int8 + | Capability::Int16 + | Capability::Int64 + | Capability::Float16 + | Capability::Float64 => needed_type_capabilities.contains(&cap), + // Keep all other capabilities + _ => true, + } + }); +} + /// Remove all [`Decoration::NonUniform`] if this module does *not* have [`Capability::ShaderNonUniform`]. /// This allows image asm to always declare `NonUniform` and not worry about conditional compilation. pub fn remove_non_uniform_decorations(_sess: &Session, module: &mut Module) -> Result<()> { diff --git a/crates/rustc_codegen_spirv/src/spirv_type.rs b/crates/rustc_codegen_spirv/src/spirv_type.rs index d674f2542f..69da8010e7 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type.rs @@ -3,7 +3,7 @@ use crate::builder_spirv::SpirvValue; use crate::codegen_cx::CodegenCx; use indexmap::IndexSet; use rspirv::dr::Operand; -use rspirv::spirv::{Capability, Decoration, Dim, ImageFormat, StorageClass, Word}; +use rspirv::spirv::{Decoration, Dim, ImageFormat, StorageClass, Word}; use rustc_data_structures::fx::FxHashMap; use rustc_middle::span_bug; use rustc_span::def_id::DefId; @@ -105,21 +105,6 @@ impl SpirvType<'_> { let result = cx.emit_global().type_int_id(id, width, signedness as u32); let u_or_i = if signedness { "i" } else { "u" }; match width { - 8 if !cx.builder.has_capability(Capability::Int8) => cx.zombie_with_span( - result, - def_span, - &format!("`{u_or_i}8` without `OpCapability Int8`"), - ), - 16 if !cx.builder.has_capability(Capability::Int16) => cx.zombie_with_span( - result, - def_span, - &format!("`{u_or_i}16` without `OpCapability Int16`"), - ), - 64 if !cx.builder.has_capability(Capability::Int64) => cx.zombie_with_span( - result, - def_span, - &format!("`{u_or_i}64` without `OpCapability Int64`"), - ), 8 | 16 | 32 | 64 => {} w => cx.zombie_with_span( result, @@ -132,16 +117,6 @@ impl SpirvType<'_> { Self::Float(width) => { let result = cx.emit_global().type_float_id(id, width); match width { - 16 if !cx.builder.has_capability(Capability::Float16) => cx.zombie_with_span( - result, - def_span, - "`f16` without `OpCapability Float16`", - ), - 64 if !cx.builder.has_capability(Capability::Float64) => cx.zombie_with_span( - result, - def_span, - "`f64` without `OpCapability Float64`", - ), 16 | 32 | 64 => (), other => cx.zombie_with_span( result, diff --git a/tests/compiletests/ui/dis/asm_op_decorate.stderr b/tests/compiletests/ui/dis/asm_op_decorate.stderr index 7364346cab..79a494bbb1 100644 --- a/tests/compiletests/ui/dis/asm_op_decorate.stderr +++ b/tests/compiletests/ui/dis/asm_op_decorate.stderr @@ -1,8 +1,4 @@ OpCapability Shader -OpCapability Float64 -OpCapability Int64 -OpCapability Int16 -OpCapability Int8 OpCapability ShaderClockKHR OpCapability RuntimeDescriptorArray OpExtension "SPV_EXT_descriptor_indexing" diff --git a/tests/compiletests/ui/dis/const-float-cast-optimized.rs b/tests/compiletests/ui/dis/const-float-cast-optimized.rs new file mode 100644 index 0000000000..757d7d8eec --- /dev/null +++ b/tests/compiletests/ui/dis/const-float-cast-optimized.rs @@ -0,0 +1,19 @@ +// Test that constant float widening casts are optimized to avoid creating +// the smaller float type when not needed elsewhere. + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "OpExtension .SPV_KHR_vulkan_memory_model.\n" -> "" +// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple" + +use spirv_std::spirv; + +#[spirv(fragment)] +pub fn main(output: &mut f64) { + // This should optimize away the f32 type since it's widening + const SMALL: f32 = 20.5; + let widened = SMALL as f64; + *output = widened; +} diff --git a/tests/compiletests/ui/dis/const-float-cast-optimized.stderr b/tests/compiletests/ui/dis/const-float-cast-optimized.stderr new file mode 100644 index 0000000000..cc72d04ed9 --- /dev/null +++ b/tests/compiletests/ui/dis/const-float-cast-optimized.stderr @@ -0,0 +1,16 @@ +OpCapability Shader +OpCapability Float64 +OpCapability ShaderClockKHR +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Simple +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +%3 = OpString "$OPSTRING_FILENAME/const-float-cast-optimized.rs" +OpName %2 "output" +OpDecorate %2 Location 0 +%4 = OpTypeFloat 64 +%5 = OpTypePointer Output %4 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%2 = OpVariable %5 Output +%8 = OpConstant %4 4626463454704697344 diff --git a/tests/compiletests/ui/dis/const-float-cast.rs b/tests/compiletests/ui/dis/const-float-cast.rs new file mode 100644 index 0000000000..d801f4cad5 --- /dev/null +++ b/tests/compiletests/ui/dis/const-float-cast.rs @@ -0,0 +1,31 @@ +// Test whether float constant casts need optimization + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "OpExtension .SPV_KHR_vulkan_memory_model.\n" -> "" +// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple" + +use spirv_std::spirv; + +#[spirv(fragment)] +pub fn main(output: &mut f32) { + // Test f64 to f32 (narrowing) + const BIG: f64 = 123.456; + let narrowed = BIG as f32; + *output = narrowed; + + // Test f32 to f64 (widening) - this might create f32 type unnecessarily + const SMALL: f32 = 20.5; + let widened = SMALL as f64; + *output += widened as f32; + + let kept: f32 = 1.0 + SMALL; + *output += kept; + + // Test integer to float + const INT: u32 = 42; + let as_float = INT as f32; + *output += as_float; +} diff --git a/tests/compiletests/ui/dis/const-float-cast.stderr b/tests/compiletests/ui/dis/const-float-cast.stderr new file mode 100644 index 0000000000..53d922f512 --- /dev/null +++ b/tests/compiletests/ui/dis/const-float-cast.stderr @@ -0,0 +1,22 @@ +OpCapability Shader +OpCapability Float64 +OpCapability ShaderClockKHR +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Simple +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +%3 = OpString "$OPSTRING_FILENAME/const-float-cast.rs" +OpName %2 "output" +OpDecorate %2 Location 0 +%4 = OpTypeFloat 32 +%5 = OpTypePointer Output %4 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeFloat 64 +%9 = OpConstant %8 4638387860618067575 +%2 = OpVariable %5 Output +%10 = OpConstant %8 4626463454704697344 +%11 = OpConstant %4 1065353216 +%12 = OpConstant %4 1101266944 +%13 = OpTypeInt 32 0 +%14 = OpConstant %13 42 diff --git a/tests/compiletests/ui/dis/const-from-cast.rs b/tests/compiletests/ui/dis/const-from-cast.rs new file mode 100644 index 0000000000..207f9e6631 --- /dev/null +++ b/tests/compiletests/ui/dis/const-from-cast.rs @@ -0,0 +1,22 @@ +// Test that constant integer from casts are optimized to avoid creating intermediate +// types that would require additional capabilities (e.g., Int8 capability for u8). + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "OpExtension .SPV_KHR_vulkan_memory_model.\n" -> "" +// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple" + +use spirv_std::spirv; + +const K: u8 = 42; + +#[spirv(fragment)] +pub fn main(output: &mut u32) { + let position = 2u32; + // This cast should be optimized to directly create a u32 constant with value 42, + // avoiding the creation of a u8 type that would require Int8 capability + let global_y_offset_bits = u32::from(K); + *output = global_y_offset_bits; +} diff --git a/tests/compiletests/ui/dis/const-from-cast.stderr b/tests/compiletests/ui/dis/const-from-cast.stderr new file mode 100644 index 0000000000..648e00c51a --- /dev/null +++ b/tests/compiletests/ui/dis/const-from-cast.stderr @@ -0,0 +1,15 @@ +OpCapability Shader +OpCapability ShaderClockKHR +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Simple +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +%3 = OpString "$OPSTRING_FILENAME/const-from-cast.rs" +OpName %2 "output" +OpDecorate %2 Location 0 +%4 = OpTypeInt 32 0 +%5 = OpTypePointer Output %4 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%2 = OpVariable %5 Output +%8 = OpConstant %4 42 diff --git a/tests/compiletests/ui/dis/const-int-cast.rs b/tests/compiletests/ui/dis/const-int-cast.rs new file mode 100644 index 0000000000..db65713362 --- /dev/null +++ b/tests/compiletests/ui/dis/const-int-cast.rs @@ -0,0 +1,22 @@ +// Test that constant integer casts are optimized to avoid creating intermediate types +// that would require additional capabilities (e.g., Int8 capability for u8). + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "OpExtension .SPV_KHR_vulkan_memory_model.\n" -> "" +// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple" + +use spirv_std::spirv; + +const K: u8 = 20; + +#[spirv(fragment)] +pub fn main(output: &mut u32) { + let position = 2u32; + // This cast should be optimized to directly create a u32 constant with value 20, + // avoiding the creation of a u8 type that would require Int8 capability + let global_y_offset_bits = position * K as u32; + *output = global_y_offset_bits; +} diff --git a/tests/compiletests/ui/dis/const-int-cast.stderr b/tests/compiletests/ui/dis/const-int-cast.stderr new file mode 100644 index 0000000000..c860b65583 --- /dev/null +++ b/tests/compiletests/ui/dis/const-int-cast.stderr @@ -0,0 +1,15 @@ +OpCapability Shader +OpCapability ShaderClockKHR +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Simple +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +%3 = OpString "$OPSTRING_FILENAME/const-int-cast.rs" +OpName %2 "output" +OpDecorate %2 Location 0 +%4 = OpTypeInt 32 0 +%5 = OpTypePointer Output %4 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%2 = OpVariable %5 Output +%8 = OpConstant %4 40 diff --git a/tests/compiletests/ui/dis/const-narrowing-cast.rs b/tests/compiletests/ui/dis/const-narrowing-cast.rs new file mode 100644 index 0000000000..a5d6ce9d7a --- /dev/null +++ b/tests/compiletests/ui/dis/const-narrowing-cast.rs @@ -0,0 +1,24 @@ +// Test that constant narrowing casts (e.g., u32 to u8) still work correctly +// and produce the expected truncation behavior. + +// build-pass +// compile-flags: -C llvm-args=--disassemble-globals +// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "OpExtension .SPV_KHR_vulkan_memory_model.\n" -> "" +// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple" + +use spirv_std::spirv; + +#[spirv(fragment)] +pub fn main(output: &mut u32) { + // This should create a u32 type and do proper truncation + const BIG: u32 = 300; // 0x12C + let truncated = BIG as u8; // Should be 0x2C = 44 + *output = truncated as u32; + + // This should optimize away the u8 type since it's widening + const SMALL: u8 = 20; + let widened = SMALL as u32; + *output += widened; +} diff --git a/tests/compiletests/ui/dis/const-narrowing-cast.stderr b/tests/compiletests/ui/dis/const-narrowing-cast.stderr new file mode 100644 index 0000000000..60a1ffb399 --- /dev/null +++ b/tests/compiletests/ui/dis/const-narrowing-cast.stderr @@ -0,0 +1,18 @@ +OpCapability Shader +OpCapability Int8 +OpCapability ShaderClockKHR +OpExtension "SPV_KHR_shader_clock" +OpMemoryModel Logical Simple +OpEntryPoint Fragment %1 "main" %2 +OpExecutionMode %1 OriginUpperLeft +%3 = OpString "$OPSTRING_FILENAME/const-narrowing-cast.rs" +OpName %2 "output" +OpDecorate %2 Location 0 +%4 = OpTypeInt 32 0 +%5 = OpTypePointer Output %4 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 8 0 +%9 = OpConstant %4 300 +%2 = OpVariable %5 Output +%10 = OpConstant %4 20 diff --git a/tests/compiletests/ui/dis/custom_entry_point.stderr b/tests/compiletests/ui/dis/custom_entry_point.stderr index e243af37f7..d00d71e8bb 100644 --- a/tests/compiletests/ui/dis/custom_entry_point.stderr +++ b/tests/compiletests/ui/dis/custom_entry_point.stderr @@ -1,8 +1,4 @@ OpCapability Shader -OpCapability Float64 -OpCapability Int64 -OpCapability Int16 -OpCapability Int8 OpCapability ShaderClockKHR OpExtension "SPV_KHR_shader_clock" OpMemoryModel Logical Simple diff --git a/tests/compiletests/ui/dis/generic-fn-op-name.stderr b/tests/compiletests/ui/dis/generic-fn-op-name.stderr index 09d24a139e..048f1e2f34 100644 --- a/tests/compiletests/ui/dis/generic-fn-op-name.stderr +++ b/tests/compiletests/ui/dis/generic-fn-op-name.stderr @@ -1,8 +1,4 @@ OpCapability Shader -OpCapability Float64 -OpCapability Int64 -OpCapability Int16 -OpCapability Int8 OpCapability ShaderClockKHR OpExtension "SPV_KHR_shader_clock" OpMemoryModel Logical Simple diff --git a/tests/compiletests/ui/dis/issue-723-output.stderr b/tests/compiletests/ui/dis/issue-723-output.stderr index e6c4b35256..9adcd37970 100644 --- a/tests/compiletests/ui/dis/issue-723-output.stderr +++ b/tests/compiletests/ui/dis/issue-723-output.stderr @@ -1,8 +1,4 @@ OpCapability Shader -OpCapability Float64 -OpCapability Int64 -OpCapability Int16 -OpCapability Int8 OpCapability ShaderClockKHR OpExtension "SPV_KHR_shader_clock" OpMemoryModel Logical Simple diff --git a/tests/compiletests/ui/dis/non-writable-storage_buffer.stderr b/tests/compiletests/ui/dis/non-writable-storage_buffer.stderr index f36836e850..14349d6707 100644 --- a/tests/compiletests/ui/dis/non-writable-storage_buffer.stderr +++ b/tests/compiletests/ui/dis/non-writable-storage_buffer.stderr @@ -1,8 +1,4 @@ OpCapability Shader -OpCapability Float64 -OpCapability Int64 -OpCapability Int16 -OpCapability Int8 OpCapability ShaderClockKHR OpExtension "SPV_KHR_shader_clock" OpMemoryModel Logical Simple diff --git a/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr b/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr index e1c8f3b28c..342fe4755f 100644 --- a/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr +++ b/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr @@ -1,8 +1,4 @@ OpCapability Shader -OpCapability Float64 -OpCapability Int64 -OpCapability Int16 -OpCapability Int8 OpCapability ShaderClockKHR OpExtension "SPV_KHR_non_semantic_info" OpExtension "SPV_KHR_shader_clock" diff --git a/tests/compiletests/ui/dis/panic_sequential_many.stderr b/tests/compiletests/ui/dis/panic_sequential_many.stderr index d748d2fd49..3e0546a139 100644 --- a/tests/compiletests/ui/dis/panic_sequential_many.stderr +++ b/tests/compiletests/ui/dis/panic_sequential_many.stderr @@ -1,8 +1,4 @@ OpCapability Shader -OpCapability Float64 -OpCapability Int64 -OpCapability Int16 -OpCapability Int8 OpCapability ShaderClockKHR OpExtension "SPV_KHR_non_semantic_info" OpExtension "SPV_KHR_shader_clock" diff --git a/tests/compiletests/ui/dis/spec_constant-attr.stderr b/tests/compiletests/ui/dis/spec_constant-attr.stderr index 5cdc1f968d..4b26a5c08b 100644 --- a/tests/compiletests/ui/dis/spec_constant-attr.stderr +++ b/tests/compiletests/ui/dis/spec_constant-attr.stderr @@ -1,8 +1,4 @@ OpCapability Shader -OpCapability Float64 -OpCapability Int64 -OpCapability Int16 -OpCapability Int8 OpCapability ShaderClockKHR OpExtension "SPV_KHR_shader_clock" OpMemoryModel Logical Simple diff --git a/tests/compiletests/ui/lang/consts/u32-from-u64-fail.rs b/tests/compiletests/ui/lang/consts/u32-from-u64-fail.rs new file mode 100644 index 0000000000..b4e26d56f5 --- /dev/null +++ b/tests/compiletests/ui/lang/consts/u32-from-u64-fail.rs @@ -0,0 +1,16 @@ +// Test that u32::from(u64) fails to compile since From is not implemented for u32 +// This ensures our From trait optimization doesn't accidentally allow invalid conversions + +// build-fail + +use spirv_std::spirv; + +const K: u64 = 42; + +#[spirv(fragment)] +pub fn main(output: &mut u32) { + // This should fail to compile because From is not implemented for u32 + // (u64 to u32 is a narrowing conversion that could lose data) + let value = u32::from(K); + *output = value; +} diff --git a/tests/compiletests/ui/lang/consts/u32-from-u64-fail.stderr b/tests/compiletests/ui/lang/consts/u32-from-u64-fail.stderr new file mode 100644 index 0000000000..4b9e8f23ea --- /dev/null +++ b/tests/compiletests/ui/lang/consts/u32-from-u64-fail.stderr @@ -0,0 +1,17 @@ +error[E0277]: the trait bound `u32: From` is not satisfied + --> $DIR/u32-from-u64-fail.rs:14:17 + | +14 | let value = u32::from(K); + | ^^^ the trait `From` is not implemented for `u32` + | + = help: the following other types implement trait `From`: + `u32` implements `From` + `u32` implements `From` + `u32` implements `From` + `u32` implements `From` + `u32` implements `From` + `u32` implements `From` + +error: aborting due to 1 previous error + +For more information about this error, try `rustc --explain E0277`. diff --git a/tests/compiletests/ui/lang/consts/u8-const-cast-no-capability.rs b/tests/compiletests/ui/lang/consts/u8-const-cast-no-capability.rs new file mode 100644 index 0000000000..49b8f49dd9 --- /dev/null +++ b/tests/compiletests/ui/lang/consts/u8-const-cast-no-capability.rs @@ -0,0 +1,14 @@ +// build-pass +// Test that u8 constants cast to u32 don't require Int8 capability when optimized away + +#![no_std] +use spirv_std::spirv; + +const K: u8 = 20; + +#[spirv(fragment)] +pub fn main(output: &mut u32) { + // This should not require Int8 capability as K is only used as u32 + // and the optimization should fold the constant cast + *output = K as u32; +} diff --git a/tests/compiletests/ui/lang/consts/u8-const-cast.rs b/tests/compiletests/ui/lang/consts/u8-const-cast.rs new file mode 100644 index 0000000000..3dbf2543a3 --- /dev/null +++ b/tests/compiletests/ui/lang/consts/u8-const-cast.rs @@ -0,0 +1,37 @@ +// build-pass +// Test that u8 constants cast to u32 don't require Int8 capability + +use spirv_std::spirv; + +const K: u8 = 20; + +#[spirv(fragment)] +pub fn main(output: &mut u32) { + let position = 2u32; + // This should not require Int8 capability as K is only used as u32 + let global_y_offset_bits = position * K as u32; + *output = global_y_offset_bits; +} + +#[spirv(fragment)] +pub fn test_various_const_casts(output: &mut [u32; 5]) { + // Test u8 -> u32 + const U8_VAL: u8 = 255; + output[0] = U8_VAL as u32; + + // Test i8 -> i32 + const I8_VAL: i8 = -128; + output[1] = I8_VAL as i32 as u32; + + // Test u16 -> u32 + const U16_VAL: u16 = 65535; + output[2] = U16_VAL as u32; + + // Test bool -> u32 + const BOOL_VAL: bool = true; + output[3] = BOOL_VAL as u32; + + // Test u32 -> bool -> u32 + const U32_VAL: u32 = 42; + output[4] = (U32_VAL != 0) as u32; +}