Skip to content

Commit 2f451ee

Browse files
authored
Add support for using target-features for extensions and capabilities (#610)
1 parent 1431c18 commit 2f451ee

27 files changed

+87
-45
lines changed

crates/rustc_codegen_spirv/src/builder_spirv.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::builder;
22
use crate::codegen_cx::CodegenCx;
33
use crate::spirv_type::SpirvType;
44
use crate::target::SpirvTarget;
5+
use crate::target_feature::TargetFeature;
56
use rspirv::dr::{Block, Builder, Module, Operand};
67
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, StorageClass, Word};
78
use rspirv::{binary::Assemble, binary::Disassemble};
@@ -303,13 +304,20 @@ pub struct BuilderSpirv {
303304
}
304305

305306
impl BuilderSpirv {
306-
pub fn new(target: &SpirvTarget) -> Self {
307+
pub fn new(target: &SpirvTarget, features: &[TargetFeature]) -> Self {
307308
let version = target.spirv_version();
308309
let memory_model = target.memory_model();
309310

310311
let mut builder = Builder::new();
311312
builder.set_version(version.0, version.1);
312313

314+
for feature in features {
315+
match feature {
316+
TargetFeature::Capability(cap) => builder.capability(*cap),
317+
TargetFeature::Extension(ext) => builder.extension(&*ext.as_str()),
318+
}
319+
}
320+
313321
if target.is_kernel() {
314322
builder.capability(Capability::Kernel);
315323
} else {

crates/rustc_codegen_spirv/src/codegen_cx/mod.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,24 @@ pub struct CodegenCx<'tcx> {
8080
impl<'tcx> CodegenCx<'tcx> {
8181
pub fn new(tcx: TyCtxt<'tcx>, codegen_unit: &'tcx CodegenUnit<'tcx>) -> Self {
8282
let sym = Symbols::get();
83-
for &feature in &tcx.sess.target_features {
84-
tcx.sess.err(&format!("Unknown feature {}", feature));
85-
}
83+
let features = tcx
84+
.sess
85+
.target_features
86+
.iter()
87+
.map(|s| s.as_str().parse())
88+
.collect::<Result<_, String>>()
89+
.unwrap_or_else(|error| {
90+
tcx.sess.err(&error);
91+
Vec::new()
92+
});
93+
8694
let codegen_args = CodegenArgs::from_session(tcx.sess);
8795
let target = tcx.sess.target.llvm_target.parse().unwrap();
8896

8997
Self {
9098
tcx,
9199
codegen_unit,
92-
builder: BuilderSpirv::new(&target),
100+
builder: BuilderSpirv::new(&target, &features),
93101
instances: Default::default(),
94102
function_parameter_values: Default::default(),
95103
type_cache: Default::default(),

crates/rustc_codegen_spirv/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ mod spirv_type;
117117
mod spirv_type_constraints;
118118
mod symbols;
119119
mod target;
120+
mod target_feature;
120121

121122
use builder::Builder;
122123
use codegen_cx::{CodegenArgs, CodegenCx, ModuleOutputType};
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use rustc_span::symbol::Symbol;
2+
3+
#[derive(Clone, Debug, Eq, PartialEq)]
4+
pub enum TargetFeature {
5+
Extension(Symbol),
6+
Capability(rspirv::spirv::Capability),
7+
}
8+
9+
impl std::str::FromStr for TargetFeature {
10+
type Err = String;
11+
12+
fn from_str(input: &str) -> Result<Self, Self::Err> {
13+
const EXT_PREFIX: &str = "ext:";
14+
15+
if let Some(input) = input.strip_prefix(EXT_PREFIX) {
16+
Ok(Self::Extension(Symbol::intern(input)))
17+
} else {
18+
Ok(Self::Capability(input.parse().map_err(|_err| {
19+
format!("Invalid Capability: `{}`", input)
20+
})?))
21+
}
22+
}
23+
}

tests/ui/arch/convert_u_to_acceleration_structure_khr.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayTracingKHR,+ext:SPV_KHR_ray_tracing
23

34
#[spirv(ray_generation)]
45
pub fn main(#[spirv(ray_payload)] payload: &mut glam::Vec3) {
56
unsafe {
6-
asm!(r#"OpExtension "SPV_KHR_ray_tracing""#);
7-
asm!("OpCapability RayTracingKHR");
8-
97
let handle = spirv_std::ray_tracing::AccelerationStructure::from_u64(0xffff_ffff);
108
let handle2 =
119
spirv_std::ray_tracing::AccelerationStructure::from_vec(glam::UVec2::new(0, 0));
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayTracingKHR,+ext:SPV_KHR_ray_tracing
23

34
#[spirv(any_hit)]
45
pub fn main() {
56
unsafe {
6-
asm!(r#"OpExtension "SPV_KHR_ray_tracing""#);
7-
asm!("OpCapability RayTracingKHR");
87
spirv_std::arch::ignore_intersection();
98
}
109
}

tests/ui/arch/ray_query_get_intersection_barycentrics_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
23

34
use glam::Vec3;
45
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
56

67
#[spirv(fragment)]
78
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
89
unsafe {
9-
asm!(r#"OpExtension "SPV_KHR_ray_query""#);
10-
asm!("OpCapability RayQueryKHR");
1110
spirv_std::ray_query!(let mut handle);
1211
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
1312
let barycentric_coords: glam::Vec2 = handle.get_intersection_barycentrics::<_, 5>();

tests/ui/arch/ray_query_get_intersection_front_face_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
23

34
use glam::Vec3;
45
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
56

67
#[spirv(fragment)]
78
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
89
unsafe {
9-
asm!(r#"OpExtension "SPV_KHR_ray_query""#);
10-
asm!("OpCapability RayQueryKHR");
1110
spirv_std::ray_query!(let mut handle);
1211
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
1312
assert!(handle.get_intersection_front_face::<5>());

tests/ui/arch/ray_query_get_intersection_geometry_index_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
23

34
use glam::Vec3;
45
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
56

67
#[spirv(fragment)]
78
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
89
unsafe {
9-
asm!(r#"OpExtension "SPV_KHR_ray_query""#);
10-
asm!("OpCapability RayQueryKHR");
1110
spirv_std::ray_query!(let mut handle);
1211
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
1312
let t = handle.get_intersection_geometry_index::<5>();

tests/ui/arch/ray_query_get_intersection_instance_custom_index_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
23

34
use glam::Vec3;
45
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
56

67
#[spirv(fragment)]
78
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
89
unsafe {
9-
asm!(r#"OpExtension "SPV_KHR_ray_query""#);
10-
asm!("OpCapability RayQueryKHR");
1110
spirv_std::ray_query!(let mut handle);
1211
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
1312
let index = handle.get_intersection_instance_custom_index::<5>();

tests/ui/arch/ray_query_get_intersection_instance_id_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
23

34
use glam::Vec3;
45
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
56

67
#[spirv(fragment)]
78
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
89
unsafe {
9-
asm!(r#"OpExtension "SPV_KHR_ray_query""#);
10-
asm!("OpCapability RayQueryKHR");
1110
spirv_std::ray_query!(let mut handle);
1211
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
1312
let id = handle.get_intersection_instance_id::<5>();

tests/ui/arch/ray_query_get_intersection_shader_binding_table_record_offset_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
23

34
use glam::Vec3;
45
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
56

67
#[spirv(fragment)]
78
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
89
unsafe {
9-
asm!(r#"OpExtension "SPV_KHR_ray_query""#);
10-
asm!("OpCapability RayQueryKHR");
1110
spirv_std::ray_query!(let mut handle);
1211
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
1312
let offset = handle.get_intersection_shader_binding_table_record_offset::<5>();

tests/ui/arch/ray_query_get_intersection_t_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
23

34
use glam::Vec3;
45
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
56

67
#[spirv(fragment)]
78
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
89
unsafe {
9-
asm!(r#"OpExtension "SPV_KHR_ray_query""#);
10-
asm!("OpCapability RayQueryKHR");
1110
spirv_std::ray_query!(let mut handle);
1211
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
1312
let t = handle.get_intersection_t::<5>();

tests/ui/arch/ray_query_get_intersection_type_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
23

34
use glam::Vec3;
45
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
56

67
#[spirv(fragment)]
78
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
89
unsafe {
9-
asm!(r#"OpExtension "SPV_KHR_ray_query""#);
10-
asm!("OpCapability RayQueryKHR");
1110
spirv_std::ray_query!(let mut handle);
1211
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
1312
handle.get_intersection_type::<5>();

tests/ui/arch/ray_query_get_ray_t_min_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
23

34
use glam::Vec3;
45
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
56

67
#[spirv(fragment)]
78
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
89
unsafe {
9-
asm!(r#"OpExtension "SPV_KHR_ray_query""#);
10-
asm!("OpCapability RayQueryKHR");
1110
spirv_std::ray_query!(let mut handle);
1211
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
1312
let tmin = handle.get_ray_t_min();

tests/ui/arch/ray_query_initialize_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
23

34
use glam::Vec3;
45
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
@@ -11,8 +12,6 @@ pub fn main(
1112
#[spirv(ray_payload)] payload: &mut Vec3,
1213
) {
1314
unsafe {
14-
asm!(r#"OpExtension "SPV_KHR_ray_query""#);
15-
asm!("OpCapability RayQueryKHR");
1615
spirv_std::ray_query!(let mut ray_query);
1716

1817
ray_query.initialize(

tests/ui/arch/ray_query_proceed_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
23

34
use glam::Vec3;
45
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
56

67
#[spirv(fragment)]
78
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
89
unsafe {
9-
asm!(r#"OpExtension "SPV_KHR_ray_query""#);
10-
asm!("OpCapability RayQueryKHR");
1110
spirv_std::ray_query!(let mut handle);
1211
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
1312
assert!(handle.proceed());

tests/ui/arch/ray_query_terminate_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayQueryKHR,+ext:SPV_KHR_ray_query
23

34
use glam::Vec3;
45
use spirv_std::ray_tracing::{AccelerationStructure, RayFlags, RayQuery};
56

67
#[spirv(fragment)]
78
pub fn main(#[spirv(descriptor_set = 0, binding = 0)] accel: &AccelerationStructure) {
89
unsafe {
9-
asm!(r#"OpExtension "SPV_KHR_ray_query""#);
10-
asm!("OpCapability RayQueryKHR");
1110
spirv_std::ray_query!(let mut handle);
1211
handle.initialize(accel, RayFlags::NONE, 0, Vec3::ZERO, 0.0, Vec3::ZERO, 0.0);
1312
handle.terminate();
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayTracingKHR,+ext:SPV_KHR_ray_tracing
23

34
#[spirv(intersection)]
45
pub fn main() {
56
unsafe {
6-
asm!(r#"OpExtension "SPV_KHR_ray_tracing""#);
7-
asm!("OpCapability RayTracingKHR");
87
spirv_std::arch::report_intersection(2.0, 4);
98
}
109
}

tests/ui/arch/terminate_ray_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayTracingKHR,+ext:SPV_KHR_ray_tracing
23

34
#[spirv(any_hit)]
45
pub fn main() {
56
unsafe {
6-
asm!(r#"OpExtension "SPV_KHR_ray_tracing""#);
7-
asm!("OpCapability RayTracingKHR");
87
spirv_std::arch::terminate_ray();
98
}
109
}

tests/ui/arch/trace_ray_khr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayTracingKHR,+ext:SPV_KHR_ray_tracing
23

34
#[spirv(ray_generation)]
45
// Rustfmt will eat long attributes (https://github.com/rust-lang/rustfmt/issues/4579)
@@ -9,8 +10,6 @@ pub fn main(
910
#[spirv(ray_payload)] payload: &mut glam::Vec3,
1011
) {
1112
unsafe {
12-
asm!(r#"OpExtension "SPV_KHR_ray_tracing""#);
13-
asm!("OpCapability RayTracingKHR");
1413
acceleration_structure.trace_ray(
1514
spirv_std::ray_tracing::RayFlags::NONE,
1615
0,

tests/ui/dis/asm_op_decorate.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// build-pass
2+
// compile-flags: -C target-feature=+RuntimeDescriptorArray,+ext:SPV_EXT_descriptor_indexing
23
// compile-flags: -C llvm-args=--disassemble-globals
34
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
45
// normalize-stderr-test "OpExtension .SPV_KHR_vulkan_memory_model.\n" -> ""
@@ -10,8 +11,6 @@ fn add_decorate() {
1011
unsafe {
1112
let offset = 1u32;
1213
asm!(
13-
"OpExtension \"SPV_EXT_descriptor_indexing\"",
14-
"OpCapability RuntimeDescriptorArray",
1514
"OpDecorate %image_2d_var DescriptorSet 0",
1615
"OpDecorate %image_2d_var Binding 0",
1716
"%uint = OpTypeInt 32 0",

tests/ui/dis/asm_op_decorate.stderr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
OpCapability Shader
21
OpCapability RuntimeDescriptorArray
2+
OpCapability Shader
33
OpExtension "SPV_EXT_descriptor_indexing"
44
OpMemoryModel Logical Simple
55
OpEntryPoint Fragment %1 "main"

tests/ui/dis/complex_image_sample_inst.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// build-pass
2+
// compile-flags: -Ctarget-feature=+RuntimeDescriptorArray,+ext:SPV_EXT_descriptor_indexing
23
// compile-flags: -C llvm-args=--disassemble-fn=complex_image_sample_inst::sample_proj_lod
34

45
use spirv_std as _;
@@ -14,8 +15,6 @@ fn sample_proj_lod(
1415
let mut result = glam::Vec4::default();
1516
let index = 0u32;
1617
asm!(
17-
"OpExtension \"SPV_EXT_descriptor_indexing\"",
18-
"OpCapability RuntimeDescriptorArray",
1918
"OpDecorate %image_2d_var DescriptorSet 0",
2019
"OpDecorate %image_2d_var Binding 0",
2120
"%uint = OpTypeInt 32 0",

tests/ui/dis/target_features.stderr

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
OpCapability RayTracingKHR
2+
OpCapability Shader
3+
OpExtension "SPV_KHR_ray_tracing"
4+
OpMemoryModel Logical Simple
5+
OpEntryPoint AnyHitNV %1 "main"
6+
OpName %2 "target_features::main"
7+
%3 = OpTypeVoid
8+
%4 = OpTypeFunction %3

tests/ui/target_features_err.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// build-fail
2+
// compile-flags: -Ctarget-feature=+rayTracingKHR,+ext:SPV_KHR_ray_tracing
3+
4+
use spirv_std as _;
5+
6+
#[spirv(any_hit)]
7+
pub fn main() {
8+
unsafe { spirv_std::arch::terminate_ray() }
9+
}
10+

tests/ui/target_features_err.stderr

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
error: Invalid Capability: `rayTracingKHR`
2+
3+
error: aborting due to previous error
4+

0 commit comments

Comments
 (0)