diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4e3422e62bd..bf8f00b4cfd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -262,9 +262,9 @@ jobs: set -e # build for WebGPU - cargo clippy --target ${{ matrix.target }} ${{ matrix.extra-flags }} --tests --features glsl,spirv,fragile-send-sync-non-atomic-wasm - cargo clippy --target ${{ matrix.target }} ${{ matrix.extra-flags }} --tests --features glsl,spirv - cargo doc --target ${{ matrix.target }} ${{ matrix.extra-flags }} --no-deps --features glsl,spirv + cargo clippy --target ${{ matrix.target }} ${{ matrix.extra-flags }} --tests --features glsl,fragile-send-sync-non-atomic-wasm + cargo clippy --target ${{ matrix.target }} ${{ matrix.extra-flags }} --tests --features glsl + cargo doc --target ${{ matrix.target }} ${{ matrix.extra-flags }} --no-deps --features glsl # check with only the web feature cargo clippy --target ${{ matrix.target }} ${{ matrix.extra-flags }} --no-default-features --features=web diff --git a/CHANGELOG.md b/CHANGELOG.md index 78871071b2d..d828662c7b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,6 +79,7 @@ By @Vecvec in [#7913](https://github.com/gfx-rs/wgpu/pull/7913). #### Naga - Naga now requires that no type be larger than 1 GB. This limit may be lowered in the future; feedback on an appropriate value for the limit is welcome. By @andyleiserson in [#7950](https://github.com/gfx-rs/wgpu/pull/7950). +- Added mesh shader support to naga with `WGSL` frontend and `SPIR-V` backend. By @SupaMaggie70Incorporated in [#7930](https://github.com/gfx-rs/wgpu/pull/7930). - If the shader source contains control characters, Naga now replaces them with U+FFFD ("replacement character") in diagnostic output. By @andyleiserson in [#8049](https://github.com/gfx-rs/wgpu/pull/8049). #### DX12 diff --git a/Cargo.lock b/Cargo.lock index 1f0c5b5ab22..a23867607aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2522,9 +2522,9 @@ dependencies = [ "spirv", "strum 0.27.2", "thiserror 2.0.14", - "toml 0.9.5", "unicode-ident", "walkdir", + "wgpu-test", ] [[package]] @@ -4940,6 +4940,7 @@ dependencies = [ "rayon", "tracy-client", "wgpu", + "wgpu-test", ] [[package]] @@ -5161,15 +5162,20 @@ dependencies = [ "js-sys", "libtest-mimic", "log", + "naga", "nanorand 0.8.0", "nv-flip", "parking_lot", "png", "pollster", "profiling", + "ron", + "rspirv", "serde", "serde_json", + "spirv", "strum 0.27.2", + "toml 0.9.5", "trybuild", "wasm-bindgen", "wasm-bindgen-futures", diff --git a/benches/Cargo.toml b/benches/Cargo.toml index 9af4cf4ae7d..53c992ce86a 100644 --- a/benches/Cargo.toml +++ b/benches/Cargo.toml @@ -47,3 +47,13 @@ profiling.workspace = true rayon.workspace = true tracy-client = { workspace = true, optional = true } wgpu.workspace = true +wgpu-test = { workspace = true, features = [ + "wgsl-in", + "spv-in", + "glsl-in", + "spv-out", + "msl-out", + "hlsl-out", + "glsl-out", + "wgsl-out", +] } diff --git a/benches/benches/wgpu-benchmark/shader.rs b/benches/benches/wgpu-benchmark/shader.rs index b98cef01ae5..3ccc5b01ced 100644 --- a/benches/benches/wgpu-benchmark/shader.rs +++ b/benches/benches/wgpu-benchmark/shader.rs @@ -1,57 +1,56 @@ use criterion::*; -use std::{fs, path::PathBuf, process::Command}; +use std::{fs, process::Command}; -struct Input { - filename: String, - size: u64, +const DIR_IN: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../naga/tests/in"); + +use wgpu_test::naga::*; + +struct InputWithInfo { + inner: Input, data: Vec, string: Option, + options: Parameters, module: Option, module_info: Option, } +impl From for InputWithInfo { + fn from(value: Input) -> Self { + let mut options = value.read_parameters(DIR_IN); + options.targets = Some(options.targets.unwrap_or(Targets::all())); + Self { + options, + inner: value, + data: Vec::new(), + string: None, + module: None, + module_info: None, + } + } +} +impl InputWithInfo { + fn filename(&self) -> &str { + self.inner.file_name.file_name().unwrap().to_str().unwrap() + } +} struct Inputs { - inner: Vec, + inner: Vec, } impl Inputs { #[track_caller] fn from_dir(folder: &str, extension: &str) -> Self { - let mut inputs = Vec::new(); - let read_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join(folder) - .read_dir() - .unwrap(); - - for file_entry in read_dir { - match file_entry { - Ok(entry) => match entry.path().extension() { - Some(ostr) if ostr == extension => { - let path = entry.path(); - - inputs.push(Input { - filename: path.to_string_lossy().into_owned(), - size: entry.metadata().unwrap().len(), - string: None, - data: vec![], - module: None, - module_info: None, - }); - } - _ => continue, - }, - Err(e) => { - eprintln!("Skipping file: {e:?}"); - continue; - } - } - } + let inputs: Vec = Input::files_in_dir(folder, &[extension], DIR_IN) + .map(|a| a.into()) + .collect(); Self { inner: inputs } } - fn bytes(&self) -> u64 { - self.inner.iter().map(|input| input.size).sum() + self.inner + .iter() + .map(|input| input.inner.bytes(DIR_IN)) + .sum() } fn load(&mut self) { @@ -60,7 +59,7 @@ impl Inputs { continue; } - input.data = fs::read(&input.filename).unwrap_or_default(); + input.data = fs::read(input.inner.input_path(DIR_IN)).unwrap_or_default(); } } @@ -85,6 +84,8 @@ impl Inputs { continue; } + parser.set_options((&input.options.wgsl_in).into()); + input.module = Some(parser.parse(input.string.as_ref().unwrap()).unwrap()); } } @@ -122,22 +123,22 @@ fn parse_glsl(stage: naga::ShaderStage, inputs: &Inputs) { }; for input in &inputs.inner { parser - .parse(&options, input.string.as_deref().unwrap()) + .parse(&options, &input.inner.read_source(DIR_IN, false)) .unwrap(); } } fn get_wgsl_inputs() -> Inputs { - let mut inputs = Inputs::from_dir("../naga/tests/in/wgsl", "wgsl"); + let mut inputs: Vec = Input::files_in_dir("wgsl", &["wgsl"], DIR_IN) + .map(|a| a.into()) + .collect(); // remove "large-source" tests, they skew the results - inputs - .inner - .retain(|input| !input.filename.contains("large-source")); + inputs.retain(|input| !input.filename().contains("large-source")); assert!(!inputs.is_empty()); - inputs + Inputs { inner: inputs } } fn frontends(c: &mut Criterion) { @@ -178,19 +179,20 @@ fn frontends(c: &mut Criterion) { let mut frontend = naga::front::wgsl::Frontend::new(); b.iter(|| { for input in &inputs_wgsl.inner { + frontend.set_options((&input.options.wgsl_in).into()); frontend.parse(input.string.as_ref().unwrap()).unwrap(); } }); }); - let inputs_spirv = Inputs::from_dir("../naga/tests/in/spv", "spvasm"); + let inputs_spirv = Inputs::from_dir("spv", "spvasm"); assert!(!inputs_spirv.is_empty()); // Assemble all the SPIR-V assembly. let mut assembled_spirv = Vec::>::new(); 'spirv: for input in &inputs_spirv.inner { let output = match Command::new("spirv-as") - .arg(&input.filename) + .arg(input.inner.input_path(DIR_IN)) .arg("-o") .arg("-") .output() @@ -220,19 +222,32 @@ fn frontends(c: &mut Criterion) { let total_bytes = assembled_spirv.iter().map(|spv| spv.len() as u64).sum(); + assert!(assembled_spirv.len() == inputs_spirv.inner.len() || assembled_spirv.is_empty()); + group.throughput(Throughput::Bytes(total_bytes)); group.bench_function("shader: spv-in", |b| { b.iter(|| { - let options = naga::front::spv::Options::default(); - for input in &assembled_spirv { - let parser = naga::front::spv::Frontend::new(input.iter().cloned(), &options); + for (i, input) in assembled_spirv.iter().enumerate() { + let params = &inputs_spirv.inner[i].options; + let SpirvInParameters { + adjust_coordinate_space, + } = params.spv_in; + + let parser = naga::front::spv::Frontend::new( + input.iter().cloned(), + &naga::front::spv::Options { + adjust_coordinate_space, + strict_capabilities: true, + ..Default::default() + }, + ); parser.parse().unwrap(); } }); }); - let mut inputs_vertex = Inputs::from_dir("../naga/tests/in/glsl", "vert"); - let mut inputs_fragment = Inputs::from_dir("../naga/tests/in/glsl", "frag"); + let mut inputs_vertex = Inputs::from_dir("glsl", "vert"); + let mut inputs_fragment = Inputs::from_dir("glsl", "frag"); assert!(!inputs_vertex.is_empty()); assert!(!inputs_fragment.is_empty()); // let mut inputs_compute = Inputs::from_dir("../naga/tests/in/glsl", "comp"); @@ -312,14 +327,16 @@ fn backends(c: &mut Criterion) { group.bench_function("shader: wgsl-out", |b| { b.iter(|| { let mut string = String::new(); - let flags = naga::back::wgsl::WriterFlags::empty(); for input in &inputs.inner { - let mut writer = naga::back::wgsl::Writer::new(&mut string, flags); - let _ = writer.write( - input.module.as_ref().unwrap(), - input.module_info.as_ref().unwrap(), - ); - string.clear(); + if input.options.targets.unwrap().contains(Targets::WGSL) { + let mut writer = + naga::back::wgsl::Writer::new(&mut string, (&input.options.wgsl).into()); + let _ = writer.write( + input.module.as_ref().unwrap(), + input.module_info.as_ref().unwrap(), + ); + string.clear(); + } } }); }); @@ -327,21 +344,28 @@ fn backends(c: &mut Criterion) { group.bench_function("shader: spv-out", |b| { b.iter(|| { let mut data = Vec::new(); - let options = naga::back::spv::Options::default(); + let mut writer = naga::back::spv::Writer::new(&Default::default()).unwrap(); for input in &inputs.inner { - if input.filename.contains("pointer-function-arg") { - // These fail due to https://github.com/gfx-rs/wgpu/issues/7315 - continue; + if input.options.targets.unwrap().contains(Targets::SPIRV) { + if input.filename().contains("pointer-function-arg") { + // These fail due to https://github.com/gfx-rs/wgpu/issues/7315 + continue; + } + let opt = input + .options + .spv + .to_options(input.options.bounds_check_policies, None); + if writer.set_options(&opt).is_ok() { + let _ = writer.write( + input.module.as_ref().unwrap(), + input.module_info.as_ref().unwrap(), + None, + &None, + &mut data, + ); + data.clear(); + } } - let mut writer = naga::back::spv::Writer::new(&options).unwrap(); - let _ = writer.write( - input.module.as_ref().unwrap(), - input.module_info.as_ref().unwrap(), - None, - &None, - &mut data, - ); - data.clear(); } }); }); @@ -350,25 +374,27 @@ fn backends(c: &mut Criterion) { let mut data = Vec::new(); let options = naga::back::spv::Options::default(); for input in &inputs.inner { - if input.filename.contains("pointer-function-arg") { - // These fail due to https://github.com/gfx-rs/wgpu/issues/7315 - continue; - } - let mut writer = naga::back::spv::Writer::new(&options).unwrap(); - let module = input.module.as_ref().unwrap(); - for ep in module.entry_points.iter() { - let pipeline_options = naga::back::spv::PipelineOptions { - shader_stage: ep.stage, - entry_point: ep.name.clone(), - }; - let _ = writer.write( - input.module.as_ref().unwrap(), - input.module_info.as_ref().unwrap(), - Some(&pipeline_options), - &None, - &mut data, - ); - data.clear(); + if input.options.targets.unwrap().contains(Targets::SPIRV) { + if input.filename().contains("pointer-function-arg") { + // These fail due to https://github.com/gfx-rs/wgpu/issues/7315 + continue; + } + let mut writer = naga::back::spv::Writer::new(&options).unwrap(); + let module = input.module.as_ref().unwrap(); + for ep in module.entry_points.iter() { + let pipeline_options = naga::back::spv::PipelineOptions { + shader_stage: ep.stage, + entry_point: ep.name.clone(), + }; + let _ = writer.write( + input.module.as_ref().unwrap(), + input.module_info.as_ref().unwrap(), + Some(&pipeline_options), + &None, + &mut data, + ); + data.clear(); + } } } }); @@ -379,15 +405,17 @@ fn backends(c: &mut Criterion) { let mut string = String::new(); let options = naga::back::msl::Options::default(); for input in &inputs.inner { - let pipeline_options = naga::back::msl::PipelineOptions::default(); - let mut writer = naga::back::msl::Writer::new(&mut string); - let _ = writer.write( - input.module.as_ref().unwrap(), - input.module_info.as_ref().unwrap(), - &options, - &pipeline_options, - ); - string.clear(); + if input.options.targets.unwrap().contains(Targets::METAL) { + let pipeline_options = naga::back::msl::PipelineOptions::default(); + let mut writer = naga::back::msl::Writer::new(&mut string); + let _ = writer.write( + input.module.as_ref().unwrap(), + input.module_info.as_ref().unwrap(), + &options, + &pipeline_options, + ); + string.clear(); + } } }); }); @@ -397,15 +425,17 @@ fn backends(c: &mut Criterion) { let options = naga::back::hlsl::Options::default(); let mut string = String::new(); for input in &inputs.inner { - let pipeline_options = Default::default(); - let mut writer = - naga::back::hlsl::Writer::new(&mut string, &options, &pipeline_options); - let _ = writer.write( - input.module.as_ref().unwrap(), - input.module_info.as_ref().unwrap(), - None, - ); // may fail on unimplemented things - string.clear(); + if input.options.targets.unwrap().contains(Targets::HLSL) { + let pipeline_options = Default::default(); + let mut writer = + naga::back::hlsl::Writer::new(&mut string, &options, &pipeline_options); + let _ = writer.write( + input.module.as_ref().unwrap(), + input.module_info.as_ref().unwrap(), + None, + ); // may fail on unimplemented things + string.clear(); + } } }); }); @@ -420,28 +450,30 @@ fn backends(c: &mut Criterion) { zero_initialize_workgroup_memory: true, }; for input in &inputs.inner { - let module = input.module.as_ref().unwrap(); - let info = input.module_info.as_ref().unwrap(); - for ep in module.entry_points.iter() { - let pipeline_options = naga::back::glsl::PipelineOptions { - shader_stage: ep.stage, - entry_point: ep.name.clone(), - multiview: None, - }; - - // might be `Err` if missing features - if let Ok(mut writer) = naga::back::glsl::Writer::new( - &mut string, - module, - info, - &options, - &pipeline_options, - naga::proc::BoundsCheckPolicies::default(), - ) { - let _ = writer.write(); // might be `Err` if unsupported + if input.options.targets.unwrap().contains(Targets::GLSL) { + let module = input.module.as_ref().unwrap(); + let info = input.module_info.as_ref().unwrap(); + for ep in module.entry_points.iter() { + let pipeline_options = naga::back::glsl::PipelineOptions { + shader_stage: ep.stage, + entry_point: ep.name.clone(), + multiview: None, + }; + + // might be `Err` if missing features + if let Ok(mut writer) = naga::back::glsl::Writer::new( + &mut string, + module, + info, + &options, + &pipeline_options, + naga::proc::BoundsCheckPolicies::default(), + ) { + let _ = writer.write(); // might be `Err` if unsupported + } + + string.clear(); } - - string.clear(); } } }); diff --git a/docs/api-specs/mesh_shading.md b/docs/api-specs/mesh_shading.md index 8c979890b78..ee14f99e757 100644 --- a/docs/api-specs/mesh_shading.md +++ b/docs/api-specs/mesh_shading.md @@ -80,32 +80,36 @@ This shader stage can be selected by marking a function with `@task`. Task shade The output of this determines how many workgroups of mesh shaders will be dispatched. Once dispatched, global id variables will be local to the task shader workgroup dispatch, and mesh shaders won't know the position of their dispatch among all mesh shader dispatches unless this is passed through the payload. The output may be zero to skip dispatching any mesh shader workgroups for the task shader workgroup. -If task shaders are marked with `@payload(someVar)`, where `someVar` is global variable declared like `var someVar: `, task shaders may write to `someVar`. This payload is passed to the mesh shader workgroup that is invoked. The mesh shader can skip declaring `@payload` to ignore this input. +If task shaders are marked with `@payload(someVar)`, where `someVar` is global variable declared like `var someVar: `, task shaders may use `someVar` as if it is a read-write workgroup storage variable. This payload is passed to the mesh shader workgroup that is invoked. The mesh shader can skip declaring `@payload` to ignore this input. ### Mesh shader This shader stage can be selected by marking a function with `@mesh`. Mesh shaders must not return anything. -Mesh shaders can be marked with `@payload(someVar)` similar to task shaders. Unlike task shaders, mesh shaders cannot write to this workgroup memory. Declaring `@payload` in a pipeline with no task shader, in a pipeline with a task shader that doesn't declare `@payload`, or in a task shader with an `@payload` that is statically sized and smaller than the mesh shader payload is illegal. +Mesh shaders can be marked with `@payload(someVar)` similar to task shaders. Unlike task shaders, mesh shaders cannot write to this memory. Declaring `@payload` in a pipeline with no task shader, in a pipeline with a task shader that doesn't declare `@payload`, or in a task shader with an `@payload` that is statically sized and smaller than the mesh shader payload is illegal. -Mesh shaders must be marked with `@vertex_output(OutputType, numOutputs)`, where `numOutputs` is the maximum number of vertices to be output by a mesh shader, and `OutputType` is the data associated with vertices, similar to a standard vertex shader output. +Mesh shaders must be marked with `@vertex_output(OutputType, numOutputs)`, where `numOutputs` is the maximum number of vertices to be output by a mesh shader, and `OutputType` is the data associated with vertices, similar to a standard vertex shader output, and must be a struct. Mesh shaders must also be marked with `@primitive_output(OutputType, numOutputs)`, which is similar to `@vertex_output` except it describes the primitive outputs. ### Mesh shader outputs -Primitive outputs from mesh shaders have some additional builtins they can set. These include `@builtin(cull_primitive)`, which must be a boolean value. If this is set to true, then the primitive is skipped during rendering. +Vertex outputs from mesh shaders function identically to outputs of vertex shaders, and as such must have a field with `@builtin(position)`. + +Primitive outputs from mesh shaders have some additional builtins they can set. These include `@builtin(cull_primitive)`, which must be a boolean value. If this is set to true, then the primitive is skipped during rendering. All non-builtin primitive outputs must be decorated with `@per_primitive`. Mesh shader primitive outputs must also specify exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`. This determines the output topology of the mesh shader, and must match the output topology of the pipeline descriptor the mesh shader is used with. These must be of type `vec3`, `vec2`, and `u32` respectively. When setting this, each of the indices must be less than the number of vertices declared in `setMeshOutputs`. Additionally, the `@location` attributes from the vertex and primitive outputs can't overlap. -Before setting any vertices or indices, or exiting, the mesh shader must call `setMeshOutputs(numVertices: u32, numIndices: u32)`, which declares the number of vertices and indices that will be written to. These must be less than the corresponding maximums set in `@vertex_output` and `@primitive_output`. The mesh shader must then write to exactly these numbers of vertices and primitives. +Before setting any vertices or indices, or exiting, the mesh shader must call `setMeshOutputs(numVertices: u32, numIndices: u32)`, which declares the number of vertices and indices that will be written to. These must be less than the corresponding maximums set in `@vertex_output` and `@primitive_output`. The mesh shader must then write to exactly these numbers of vertices and primitives. A varying member with `@per_primitive` cannot be used in function interfaces except as the primitive output for mesh shaders or as input for fragment shaders. The mesh shader can write to vertices using the `setVertex(idx: u32, vertex: VertexOutput)` where `VertexOutput` is replaced with the vertex type declared in `@vertex_output`, and `idx` is the index of the vertex to write. Similarly, the mesh shader can write to vertices using `setPrimitive(idx: u32, primitive: PrimitiveOutput)`. These can be written to multiple times, however unsynchronized writes are undefined behavior. The primitives and indices are shared across the entire mesh shader workgroup. ### Fragment shader -Fragment shaders may now be passed the primitive info from a mesh shader the same was as they are passed vertex inputs, for example `fn fs_main(vertex: VertexOutput, primitive: PrimitiveOutput)`. The primitive state is part of the fragment input and must match the output of the mesh shader in the pipeline. +Fragment shaders can access vertex output data as if it is from a vertex shader. They can also access primitive output data, provided the input is decorated with `@per_primitive`. The `@per_primitive` attribute can be applied to a value directly, such as `@per_primitive @location(1) value: vec4`, to a struct such as `@per_primitive primitive_input: PrimitiveInput` where `PrimitiveInput` is a struct containing fields decorated with `@location` and `@builtin`, or to members of a struct that are themselves decorated with `@location` or `@builtin`. + +The primitive state is part of the fragment input and must match the output of the mesh shader in the pipeline. Using `@per_primitive` also requires enabling the mesh shader extension. Additionally, the locations of vertex and primitive input cannot overlap. ### Full example @@ -115,9 +119,9 @@ The following is a full example of WGSL shaders that could be used to create a m enable mesh_shading; const positions = array( - vec4(0.,-1.,0.,1.), - vec4(-1.,1.,0.,1.), - vec4(1.,1.,0.,1.) + vec4(0.,1.,0.,1.), + vec4(-1.,-1.,0.,1.), + vec4(1.,-1.,0.,1.) ); const colors = array( vec4(0.,1.,0.,1.), @@ -128,7 +132,7 @@ struct TaskPayload { colorMask: vec4, visible: bool, } -var taskPayload: TaskPayload; +var taskPayload: TaskPayload; var workgroupData: f32; struct VertexOutput { @builtin(position) position: vec4, @@ -137,14 +141,12 @@ struct VertexOutput { struct PrimitiveOutput { @builtin(triangle_indices) index: vec3, @builtin(cull_primitive) cull: bool, - @location(1) colorMask: vec4, + @per_primitive @location(1) colorMask: vec4, } struct PrimitiveInput { - @location(1) colorMask: vec4, + @per_primitive @location(1) colorMask: vec4, } -fn test_function(input: u32) { -} @task @payload(taskPayload) @workgroup_size(1) @@ -163,8 +165,6 @@ fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocati workgroupData = 2.0; var v: VertexOutput; - test_function(1); - v.position = positions[0]; v.color = colors[0] * taskPayload.colorMask; setVertex(0, v); diff --git a/examples/features/src/mesh_shader/mod.rs b/examples/features/src/mesh_shader/mod.rs index 956722a661d..b16dd7e2460 100644 --- a/examples/features/src/mesh_shader/mod.rs +++ b/examples/features/src/mesh_shader/mod.rs @@ -1,37 +1,4 @@ -use std::{io::Write, process::Stdio}; - -// Same as in mesh shader tests -fn compile_glsl( - device: &wgpu::Device, - data: &[u8], - shader_stage: &'static str, -) -> wgpu::ShaderModule { - let cmd = std::process::Command::new("glslc") - .args([ - &format!("-fshader-stage={shader_stage}"), - "-", - "-o", - "-", - "--target-env=vulkan1.2", - "--target-spv=spv1.4", - ]) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .spawn() - .expect("Failed to call glslc"); - cmd.stdin.as_ref().unwrap().write_all(data).unwrap(); - println!("{shader_stage}"); - let output = cmd.wait_with_output().expect("Error waiting for glslc"); - assert!(output.status.success()); - unsafe { - device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough::SpirV( - wgpu::ShaderModuleDescriptorSpirV { - label: None, - source: wgpu::util::make_spirv_raw(&output.stdout), - }, - )) - } -} +use wgpu::include_wgsl; pub struct Example { pipeline: wgpu::RenderPipeline, @@ -48,27 +15,28 @@ impl crate::framework::Example for Example { bind_group_layouts: &[], push_constant_ranges: &[], }); - let (ts, ms, fs) = ( - compile_glsl(device, include_bytes!("shader.task"), "task"), - compile_glsl(device, include_bytes!("shader.mesh"), "mesh"), - compile_glsl(device, include_bytes!("shader.frag"), "frag"), - ); + let shader_module = unsafe { + device.create_shader_module_trusted( + include_wgsl!("shader.wgsl"), + wgpu::ShaderRuntimeChecks::unchecked(), + ) + }; let pipeline = device.create_mesh_pipeline(&wgpu::MeshPipelineDescriptor { label: None, layout: Some(&pipeline_layout), task: Some(wgpu::TaskState { - module: &ts, - entry_point: Some("main"), + module: &shader_module, + entry_point: Some("ts_main"), compilation_options: Default::default(), }), mesh: wgpu::MeshState { - module: &ms, - entry_point: Some("main"), + module: &shader_module, + entry_point: Some("ms_main"), compilation_options: Default::default(), }, fragment: Some(wgpu::FragmentState { - module: &fs, - entry_point: Some("main"), + module: &shader_module, + entry_point: Some("fs_main"), compilation_options: Default::default(), targets: &[Some(config.view_formats[0].into())], }), diff --git a/examples/features/src/mesh_shader/shader.wgsl b/examples/features/src/mesh_shader/shader.wgsl new file mode 100644 index 00000000000..70fc2aec333 --- /dev/null +++ b/examples/features/src/mesh_shader/shader.wgsl @@ -0,0 +1,71 @@ +enable mesh_shading; + +const positions = array( + vec4(0.,1.,0.,1.), + vec4(-1.,-1.,0.,1.), + vec4(1.,-1.,0.,1.) +); +const colors = array( + vec4(0.,1.,0.,1.), + vec4(0.,0.,1.,1.), + vec4(1.,0.,0.,1.) +); +struct TaskPayload { + colorMask: vec4, + visible: bool, +} +var taskPayload: TaskPayload; +var workgroupData: f32; +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) color: vec4, +} +struct PrimitiveOutput { + @builtin(triangle_indices) index: vec3, + @builtin(cull_primitive) cull: bool, + @per_primitive @location(1) colorMask: vec4, +} +struct PrimitiveInput { + @per_primitive @location(1) colorMask: vec4, +} + +@task +@payload(taskPayload) +@workgroup_size(1) +fn ts_main() -> @builtin(mesh_task_size) vec3 { + workgroupData = 1.0; + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return vec3(3, 1, 1); +} +@mesh +@payload(taskPayload) +@vertex_output(VertexOutput, 3) @primitive_output(PrimitiveOutput, 1) +@workgroup_size(1) +fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { + setMeshOutputs(3, 1); + workgroupData = 2.0; + var v: VertexOutput; + + v.position = positions[0]; + v.color = colors[0] * taskPayload.colorMask; + setVertex(0, v); + + v.position = positions[1]; + v.color = colors[1] * taskPayload.colorMask; + setVertex(1, v); + + v.position = positions[2]; + v.color = colors[2] * taskPayload.colorMask; + setVertex(2, v); + + var p: PrimitiveOutput; + p.index = vec3(0, 1, 2); + p.cull = !taskPayload.visible; + p.colorMask = vec4(1.0, 0.0, 1.0, 1.0); + setPrimitive(0, p); +} +@fragment +fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { + return vertex.color * primitive.colorMask; +} diff --git a/naga-cli/src/bin/naga.rs b/naga-cli/src/bin/naga.rs index 44369e9df7d..171d970166e 100644 --- a/naga-cli/src/bin/naga.rs +++ b/naga-cli/src/bin/naga.rs @@ -64,6 +64,12 @@ struct Args { #[argh(option)] shader_model: Option, + /// the SPIR-V version to use if targeting SPIR-V + /// + /// For example, 1.0, 1.4, etc + #[argh(option)] + spirv_version: Option, + /// the shader stage, for example 'frag', 'vert', or 'compute'. /// if the shader stage is unspecified it will be derived from /// the file extension. @@ -189,6 +195,22 @@ impl FromStr for ShaderModelArg { } } +#[derive(Debug, Clone)] +struct SpirvVersionArg(u8, u8); + +impl FromStr for SpirvVersionArg { + type Err = String; + + fn from_str(s: &str) -> Result { + let dot = s + .find(".") + .ok_or_else(|| "Missing dot separator".to_owned())?; + let major = s[..dot].parse::().map_err(|e| e.to_string())?; + let minor = s[dot + 1..].parse::().map_err(|e| e.to_string())?; + Ok(Self(major, minor)) + } +} + /// Newtype so we can implement [`FromStr`] for `ShaderSource`. #[derive(Debug, Clone, Copy)] struct ShaderStage(naga::ShaderStage); @@ -465,6 +487,9 @@ fn run() -> anyhow::Result<()> { if let Some(ref version) = args.metal_version { params.msl.lang_version = version.0; } + if let Some(ref version) = args.spirv_version { + params.spv_out.lang_version = (version.0, version.1); + } params.keep_coordinate_space = args.keep_coordinate_space; params.dot.cfg_only = args.dot_cfg_only; diff --git a/naga/Cargo.toml b/naga/Cargo.toml index 02eda4c198a..824bc9d6d87 100644 --- a/naga/Cargo.toml +++ b/naga/Cargo.toml @@ -20,20 +20,20 @@ all-features = true [features] default = [] -dot-out = [] -glsl-in = ["dep:pp-rs"] -glsl-out = [] +dot-out = ["wgpu-test/dot-out"] +glsl-in = ["dep:pp-rs", "wgpu-test/glsl-in"] +glsl-out = ["wgpu-test/glsl-out"] ## Enables outputting to the Metal Shading Language (MSL). ## ## This enables MSL output regardless of the target platform. ## If you want to enable it only when targeting iOS/tvOS/watchOS/macOS, use `naga/msl-out-if-target-apple`. -msl-out = [] +msl-out = ["wgpu-test/msl-out"] ## Enables outputting to the Metal Shading Language (MSL) only if the target platform is iOS/tvOS/watchOS/macOS. ## ## If you want to enable MSL output it regardless of the target platform, use `naga/msl-out`. -msl-out-if-target-apple = [] +msl-out-if-target-apple = ["wgpu-test/msl-out"] serialize = [ "dep:serde", @@ -56,16 +56,16 @@ arbitrary = [ "half/arbitrary", "half/std", ] -spv-in = ["dep:petgraph", "petgraph/graphmap", "dep:spirv"] -spv-out = ["dep:spirv"] -wgsl-in = ["dep:hexf-parse", "dep:unicode-ident"] -wgsl-out = [] +spv-in = ["dep:petgraph", "petgraph/graphmap", "dep:spirv", "wgpu-test/spv-in"] +spv-out = ["dep:spirv", "wgpu-test/spv-out"] +wgsl-in = ["dep:hexf-parse", "dep:unicode-ident", "wgpu-test/wgsl-in"] +wgsl-out = ["wgpu-test/wgsl-out"] ## Enables outputting to HLSL (Microsoft's High-Level Shader Language). ## ## This enables HLSL output regardless of the target platform. ## If you want to enable it only when targeting Windows, use `hlsl-out-if-target-windows`. -hlsl-out = [] +hlsl-out = ["wgpu-test/hlsl-out"] ## Enables outputting to HLSL (Microsoft's High-Level Shader Language) only if the target platform is Windows. ## @@ -116,10 +116,9 @@ itertools.workspace = true ron.workspace = true rspirv.workspace = true serde = { workspace = true, features = ["default", "derive"] } -spirv = { workspace = true, features = ["deserialize"] } strum = { workspace = true } -toml.workspace = true walkdir.workspace = true +wgpu-test.workspace = true [lints.clippy] std_instead_of_alloc = "warn" diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 826dad1c219..1f1396eccff 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -307,6 +307,25 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } + S::MeshFunction(crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + }) => { + self.dependencies.push((id, vertex_count, "vertex_count")); + self.dependencies + .push((id, primitive_count, "primitive_count")); + "SetMeshOutputs" + } + S::MeshFunction(crate::MeshFunction::SetVertex { index, value }) => { + self.dependencies.push((id, index, "index")); + self.dependencies.push((id, value, "value")); + "SetVertex" + } + S::MeshFunction(crate::MeshFunction::SetPrimitive { index, value }) => { + self.dependencies.push((id, index, "index")); + self.dependencies.push((id, value, "value")); + "SetPrimitive" + } S::SubgroupBallot { result, predicate } => { if let Some(predicate) = predicate { self.dependencies.push((id, predicate, "predicate")); diff --git a/naga/src/back/glsl/features.rs b/naga/src/back/glsl/features.rs index a6dfe4e3100..b884f08ac39 100644 --- a/naga/src/back/glsl/features.rs +++ b/naga/src/back/glsl/features.rs @@ -610,6 +610,7 @@ impl Writer<'_, W> { interpolation, sampling, blend_src, + per_primitive: _, } => { if interpolation == Some(Interpolation::Linear) { self.features.request(Features::NOPERSPECTIVE_QUALIFIER); diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index e78af74c844..1af18528944 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -139,7 +139,8 @@ impl crate::AddressSpace { | crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } | crate::AddressSpace::Handle - | crate::AddressSpace::PushConstant => false, + | crate::AddressSpace::PushConstant + | crate::AddressSpace::TaskPayload => false, } } } @@ -1300,6 +1301,9 @@ impl<'a, W: Write> Writer<'a, W> { crate::AddressSpace::Storage { .. } => { self.write_interface_block(handle, global)?; } + crate::AddressSpace::TaskPayload => { + self.write_interface_block(handle, global)?; + } // A global variable in the `Function` address space is a // contradiction in terms. crate::AddressSpace::Function => unreachable!(), @@ -1614,6 +1618,7 @@ impl<'a, W: Write> Writer<'a, W> { interpolation, sampling, blend_src, + per_primitive: _, } => (location, interpolation, sampling, blend_src), crate::Binding::BuiltIn(built_in) => { match built_in { @@ -1732,6 +1737,7 @@ impl<'a, W: Write> Writer<'a, W> { interpolation: None, sampling: None, blend_src, + per_primitive: false, }, stage: self.entry_point.stage, options: VaryingOptions::from_writer_options(self.options, output), @@ -2669,6 +2675,11 @@ impl<'a, W: Write> Writer<'a, W> { self.write_image_atomic(ctx, image, coordinate, array_index, fun, value)? } Statement::RayQuery { .. } => unreachable!(), + Statement::MeshFunction( + crate::MeshFunction::SetMeshOutputs { .. } + | crate::MeshFunction::SetVertex { .. } + | crate::MeshFunction::SetPrimitive { .. }, + ) => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); @@ -5247,6 +5258,15 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s Bi::SubgroupId => "gl_SubgroupID", Bi::SubgroupSize => "gl_SubgroupSize", Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", + // mesh + // TODO: figure out how to map these to glsl things as glsl treats them as arrays + Bi::CullPrimitive + | Bi::PointIndex + | Bi::LineIndices + | Bi::TriangleIndices + | Bi::MeshTaskSize => { + unimplemented!() + } } } @@ -5262,6 +5282,7 @@ const fn glsl_storage_qualifier(space: crate::AddressSpace) -> Option<&'static s As::Handle => Some("uniform"), As::WorkGroup => Some("shared"), As::PushConstant => Some("uniform"), + As::TaskPayload => unreachable!(), } } diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index ed40cbe5102..d6ccc5ec6e4 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -183,6 +183,9 @@ impl crate::BuiltIn { Self::PointSize | Self::ViewIndex | Self::PointCoord | Self::DrawID => { return Err(Error::Custom(format!("Unsupported builtin {self:?}"))) } + Self::CullPrimitive => "SV_CullPrimitive", + Self::PointIndex | Self::LineIndices | Self::TriangleIndices => unimplemented!(), + Self::MeshTaskSize => unreachable!(), }) } } diff --git a/naga/src/back/hlsl/mod.rs b/naga/src/back/hlsl/mod.rs index 8df06cf1323..f357c02bb3f 100644 --- a/naga/src/back/hlsl/mod.rs +++ b/naga/src/back/hlsl/mod.rs @@ -283,7 +283,8 @@ impl crate::ShaderStage { Self::Vertex => "vs", Self::Fragment => "ps", Self::Compute => "cs", - Self::Task | Self::Mesh => unreachable!(), + Self::Task => "ts", + Self::Mesh => "ms", } } } diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 357b8597521..9401766448f 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -507,7 +507,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_wrapped_functions(module, &ctx)?; - if ep.stage == ShaderStage::Compute { + if ep.stage.compute_like() { // HLSL is calling workgroup size "num threads" let num_threads = ep.workgroup_size; writeln!( @@ -967,6 +967,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_type(module, global.ty)?; "" } + crate::AddressSpace::TaskPayload => unimplemented!(), crate::AddressSpace::Uniform => { // constant buffer declarations are expected to be inlined, e.g. // `cbuffer foo: register(b0) { field1: type1; }` @@ -2599,6 +2600,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, ".Abort();")?; } }, + Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + }) => { + write!(self.out, "{level}SetMeshOutputCounts(")?; + self.write_expr(module, vertex_count, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, primitive_count, func_ctx)?; + write!(self.out, ");")?; + } + Statement::MeshFunction( + crate::MeshFunction::SetVertex { .. } | crate::MeshFunction::SetPrimitive { .. }, + ) => unimplemented!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = Baked(result).to_string(); @@ -3076,7 +3090,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::AddressSpace::Function | crate::AddressSpace::Private | crate::AddressSpace::WorkGroup - | crate::AddressSpace::PushConstant, + | crate::AddressSpace::PushConstant + | crate::AddressSpace::TaskPayload, ) | None => true, Some(crate::AddressSpace::Uniform) => { diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 7bc8289b9b8..8a2e07635b8 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -494,6 +494,7 @@ impl Options { interpolation, sampling, blend_src, + per_primitive: _, } => match mode { LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)), LocationMode::FragmentOutput => { @@ -651,6 +652,10 @@ impl ResolvedBinding { Bi::CullDistance | Bi::ViewIndex | Bi::DrawID => { return Err(Error::UnsupportedBuiltIn(built_in)) } + Bi::CullPrimitive => "primitive_culled", + // TODO: figure out how to make this written as a function call + Bi::PointIndex | Bi::LineIndices | Bi::TriangleIndices => unimplemented!(), + Bi::MeshTaskSize => unreachable!(), }; write!(out, "{name}")?; } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 2525855cd70..a6b80a2dd27 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -578,7 +578,8 @@ impl crate::AddressSpace { | Self::Private | Self::WorkGroup | Self::PushConstant - | Self::Handle => true, + | Self::Handle + | Self::TaskPayload => true, Self::Function => false, } } @@ -591,6 +592,7 @@ impl crate::AddressSpace { // may end up with "const" even if the binding is read-write, // and that should be OK. Self::Storage { .. } => true, + Self::TaskPayload => unimplemented!(), // These should always be read-write. Self::Private | Self::WorkGroup => false, // These translate to `constant` address space, no need for qualifiers. @@ -607,6 +609,7 @@ impl crate::AddressSpace { Self::Storage { .. } => Some("device"), Self::Private | Self::Function => Some("thread"), Self::WorkGroup => Some("threadgroup"), + Self::TaskPayload => Some("object_data"), } } } @@ -4020,6 +4023,14 @@ impl Writer { } } } + // TODO: write emitters for these + crate::Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { .. }) => { + unimplemented!() + } + crate::Statement::MeshFunction( + crate::MeshFunction::SetVertex { .. } + | crate::MeshFunction::SetPrimitive { .. }, + ) => unimplemented!(), crate::Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = self.namer.call(""); @@ -6169,7 +6180,7 @@ template LocationMode::Uniform, false, ), - crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(), + crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(), }; // Should this entry point be modified to do vertex pulling? @@ -6232,6 +6243,9 @@ template break; } } + crate::AddressSpace::TaskPayload => { + unimplemented!() + } crate::AddressSpace::Function | crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => {} @@ -7159,7 +7173,7 @@ mod workgroup_mem_init { fun_info: &valid::FunctionInfo, ) -> bool { options.zero_initialize_workgroup_memory - && ep.stage == crate::ShaderStage::Compute + && ep.stage.compute_like() && module.global_variables.iter().any(|(handle, var)| { !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup }) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index d2b3ed70eda..c009082a3c9 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -39,6 +39,8 @@ pub enum PipelineConstantError { ValidationError(#[from] WithSpan), #[error("workgroup_size override isn't strictly positive")] NegativeWorkgroupSize, + #[error("max vertices or max primitives is negative")] + NegativeMeshOutputMax, } /// Compact `module` and replace all overrides with constants. @@ -243,6 +245,7 @@ pub fn process_overrides<'a>( for ep in entry_points.iter_mut() { process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?; process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?; + process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?; } module.entry_points = entry_points; module.overrides = overrides; @@ -296,6 +299,28 @@ fn process_workgroup_size_override( Ok(()) } +fn process_mesh_shader_overrides( + module: &mut Module, + adjusted_global_expressions: &HandleVec>, + ep: &mut crate::EntryPoint, +) -> Result<(), PipelineConstantError> { + if let Some(ref mut mesh_info) = ep.mesh_info { + if let Some(r#override) = mesh_info.max_vertices_override { + mesh_info.max_vertices = module + .to_ctx() + .eval_expr_to_u32(adjusted_global_expressions[r#override]) + .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)?; + } + if let Some(r#override) = mesh_info.max_primitives_override { + mesh_info.max_primitives = module + .to_ctx() + .eval_expr_to_u32(adjusted_global_expressions[r#override]) + .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)?; + } + } + Ok(()) +} + /// Add a [`Constant`] to `module` for the override `old_h`. /// /// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`. @@ -835,6 +860,26 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S crate::RayQueryFunction::Terminate => {} } } + Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { + ref mut vertex_count, + ref mut primitive_count, + }) => { + adjust(vertex_count); + adjust(primitive_count); + } + Statement::MeshFunction( + crate::MeshFunction::SetVertex { + ref mut index, + ref mut value, + } + | crate::MeshFunction::SetPrimitive { + ref mut index, + ref mut value, + }, + ) => { + adjust(index); + adjust(value); + } Statement::Break | Statement::Continue | Statement::Kill diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 0cd414bfbeb..88a8c0a17a1 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -221,8 +221,12 @@ impl Writer { ir_result: &crate::FunctionResult, result_members: &[ResultMember], body: &mut Vec, - ) -> Result<(), Error> { + task_payload: Option, + ) -> Result { for (index, res_member) in result_members.iter().enumerate() { + if res_member.built_in == Some(crate::BuiltIn::MeshTaskSize) { + continue; + } let member_value_id = match ir_result.binding { Some(_) => value_id, None => { @@ -253,6 +257,369 @@ impl Writer { _ => {} } } + // OpEmitMeshTasksEXT must be called right before exiting (after setting other + // output variables if there are any) + for (index, res_member) in result_members.iter().enumerate() { + if res_member.built_in == Some(crate::BuiltIn::MeshTaskSize) { + let member_value_id = match ir_result.binding { + Some(_) => value_id, + None => { + let member_value_id = self.id_gen.next(); + body.push(Instruction::composite_extract( + res_member.type_id, + member_value_id, + value_id, + &[index as Word], + )); + member_value_id + } + }; + + let values = [self.id_gen.next(), self.id_gen.next(), self.id_gen.next()]; + for (i, &value) in values.iter().enumerate() { + let instruction = Instruction::composite_extract( + self.get_u32_type_id(), + value, + member_value_id, + &[i as Word], + ); + body.push(instruction); + } + let mut instruction = Instruction::new(spirv::Op::EmitMeshTasksEXT); + for id in values { + instruction.add_operand(id); + } + if let Some(task_payload) = task_payload { + instruction.add_operand(task_payload); + } + return Ok(instruction); + } + } + Ok(Instruction::return_void()) + } + + /// Writes the return call for a mesh shader, which involves copying previously + /// written vertices/primitives into the actual output location. + fn write_mesh_shader_return( + &mut self, + return_info: &super::MeshReturnInfo, + body: &mut Vec, + ) -> Result<(), Error> { + // Gets the info about temporary buffers and such + let vert_info = self.mesh_shader_output_variable( + return_info.vertex_type, + false, + return_info.max_vertices, + )?; + let prim_info = self.mesh_shader_output_variable( + return_info.primitive_type, + true, + return_info.max_primitives, + )?; + // Load the actual vertex and primitive counts + let vert_count_id = self.id_gen.next(); + body.push(Instruction::load( + self.get_u32_type_id(), + vert_count_id, + self.mesh_state.num_vertices_var.unwrap(), + None, + )); + let prim_count_id = self.id_gen.next(); + body.push(Instruction::load( + self.get_u32_type_id(), + prim_count_id, + self.mesh_state.num_primitives_var.unwrap(), + None, + )); + + // Call this. It must be called exactly once, which the user shouldn't be assumed + // to have done correctly. + { + let mut ins = Instruction::new(spirv::Op::SetMeshOutputsEXT); + ins.add_operand(vert_count_id); + ins.add_operand(prim_count_id); + body.push(ins); + } + + // All this for a `for i in 0..num_vertices` lol + // This is basically just a memcpy but the result is split up to multiple places + let u32_type_id = self.get_u32_type_id(); + let zero_u32 = self.get_constant_scalar(crate::Literal::U32(0)); + let vertex_loop_header = self.id_gen.next(); + let prim_loop_header = self.id_gen.next(); + let in_between_loops = self.id_gen.next(); + let func_end = self.id_gen.next(); + let index_var = return_info.function_variable; + + body.push(Instruction::store( + index_var, + return_info.local_invocation_index_id, + None, + )); + body.push(Instruction::branch(vertex_loop_header)); + + // Vertex copies + let vertex_copy_body = { + let mut body = Vec::new(); + // Current index to copy + let val_i = self.id_gen.next(); + body.push(Instruction::load(u32_type_id, val_i, index_var, None)); + + let vert_to_copy_ptr = self.id_gen.next(); + body.push(Instruction::access_chain( + self.get_pointer_type_id(vert_info.inner_ty, spirv::StorageClass::Workgroup), + vert_to_copy_ptr, + vert_info.var_id, + &[val_i], + )); + + // Load the entire vertex value + let vert_to_copy = self.id_gen.next(); + body.push(Instruction::load( + vert_info.inner_ty, + vert_to_copy, + vert_to_copy_ptr, + None, + )); + + let mut builtin_index = 0; + let mut binding_index = 0; + // Write individual members of the vertex + for (member_id, member) in return_info.vertex_members.iter().enumerate() { + let val_to_copy = self.id_gen.next(); + let mut needs_y_flip = false; + body.push(Instruction::composite_extract( + member.ty_id, + val_to_copy, + vert_to_copy, + &[member_id as u32], + )); + let ptr_to_copy_to = self.id_gen.next(); + // Get the variable that holds it and indexed pointer, which points to + // the value and not a wrapper struct + match member.binding { + crate::Binding::BuiltIn(bi) => { + body.push(Instruction::access_chain( + self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output), + ptr_to_copy_to, + return_info.vertex_builtin_block.as_ref().unwrap().var_id, + &[ + val_i, + self.get_constant_scalar(crate::Literal::U32(builtin_index)), + ], + )); + needs_y_flip = matches!(bi, crate::BuiltIn::Position { .. }) + && self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE); + builtin_index += 1; + } + crate::Binding::Location { .. } => { + body.push(Instruction::access_chain( + self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output), + ptr_to_copy_to, + return_info.vertex_bindings[binding_index].var_id, + &[val_i, zero_u32], + )); + binding_index += 1; + } + } + body.push(Instruction::store(ptr_to_copy_to, val_to_copy, None)); + // Can't use epilogue flip because can't read from this storage class I believe + if needs_y_flip { + let prev_y = self.id_gen.next(); + body.push(Instruction::composite_extract( + self.get_f32_type_id(), + prev_y, + val_to_copy, + &[1], + )); + let new_y = self.id_gen.next(); + body.push(Instruction::unary( + spirv::Op::FNegate, + self.get_f32_type_id(), + new_y, + prev_y, + )); + let new_ptr_to_copy_to = self.id_gen.next(); + body.push(Instruction::access_chain( + self.get_f32_pointer_type_id(spirv::StorageClass::Output), + new_ptr_to_copy_to, + ptr_to_copy_to, + &[self.get_constant_scalar(crate::Literal::U32(1))], + )); + body.push(Instruction::store(new_ptr_to_copy_to, new_y, None)); + } + } + body + }; + + // Primitive copies + let primitive_copy_body = { + // See comments in `vertex_copy_body` + let mut body = Vec::new(); + let val_i = self.id_gen.next(); + body.push(Instruction::load(u32_type_id, val_i, index_var, None)); + + let prim_to_copy_ptr = self.id_gen.next(); + body.push(Instruction::access_chain( + self.get_pointer_type_id(prim_info.inner_ty, spirv::StorageClass::Workgroup), + prim_to_copy_ptr, + prim_info.var_id, + &[val_i], + )); + let prim_to_copy = self.id_gen.next(); + body.push(Instruction::load( + prim_info.inner_ty, + prim_to_copy, + prim_to_copy_ptr, + None, + )); + + let mut builtin_index = 0; + let mut binding_index = 0; + for (member_id, member) in return_info.primitive_members.iter().enumerate() { + let val_to_copy = self.id_gen.next(); + body.push(Instruction::composite_extract( + member.ty_id, + val_to_copy, + prim_to_copy, + &[member_id as u32], + )); + let ptr_to_copy_to = self.id_gen.next(); + match member.binding { + crate::Binding::BuiltIn( + crate::BuiltIn::PointIndex + | crate::BuiltIn::LineIndices + | crate::BuiltIn::TriangleIndices, + ) => { + body.push(Instruction::access_chain( + self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output), + ptr_to_copy_to, + return_info.primitive_indices.as_ref().unwrap().var_id, + &[val_i], + )); + } + crate::Binding::BuiltIn(_) => { + body.push(Instruction::access_chain( + self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output), + ptr_to_copy_to, + return_info.primitive_builtin_block.as_ref().unwrap().var_id, + &[ + val_i, + self.get_constant_scalar(crate::Literal::U32(builtin_index)), + ], + )); + builtin_index += 1; + } + crate::Binding::Location { .. } => { + body.push(Instruction::access_chain( + self.get_pointer_type_id(member.ty_id, spirv::StorageClass::Output), + ptr_to_copy_to, + return_info.primitive_bindings[binding_index].var_id, + &[val_i, zero_u32], + )); + binding_index += 1; + } + } + body.push(Instruction::store(ptr_to_copy_to, val_to_copy, None)); + } + body + }; + + // This writes the actual loop + let mut get_loop_continue_id = |body: &mut Vec, + mut loop_body_block, + loop_header, + loop_merge, + count_id, + index_var| { + let condition_check = self.id_gen.next(); + let loop_continue = self.id_gen.next(); + let loop_body = self.id_gen.next(); + + // Loop header + { + body.push(Instruction::label(loop_header)); + body.push(Instruction::loop_merge( + loop_merge, + loop_continue, + spirv::SelectionControl::empty(), + )); + body.push(Instruction::branch(condition_check)); + } + // Condition check - check if i is less than num vertices to copy + { + body.push(Instruction::label(condition_check)); + + let val_i = self.id_gen.next(); + body.push(Instruction::load(u32_type_id, val_i, index_var, None)); + + let cond = self.id_gen.next(); + body.push(Instruction::binary( + spirv::Op::ULessThan, + self.get_bool_type_id(), + cond, + val_i, + count_id, + )); + body.push(Instruction::branch_conditional(cond, loop_body, loop_merge)); + } + // Loop body + { + body.push(Instruction::label(loop_body)); + body.append(&mut loop_body_block); + body.push(Instruction::branch(loop_continue)); + } + // Loop continue - increment i + { + body.push(Instruction::label(loop_continue)); + + let prev_val_i = self.id_gen.next(); + body.push(Instruction::load(u32_type_id, prev_val_i, index_var, None)); + let new_val_i = self.id_gen.next(); + body.push(Instruction::binary( + spirv::Op::IAdd, + u32_type_id, + new_val_i, + prev_val_i, + return_info.workgroup_size, + )); + body.push(Instruction::store(index_var, new_val_i, None)); + + body.push(Instruction::branch(loop_header)); + } + }; + // Write vertex copy loop + get_loop_continue_id( + body, + vertex_copy_body, + vertex_loop_header, + in_between_loops, + vert_count_id, + index_var, + ); + // In between loops, reset the initial index + { + body.push(Instruction::label(in_between_loops)); + + body.push(Instruction::store( + index_var, + return_info.local_invocation_index_id, + None, + )); + + body.push(Instruction::branch(prim_loop_header)); + } + // Write primitive copy loop + get_loop_continue_id( + body, + primitive_copy_body, + prim_loop_header, + func_end, + prim_count_id, + index_var, + ); + + body.push(Instruction::label(func_end)); Ok(()) } } @@ -3227,21 +3594,27 @@ impl BlockContext<'_> { let instruction = match self.function.entry_point_context { // If this is an entry point, and we need to return anything, // let's instead store the output variables and return `void`. - Some(ref context) => { - self.writer.write_entry_point_return( - value_id, - self.ir_function.result.as_ref().unwrap(), - &context.results, - &mut block.body, - )?; - Instruction::return_void() - } + Some(ref context) => self.writer.write_entry_point_return( + value_id, + self.ir_function.result.as_ref().unwrap(), + &context.results, + &mut block.body, + context.task_payload, + )?, None => Instruction::return_value(value_id), }; self.function.consume(block, instruction); return Ok(BlockExitDisposition::Discarded); } Statement::Return { value: None } => { + if let Some(super::EntryPointContext { + mesh_state: Some(ref mesh_state), + .. + }) = self.function.entry_point_context + { + self.writer + .write_mesh_shader_return(mesh_state, &mut block.body)?; + }; self.function.consume(block, Instruction::return_void()); return Ok(BlockExitDisposition::Discarded); } @@ -3633,6 +4006,50 @@ impl BlockContext<'_> { Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } + Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + }) => { + self.writer.require_mesh_shaders()?; + block.body.push(Instruction::store( + self.writer.mesh_state.num_vertices_var.unwrap(), + self.cached[vertex_count], + None, + )); + block.body.push(Instruction::store( + self.writer.mesh_state.num_primitives_var.unwrap(), + self.cached[primitive_count], + None, + )); + } + Statement::MeshFunction( + crate::MeshFunction::SetVertex { index, value } + | crate::MeshFunction::SetPrimitive { index, value }, + ) => { + self.writer.require_mesh_shaders()?; + let is_prim = matches!( + *statement, + Statement::MeshFunction(crate::MeshFunction::SetPrimitive { .. }) + ); + let type_handle = if is_prim { + self.fun_info.mesh_shader_info.primitive_type.unwrap().0 + } else { + self.fun_info.mesh_shader_info.vertex_type.unwrap().0 + }; + let info = self + .writer + .mesh_shader_output_variable(type_handle, is_prim, 0)?; + let out_ptr_id = self.gen_id(); + block.body.push(Instruction::access_chain( + self.get_pointer_type_id(info.inner_ty, spirv::StorageClass::Workgroup), + out_ptr_id, + info.var_id, + &[self.cached[index]], + )); + block + .body + .push(Instruction::store(out_ptr_id, self.cached[value], None)); + } Statement::SubgroupBallot { result, ref predicate, diff --git a/naga/src/back/spv/helpers.rs b/naga/src/back/spv/helpers.rs index 84e130efaa3..48dc7550ec6 100644 --- a/naga/src/back/spv/helpers.rs +++ b/naga/src/back/spv/helpers.rs @@ -1,5 +1,6 @@ use alloc::{vec, vec::Vec}; +use arrayvec::ArrayVec; use spirv::Word; use crate::{Handle, UniqueArena}; @@ -54,6 +55,7 @@ pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::Stor crate::AddressSpace::Uniform => spirv::StorageClass::Uniform, crate::AddressSpace::WorkGroup => spirv::StorageClass::Workgroup, crate::AddressSpace::PushConstant => spirv::StorageClass::PushConstant, + crate::AddressSpace::TaskPayload => spirv::StorageClass::TaskPayloadWorkgroupEXT, } } @@ -153,3 +155,14 @@ impl StrUnstable for str { } } } + +pub enum BindingDecorations { + BuiltIn(spirv::BuiltIn, ArrayVec), + Location { + location: Word, + others: ArrayVec, + /// If this is `Some`, use Decoration::Index with blend_src as an operand + blend_src: Option, + }, + None, +} diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 21c00015478..e5c1d4a692b 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -52,6 +52,7 @@ struct LogicalLayout { function_definitions: Vec, } +#[derive(Clone)] struct Instruction { op: spirv::Op, wc: u32, @@ -78,6 +79,8 @@ pub enum Error { Override, #[error(transparent)] ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError), + #[error("module requires SPIRV-{0}.{1}, which isn't supported")] + SpirvVersionTooLow(u8, u8), } #[derive(Default)] @@ -139,9 +142,43 @@ struct ResultMember { built_in: Option, } +struct MeshReturnGlobalVariable { + _inner_ty: u32, + var_id: u32, +} + +#[derive(Clone)] +struct MeshReturnMember { + ty_id: u32, + binding: crate::Binding, +} +struct MeshReturnInfo { + vertex_type: Handle, + vertex_members: Vec, + max_vertices: u32, + primitive_type: Handle, + primitive_members: Vec, + max_primitives: u32, + // In vulkan, all builtins must be in the same block. + // All bindings must be in their own unique block. + // Also, the primitive indices builtin family needs its own block. + // Also also, cull primitive doesn't care about having its own block. + vertex_builtin_block: Option, + vertex_bindings: Vec, + primitive_builtin_block: Option, + primitive_bindings: Vec, + primitive_indices: Option, + local_invocation_index_id: u32, + workgroup_size: u32, + /// The id of a function variable in the entry point for a u32 + function_variable: u32, +} + struct EntryPointContext { argument_ids: Vec, results: Vec, + task_payload: Option, + mesh_state: Option, } #[derive(Default)] @@ -771,6 +808,8 @@ pub struct Writer { ray_get_committed_intersection_function: Option, ray_get_candidate_intersection_function: Option, + + mesh_state: WriteMeshInfo, } bitflags::bitflags! { @@ -908,3 +947,30 @@ pub fn write_vec( )?; Ok(words) } + +/// The outputs of a mesh shader must be stored in global variables. These outputs are determined by the attributes on +/// the entry point, but other functions may also set these outputs. A single module might have multiple such global +/// variables, but each function will only end up using one and will have to look up the global variable for its type, +/// not its entry point. Therefore the variables must be associated with types and not entry points. +pub struct WriteMeshInfo { + pub vertex_outputs_by_type: crate::FastHashMap, MeshOutputInfo>, + pub primitive_outputs_by_type: crate::FastHashMap, MeshOutputInfo>, + /// The workgroup variable containing the number of vertices to write + pub num_vertices_var: Option, + /// The workgroup variable containing the number of primitives to write + pub num_primitives_var: Option, +} + +#[derive(Clone)] +pub struct MeshOutputInfo { + /// The index of the word that specifies the length of the global variable array. This is very hacky lol + /// We want to allow the same global variable to be used across entry points of the same output type + /// so that they can reuse functions that might set vertices/indices. So we make the array the largest + /// of the max output sizes among all the entry points! It will default to zero, so that if an unused function + /// tries to write to it, the function can still be valid if it is never called. + pub index_of_length_decl: usize, + pub inner_ty: Word, + pub array_ty: Word, + pub var_id: Word, + pub array_size_id: Word, +} diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 0688eb6c975..2f5315c7443 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1,5 +1,6 @@ use alloc::{string::String, vec, vec::Vec}; +use arrayvec::ArrayVec; use hashbrown::hash_map::Entry; use spirv::Word; @@ -13,7 +14,8 @@ use super::{ }; use crate::{ arena::{Handle, HandleVec, UniqueArena}, - back::spv::{BindingInfo, WrappedFunction}, + back::spv::{helpers::BindingDecorations, BindingInfo, WrappedFunction}, + non_max_u32::NonMaxU32, path_like::PathLike, proc::{Alignment, TypeResolution}, valid::{FunctionInfo, ModuleInfo}, @@ -22,6 +24,9 @@ use crate::{ struct FunctionInterface<'a> { varying_ids: &'a mut Vec, stage: crate::ShaderStage, + task_payload: Option>, + mesh_info: Option, + workgroup_size: [u32; 3], } impl Function { @@ -92,9 +97,30 @@ impl Writer { temp_list: Vec::new(), ray_get_committed_intersection_function: None, ray_get_candidate_intersection_function: None, + mesh_state: super::WriteMeshInfo { + vertex_outputs_by_type: crate::FastHashMap::default(), + primitive_outputs_by_type: crate::FastHashMap::default(), + num_vertices_var: None, + num_primitives_var: None, + }, }) } + pub fn set_options(&mut self, options: &Options) -> Result<(), Error> { + let (major, minor) = options.lang_version; + if major != 1 { + return Err(Error::UnsupportedVersion(major, minor)); + } + self.physical_layout = PhysicalLayout::new(major, minor); + self.capabilities_available = options.capabilities.clone(); + self.flags = options.flags; + self.bounds_check_policies = options.bounds_check_policies; + self.zero_initialize_workgroup_memory = options.zero_initialize_workgroup_memory; + self.force_loop_bounding = options.force_loop_bounding; + self.binding_map = options.binding_map.clone(); + Ok(()) + } + /// Returns `(major, minor)` of the SPIR-V language version. pub const fn lang_version(&self) -> (u8, u8) { self.physical_layout.lang_version() @@ -151,6 +177,14 @@ impl Writer { temp_list: take(&mut self.temp_list).recycle(), ray_get_candidate_intersection_function: None, ray_get_committed_intersection_function: None, + + mesh_state: super::WriteMeshInfo { + vertex_outputs_by_type: take(&mut self.mesh_state.vertex_outputs_by_type).recycle(), + primitive_outputs_by_type: take(&mut self.mesh_state.primitive_outputs_by_type) + .recycle(), + num_vertices_var: None, + num_primitives_var: None, + }, }; *self = fresh; @@ -681,6 +715,396 @@ impl Writer { Ok(()) } + fn write_mesh_return_global_variable( + &mut self, + ty: u32, + array_size_id: u32, + ) -> Result { + let array_ty = self.id_gen.next(); + Instruction::type_array(array_ty, ty, array_size_id) + .to_words(&mut self.logical_layout.declarations); + let ptr_ty = self.get_pointer_type_id(array_ty, spirv::StorageClass::Output); + let var_id = self.id_gen.next(); + Instruction::variable(ptr_ty, var_id, spirv::StorageClass::Output, None) + .to_words(&mut self.logical_layout.declarations); + Ok(super::MeshReturnGlobalVariable { + _inner_ty: ty, + var_id, + }) + } + + /// This does various setup things to allow mesh shader entry points + /// to be properly written, such as creating the output variables + fn write_entry_point_mesh_shader_info( + &mut self, + iface: &mut FunctionInterface, + local_invocation_index_id: Option, + ir_module: &crate::Module, + prelude: &mut Block, + ep_context: &mut EntryPointContext, + ) -> Result<(), Error> { + if let Some(ref mesh_info) = iface.mesh_info { + // Create the temporary output variables + let vert_info = self.mesh_shader_output_variable( + mesh_info.vertex_output_type, + false, + mesh_info.max_vertices, + )?; + let prim_info = self.mesh_shader_output_variable( + mesh_info.primitive_output_type, + true, + mesh_info.max_primitives, + )?; + iface.varying_ids.push(vert_info.var_id); + iface.varying_ids.push(prim_info.var_id); + + // These are guaranteed to be initialized after mesh_shader_output_variable + // is called + iface + .varying_ids + .push(self.mesh_state.num_vertices_var.unwrap()); + iface + .varying_ids + .push(self.mesh_state.num_primitives_var.unwrap()); + + // Maybe TODO: zero initialize num_vertices and num_primitives + + // Collect the members in the output structs + let vertex_members = match &ir_module.types[mesh_info.vertex_output_type] { + &crate::Type { + inner: crate::TypeInner::Struct { ref members, .. }, + .. + } => members + .iter() + .map(|a| super::MeshReturnMember { + ty_id: self.get_handle_type_id(a.ty), + binding: a.binding.clone().unwrap(), + }) + .collect(), + _ => unreachable!(), + }; + let primitive_members = match &ir_module.types[mesh_info.primitive_output_type] { + &crate::Type { + inner: crate::TypeInner::Struct { ref members, .. }, + .. + } => members + .iter() + .map(|a| super::MeshReturnMember { + ty_id: self.get_handle_type_id(a.ty), + binding: a.binding.clone().unwrap(), + }) + .collect(), + _ => unreachable!(), + }; + // In the final return, we do a giant memcpy, for which this is helpful + let local_invocation_index_id = match local_invocation_index_id { + Some(a) => a, + None => { + let u32_id = self.get_u32_type_id(); + let var = self.id_gen.next(); + Instruction::variable( + self.get_pointer_type_id(u32_id, spirv::StorageClass::Input), + var, + spirv::StorageClass::Input, + None, + ) + .to_words(&mut self.logical_layout.declarations); + Instruction::decorate( + var, + spirv::Decoration::BuiltIn, + &[spirv::BuiltIn::LocalInvocationIndex as u32], + ) + .to_words(&mut self.logical_layout.annotations); + + let loaded_value = self.id_gen.next(); + prelude + .body + .push(Instruction::load(u32_id, loaded_value, var, None)); + loaded_value + } + }; + let u32_id = self.get_u32_type_id(); + // A general function variable that we guarantee to allow in the final return. It must be + // declared at the top of the function. Currently it is used in the memcpy part to keep + // index to copy track of the current + let function_variable = self.id_gen.next(); + prelude.body.insert( + 0, + Instruction::variable( + self.get_pointer_type_id(u32_id, spirv::StorageClass::Function), + function_variable, + spirv::StorageClass::Function, + None, + ), + ); + // This is the information that is passed to the function writer + // so that it can write the final return logic + let mut mesh_return_info = super::MeshReturnInfo { + vertex_type: mesh_info.vertex_output_type, + vertex_members, + max_vertices: mesh_info.max_vertices, + primitive_type: mesh_info.primitive_output_type, + primitive_members, + max_primitives: mesh_info.max_primitives, + vertex_bindings: Vec::new(), + vertex_builtin_block: None, + primitive_bindings: Vec::new(), + primitive_builtin_block: None, + primitive_indices: None, + local_invocation_index_id, + workgroup_size: self.get_constant_scalar(crate::Literal::U32( + iface.workgroup_size.iter().product(), + )), + function_variable, + }; + // Create the actual output variables and types. + // According to SPIR-V, + // * All builtins must be in the same output `Block` + // * Each member with `location` must be in its own `Block`. + // * Some builtins like CullPrimitiveEXT don't care as much (older validation layers don't know this!) + // * Some builtins like the indices ones need to be in their + // own output variable without a struct wrapper + if mesh_return_info + .vertex_members + .iter() + .any(|a| matches!(a.binding, crate::Binding::BuiltIn(..))) + { + let builtin_block_ty_id = self.id_gen.next(); + let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]); + let mut bi_index = 0; + let mut decorations = Vec::new(); + for member in &mesh_return_info.vertex_members { + if let crate::Binding::BuiltIn(_) = member.binding { + ins.add_operand(member.ty_id); + let binding = self.map_binding( + ir_module, + iface.stage, + spirv::StorageClass::Output, + // Unused except in fragment shaders with other conditions, so we can pass null + Handle::new(NonMaxU32::new(0).unwrap()), + &member.binding, + )?; + match binding { + BindingDecorations::BuiltIn(bi, others) => { + decorations.push(Instruction::member_decorate( + builtin_block_ty_id, + bi_index, + spirv::Decoration::BuiltIn, + &[bi as Word], + )); + for other in others { + decorations.push(Instruction::member_decorate( + builtin_block_ty_id, + bi_index, + other, + &[], + )); + } + } + _ => unreachable!(), + } + bi_index += 1; + } + } + ins.to_words(&mut self.logical_layout.declarations); + decorations.push(Instruction::decorate( + builtin_block_ty_id, + spirv::Decoration::Block, + &[], + )); + for dec in decorations { + dec.to_words(&mut self.logical_layout.annotations); + } + let v = self.write_mesh_return_global_variable( + builtin_block_ty_id, + vert_info.array_size_id, + )?; + iface.varying_ids.push(v.var_id); + if self.flags.contains(WriterFlags::DEBUG) { + Instruction::name(v.var_id, "naga_vertex_builtin_outputs") + .to_words(&mut self.logical_layout.debugs); + } + mesh_return_info.vertex_builtin_block = Some(v); + } + if mesh_return_info.primitive_members.iter().any(|a| { + !matches!( + a.binding, + crate::Binding::BuiltIn( + crate::BuiltIn::PointIndex + | crate::BuiltIn::LineIndices + | crate::BuiltIn::TriangleIndices + ) | crate::Binding::Location { .. } + ) + }) { + let builtin_block_ty_id = self.id_gen.next(); + let mut ins = Instruction::type_struct(builtin_block_ty_id, &[]); + let mut bi_index = 0; + let mut decorations = Vec::new(); + for member in &mesh_return_info.primitive_members { + if let crate::Binding::BuiltIn(bi) = member.binding { + if matches!( + bi, + crate::BuiltIn::PointIndex + | crate::BuiltIn::LineIndices + | crate::BuiltIn::TriangleIndices, + ) { + continue; + } + ins.add_operand(member.ty_id); + let binding = self.map_binding( + ir_module, + iface.stage, + spirv::StorageClass::Output, + // Unused except in fragment shaders with other conditions, so we can pass null + Handle::new(NonMaxU32::new(0).unwrap()), + &member.binding, + )?; + match binding { + BindingDecorations::BuiltIn(bi, others) => { + decorations.push(Instruction::member_decorate( + builtin_block_ty_id, + bi_index, + spirv::Decoration::BuiltIn, + &[bi as Word], + )); + for other in others { + decorations.push(Instruction::member_decorate( + builtin_block_ty_id, + bi_index, + other, + &[], + )); + } + } + _ => unreachable!(), + } + bi_index += 1; + } + } + ins.to_words(&mut self.logical_layout.declarations); + decorations.push(Instruction::decorate( + builtin_block_ty_id, + spirv::Decoration::Block, + &[], + )); + for dec in decorations { + dec.to_words(&mut self.logical_layout.annotations); + } + let v = self.write_mesh_return_global_variable( + builtin_block_ty_id, + prim_info.array_size_id, + )?; + Instruction::decorate(v.var_id, spirv::Decoration::PerPrimitiveEXT, &[]) + .to_words(&mut self.logical_layout.annotations); + iface.varying_ids.push(v.var_id); + if self.flags.contains(WriterFlags::DEBUG) { + Instruction::name(v.var_id, "naga_primitive_builtin_outputs") + .to_words(&mut self.logical_layout.debugs); + } + mesh_return_info.primitive_builtin_block = Some(v); + } + { + for member in &mesh_return_info.vertex_members { + match member.binding { + crate::Binding::Location { location, .. } => { + let s_type = self.id_gen.next(); + Instruction::type_struct(s_type, &[member.ty_id]) + .to_words(&mut self.logical_layout.declarations); + Instruction::decorate(s_type, spirv::Decoration::Block, &[]) + .to_words(&mut self.logical_layout.annotations); + Instruction::member_decorate( + s_type, + 0, + spirv::Decoration::Location, + &[location], + ) + .to_words(&mut self.logical_layout.annotations); + let v = self.write_mesh_return_global_variable( + s_type, + prim_info.array_size_id, + )?; + iface.varying_ids.push(v.var_id); + mesh_return_info.vertex_bindings.push(v); + } + crate::Binding::BuiltIn(_) => (), + } + } + for member in &mesh_return_info.primitive_members { + match member.binding { + crate::Binding::BuiltIn( + crate::BuiltIn::PointIndex + | crate::BuiltIn::LineIndices + | crate::BuiltIn::TriangleIndices, + ) => { + let v = self.write_mesh_return_global_variable( + member.ty_id, + prim_info.array_size_id, + )?; + Instruction::decorate( + v.var_id, + spirv::Decoration::PerPrimitiveEXT, + &[], + ) + .to_words(&mut self.logical_layout.annotations); + Instruction::decorate( + v.var_id, + spirv::Decoration::BuiltIn, + &[match member.binding.to_built_in().unwrap() { + crate::BuiltIn::PointIndex => { + spirv::BuiltIn::PrimitivePointIndicesEXT + } + crate::BuiltIn::LineIndices => { + spirv::BuiltIn::PrimitiveLineIndicesEXT + } + crate::BuiltIn::TriangleIndices => { + spirv::BuiltIn::PrimitiveTriangleIndicesEXT + } + _ => unreachable!(), + } as Word], + ) + .to_words(&mut self.logical_layout.annotations); + iface.varying_ids.push(v.var_id); + if self.flags.contains(WriterFlags::DEBUG) { + Instruction::name(v.var_id, "naga_primitive_indices_outputs") + .to_words(&mut self.logical_layout.debugs); + } + mesh_return_info.primitive_indices = Some(v); + } + crate::Binding::Location { location, .. } => { + let s_type = self.id_gen.next(); + Instruction::type_struct(s_type, &[member.ty_id]) + .to_words(&mut self.logical_layout.declarations); + Instruction::decorate(s_type, spirv::Decoration::Block, &[]) + .to_words(&mut self.logical_layout.annotations); + Instruction::member_decorate( + s_type, + 0, + spirv::Decoration::Location, + &[location], + ) + .to_words(&mut self.logical_layout.annotations); + let v = self.write_mesh_return_global_variable( + s_type, + prim_info.array_size_id, + )?; + Instruction::decorate( + v.var_id, + spirv::Decoration::PerPrimitiveEXT, + &[], + ) + .to_words(&mut self.logical_layout.annotations); + iface.varying_ids.push(v.var_id); + mesh_return_info.primitive_bindings.push(v); + } + crate::Binding::BuiltIn(_) => (), + } + } + } + ep_context.mesh_state = Some(mesh_return_info); + } + Ok(()) + } + fn write_function( &mut self, ir_function: &crate::Function, @@ -699,11 +1123,20 @@ impl Writer { let mut ep_context = EntryPointContext { argument_ids: Vec::new(), results: Vec::new(), + task_payload: if let Some(ref i) = interface { + i.task_payload.map(|a| self.global_variables[a].var_id) + } else { + None + }, + mesh_state: None, }; let mut local_invocation_id = None; let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len()); + + let mut local_invocation_index_id = None; + for argument in ir_function.arguments.iter() { let class = spirv::StorageClass::Input; let handle_ty = ir_module.types[argument.ty].inner.is_handle(); @@ -733,6 +1166,10 @@ impl Writer { if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) { local_invocation_id = Some(id); + } else if binding + == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) + { + local_invocation_index_id = Some(id); } id @@ -762,6 +1199,10 @@ impl Writer { if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) { local_invocation_id = Some(id); + } else if binding + == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationIndex) + { + local_invocation_index_id = Some(id); } } prelude.body.push(Instruction::composite_construct( @@ -810,15 +1251,21 @@ impl Writer { has_point_size |= *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize); let type_id = self.get_handle_type_id(result.ty); - let varying_id = self.write_varying( - ir_module, - iface.stage, - class, - None, - result.ty, - binding, - )?; - iface.varying_ids.push(varying_id); + let varying_id = + if *binding == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize) { + 0 + } else { + let varying_id = self.write_varying( + ir_module, + iface.stage, + class, + None, + result.ty, + binding, + )?; + iface.varying_ids.push(varying_id); + varying_id + }; ep_context.results.push(ResultMember { id: varying_id, type_id, @@ -833,15 +1280,25 @@ impl Writer { let binding = member.binding.as_ref().unwrap(); has_point_size |= *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize); - let varying_id = self.write_varying( - ir_module, - iface.stage, - class, - name, - member.ty, - binding, - )?; - iface.varying_ids.push(varying_id); + // This isn't an actual builtin in SPIR-V. It can only appear as the + // output of a task shader and the output is used when writing the + // entry point return, in which case the id is ignored anyway. + let varying_id = if *binding + == crate::Binding::BuiltIn(crate::BuiltIn::MeshTaskSize) + { + 0 + } else { + let varying_id = self.write_varying( + ir_module, + iface.stage, + class, + name, + member.ty, + binding, + )?; + iface.varying_ids.push(varying_id); + varying_id + }; ep_context.results.push(ResultMember { id: varying_id, type_id, @@ -881,6 +1338,21 @@ impl Writer { None => self.void_type, }; + if let Some(ref mut iface) = interface { + if let Some(task_payload) = iface.task_payload { + iface + .varying_ids + .push(self.global_variables[task_payload].var_id); + } + self.write_entry_point_mesh_shader_info( + iface, + local_invocation_index_id, + ir_module, + &mut prelude, + &mut ep_context, + )?; + } + let lookup_function_type = LookupFunctionType { parameter_type_ids, return_type_id, @@ -917,7 +1389,7 @@ impl Writer { let mut gv = self.global_variables[handle].clone(); if let Some(ref mut iface) = interface { // Have to include global variables in the interface - if self.physical_layout.version >= 0x10400 { + if self.physical_layout.version >= 0x10400 && iface.task_payload != Some(handle) { iface.varying_ids.push(gv.var_id); } } @@ -1048,20 +1520,21 @@ impl Writer { match (context.writer.zero_initialize_workgroup_memory, interface) { ( super::ZeroInitializeWorkgroupMemoryMode::Polyfill, - Some( - ref mut interface @ FunctionInterface { - stage: crate::ShaderStage::Compute, - .. - }, - ), - ) => context.writer.generate_workgroup_vars_init_block( - next_id, - ir_module, - info, - local_invocation_id, - interface, - context.function, - ), + Some(ref mut interface @ FunctionInterface { stage, .. }), + ) => { + if stage.compute_like() { + context.writer.generate_workgroup_vars_init_block( + next_id, + ir_module, + info, + local_invocation_id, + interface, + context.function, + ) + } else { + None + } + } _ => None, }; @@ -1113,6 +1586,9 @@ impl Writer { Some(FunctionInterface { varying_ids: &mut interface_ids, stage: entry_point.stage, + task_payload: entry_point.task_payload, + mesh_info: entry_point.mesh_info.clone(), + workgroup_size: entry_point.workgroup_size, }), debug_info, )?; @@ -1176,7 +1652,53 @@ impl Writer { .to_words(&mut self.logical_layout.execution_modes); spirv::ExecutionModel::GLCompute } - crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(), + crate::ShaderStage::Task => { + let execution_mode = spirv::ExecutionMode::LocalSize; + //self.check(execution_mode.required_capabilities())?; + Instruction::execution_mode( + function_id, + execution_mode, + &entry_point.workgroup_size, + ) + .to_words(&mut self.logical_layout.execution_modes); + spirv::ExecutionModel::TaskEXT + } + crate::ShaderStage::Mesh => { + let execution_mode = spirv::ExecutionMode::LocalSize; + //self.check(execution_mode.required_capabilities())?; + Instruction::execution_mode( + function_id, + execution_mode, + &entry_point.workgroup_size, + ) + .to_words(&mut self.logical_layout.execution_modes); + let mesh_info = entry_point.mesh_info.as_ref().unwrap(); + Instruction::execution_mode( + function_id, + match mesh_info.topology { + crate::MeshOutputTopology::Points => spirv::ExecutionMode::OutputPoints, + crate::MeshOutputTopology::Lines => spirv::ExecutionMode::OutputLinesEXT, + crate::MeshOutputTopology::Triangles => { + spirv::ExecutionMode::OutputTrianglesEXT + } + }, + &[], + ) + .to_words(&mut self.logical_layout.execution_modes); + Instruction::execution_mode( + function_id, + spirv::ExecutionMode::OutputVertices, + core::slice::from_ref(&mesh_info.max_vertices), + ) + .to_words(&mut self.logical_layout.execution_modes); + Instruction::execution_mode( + function_id, + spirv::ExecutionMode::OutputPrimitivesEXT, + core::slice::from_ref(&mesh_info.max_primitives), + ) + .to_words(&mut self.logical_layout.execution_modes); + spirv::ExecutionModel::MeshEXT + } }; //self.check(exec_model.required_capabilities())?; @@ -1919,16 +2441,111 @@ impl Writer { } } - use spirv::{BuiltIn, Decoration}; + let binding = self.map_binding(ir_module, stage, class, ty, binding)?; + self.write_binding(id, binding); + + Ok(id) + } + + pub fn write_binding(&mut self, id: Word, binding: BindingDecorations) { + match binding { + BindingDecorations::None => (), + BindingDecorations::BuiltIn(bi, others) => { + self.decorate(id, spirv::Decoration::BuiltIn, &[bi as u32]); + for other in others { + self.decorate(id, other, &[]); + } + } + BindingDecorations::Location { + location, + others, + blend_src, + } => { + self.decorate(id, spirv::Decoration::Location, &[location]); + for other in others { + self.decorate(id, other, &[]); + } + if let Some(blend_src) = blend_src { + self.decorate(id, spirv::Decoration::Index, &[blend_src]); + } + } + } + } + + pub fn write_binding_struct_member( + &mut self, + struct_id: Word, + member_idx: Word, + binding_info: BindingDecorations, + ) { + match binding_info { + BindingDecorations::None => (), + BindingDecorations::BuiltIn(bi, others) => { + self.annotations.push(Instruction::member_decorate( + struct_id, + member_idx, + spirv::Decoration::BuiltIn, + &[bi as Word], + )); + for other in others { + self.annotations.push(Instruction::member_decorate( + struct_id, + member_idx, + other, + &[], + )); + } + } + BindingDecorations::Location { + location, + others, + blend_src, + } => { + self.annotations.push(Instruction::member_decorate( + struct_id, + member_idx, + spirv::Decoration::Location, + &[location], + )); + for other in others { + self.annotations.push(Instruction::member_decorate( + struct_id, + member_idx, + other, + &[], + )); + } + if let Some(blend_src) = blend_src { + self.annotations.push(Instruction::member_decorate( + struct_id, + member_idx, + spirv::Decoration::Index, + &[blend_src], + )); + } + } + } + } + pub fn map_binding( + &mut self, + ir_module: &crate::Module, + stage: crate::ShaderStage, + class: spirv::StorageClass, + ty: Handle, + binding: &crate::Binding, + ) -> Result { + use spirv::BuiltIn; + use spirv::Decoration; match *binding { crate::Binding::Location { location, interpolation, sampling, blend_src, + per_primitive, } => { - self.decorate(id, Decoration::Location, &[location]); + let mut others = ArrayVec::new(); let no_decorations = // VUID-StandaloneSpirv-Flat-06202 @@ -1945,10 +2562,10 @@ impl Writer { // Perspective-correct interpolation is the default in SPIR-V. None | Some(crate::Interpolation::Perspective) => (), Some(crate::Interpolation::Flat) => { - self.decorate(id, Decoration::Flat, &[]); + others.push(Decoration::Flat); } Some(crate::Interpolation::Linear) => { - self.decorate(id, Decoration::NoPerspective, &[]); + others.push(Decoration::NoPerspective); } } match sampling { @@ -1960,27 +2577,34 @@ impl Writer { | crate::Sampling::Either, ) => (), Some(crate::Sampling::Centroid) => { - self.decorate(id, Decoration::Centroid, &[]); + others.push(Decoration::Centroid); } Some(crate::Sampling::Sample) => { self.require_any( "per-sample interpolation", &[spirv::Capability::SampleRateShading], )?; - self.decorate(id, Decoration::Sample, &[]); + others.push(Decoration::Sample); } } } - if let Some(blend_src) = blend_src { - self.decorate(id, Decoration::Index, &[blend_src]); + if per_primitive && stage == crate::ShaderStage::Fragment { + others.push(Decoration::PerPrimitiveEXT); + self.require_mesh_shaders()?; } + Ok(BindingDecorations::Location { + location, + others, + blend_src, + }) } crate::Binding::BuiltIn(built_in) => { use crate::BuiltIn as Bi; + let mut others = ArrayVec::new(); let built_in = match built_in { Bi::Position { invariant } => { if invariant { - self.decorate(id, Decoration::Invariant, &[]); + others.push(Decoration::Invariant); } if class == spirv::StorageClass::Output { @@ -2076,10 +2700,14 @@ impl Writer { )?; BuiltIn::SubgroupLocalInvocationId } + Bi::CullPrimitive => BuiltIn::CullPrimitiveEXT, + Bi::PointIndex => BuiltIn::PrimitivePointIndicesEXT, + Bi::LineIndices => BuiltIn::PrimitiveLineIndicesEXT, + Bi::TriangleIndices => BuiltIn::PrimitiveTriangleIndicesEXT, + // No decoration, this EmitMeshTasksEXT is called at function return + Bi::MeshTaskSize => return Ok(BindingDecorations::None), }; - self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); - use crate::ScalarKind as Sk; // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`: @@ -2103,13 +2731,107 @@ impl Writer { }; if is_flat { - self.decorate(id, Decoration::Flat, &[]); + others.push(Decoration::Flat); } } + Ok(BindingDecorations::BuiltIn(built_in, others)) } } + } - Ok(id) + /// Sets up the temporary mesh shader output buffer for the given output type, + /// and ensures it is long enough. + pub fn mesh_shader_output_variable( + &mut self, + output_type: Handle, + is_primitive: bool, + array_len: Word, + ) -> Result { + // We only want one temporary buffer per (type, is_primitive) combo, + // as functions that can be used by multiple entry points should be + // able to write to the output for all of them. However, the actual + // output buffers for mesh shaders must have an exact size, so we use + // a temporary buffer with size the largest of any entry points'. + let u32_ty = self.get_u32_type_id(); + let u32_ptr = self.get_pointer_type_id(u32_ty, spirv::StorageClass::Workgroup); + if self.mesh_state.num_vertices_var.is_none() { + let var_id = self.id_gen.next(); + Instruction::variable(u32_ptr, var_id, spirv::StorageClass::Workgroup, None) + .to_words(&mut self.logical_layout.declarations); + self.mesh_state.num_vertices_var = Some(var_id); + if self.flags.contains(WriterFlags::DEBUG) { + Instruction::name(var_id, "naga_num_vertices") + .to_words(&mut self.logical_layout.debugs); + } + } + if self.mesh_state.num_primitives_var.is_none() { + let var_id = self.id_gen.next(); + Instruction::variable(u32_ptr, var_id, spirv::StorageClass::Workgroup, None) + .to_words(&mut self.logical_layout.declarations); + self.mesh_state.num_primitives_var = Some(var_id); + if self.flags.contains(WriterFlags::DEBUG) { + Instruction::name(var_id, "naga_num_primitives") + .to_words(&mut self.logical_layout.debugs); + } + } + let entry = if is_primitive { + self.mesh_state.primitive_outputs_by_type.get(&output_type) + } else { + self.mesh_state.vertex_outputs_by_type.get(&output_type) + }; + // We need mutable access to `self` + let out = match entry { + Some(value) => { + // This is the hacky part lol. We want to make sure the array size is the largest max output + // of any of the entry points.get + let val = value.clone(); + let len_ref = &mut self.logical_layout.declarations[val.index_of_length_decl]; + *len_ref = (*len_ref).max(array_len); + val + } + None => { + // We write the literal, and avoid caching as it might change lol + // (no `get_constant_scalar`) + let len_value_id = self.id_gen.next(); + let main_type_id = self.get_handle_type_id(output_type); + Instruction::constant_32bit(self.get_u32_type_id(), len_value_id, array_len) + .to_words(&mut self.logical_layout.declarations); + // This is the best part. We store the word index so we can change it later as needed + let len_literal_idx = self.logical_layout.declarations.len() - 1; + + let array_ty = self.id_gen.next(); + Instruction::type_array(array_ty, main_type_id, len_value_id) + .to_words(&mut self.logical_layout.declarations); + let var_id = self.id_gen.next(); + Instruction::variable( + self.get_pointer_type_id(array_ty, spirv::StorageClass::Workgroup), + var_id, + spirv::StorageClass::Workgroup, + None, + ) + .to_words(&mut self.logical_layout.declarations); + + let info = super::MeshOutputInfo { + inner_ty: main_type_id, + index_of_length_decl: len_literal_idx, + array_ty, + var_id, + array_size_id: len_value_id, + }; + if is_primitive { + self.mesh_state + .primitive_outputs_by_type + .insert(output_type, info.clone()); + } else { + self.mesh_state + .vertex_outputs_by_type + .insert(output_type, info.clone()); + }; + info + } + }; + + Ok(out) } fn write_global_variable( @@ -2341,6 +3063,16 @@ impl Writer { self.physical_layout.bound = self.id_gen.0 + 1; } + pub(super) fn require_mesh_shaders(&mut self) -> Result<(), Error> { + self.use_extension("SPV_EXT_mesh_shader"); + self.require_any("Mesh Shaders", &[spirv::Capability::MeshShadingEXT])?; + let lang_version = self.lang_version(); + if lang_version.0 <= 1 && lang_version.1 < 4 { + return Err(Error::SpirvVersionTooLow(1, 4)); + } + Ok(()) + } + fn write_logical_layout( &mut self, ir_module: &crate::Module, @@ -2378,6 +3110,17 @@ impl Writer { | ir_module.special_types.ray_intersection.is_some(); let has_vertex_return = ir_module.special_types.ray_vertex_return.is_some(); + // Ways mesh shaders are required: + // * Mesh entry point used - checked for + // * Mesh function like setVertex used outside mesh entry point, this is handled when those are written + // * Fragment shader with per primitive data - handled in `map_binding` + let has_mesh_shaders = ir_module.entry_points.iter().any(|entry| { + entry.stage == crate::ShaderStage::Mesh || entry.stage == crate::ShaderStage::Task + }) || ir_module + .global_variables + .iter() + .any(|gvar| gvar.1.space == crate::AddressSpace::TaskPayload); + for (_, &crate::Type { ref inner, .. }) in ir_module.types.iter() { // spirv does not know whether these have vertex return - that is done by us if let &crate::TypeInner::AccelerationStructure { .. } @@ -2404,6 +3147,9 @@ impl Writer { Instruction::extension("SPV_KHR_ray_tracing_position_fetch") .to_words(&mut self.logical_layout.extensions); } + if has_mesh_shaders { + self.require_mesh_shaders()?; + } Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations); Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450") .to_words(&mut self.logical_layout.ext_inst_imports); @@ -2412,20 +3158,18 @@ impl Writer { if self.flags.contains(WriterFlags::DEBUG) { if let Some(debug_info) = debug_info.as_ref() { let source_file_id = self.id_gen.next(); - self.debugs.push(Instruction::string( - &debug_info.file_name.to_string_lossy(), - source_file_id, - )); + Instruction::string(&debug_info.file_name.to_string_lossy(), source_file_id) + .to_words(&mut self.logical_layout.debugs); debug_info_inner = Some(DebugInfoInner { source_code: debug_info.source_code, source_file_id, }); - self.debugs.append(&mut Instruction::source_auto_continued( - debug_info.language, - 0, - &debug_info_inner, - )); + for ins in + Instruction::source_auto_continued(debug_info.language, 0, &debug_info_inner) + { + ins.to_words(&mut self.logical_layout.debugs); + } } } diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 8982242daca..245bc40dd5d 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -207,7 +207,7 @@ impl Writer { Attribute::Stage(ShaderStage::Compute), Attribute::WorkGroupSize(ep.workgroup_size), ], - ShaderStage::Task | ShaderStage::Mesh => unreachable!(), + ShaderStage::Mesh | ShaderStage::Task => unreachable!(), }; self.write_attributes(&attributes)?; @@ -856,6 +856,7 @@ impl Writer { } } Statement::RayQuery { .. } => unreachable!(), + Statement::MeshFunction(..) => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); @@ -1822,6 +1823,7 @@ fn map_binding_to_attribute(binding: &crate::Binding) -> Vec { interpolation, sampling, blend_src: None, + per_primitive: _, } => vec![ Attribute::Location(location), Attribute::Interpolate(interpolation, sampling), @@ -1831,6 +1833,7 @@ fn map_binding_to_attribute(binding: &crate::Binding) -> Vec { interpolation, sampling, blend_src: Some(blend_src), + per_primitive: _, } => vec![ Attribute::Location(location), Attribute::BlendSrc(blend_src), diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index 035c4eafb32..dc891aa5a3f 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -188,7 +188,12 @@ impl TryToWgsl for crate::BuiltIn { | Bi::PointSize | Bi::DrawID | Bi::PointCoord - | Bi::WorkGroupSize => return None, + | Bi::WorkGroupSize + | Bi::CullPrimitive + | Bi::TriangleIndices + | Bi::LineIndices + | Bi::MeshTaskSize + | Bi::PointIndex => return None, }) } } @@ -352,6 +357,7 @@ pub const fn address_space_str( As::WorkGroup => "workgroup", As::Handle => return (None, None), As::Function => "function", + As::TaskPayload => return (None, None), }), None, ) diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index d059ba21e4f..a7d3d463f11 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -221,6 +221,45 @@ pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) { } } + for entry in &module.entry_points { + if let Some(task_payload) = entry.task_payload { + module_tracer.global_variables_used.insert(task_payload); + } + if let Some(ref mesh_info) = entry.mesh_info { + module_tracer + .types_used + .insert(mesh_info.vertex_output_type); + module_tracer + .types_used + .insert(mesh_info.primitive_output_type); + if let Some(max_vertices_override) = mesh_info.max_vertices_override { + module_tracer + .global_expressions_used + .insert(max_vertices_override); + } + if let Some(max_primitives_override) = mesh_info.max_primitives_override { + module_tracer + .global_expressions_used + .insert(max_primitives_override); + } + } + if entry.stage == crate::ShaderStage::Task || entry.stage == crate::ShaderStage::Mesh { + // u32 should always be there if the module is valid, as it is e.g. the type of some expressions + let u32_type = module + .types + .iter() + .find_map(|tuple| { + if tuple.1.inner == crate::TypeInner::Scalar(crate::Scalar::U32) { + Some(tuple.0) + } else { + None + } + }) + .unwrap(); + module_tracer.types_used.insert(u32_type); + } + } + module_tracer.type_expression_tandem(); // Now that we know what is used and what is never touched, @@ -342,6 +381,23 @@ pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) { &module_map, &mut reused_named_expressions, ); + if let Some(ref mut task_payload) = entry.task_payload { + module_map.globals.adjust(task_payload); + } + if let Some(ref mut mesh_info) = entry.mesh_info { + module_map.types.adjust(&mut mesh_info.vertex_output_type); + module_map + .types + .adjust(&mut mesh_info.primitive_output_type); + if let Some(ref mut max_vertices_override) = mesh_info.max_vertices_override { + module_map.global_expressions.adjust(max_vertices_override); + } + if let Some(ref mut max_primitives_override) = mesh_info.max_primitives_override { + module_map + .global_expressions + .adjust(max_primitives_override); + } + } } } diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 39d6065f5f0..b370501baca 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -117,6 +117,20 @@ impl FunctionTracer<'_> { self.expressions_used.insert(query); self.trace_ray_query_function(fun); } + St::MeshFunction(crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + }) => { + self.expressions_used.insert(vertex_count); + self.expressions_used.insert(primitive_count); + } + St::MeshFunction( + crate::MeshFunction::SetPrimitive { index, value } + | crate::MeshFunction::SetVertex { index, value }, + ) => { + self.expressions_used.insert(index); + self.expressions_used.insert(value); + } St::SubgroupBallot { result, predicate } => { if let Some(predicate) = predicate { self.expressions_used.insert(predicate); @@ -335,6 +349,26 @@ impl FunctionMap { adjust(query); self.adjust_ray_query_function(fun); } + St::MeshFunction(crate::MeshFunction::SetMeshOutputs { + ref mut vertex_count, + ref mut primitive_count, + }) => { + adjust(vertex_count); + adjust(primitive_count); + } + St::MeshFunction( + crate::MeshFunction::SetVertex { + ref mut index, + ref mut value, + } + | crate::MeshFunction::SetPrimitive { + ref mut index, + ref mut value, + }, + ) => { + adjust(index); + adjust(value); + } St::SubgroupBallot { ref mut result, ref mut predicate, diff --git a/naga/src/front/glsl/functions.rs b/naga/src/front/glsl/functions.rs index 7de7364cd40..ba096a82b3b 100644 --- a/naga/src/front/glsl/functions.rs +++ b/naga/src/front/glsl/functions.rs @@ -1377,6 +1377,8 @@ impl Frontend { result: ty.map(|ty| FunctionResult { ty, binding: None }), ..Default::default() }, + mesh_info: None, + task_payload: None, }); Ok(()) @@ -1446,6 +1448,7 @@ impl Context<'_> { interpolation, sampling: None, blend_src: None, + per_primitive: false, }; location += 1; @@ -1482,6 +1485,7 @@ impl Context<'_> { interpolation, sampling: None, blend_src: None, + per_primitive: false, }; location += 1; binding diff --git a/naga/src/front/glsl/mod.rs b/naga/src/front/glsl/mod.rs index 876add46a1c..e5eda6b3ad9 100644 --- a/naga/src/front/glsl/mod.rs +++ b/naga/src/front/glsl/mod.rs @@ -107,7 +107,7 @@ impl ShaderMetadata { self.version = 0; self.profile = Profile::Core; self.stage = stage; - self.workgroup_size = [u32::from(stage == ShaderStage::Compute); 3]; + self.workgroup_size = [u32::from(stage.compute_like()); 3]; self.early_fragment_tests = false; self.extensions.clear(); } diff --git a/naga/src/front/glsl/variables.rs b/naga/src/front/glsl/variables.rs index ef98143b769..98871bd2f81 100644 --- a/naga/src/front/glsl/variables.rs +++ b/naga/src/front/glsl/variables.rs @@ -465,6 +465,7 @@ impl Frontend { interpolation, sampling, blend_src, + per_primitive: false, }, handle, storage, diff --git a/naga/src/front/interpolator.rs b/naga/src/front/interpolator.rs index e23cae0e7c2..126e860426c 100644 --- a/naga/src/front/interpolator.rs +++ b/naga/src/front/interpolator.rs @@ -44,6 +44,7 @@ impl crate::Binding { interpolation: ref mut interpolation @ None, ref mut sampling, blend_src: _, + per_primitive: _, } = *self { match ty.scalar_kind() { diff --git a/naga/src/front/spv/function.rs b/naga/src/front/spv/function.rs index 67cbf05f04f..48b23e7c4c4 100644 --- a/naga/src/front/spv/function.rs +++ b/naga/src/front/spv/function.rs @@ -596,6 +596,8 @@ impl> super::Frontend { workgroup_size: ep.workgroup_size, workgroup_size_overrides: None, function, + mesh_info: None, + task_payload: None, }); Ok(()) diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 960437ece58..396318f14dc 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -263,6 +263,7 @@ impl Decoration { interpolation, sampling, blend_src: None, + per_primitive: false, }), _ => Err(Error::MissingDecoration(spirv::Decoration::Location)), } @@ -4613,6 +4614,7 @@ impl> Frontend { | S::Atomic { .. } | S::ImageAtomic { .. } | S::RayQuery { .. } + | S::MeshFunction(..) | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } | S::SubgroupGather { .. } => {} @@ -4894,6 +4896,8 @@ impl> Frontend { spirv::ExecutionModel::Vertex => crate::ShaderStage::Vertex, spirv::ExecutionModel::Fragment => crate::ShaderStage::Fragment, spirv::ExecutionModel::GLCompute => crate::ShaderStage::Compute, + spirv::ExecutionModel::TaskEXT => crate::ShaderStage::Task, + spirv::ExecutionModel::MeshEXT => crate::ShaderStage::Mesh, _ => return Err(Error::UnsupportedExecutionModel(exec_model as u32)), }, name, diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 1cdf53f37dc..a6386bb473c 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -410,6 +410,19 @@ pub(crate) enum Error<'a> { accept_span: Span, accept_type: String, }, + MissingMeshShaderInfo { + mesh_attribute_span: Span, + }, + OneMeshShaderAttribute { + attribute_span: Span, + }, + ExpectedGlobalVariable { + name_span: Span, + }, + MeshPrimitiveNoDefinedTopology { + attribute_span: Span, + struct_span: Span, + }, StructMemberTooLarge { member_name_span: Span, }, @@ -1374,6 +1387,27 @@ impl<'a> Error<'a> { ], notes: vec![], }, + Error::MissingMeshShaderInfo { mesh_attribute_span} => ParseError { + message: "mesh shader entry point is missing @vertex_output or @primitive_output".into(), + labels: vec![(mesh_attribute_span, "mesh shader entry declared here".into())], + notes: vec![], + }, + Error::OneMeshShaderAttribute { attribute_span } => ParseError { + message: "only one of @vertex_output or @primitive_output was given".into(), + labels: vec![(attribute_span, "only one of @vertex_output or @primitive_output is provided".into())], + notes: vec![], + }, + Error::ExpectedGlobalVariable { name_span } => ParseError { + message: "expected global variable".to_string(), + // TODO: I would like to also include the global declaration span + labels: vec![(name_span, "variable used here".into())], + notes: vec![], + }, + Error::MeshPrimitiveNoDefinedTopology { struct_span, attribute_span } => ParseError { + message: "mesh primitive struct must have exactly one of point indices, line indices, or triangle indices".to_string(), + labels: vec![(attribute_span, "primitive type declared here".into()), (struct_span, "primitive struct declared here".into())], + notes: vec![] + }, Error::StructMemberTooLarge { member_name_span } => ParseError { message: "struct member is too large".into(), labels: vec![(member_name_span, "this member exceeds the maximum size".into())], diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index e90d7eab0a8..ef63e6aaea7 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1479,47 +1479,147 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .collect(); if let Some(ref entry) = f.entry_point { - let workgroup_size_info = if let Some(workgroup_size) = entry.workgroup_size { - // TODO: replace with try_map once stabilized - let mut workgroup_size_out = [1; 3]; - let mut workgroup_size_overrides_out = [None; 3]; - for (i, size) in workgroup_size.into_iter().enumerate() { - if let Some(size_expr) = size { - match self.const_u32(size_expr, &mut ctx.as_const()) { - Ok(value) => { - workgroup_size_out[i] = value.0; - } - Err(err) => { - if let Error::ConstantEvaluatorError(ref ty, _) = *err { - match **ty { - proc::ConstantEvaluatorError::OverrideExpr => { - workgroup_size_overrides_out[i] = - Some(self.workgroup_size_override( - size_expr, - &mut ctx.as_override(), - )?); - } - _ => { - return Err(err); + let (workgroup_size, workgroup_size_overrides) = + if let Some(workgroup_size) = entry.workgroup_size { + // TODO: replace with try_map once stabilized + let mut workgroup_size_out = [1; 3]; + let mut workgroup_size_overrides_out = [None; 3]; + for (i, size) in workgroup_size.into_iter().enumerate() { + if let Some(size_expr) = size { + match self.const_u32(size_expr, &mut ctx.as_const()) { + Ok(value) => { + workgroup_size_out[i] = value.0; + } + Err(err) => { + if let Error::ConstantEvaluatorError(ref ty, _) = *err { + match **ty { + proc::ConstantEvaluatorError::OverrideExpr => { + workgroup_size_overrides_out[i] = + Some(self.workgroup_size_override( + size_expr, + &mut ctx.as_override(), + )?); + } + _ => { + return Err(err); + } } + } else { + return Err(err); } - } else { - return Err(err); } } } } - } - if workgroup_size_overrides_out.iter().all(|x| x.is_none()) { - (workgroup_size_out, None) + if workgroup_size_overrides_out.iter().all(|x| x.is_none()) { + (workgroup_size_out, None) + } else { + (workgroup_size_out, Some(workgroup_size_overrides_out)) + } } else { - (workgroup_size_out, Some(workgroup_size_overrides_out)) + ([0; 3], None) + }; + + let mesh_info = if let Some(mesh_info) = entry.mesh_shader_info { + let mut const_u32 = |expr| match self.const_u32(expr, &mut ctx.as_const()) { + Ok(value) => Ok((value.0, None)), + Err(err) => { + if let Error::ConstantEvaluatorError(ref ty, _) = *err { + match **ty { + proc::ConstantEvaluatorError::OverrideExpr => Ok(( + 0, + Some( + // This is dubious but it seems the code isn't workgroup size specific + self.workgroup_size_override(expr, &mut ctx.as_override())?, + ), + )), + _ => Err(err), + } + } else { + Err(err) + } + } + }; + let (max_vertices, max_vertices_override) = const_u32(mesh_info.vertex_count)?; + let (max_primitives, max_primitives_override) = + const_u32(mesh_info.primitive_count)?; + let vertex_output_type = + self.resolve_ast_type(mesh_info.vertex_type.0, &mut ctx.as_const())?; + let primitive_output_type = + self.resolve_ast_type(mesh_info.primitive_type.0, &mut ctx.as_const())?; + + let mut topology = None; + let struct_span = ctx.module.types.get_span(primitive_output_type); + match &ctx.module.types[primitive_output_type].inner { + &ir::TypeInner::Struct { + ref members, + span: _, + } => { + for member in members { + let out_topology = match member.binding { + Some(ir::Binding::BuiltIn(ir::BuiltIn::TriangleIndices)) => { + Some(ir::MeshOutputTopology::Triangles) + } + Some(ir::Binding::BuiltIn(ir::BuiltIn::LineIndices)) => { + Some(ir::MeshOutputTopology::Lines) + } + _ => None, + }; + if out_topology.is_some() { + if topology.is_some() { + return Err(Box::new(Error::MeshPrimitiveNoDefinedTopology { + attribute_span: mesh_info.primitive_type.1, + struct_span, + })); + } + topology = out_topology; + } + } + } + _ => { + return Err(Box::new(Error::MeshPrimitiveNoDefinedTopology { + attribute_span: mesh_info.primitive_type.1, + struct_span, + })) + } } + let topology = if let Some(t) = topology { + t + } else { + return Err(Box::new(Error::MeshPrimitiveNoDefinedTopology { + attribute_span: mesh_info.primitive_type.1, + struct_span, + })); + }; + + Some(ir::MeshStageInfo { + max_vertices, + max_vertices_override, + max_primitives, + max_primitives_override, + + vertex_output_type, + primitive_output_type, + topology, + }) + } else { + None + }; + + let task_payload = if let Some((var_name, var_span)) = entry.task_payload { + Some(match ctx.globals.get(var_name) { + Some(&LoweredGlobalDecl::Var(handle)) => handle, + Some(_) => { + return Err(Box::new(Error::ExpectedGlobalVariable { + name_span: var_span, + })) + } + None => return Err(Box::new(Error::UnknownIdent(var_span, var_name))), + }) } else { - ([0; 3], None) + None }; - let (workgroup_size, workgroup_size_overrides) = workgroup_size_info; ctx.module.entry_points.push(ir::EntryPoint { name: f.name.name.to_string(), stage: entry.stage, @@ -1527,6 +1627,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { workgroup_size, workgroup_size_overrides, function, + mesh_info, + task_payload, }); Ok(LoweredGlobalDecl::EntryPoint( ctx.module.entry_points.len() - 1, @@ -3130,6 +3232,59 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } + + "setMeshOutputs" | "setVertex" | "setPrimitive" => { + let mut args = ctx.prepare_args(arguments, 2, span); + let arg1 = args.next()?; + let arg2 = args.next()?; + args.finish()?; + + let mut cast_u32 = |arg| { + // Try to convert abstract values to the known argument types + let expr = self.expression_for_abstract(arg, ctx)?; + let goal_ty = + ctx.ensure_type_exists(ir::TypeInner::Scalar(ir::Scalar::U32)); + ctx.try_automatic_conversions( + expr, + &proc::TypeResolution::Handle(goal_ty), + ctx.ast_expressions.get_span(arg), + ) + }; + + let arg1 = cast_u32(arg1)?; + let arg2 = if function.name == "setMeshOutputs" { + cast_u32(arg2)? + } else { + self.expression(arg2, ctx)? + }; + + let rctx = ctx.runtime_expression_ctx(span)?; + + // Emit all previous expressions, even if not used directly + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.block.push( + crate::Statement::MeshFunction(match function.name { + "setMeshOutputs" => crate::MeshFunction::SetMeshOutputs { + vertex_count: arg1, + primitive_count: arg2, + }, + "setVertex" => crate::MeshFunction::SetVertex { + index: arg1, + value: arg2, + }, + "setPrimitive" => crate::MeshFunction::SetPrimitive { + index: arg1, + value: arg2, + }, + _ => unreachable!(), + }), + span, + ); + rctx.emitter.start(&rctx.function.expressions); + + return Ok(None); + } _ => { return Err(Box::new(Error::UnknownIdent(function.span, function.name))) } @@ -4057,6 +4212,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { interpolation, sampling, blend_src, + per_primitive, }) => { let blend_src = if let Some(blend_src) = blend_src { Some(self.const_u32(blend_src, &mut ctx.as_const())?.0) @@ -4069,6 +4225,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { interpolation, sampling, blend_src, + per_primitive, }; binding.apply_default_interpolation(&ctx.module.types[ty].inner); Some(binding) diff --git a/naga/src/front/wgsl/mod.rs b/naga/src/front/wgsl/mod.rs index 1080392cc61..dfacc7d975a 100644 --- a/naga/src/front/wgsl/mod.rs +++ b/naga/src/front/wgsl/mod.rs @@ -48,6 +48,9 @@ impl Frontend { options, } } + pub fn set_options(&mut self, options: Options) { + self.options = options; + } pub fn parse(&mut self, source: &str) -> core::result::Result { self.inner(source).map_err(|x| x.as_parse_error(source)) diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index 345e9c4c486..49ecddfdee5 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -128,6 +128,16 @@ pub struct EntryPoint<'a> { pub stage: crate::ShaderStage, pub early_depth_test: Option, pub workgroup_size: Option<[Option>>; 3]>, + pub mesh_shader_info: Option>, + pub task_payload: Option<(&'a str, Span)>, +} + +#[derive(Debug, Clone, Copy)] +pub struct EntryPointMeshShaderInfo<'a> { + pub vertex_count: Handle>, + pub primitive_count: Handle>, + pub vertex_type: (Handle>, Span), + pub primitive_type: (Handle>, Span), } #[cfg(doc)] @@ -152,6 +162,7 @@ pub enum Binding<'a> { interpolation: Option, sampling: Option, blend_src: Option>>, + per_primitive: bool, }, } diff --git a/naga/src/front/wgsl/parse/conv.rs b/naga/src/front/wgsl/parse/conv.rs index cbc485fb24a..b75d104afbd 100644 --- a/naga/src/front/wgsl/parse/conv.rs +++ b/naga/src/front/wgsl/parse/conv.rs @@ -16,6 +16,7 @@ pub fn map_address_space(word: &str, span: Span) -> Result<'_, crate::AddressSpa }), "push_constant" => Ok(crate::AddressSpace::PushConstant), "function" => Ok(crate::AddressSpace::Function), + "task_payload" => Ok(crate::AddressSpace::TaskPayload), _ => Err(Box::new(Error::UnknownAddressSpace(span))), } } @@ -49,6 +50,12 @@ pub fn map_built_in( "subgroup_id" => crate::BuiltIn::SubgroupId, "subgroup_size" => crate::BuiltIn::SubgroupSize, "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, + // mesh + "cull_primitive" => crate::BuiltIn::CullPrimitive, + "point_index" => crate::BuiltIn::PointIndex, + "line_indices" => crate::BuiltIn::LineIndices, + "triangle_indices" => crate::BuiltIn::TriangleIndices, + "mesh_task_size" => crate::BuiltIn::MeshTaskSize, _ => return Err(Box::new(Error::UnknownBuiltin(span))), }; match built_in { diff --git a/naga/src/front/wgsl/parse/directive/enable_extension.rs b/naga/src/front/wgsl/parse/directive/enable_extension.rs index e614d19b6fe..3b3c3f3d70d 100644 --- a/naga/src/front/wgsl/parse/directive/enable_extension.rs +++ b/naga/src/front/wgsl/parse/directive/enable_extension.rs @@ -10,6 +10,7 @@ use alloc::boxed::Box; /// Tracks the status of every enable-extension known to Naga. #[derive(Clone, Debug, Eq, PartialEq)] pub struct EnableExtensions { + mesh_shader: bool, dual_source_blending: bool, /// Whether `enable f16;` was written earlier in the shader module. f16: bool, @@ -19,6 +20,7 @@ pub struct EnableExtensions { impl EnableExtensions { pub(crate) const fn empty() -> Self { Self { + mesh_shader: false, f16: false, dual_source_blending: false, clip_distances: false, @@ -28,6 +30,7 @@ impl EnableExtensions { /// Add an enable-extension to the set requested by a module. pub(crate) fn add(&mut self, ext: ImplementedEnableExtension) { let field = match ext { + ImplementedEnableExtension::MeshShader => &mut self.mesh_shader, ImplementedEnableExtension::DualSourceBlending => &mut self.dual_source_blending, ImplementedEnableExtension::F16 => &mut self.f16, ImplementedEnableExtension::ClipDistances => &mut self.clip_distances, @@ -38,6 +41,7 @@ impl EnableExtensions { /// Query whether an enable-extension tracked here has been requested. pub(crate) const fn contains(&self, ext: ImplementedEnableExtension) -> bool { match ext { + ImplementedEnableExtension::MeshShader => self.mesh_shader, ImplementedEnableExtension::DualSourceBlending => self.dual_source_blending, ImplementedEnableExtension::F16 => self.f16, ImplementedEnableExtension::ClipDistances => self.clip_distances, @@ -70,6 +74,7 @@ impl EnableExtension { const F16: &'static str = "f16"; const CLIP_DISTANCES: &'static str = "clip_distances"; const DUAL_SOURCE_BLENDING: &'static str = "dual_source_blending"; + const MESH_SHADER: &'static str = "mesh_shading"; const SUBGROUPS: &'static str = "subgroups"; /// Convert from a sentinel word in WGSL into its associated [`EnableExtension`], if possible. @@ -80,6 +85,7 @@ impl EnableExtension { Self::DUAL_SOURCE_BLENDING => { Self::Implemented(ImplementedEnableExtension::DualSourceBlending) } + Self::MESH_SHADER => Self::Implemented(ImplementedEnableExtension::MeshShader), Self::SUBGROUPS => Self::Unimplemented(UnimplementedEnableExtension::Subgroups), _ => return Err(Box::new(Error::UnknownEnableExtension(span, word))), }) @@ -89,6 +95,7 @@ impl EnableExtension { pub const fn to_ident(self) -> &'static str { match self { Self::Implemented(kind) => match kind { + ImplementedEnableExtension::MeshShader => Self::MESH_SHADER, ImplementedEnableExtension::DualSourceBlending => Self::DUAL_SOURCE_BLENDING, ImplementedEnableExtension::F16 => Self::F16, ImplementedEnableExtension::ClipDistances => Self::CLIP_DISTANCES, @@ -121,6 +128,8 @@ pub enum ImplementedEnableExtension { /// /// [`enable clip_distances;`]: https://www.w3.org/TR/WGSL/#extension-clip_distances ClipDistances, + /// Enables the `mesh_shader` extension, native only + MeshShader, } /// A variant of [`EnableExtension::Unimplemented`]. diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index cf4dd4d4bb6..bcefc01cb3d 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -178,6 +178,7 @@ struct BindingParser<'a> { sampling: ParsedAttribute, invariant: ParsedAttribute, blend_src: ParsedAttribute>>, + per_primitive: ParsedAttribute<()>, } impl<'a> BindingParser<'a> { @@ -237,6 +238,9 @@ impl<'a> BindingParser<'a> { .set(parser.general_expression(lexer, ctx)?, name_span)?; lexer.expect(Token::Paren(')'))?; } + "per_primitive" => { + self.per_primitive.set((), name_span)?; + } _ => return Err(Box::new(Error::UnknownAttribute(name_span))), } Ok(()) @@ -250,9 +254,10 @@ impl<'a> BindingParser<'a> { self.sampling.value, self.invariant.value.unwrap_or_default(), self.blend_src.value, + self.per_primitive.value, ) { - (None, None, None, None, false, None) => Ok(None), - (Some(location), None, interpolation, sampling, false, blend_src) => { + (None, None, None, None, false, None, None) => Ok(None), + (Some(location), None, interpolation, sampling, false, blend_src, per_primitive) => { // Before handing over the completed `Module`, we call // `apply_default_interpolation` to ensure that the interpolation and // sampling have been explicitly specified on all vertex shader output and fragment @@ -262,17 +267,18 @@ impl<'a> BindingParser<'a> { interpolation, sampling, blend_src, + per_primitive: per_primitive.is_some(), })) } - (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant, None) => { + (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant, None, None) => { Ok(Some(ast::Binding::BuiltIn(crate::BuiltIn::Position { invariant, }))) } - (None, Some(built_in), None, None, false, None) => { + (None, Some(built_in), None, None, false, None, None) => { Ok(Some(ast::Binding::BuiltIn(built_in))) } - (_, _, _, _, _, _) => Err(Box::new(Error::InconsistentBinding(span))), + (_, _, _, _, _, _, _) => Err(Box::new(Error::InconsistentBinding(span))), } } } @@ -2784,12 +2790,15 @@ impl Parser { // read attributes let mut binding = None; let mut stage = ParsedAttribute::default(); - let mut compute_span = Span::new(0, 0); + let mut compute_like_span = Span::new(0, 0); let mut workgroup_size = ParsedAttribute::default(); let mut early_depth_test = ParsedAttribute::default(); let (mut bind_index, mut bind_group) = (ParsedAttribute::default(), ParsedAttribute::default()); let mut id = ParsedAttribute::default(); + let mut payload = ParsedAttribute::default(); + let mut vertex_output = ParsedAttribute::default(); + let mut primitive_output = ParsedAttribute::default(); let mut must_use: ParsedAttribute = ParsedAttribute::default(); @@ -2848,7 +2857,35 @@ impl Parser { } "compute" => { stage.set(ShaderStage::Compute, name_span)?; - compute_span = name_span; + compute_like_span = name_span; + } + "task" => { + stage.set(ShaderStage::Task, name_span)?; + compute_like_span = name_span; + } + "mesh" => { + stage.set(ShaderStage::Mesh, name_span)?; + compute_like_span = name_span; + } + "payload" => { + lexer.expect(Token::Paren('('))?; + payload.set(lexer.next_ident_with_span()?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } + "vertex_output" | "primitive_output" => { + lexer.expect(Token::Paren('('))?; + let type_span = lexer.peek().1; + let r#type = self.type_decl(lexer, &mut ctx)?; + let type_span = lexer.span_from(type_span.to_range().unwrap().start); + lexer.expect(Token::Separator(','))?; + let max_output = self.general_expression(lexer, &mut ctx)?; + let end_span = lexer.expect_span(Token::Paren(')'))?; + let total_span = name_span.until(&end_span); + if name == "vertex_output" { + vertex_output.set((r#type, type_span, max_output), total_span)?; + } else if name == "primitive_output" { + primitive_output.set((r#type, type_span, max_output), total_span)?; + } } "workgroup_size" => { lexer.expect(Token::Paren('('))?; @@ -3014,13 +3051,39 @@ impl Parser { )?; Some(ast::GlobalDeclKind::Fn(ast::Function { entry_point: if let Some(stage) = stage.value { - if stage == ShaderStage::Compute && workgroup_size.value.is_none() { - return Err(Box::new(Error::MissingWorkgroupSize(compute_span))); + if stage.compute_like() && workgroup_size.value.is_none() { + return Err(Box::new(Error::MissingWorkgroupSize(compute_like_span))); } + if stage == ShaderStage::Mesh + && (vertex_output.value.is_none() || primitive_output.value.is_none()) + { + return Err(Box::new(Error::MissingMeshShaderInfo { + mesh_attribute_span: compute_like_span, + })); + } + let mesh_shader_info = match (vertex_output.value, primitive_output.value) { + (Some(vertex_output), Some(primitive_output)) => { + Some(ast::EntryPointMeshShaderInfo { + vertex_count: vertex_output.2, + primitive_count: primitive_output.2, + vertex_type: (vertex_output.0, vertex_output.1), + primitive_type: (primitive_output.0, primitive_output.1), + }) + } + (None, None) => None, + (Some(v), None) | (None, Some(v)) => { + return Err(Box::new(Error::OneMeshShaderAttribute { + attribute_span: v.1, + })) + } + }; + Some(ast::EntryPoint { stage, early_depth_test: early_depth_test.value, workgroup_size: workgroup_size.value, + mesh_shader_info, + task_payload: payload.value, }) } else { None diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 257445952b8..a182bf0e064 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -329,6 +329,16 @@ pub enum ShaderStage { Mesh, } +impl ShaderStage { + // TODO: make more things respect this + pub const fn compute_like(self) -> bool { + match self { + Self::Vertex | Self::Fragment => false, + Self::Compute | Self::Task | Self::Mesh => true, + } + } +} + /// Addressing space of variables. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -363,6 +373,8 @@ pub enum AddressSpace { /// /// [`SHADER_FLOAT16`]: crate::valid::Capabilities::SHADER_FLOAT16 PushConstant, + /// Task shader to mesh shader payload + TaskPayload, } /// Built-in inputs and outputs. @@ -373,7 +385,7 @@ pub enum AddressSpace { pub enum BuiltIn { Position { invariant: bool }, ViewIndex, - // vertex + // vertex (and often mesh) BaseInstance, BaseVertex, ClipDistance, @@ -386,10 +398,10 @@ pub enum BuiltIn { FragDepth, PointCoord, FrontFacing, - PrimitiveIndex, + PrimitiveIndex, // Also for mesh output SampleIndex, SampleMask, - // compute + // compute (and task/mesh) GlobalInvocationId, LocalInvocationId, LocalInvocationIndex, @@ -401,6 +413,12 @@ pub enum BuiltIn { SubgroupId, SubgroupSize, SubgroupInvocationId, + // mesh + MeshTaskSize, + CullPrimitive, + PointIndex, + LineIndices, + TriangleIndices, } /// Number of bytes per scalar. @@ -966,6 +984,7 @@ pub enum Binding { /// Optional `blend_src` index used for dual source blending. /// See blend_src: Option, + per_primitive: bool, }, } @@ -1935,7 +1954,9 @@ pub enum Statement { /// [`Loop`] statement. /// /// [`Loop`]: Statement::Loop - Return { value: Option> }, + Return { + value: Option>, + }, /// Aborts the current shader execution. /// @@ -2141,6 +2162,7 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, + MeshFunction(MeshFunction), /// Calculate a bitmask using a boolean from each active thread in the subgroup SubgroupBallot { /// The [`SubgroupBallotResult`] expression representing this load's result. @@ -2314,6 +2336,9 @@ pub struct EntryPoint { pub workgroup_size_overrides: Option<[Option>; 3]>, /// The entrance function. pub function: Function, + /// The information relating to a mesh shader + pub mesh_info: Option, + pub task_payload: Option>, } /// Return types predeclared for the frexp, modf, and atomicCompareExchangeWeak built-in functions. @@ -2578,3 +2603,46 @@ pub struct Module { /// Doc comments. pub doc_comments: Option>, } + +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum MeshOutputTopology { + Points, + Lines, + Triangles, +} +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[allow(dead_code)] +pub struct MeshStageInfo { + pub topology: MeshOutputTopology, + pub max_vertices: u32, + pub max_vertices_override: Option>, + pub max_primitives: u32, + pub max_primitives_override: Option>, + pub vertex_output_type: Handle, + pub primitive_output_type: Handle, +} + +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum MeshFunction { + SetMeshOutputs { + vertex_count: Handle, + primitive_count: Handle, + }, + SetVertex { + index: Handle, + value: Handle, + }, + SetPrimitive { + index: Handle, + value: Handle, + }, +} diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 413e49c1eed..434c6e3f724 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -177,6 +177,9 @@ impl super::AddressSpace { crate::AddressSpace::Storage { access } => access, crate::AddressSpace::Handle => Sa::LOAD, crate::AddressSpace::PushConstant => Sa::LOAD, + // TaskPayload isn't always writable, but this is checked for elsewhere, + // when not using multiple payloads and matching the entry payload is checked. + crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE, } } } diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index b29ccb054a3..f76d4c06a3b 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -36,6 +36,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::ImageStore { .. } | S::Call { .. } | S::RayQuery { .. } + | S::MeshFunction(..) | S::Atomic { .. } | S::ImageAtomic { .. } | S::WorkGroupUniformLoad { .. } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 95ae40dcdb4..101ea046487 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -85,6 +85,16 @@ struct FunctionUniformity { exit: ExitFlags, } +/// Mesh shader related characteristics of a function. +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr(test, derive(PartialEq))] +pub struct FunctionMeshShaderInfo { + pub vertex_type: Option<(Handle, Handle)>, + pub primitive_type: Option<(Handle, Handle)>, +} + impl ops::BitOr for FunctionUniformity { type Output = Self; fn bitor(self, other: Self) -> Self { @@ -302,6 +312,8 @@ pub struct FunctionInfo { /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in /// validation. diagnostic_filter_leaf: Option>, + + pub mesh_shader_info: FunctionMeshShaderInfo, } impl FunctionInfo { @@ -372,6 +384,14 @@ impl FunctionInfo { info.uniformity.non_uniform_result } + pub fn insert_global_use( + &mut self, + global_use: GlobalUse, + global: Handle, + ) { + self.global_uses[global.index()] |= global_use; + } + /// Record a use of `expr` for its value. /// /// This is used for almost all expression references. Anything @@ -482,6 +502,8 @@ impl FunctionInfo { *mine |= *other; } + self.try_update_mesh_info(&callee.mesh_shader_info)?; + Ok(FunctionUniformity { result: callee.uniformity.clone(), exit: if callee.may_kill { @@ -635,7 +657,8 @@ impl FunctionInfo { // local data is non-uniform As::Function | As::Private => false, // workgroup memory is exclusively accessed by the group - As::WorkGroup => true, + // task payload memory is very similar to workgroup memory + As::WorkGroup | As::TaskPayload => true, // uniform data As::Uniform | As::PushConstant => true, // storage data is only uniform when read-only @@ -1113,6 +1136,34 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::MeshFunction(func) => match &func { + // TODO: double check all of this uniformity stuff. I frankly don't fully understand all of it. + &crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + } => { + let _ = self.add_ref(vertex_count); + let _ = self.add_ref(primitive_count); + FunctionUniformity::new() + } + &crate::MeshFunction::SetVertex { index, value } + | &crate::MeshFunction::SetPrimitive { index, value } => { + let _ = self.add_ref(index); + let _ = self.add_ref(value); + let ty = + self.expressions[value.index()].ty.clone().handle().ok_or( + FunctionError::InvalidMeshShaderOutputType(value).with_span(), + )?; + + if matches!(func, crate::MeshFunction::SetVertex { .. }) { + self.try_update_mesh_vertex_type(ty, value)?; + } else { + self.try_update_mesh_primitive_type(ty, value)?; + }; + + FunctionUniformity::new() + } + }, S::SubgroupBallot { result: _, predicate, @@ -1158,6 +1209,53 @@ impl FunctionInfo { } Ok(combined_uniformity) } + + fn try_update_mesh_vertex_type( + &mut self, + ty: Handle, + value: Handle, + ) -> Result<(), WithSpan> { + if let &Some(ref existing) = &self.mesh_shader_info.vertex_type { + if existing.0 != ty { + return Err( + FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span() + ); + } + } else { + self.mesh_shader_info.vertex_type = Some((ty, value)); + } + Ok(()) + } + + fn try_update_mesh_primitive_type( + &mut self, + ty: Handle, + value: Handle, + ) -> Result<(), WithSpan> { + if let &Some(ref existing) = &self.mesh_shader_info.primitive_type { + if existing.0 != ty { + return Err( + FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span() + ); + } + } else { + self.mesh_shader_info.primitive_type = Some((ty, value)); + } + Ok(()) + } + + fn try_update_mesh_info( + &mut self, + other: &FunctionMeshShaderInfo, + ) -> Result<(), WithSpan> { + if let &Some(ref other_vertex) = &other.vertex_type { + self.try_update_mesh_vertex_type(other_vertex.0, other_vertex.1)?; + } + if let &Some(ref other_primitive) = &other.vertex_type { + self.try_update_mesh_primitive_type(other_primitive.0, other_primitive.1)?; + } + Ok(()) + } } impl ModuleInfo { @@ -1193,6 +1291,7 @@ impl ModuleInfo { sampling: crate::FastHashSet::default(), dual_source_blending: false, diagnostic_filter_leaf: fun.diagnostic_filter_leaf, + mesh_shader_info: FunctionMeshShaderInfo::default(), }; let resolve_context = ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments); @@ -1326,6 +1425,7 @@ fn uniform_control_flow() { sampling: crate::FastHashSet::default(), dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: FunctionMeshShaderInfo::default(), }; let resolve_context = ResolveContext { constants: &Arena::new(), diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index dc19e191764..0ae2ffdb54f 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -217,6 +217,14 @@ pub enum FunctionError { EmitResult(Handle), #[error("Expression not visited by the appropriate statement")] UnvisitedExpression(Handle), + #[error("Expression {0:?} should be u32, but isn't")] + InvalidMeshFunctionCall(Handle), + #[error("Mesh output types differ from {0:?} to {1:?}")] + ConflictingMeshOutputTypes(Handle, Handle), + #[error("Task payload variables differ from {0:?} to {1:?}")] + ConflictingTaskPayloadVariables(Handle, Handle), + #[error("Mesh shader output at {0:?} is not a user-defined struct")] + InvalidMeshShaderOutputType(Handle), } bitflags::bitflags! { @@ -1539,6 +1547,40 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } + S::MeshFunction(func) => { + let ensure_u32 = + |expr: Handle| -> Result<(), WithSpan> { + let u32_ty = TypeResolution::Value(Ti::Scalar(crate::Scalar::U32)); + let ty = context + .resolve_type_impl(expr, &self.valid_expression_set) + .map_err_inner(|source| { + FunctionError::Expression { + source, + handle: expr, + } + .with_span_handle(expr, context.expressions) + })?; + if !context.compare_types(&u32_ty, ty) { + return Err(FunctionError::InvalidMeshFunctionCall(expr) + .with_span_handle(expr, context.expressions)); + } + Ok(()) + }; + match func { + crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + } => { + ensure_u32(vertex_count)?; + ensure_u32(primitive_count)?; + } + crate::MeshFunction::SetVertex { index, value: _ } + | crate::MeshFunction::SetPrimitive { index, value: _ } => { + ensure_u32(index)?; + // TODO: ensure it is correct for the value + } + } + } S::SubgroupBallot { result, predicate } => { stages &= self.subgroup_stages; if !self.capabilities.contains(super::Capabilities::SUBGROUP) { diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index e8a69013434..a0153e9398c 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -801,6 +801,22 @@ impl super::Validator { } Ok(()) } + crate::Statement::MeshFunction(func) => match func { + crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + } => { + validate_expr(vertex_count)?; + validate_expr(primitive_count)?; + Ok(()) + } + crate::MeshFunction::SetVertex { index, value } + | crate::MeshFunction::SetPrimitive { index, value } => { + validate_expr(index)?; + validate_expr(value)?; + Ok(()) + } + }, crate::Statement::SubgroupBallot { result, predicate } => { validate_expr_opt(predicate)?; validate_expr(result)?; diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 7c8cc903139..51167a4810d 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -92,6 +92,10 @@ pub enum VaryingError { }, #[error("Workgroup size is multi dimensional, `@builtin(subgroup_id)` and `@builtin(subgroup_invocation_id)` are not supported.")] InvalidMultiDimensionalSubgroupBuiltIn, + #[error("The `@per_primitive` attribute can only be used in fragment shader inputs or mesh shader primitive outputs")] + InvalidPerPrimitive, + #[error("Non-builtin members of a mesh primitive output struct must be decorated with `@per_primitive`")] + MissingPerPrimitive, } #[derive(Clone, Debug, thiserror::Error)] @@ -123,6 +127,26 @@ pub enum EntryPointError { InvalidIntegerInterpolation { location: u32 }, #[error(transparent)] Function(#[from] FunctionError), + #[error("Non mesh shader entry point cannot have mesh shader attributes")] + UnexpectedMeshShaderAttributes, + #[error("Non mesh/task shader entry point cannot have task payload attribute")] + UnexpectedTaskPayload, + #[error("Task payload must be declared with `var`")] + TaskPayloadWrongAddressSpace, + #[error("For a task payload to be used, it must be declared with @payload")] + WrongTaskPayloadUsed, + #[error("A function can only set vertex and primitive types that correspond to the mesh shader attributes")] + WrongMeshOutputType, + #[error("Only mesh shader entry points can write to mesh output vertices and primitives")] + UnexpectedMeshShaderOutput, + #[error("Mesh shader entry point cannot have a return type")] + UnexpectedMeshShaderEntryResult, + #[error("Task shader entry point must return @builtin(mesh_task_size) vec3")] + WrongTaskShaderEntryResult, + #[error("Mesh output type must be a user-defined struct.")] + InvalidMeshOutputType, + #[error("Mesh primitive outputs must have exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`")] + InvalidMeshPrimitiveOutputType, } fn storage_usage(access: crate::StorageAccess) -> GlobalUse { @@ -139,6 +163,13 @@ fn storage_usage(access: crate::StorageAccess) -> GlobalUse { storage_usage } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum MeshOutputType { + None, + VertexOutput, + PrimitiveOutput, +} + struct VaryingContext<'a> { stage: crate::ShaderStage, output: bool, @@ -149,6 +180,7 @@ struct VaryingContext<'a> { built_ins: &'a mut crate::FastHashSet, capabilities: Capabilities, flags: super::ValidationFlags, + mesh_output_type: MeshOutputType, } impl VaryingContext<'_> { @@ -236,10 +268,9 @@ impl VaryingContext<'_> { ), Bi::Position { .. } => ( match self.stage { - St::Vertex => self.output, + St::Vertex | St::Mesh => self.output, St::Fragment => !self.output, - St::Compute => false, - St::Task | St::Mesh => unreachable!(), + St::Compute | St::Task => false, }, *ty_inner == Ti::Vector { @@ -276,7 +307,7 @@ impl VaryingContext<'_> { *ty_inner == Ti::Scalar(crate::Scalar::U32), ), Bi::LocalInvocationIndex => ( - self.stage == St::Compute && !self.output, + self.stage.compute_like() && !self.output, *ty_inner == Ti::Scalar(crate::Scalar::U32), ), Bi::GlobalInvocationId @@ -284,7 +315,7 @@ impl VaryingContext<'_> { | Bi::WorkGroupId | Bi::WorkGroupSize | Bi::NumWorkGroups => ( - self.stage == St::Compute && !self.output, + self.stage.compute_like() && !self.output, *ty_inner == Ti::Vector { size: Vs::Tri, @@ -292,17 +323,48 @@ impl VaryingContext<'_> { }, ), Bi::NumSubgroups | Bi::SubgroupId => ( - self.stage == St::Compute && !self.output, + self.stage.compute_like() && !self.output, *ty_inner == Ti::Scalar(crate::Scalar::U32), ), Bi::SubgroupSize | Bi::SubgroupInvocationId => ( match self.stage { - St::Compute | St::Fragment => !self.output, + St::Compute | St::Fragment | St::Task | St::Mesh => !self.output, St::Vertex => false, - St::Task | St::Mesh => unreachable!(), }, *ty_inner == Ti::Scalar(crate::Scalar::U32), ), + Bi::CullPrimitive => ( + self.mesh_output_type == MeshOutputType::PrimitiveOutput, + *ty_inner == Ti::Scalar(crate::Scalar::BOOL), + ), + Bi::PointIndex => ( + self.mesh_output_type == MeshOutputType::PrimitiveOutput, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::LineIndices => ( + self.mesh_output_type == MeshOutputType::PrimitiveOutput, + *ty_inner + == Ti::Vector { + size: Vs::Bi, + scalar: crate::Scalar::U32, + }, + ), + Bi::TriangleIndices => ( + self.mesh_output_type == MeshOutputType::PrimitiveOutput, + *ty_inner + == Ti::Vector { + size: Vs::Tri, + scalar: crate::Scalar::U32, + }, + ), + Bi::MeshTaskSize => ( + self.stage == St::Task && self.output, + *ty_inner + == Ti::Vector { + size: Vs::Tri, + scalar: crate::Scalar::U32, + }, + ), }; if !visible { @@ -318,6 +380,7 @@ impl VaryingContext<'_> { interpolation, sampling, blend_src, + per_primitive, } => { // Only IO-shareable types may be stored in locations. if !self.type_info[ty.index()] @@ -326,6 +389,14 @@ impl VaryingContext<'_> { { return Err(VaryingError::NotIOShareableType(ty)); } + if !per_primitive && self.mesh_output_type == MeshOutputType::PrimitiveOutput { + return Err(VaryingError::MissingPerPrimitive); + } else if per_primitive + && ((self.stage != crate::ShaderStage::Fragment || self.output) + && self.mesh_output_type != MeshOutputType::PrimitiveOutput) + { + return Err(VaryingError::InvalidPerPrimitive); + } if let Some(blend_src) = blend_src { // `blend_src` is only valid if dual source blending was explicitly enabled, @@ -390,11 +461,12 @@ impl VaryingContext<'_> { } } + // TODO: update this to reflect the fact that per-primitive outputs aren't interpolated for fragment and mesh stages let needs_interpolation = match self.stage { crate::ShaderStage::Vertex => self.output, crate::ShaderStage::Fragment => !self.output, - crate::ShaderStage::Compute => false, - crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(), + crate::ShaderStage::Compute | crate::ShaderStage::Task => false, + crate::ShaderStage::Mesh => self.output, }; // It doesn't make sense to specify a sampling when `interpolation` is `Flat`, but @@ -595,7 +667,9 @@ impl super::Validator { TypeFlags::CONSTRUCTIBLE | TypeFlags::CREATION_RESOLVED, false, ), - crate::AddressSpace::WorkGroup => (TypeFlags::DATA | TypeFlags::SIZED, false), + crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => { + (TypeFlags::DATA | TypeFlags::SIZED, false) + } crate::AddressSpace::PushConstant => { if !self.capabilities.contains(Capabilities::PUSH_CONSTANT) { return Err(GlobalVariableError::UnsupportedCapability( @@ -671,7 +745,7 @@ impl super::Validator { } } - if ep.stage == crate::ShaderStage::Compute { + if ep.stage.compute_like() { if ep .workgroup_size .iter() @@ -683,10 +757,30 @@ impl super::Validator { return Err(EntryPointError::UnexpectedWorkgroupSize.with_span()); } + if ep.stage != crate::ShaderStage::Mesh && ep.mesh_info.is_some() { + return Err(EntryPointError::UnexpectedMeshShaderAttributes.with_span()); + } + let mut info = self .validate_function(&ep.function, module, mod_info, true) .map_err(WithSpan::into_other)?; + if let Some(handle) = ep.task_payload { + if ep.stage != crate::ShaderStage::Task && ep.stage != crate::ShaderStage::Mesh { + return Err(EntryPointError::UnexpectedTaskPayload.with_span()); + } + if module.global_variables[handle].space != crate::AddressSpace::TaskPayload { + return Err(EntryPointError::TaskPayloadWrongAddressSpace.with_span()); + } + // Make sure that this is always present in the outputted shader + let uses = if ep.stage == crate::ShaderStage::Mesh { + GlobalUse::READ + } else { + GlobalUse::READ | GlobalUse::WRITE + }; + info.insert_global_use(uses, handle); + } + { use super::ShaderStages; @@ -694,7 +788,8 @@ impl super::Validator { crate::ShaderStage::Vertex => ShaderStages::VERTEX, crate::ShaderStage::Fragment => ShaderStages::FRAGMENT, crate::ShaderStage::Compute => ShaderStages::COMPUTE, - crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(), + crate::ShaderStage::Mesh => ShaderStages::MESH, + crate::ShaderStage::Task => ShaderStages::TASK, }; if !info.available_stages.contains(stage_bit) { @@ -716,6 +811,7 @@ impl super::Validator { built_ins: &mut argument_built_ins, capabilities: self.capabilities, flags: self.flags, + mesh_output_type: MeshOutputType::None, }; ctx.validate(ep, fa.ty, fa.binding.as_ref()) .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?; @@ -734,6 +830,7 @@ impl super::Validator { built_ins: &mut result_built_ins, capabilities: self.capabilities, flags: self.flags, + mesh_output_type: MeshOutputType::None, }; ctx.validate(ep, fr.ty, fr.binding.as_ref()) .map_err_inner(|e| EntryPointError::Result(e).with_span())?; @@ -742,11 +839,26 @@ impl super::Validator { { return Err(EntryPointError::MissingVertexOutputPosition.with_span()); } + if ep.stage == crate::ShaderStage::Mesh + && (!result_built_ins.is_empty() || !self.location_mask.is_empty()) + { + return Err(EntryPointError::UnexpectedMeshShaderEntryResult.with_span()); + } + // Cannot have any other built-ins or @location outputs as those are per-vertex or per-primitive + if ep.stage == crate::ShaderStage::Task + && (!result_built_ins.contains(&crate::BuiltIn::MeshTaskSize) + || result_built_ins.len() != 1 + || !self.location_mask.is_empty()) + { + return Err(EntryPointError::WrongTaskShaderEntryResult.with_span()); + } if !self.blend_src_mask.is_empty() { info.dual_source_blending = true; } } else if ep.stage == crate::ShaderStage::Vertex { return Err(EntryPointError::MissingVertexOutputPosition.with_span()); + } else if ep.stage == crate::ShaderStage::Task { + return Err(EntryPointError::WrongTaskShaderEntryResult.with_span()); } { @@ -764,6 +876,13 @@ impl super::Validator { } } + if let Some(task_payload) = ep.task_payload { + if module.global_variables[task_payload].space != crate::AddressSpace::TaskPayload { + return Err(EntryPointError::TaskPayloadWrongAddressSpace + .with_span_handle(task_payload, &module.global_variables)); + } + } + self.ep_resource_bindings.clear(); for (var_handle, var) in module.global_variables.iter() { let usage = info[var_handle]; @@ -771,6 +890,13 @@ impl super::Validator { continue; } + if var.space == crate::AddressSpace::TaskPayload { + if ep.task_payload != Some(var_handle) { + return Err(EntryPointError::WrongTaskPayloadUsed + .with_span_handle(var_handle, &module.global_variables)); + } + } + let allowed_usage = match var.space { crate::AddressSpace::Function => unreachable!(), crate::AddressSpace::Uniform => GlobalUse::READ | GlobalUse::QUERY, @@ -792,6 +918,15 @@ impl super::Validator { crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => { GlobalUse::READ | GlobalUse::WRITE | GlobalUse::QUERY } + crate::AddressSpace::TaskPayload => { + GlobalUse::READ + | GlobalUse::QUERY + | if ep.stage == crate::ShaderStage::Task { + GlobalUse::WRITE + } else { + GlobalUse::empty() + } + } crate::AddressSpace::PushConstant => GlobalUse::READ, }; if !allowed_usage.contains(usage) { @@ -811,6 +946,77 @@ impl super::Validator { } } + if let &Some(ref mesh_info) = &ep.mesh_info { + // Technically it is allowed to not output anything + // TODO: check that only the allowed builtins are used here + if let Some(used_vertex_type) = info.mesh_shader_info.vertex_type { + if used_vertex_type.0 != mesh_info.vertex_output_type { + return Err(EntryPointError::WrongMeshOutputType + .with_span_handle(mesh_info.vertex_output_type, &module.types)); + } + } + if let Some(used_primitive_type) = info.mesh_shader_info.primitive_type { + if used_primitive_type.0 != mesh_info.primitive_output_type { + return Err(EntryPointError::WrongMeshOutputType + .with_span_handle(mesh_info.primitive_output_type, &module.types)); + } + } + + for (ty, mesh_output_type) in [ + (mesh_info.vertex_output_type, MeshOutputType::VertexOutput), + ( + mesh_info.primitive_output_type, + MeshOutputType::PrimitiveOutput, + ), + ] { + if !matches!(module.types[ty].inner, crate::TypeInner::Struct { .. }) { + return Err( + EntryPointError::InvalidMeshOutputType.with_span_handle(ty, &module.types) + ); + } + let mut result_built_ins = crate::FastHashSet::default(); + let mut ctx = VaryingContext { + stage: ep.stage, + output: true, + types: &module.types, + type_info: &self.types, + location_mask: &mut self.location_mask, + blend_src_mask: &mut self.blend_src_mask, + built_ins: &mut result_built_ins, + capabilities: self.capabilities, + flags: self.flags, + mesh_output_type, + }; + ctx.validate(ep, ty, None) + .map_err_inner(|e| EntryPointError::Result(e).with_span())?; + if mesh_output_type == MeshOutputType::PrimitiveOutput { + let mut num_indices_builtins = 0; + if result_built_ins.contains(&crate::BuiltIn::PointIndex) { + num_indices_builtins += 1; + } + if result_built_ins.contains(&crate::BuiltIn::LineIndices) { + num_indices_builtins += 1; + } + if result_built_ins.contains(&crate::BuiltIn::TriangleIndices) { + num_indices_builtins += 1; + } + if num_indices_builtins != 1 { + return Err(EntryPointError::InvalidMeshPrimitiveOutputType + .with_span_handle(ty, &module.types)); + } + } else if mesh_output_type == MeshOutputType::VertexOutput + && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false }) + { + return Err(EntryPointError::MissingVertexOutputPosition + .with_span_handle(ty, &module.types)); + } + } + } else if info.mesh_shader_info.vertex_type.is_some() + || info.mesh_shader_info.primitive_type.is_some() + { + return Err(EntryPointError::UnexpectedMeshShaderOutput.with_span()); + } + Ok(info) } } diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index fe45d3bfb07..babea985244 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -240,6 +240,8 @@ bitflags::bitflags! { const VERTEX = 0x1; const FRAGMENT = 0x2; const COMPUTE = 0x4; + const MESH = 0x8; + const TASK = 0x10; } } diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index e8b83ff08f3..aa0633e1852 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -220,9 +220,12 @@ const fn ptr_space_argument_flag(space: crate::AddressSpace) -> TypeFlags { use crate::AddressSpace as As; match space { As::Function | As::Private => TypeFlags::ARGUMENT, - As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant | As::WorkGroup => { - TypeFlags::empty() - } + As::Uniform + | As::Storage { .. } + | As::Handle + | As::PushConstant + | As::WorkGroup + | As::TaskPayload => TypeFlags::empty(), } } diff --git a/naga/tests/in/wgsl/mesh-shader.toml b/naga/tests/in/wgsl/mesh-shader.toml new file mode 100644 index 00000000000..1f02c781b5d --- /dev/null +++ b/naga/tests/in/wgsl/mesh-shader.toml @@ -0,0 +1,19 @@ +# Stolen from ray-query.toml + +god_mode = true +targets = "SPIRV | IR | ANALYSIS" + +[msl] +fake_missing_bindings = true +lang_version = [2, 4] +spirv_cross_compatibility = false +zero_initialize_workgroup_memory = false + +[hlsl] +shader_model = "V6_5" +fake_missing_bindings = true +zero_initialize_workgroup_memory = true + +[spv] +version = [1, 4] +capabilities = ["MeshShadingEXT"] diff --git a/naga/tests/in/wgsl/mesh-shader.wgsl b/naga/tests/in/wgsl/mesh-shader.wgsl new file mode 100644 index 00000000000..70fc2aec333 --- /dev/null +++ b/naga/tests/in/wgsl/mesh-shader.wgsl @@ -0,0 +1,71 @@ +enable mesh_shading; + +const positions = array( + vec4(0.,1.,0.,1.), + vec4(-1.,-1.,0.,1.), + vec4(1.,-1.,0.,1.) +); +const colors = array( + vec4(0.,1.,0.,1.), + vec4(0.,0.,1.,1.), + vec4(1.,0.,0.,1.) +); +struct TaskPayload { + colorMask: vec4, + visible: bool, +} +var taskPayload: TaskPayload; +var workgroupData: f32; +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) color: vec4, +} +struct PrimitiveOutput { + @builtin(triangle_indices) index: vec3, + @builtin(cull_primitive) cull: bool, + @per_primitive @location(1) colorMask: vec4, +} +struct PrimitiveInput { + @per_primitive @location(1) colorMask: vec4, +} + +@task +@payload(taskPayload) +@workgroup_size(1) +fn ts_main() -> @builtin(mesh_task_size) vec3 { + workgroupData = 1.0; + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return vec3(3, 1, 1); +} +@mesh +@payload(taskPayload) +@vertex_output(VertexOutput, 3) @primitive_output(PrimitiveOutput, 1) +@workgroup_size(1) +fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { + setMeshOutputs(3, 1); + workgroupData = 2.0; + var v: VertexOutput; + + v.position = positions[0]; + v.color = colors[0] * taskPayload.colorMask; + setVertex(0, v); + + v.position = positions[1]; + v.color = colors[1] * taskPayload.colorMask; + setVertex(1, v); + + v.position = positions[2]; + v.color = colors[2] * taskPayload.colorMask; + setVertex(2, v); + + var p: PrimitiveOutput; + p.index = vec3(0, 1, 2); + p.cull = !taskPayload.visible; + p.colorMask = vec4(1.0, 0.0, 1.0, 1.0); + setPrimitive(0, p); +} +@fragment +fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { + return vertex.color * primitive.colorMask; +} diff --git a/naga/tests/naga/snapshots.rs b/naga/tests/naga/snapshots.rs index e2288eee918..f08dbcb59dd 100644 --- a/naga/tests/naga/snapshots.rs +++ b/naga/tests/naga/snapshots.rs @@ -1,389 +1,12 @@ -// A lot of the code can be unused based on configuration flags, -// the corresponding warnings aren't helpful. -#![allow(dead_code, unused_imports)] - -use core::fmt::Write; - -use std::{ - fs, - path::{Path, PathBuf}, -}; - use naga::compact::KeepUnused; -use ron::de; - -const CRATE_ROOT: &str = env!("CARGO_MANIFEST_DIR"); -const BASE_DIR_IN: &str = "tests/in"; -const BASE_DIR_OUT: &str = "tests/out"; - -bitflags::bitflags! { - #[derive(Clone, Copy, serde::Deserialize)] - #[serde(transparent)] - #[derive(Debug, Eq, PartialEq)] - struct Targets: u32 { - /// A serialization of the `naga::Module`, in RON format. - const IR = 1; - - /// A serialization of the `naga::valid::ModuleInfo`, in RON format. - const ANALYSIS = 1 << 1; - - const SPIRV = 1 << 2; - const METAL = 1 << 3; - const GLSL = 1 << 4; - const DOT = 1 << 5; - const HLSL = 1 << 6; - const WGSL = 1 << 7; - const NO_VALIDATION = 1 << 8; - } -} - -impl Targets { - /// Defaults for `spv` and `glsl` snapshots. - fn non_wgsl_default() -> Self { - Targets::WGSL - } - - /// Defaults for `wgsl` snapshots. - fn wgsl_default() -> Self { - Targets::HLSL | Targets::SPIRV | Targets::GLSL | Targets::METAL | Targets::WGSL - } -} - -#[derive(serde::Deserialize)] -struct SpvOutVersion(u8, u8); -impl Default for SpvOutVersion { - fn default() -> Self { - SpvOutVersion(1, 1) - } -} - -#[cfg(all(feature = "deserialize", spv_out))] -#[derive(serde::Deserialize)] -struct BindingMapSerialization { - resource_binding: naga::ResourceBinding, - bind_target: naga::back::spv::BindingInfo, -} - -#[cfg(all(feature = "deserialize", spv_out))] -fn deserialize_binding_map<'de, D>(deserializer: D) -> Result -where - D: serde::Deserializer<'de>, -{ - use serde::Deserialize; - - let vec = Vec::::deserialize(deserializer)?; - let mut map = naga::back::spv::BindingMap::default(); - for item in vec { - map.insert(item.resource_binding, item.bind_target); - } - Ok(map) -} - -#[derive(Default, serde::Deserialize)] -#[serde(default)] -struct WgslInParameters { - parse_doc_comments: bool, -} - -#[derive(Default, serde::Deserialize)] -#[serde(default)] -struct SpirvInParameters { - adjust_coordinate_space: bool, -} - -#[derive(Default, serde::Deserialize)] -#[serde(default)] -struct SpirvOutParameters { - version: SpvOutVersion, - capabilities: naga::FastHashSet, - debug: bool, - adjust_coordinate_space: bool, - force_point_size: bool, - clamp_frag_depth: bool, - separate_entry_points: bool, - #[cfg(all(feature = "deserialize", spv_out))] - #[serde(deserialize_with = "deserialize_binding_map")] - binding_map: naga::back::spv::BindingMap, -} - -#[derive(Default, serde::Deserialize)] -#[serde(default)] -struct WgslOutParameters { - explicit_types: bool, -} - -#[derive(Default, serde::Deserialize)] -struct FragmentModule { - path: String, - entry_point: String, -} - -#[derive(Default, serde::Deserialize)] -#[serde(default)] -struct Parameters { - // -- GOD MODE -- - god_mode: bool, - - // -- wgsl-in options -- - #[serde(rename = "wgsl-in")] - wgsl_in: WgslInParameters, - - // -- spirv-in options -- - #[serde(rename = "spv-in")] - spv_in: SpirvInParameters, - - // -- SPIR-V options -- - spv: SpirvOutParameters, - - /// Defaults to [`Targets::non_wgsl_default()`] for `spv` and `glsl` snapshots, - /// and [`Targets::wgsl_default()`] for `wgsl` snapshots. - targets: Option, - - // -- MSL options -- - #[cfg(all(feature = "deserialize", msl_out))] - msl: naga::back::msl::Options, - #[cfg(all(feature = "deserialize", msl_out))] - #[serde(default)] - msl_pipeline: naga::back::msl::PipelineOptions, - - // -- GLSL options -- - #[cfg(all(feature = "deserialize", glsl_out))] - glsl: naga::back::glsl::Options, - glsl_exclude_list: naga::FastHashSet, - #[cfg(all(feature = "deserialize", glsl_out))] - glsl_multiview: Option, - - // -- HLSL options -- - #[cfg(all(feature = "deserialize", hlsl_out))] - hlsl: naga::back::hlsl::Options, - - // -- WGSL options -- - wgsl: WgslOutParameters, - - // -- General options -- - - // Allow backends to be aware of the fragment module. - // Is the name of a WGSL file in the same directory as the test file. - fragment_module: Option, - - #[cfg(feature = "deserialize")] - bounds_check_policies: naga::proc::BoundsCheckPolicies, - - #[cfg(all(feature = "deserialize", any(hlsl_out, msl_out, spv_out, glsl_out)))] - pipeline_constants: naga::back::PipelineConstants, -} - -/// Information about a shader input file. -#[derive(Debug)] -struct Input { - /// The subdirectory of `tests/in` to which this input belongs, if any. - /// - /// If the subdirectory is omitted, we assume that the output goes - /// to "wgsl". - subdirectory: PathBuf, - - /// The input filename name, without a directory. - file_name: PathBuf, - - /// True if output filenames should add the output extension on top of - /// `file_name`'s existing extension, rather than replacing it. - /// - /// This is used by `convert_snapshots_glsl`, which wants to take input files - /// like `210-bevy-2d-shader.frag` and just add `.wgsl` to it, producing - /// `210-bevy-2d-shader.frag.wgsl`. - keep_input_extension: bool, -} - -impl Input { - /// Read an input file and its corresponding parameters file. - /// - /// Given `input`, the relative path of a shader input file, return - /// a `Source` value containing its path, code, and parameters. - /// - /// The `input` path is interpreted relative to the `BASE_DIR_IN` - /// subdirectory of the directory given by the `CARGO_MANIFEST_DIR` - /// environment variable. - fn new(subdirectory: &str, name: &str, extension: &str) -> Input { - Input { - subdirectory: PathBuf::from(subdirectory), - // Don't wipe out any extensions on `name`, as - // `with_extension` would do. - file_name: PathBuf::from(format!("{name}.{extension}")), - keep_input_extension: false, - } - } - - /// Return an iterator that produces an `Input` for each entry in `subdirectory`. - fn files_in_dir( - subdirectory: &'static str, - file_extensions: &'static [&'static str], - ) -> impl Iterator + 'static { - let input_directory = Path::new(CRATE_ROOT).join(BASE_DIR_IN).join(subdirectory); - - let entries = match std::fs::read_dir(&input_directory) { - Ok(entries) => entries, - Err(err) => panic!( - "Error opening directory '{}': {}", - input_directory.display(), - err - ), - }; - - entries.filter_map(move |result| { - let entry = result.expect("error reading directory"); - if !entry.file_type().unwrap().is_file() { - return None; - } +use wgpu_test::naga::*; - let file_name = PathBuf::from(entry.file_name()); - let extension = file_name - .extension() - .expect("all files in snapshot input directory should have extensions"); - - if !file_extensions.contains(&extension.to_str().unwrap()) { - return None; - } - - if let Ok(pat) = std::env::var("NAGA_SNAPSHOT") { - if !file_name.to_string_lossy().contains(&pat) { - return None; - } - } - - let input = Input::new( - subdirectory, - file_name.file_stem().unwrap().to_str().unwrap(), - extension.to_str().unwrap(), - ); - Some(input) - }) - } - - /// Return the path to the input directory. - fn input_directory(&self) -> PathBuf { - let mut dir = Path::new(CRATE_ROOT).join(BASE_DIR_IN); - dir.push(&self.subdirectory); - dir - } - - /// Return the path to the output directory. - fn output_directory(subdirectory: &str) -> PathBuf { - let mut dir = Path::new(CRATE_ROOT).join(BASE_DIR_OUT); - dir.push(subdirectory); - dir - } - - /// Return the path to the input file. - fn input_path(&self) -> PathBuf { - let mut input = self.input_directory(); - input.push(&self.file_name); - input - } - - fn output_path(&self, subdirectory: &str, extension: &str) -> PathBuf { - let mut output = Self::output_directory(subdirectory); - if self.keep_input_extension { - let file_name = format!( - "{}-{}.{}", - self.subdirectory.display(), - self.file_name.display(), - extension - ); - - output.push(&file_name); - } else { - let file_name = format!( - "{}-{}", - self.subdirectory.display(), - self.file_name.display() - ); - - output.push(&file_name); - output.set_extension(extension); - } - output - } - - /// Return the contents of the input file as a string. - fn read_source(&self) -> String { - println!("Processing '{}'", self.file_name.display()); - let input_path = self.input_path(); - match fs::read_to_string(&input_path) { - Ok(source) => source, - Err(err) => { - panic!( - "Couldn't read shader input file `{}`: {}", - input_path.display(), - err - ); - } - } - } - - /// Return the contents of the input file as a vector of bytes. - fn read_bytes(&self) -> Vec { - println!("Processing '{}'", self.file_name.display()); - let input_path = self.input_path(); - match fs::read(&input_path) { - Ok(bytes) => bytes, - Err(err) => { - panic!( - "Couldn't read shader input file `{}`: {}", - input_path.display(), - err - ); - } - } - } - - /// Return this input's parameter file, parsed. - fn read_parameters(&self) -> Parameters { - let mut param_path = self.input_path(); - param_path.set_extension("toml"); - let mut params = match fs::read_to_string(¶m_path) { - Ok(string) => match toml::de::from_str(&string) { - Ok(params) => params, - Err(e) => panic!( - "Couldn't parse param file: {} due to: {e}", - param_path.display() - ), - }, - Err(_) => Parameters::default(), - }; - - if params.targets.is_none() { - match self.input_path().extension().unwrap().to_str().unwrap() { - "wgsl" => params.targets = Some(Targets::wgsl_default()), - "spvasm" => params.targets = Some(Targets::non_wgsl_default()), - "vert" | "frag" | "comp" => params.targets = Some(Targets::non_wgsl_default()), - e => { - panic!("Unknown extension: {e}"); - } - } - } - - params - } - - /// Write `data` to a file corresponding to this input file in - /// `subdirectory`, with `extension`. - fn write_output_file(&self, subdirectory: &str, extension: &str, data: impl AsRef<[u8]>) { - let output_path = self.output_path(subdirectory, extension); - fs::create_dir_all(output_path.parent().unwrap()).unwrap(); - if let Err(err) = fs::write(&output_path, data) { - panic!("Error writing {}: {}", output_path.display(), err); - } - } -} - -#[cfg(hlsl_out)] -type FragmentEntryPoint<'a> = naga::back::hlsl::FragmentEntryPoint<'a>; -#[cfg(not(hlsl_out))] -type FragmentEntryPoint<'a> = (); +const DIR_IN: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/in"); +const DIR_OUT: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/tests/out"); #[allow(unused_variables)] fn check_targets(input: &Input, module: &mut naga::Module, source_code: Option<&str>) { - let params = input.read_parameters(); + let params = input.read_parameters(DIR_IN); let name = &input.file_name; let targets = params.targets.unwrap(); @@ -402,12 +25,11 @@ fn check_targets(input: &Input, module: &mut naga::Module, source_code: Option<& ) }; - #[cfg(feature = "serialize")] { if targets.contains(Targets::IR) { let config = ron::ser::PrettyConfig::default().new_line("\n".to_string()); let string = ron::ser::to_string_pretty(module, config).unwrap(); - input.write_output_file("ir", "ron", string); + input.write_output_file("ir", "ron", string, DIR_OUT); } } @@ -438,12 +60,11 @@ fn check_targets(input: &Input, module: &mut naga::Module, source_code: Option<& // snapshots makes the output independent of unused arena entries. naga::compact::compact(module, KeepUnused::No); - #[cfg(feature = "serialize")] { if targets.contains(Targets::IR) { let config = ron::ser::PrettyConfig::default().new_line("\n".to_string()); let string = ron::ser::to_string_pretty(module, config).unwrap(); - input.write_output_file("ir", "compact.ron", string); + input.write_output_file("ir", "compact.ron", string, DIR_OUT); } } @@ -460,16 +81,15 @@ fn check_targets(input: &Input, module: &mut naga::Module, source_code: Option<& }) }; - #[cfg(feature = "serialize")] { if targets.contains(Targets::ANALYSIS) { let config = ron::ser::PrettyConfig::default().new_line("\n".to_string()); let string = ron::ser::to_string_pretty(&info, config).unwrap(); - input.write_output_file("analysis", "info.ron", string); + input.write_output_file("analysis", "info.ron", string, DIR_OUT); } } - #[cfg(all(feature = "deserialize", spv_out))] + #[cfg(feature = "spv-out")] { if targets.contains(Targets::SPIRV) { let mut debug_info = None; @@ -495,7 +115,7 @@ fn check_targets(input: &Input, module: &mut naga::Module, source_code: Option<& ); } } - #[cfg(all(feature = "deserialize", msl_out))] + #[cfg(feature = "msl-out")] { if targets.contains(Targets::METAL) { write_output_msl( @@ -509,7 +129,7 @@ fn check_targets(input: &Input, module: &mut naga::Module, source_code: Option<& ); } } - #[cfg(all(feature = "deserialize", glsl_out))] + #[cfg(feature = "glsl-out")] { if targets.contains(Targets::GLSL) { for ep in module.entry_points.iter() { @@ -530,20 +150,20 @@ fn check_targets(input: &Input, module: &mut naga::Module, source_code: Option<& } } } - #[cfg(dot_out)] + #[cfg(feature = "dot-out")] { if targets.contains(Targets::DOT) { let string = naga::back::dot::write(module, Some(&info), Default::default()).unwrap(); - input.write_output_file("dot", "dot", string); + input.write_output_file("dot", "dot", string, DIR_OUT); } } - #[cfg(all(feature = "deserialize", hlsl_out))] + #[cfg(feature = "hlsl-out")] { if targets.contains(Targets::HLSL) { let frag_module; let mut frag_ep = None; if let Some(ref module_spec) = params.fragment_module { - let full_path = input.input_directory().join(&module_spec.path); + let full_path = input.input_directory(DIR_IN).join(&module_spec.path); assert_eq!( full_path.extension().unwrap().to_string_lossy(), @@ -551,7 +171,7 @@ fn check_targets(input: &Input, module: &mut naga::Module, source_code: Option<& "Currently all fragment modules must be in WGSL" ); - let frag_src = fs::read_to_string(full_path).unwrap(); + let frag_src = std::fs::read_to_string(full_path).unwrap(); frag_module = naga::front::wgsl::parse_str(&frag_src) .expect("Failed to parse fragment module"); @@ -575,7 +195,7 @@ fn check_targets(input: &Input, module: &mut naga::Module, source_code: Option<& ); } } - #[cfg(all(feature = "deserialize", wgsl_out))] + #[cfg(feature = "wgsl-out")] { if targets.contains(Targets::WGSL) { write_output_wgsl(input, module, &info, ¶ms.wgsl); @@ -583,7 +203,7 @@ fn check_targets(input: &Input, module: &mut naga::Module, source_code: Option<& } } -#[cfg(spv_out)] +#[cfg(feature = "spv-out")] fn write_output_spv( input: &Input, module: &naga::Module, @@ -594,31 +214,8 @@ fn write_output_spv( pipeline_constants: &naga::back::PipelineConstants, ) { use naga::back::spv; - use rspirv::binary::Disassemble; - let mut flags = spv::WriterFlags::LABEL_VARYINGS; - flags.set(spv::WriterFlags::DEBUG, params.debug); - flags.set( - spv::WriterFlags::ADJUST_COORDINATE_SPACE, - params.adjust_coordinate_space, - ); - flags.set(spv::WriterFlags::FORCE_POINT_SIZE, params.force_point_size); - flags.set(spv::WriterFlags::CLAMP_FRAG_DEPTH, params.clamp_frag_depth); - - let options = spv::Options { - lang_version: (params.version.0, params.version.1), - flags, - capabilities: if params.capabilities.is_empty() { - None - } else { - Some(params.capabilities.clone()) - }, - bounds_check_policies, - binding_map: params.binding_map.clone(), - zero_initialize_workgroup_memory: spv::ZeroInitializeWorkgroupMemoryMode::Polyfill, - force_loop_bounding: true, - debug_info, - }; + let options = params.to_options(bounds_check_policies, debug_info); let (module, info) = naga::back::pipeline_constants::process_overrides(module, info, None, pipeline_constants) @@ -644,7 +241,7 @@ fn write_output_spv( } } -#[cfg(spv_out)] +#[cfg(feature = "spv-out")] fn write_output_spv_inner( input: &Input, module: &naga::Module, @@ -667,10 +264,10 @@ fn write_output_spv_inner( } else { dis }; - input.write_output_file("spv", extension, dis); + input.write_output_file("spv", extension, dis, DIR_OUT); } -#[cfg(msl_out)] +#[cfg(feature = "msl-out")] fn write_output_msl( input: &Input, module: &naga::Module, @@ -699,10 +296,10 @@ fn write_output_msl( } } - input.write_output_file("msl", "msl", string); + input.write_output_file("msl", "msl", string, DIR_OUT); } -#[cfg(glsl_out)] +#[cfg(feature = "glsl-out")] #[allow(clippy::too_many_arguments)] fn write_output_glsl( input: &Input, @@ -741,10 +338,10 @@ fn write_output_glsl( writer.write().expect("GLSL write failed"); let extension = format!("{ep_name}.{stage:?}.glsl"); - input.write_output_file("glsl", &extension, buffer); + input.write_output_file("glsl", &extension, buffer, DIR_OUT); } -#[cfg(hlsl_out)] +#[cfg(feature = "hlsl-out")] fn write_output_hlsl( input: &Input, module: &naga::Module, @@ -753,7 +350,6 @@ fn write_output_hlsl( pipeline_constants: &naga::back::PipelineConstants, frag_ep: Option, ) { - use core::fmt::Write as _; use naga::back::hlsl; println!("generating HLSL"); @@ -769,7 +365,7 @@ fn write_output_hlsl( .write(&module, &info, frag_ep.as_ref()) .expect("HLSL write failed"); - input.write_output_file("hlsl", "hlsl", buffer); + input.write_output_file("hlsl", "hlsl", buffer, DIR_OUT); // We need a config file for validation script // This file contains an info about profiles (shader stages) contains inside generated shader @@ -796,10 +392,12 @@ fn write_output_hlsl( }); } - config.to_file(input.output_path("hlsl", "ron")).unwrap(); + config + .to_file(input.output_path("hlsl", "ron", DIR_OUT)) + .unwrap(); } -#[cfg(wgsl_out)] +#[cfg(feature = "wgsl-out")] fn write_output_wgsl( input: &Input, module: &naga::Module, @@ -810,12 +408,9 @@ fn write_output_wgsl( println!("generating WGSL"); - let mut flags = wgsl::WriterFlags::empty(); - flags.set(wgsl::WriterFlags::EXPLICIT_TYPES, params.explicit_types); - - let string = wgsl::write_string(module, info, flags).expect("WGSL write failed"); + let string = wgsl::write_string(module, info, params.into()).expect("WGSL write failed"); - input.write_output_file("wgsl", "wgsl", string); + input.write_output_file("wgsl", "wgsl", string, DIR_OUT); } // While we _can_ run this test under miri, it is extremely slow (>5 minutes), @@ -826,21 +421,19 @@ fn write_output_wgsl( fn convert_snapshots_wgsl() { let _ = env_logger::try_init(); - for input in Input::files_in_dir("wgsl", &["wgsl"]) { - let source = input.read_source(); + for input in Input::files_in_dir("wgsl", &["wgsl"], DIR_IN) { + let source = input.read_source(DIR_IN, true); // crlf will make the large split output different on different platform let source = source.replace('\r', ""); - let params = input.read_parameters(); - let WgslInParameters { parse_doc_comments } = params.wgsl_in; + let params = input.read_parameters(DIR_IN); - let options = naga::front::wgsl::Options { parse_doc_comments }; - let mut frontend = naga::front::wgsl::Frontend::new_with_options(options); + let mut frontend = naga::front::wgsl::Frontend::new_with_options((¶ms.wgsl_in).into()); match frontend.parse(&source) { Ok(mut module) => check_targets(&input, &mut module, Some(&source)), Err(e) => panic!( "{}", - e.emit_to_string_with_path(&source, input.input_path()) + e.emit_to_string_with_path(&source, input.input_path(DIR_IN)) ), } } @@ -855,11 +448,11 @@ fn convert_snapshots_spv() { let _ = env_logger::try_init(); - for input in Input::files_in_dir("spv", &["spvasm"]) { + for input in Input::files_in_dir("spv", &["spvasm"], DIR_IN) { println!("Assembling '{}'", input.file_name.display()); let command = Command::new("spirv-as") - .arg(input.input_path()) + .arg(input.input_path(DIR_IN)) .arg("-o") .arg("-") .output() @@ -878,20 +471,10 @@ fn convert_snapshots_spv() { ); } - let params = input.read_parameters(); - let SpirvInParameters { - adjust_coordinate_space, - } = params.spv_in; - - let mut module = naga::front::spv::parse_u8_slice( - &command.stdout, - &naga::front::spv::Options { - adjust_coordinate_space, - strict_capabilities: true, - ..Default::default() - }, - ) - .unwrap(); + let params = input.read_parameters(DIR_IN); + + let mut module = + naga::front::spv::parse_u8_slice(&command.stdout, &(¶ms.spv_in).into()).unwrap(); check_targets(&input, &mut module, None); } @@ -906,7 +489,7 @@ fn convert_snapshots_spv() { fn convert_snapshots_glsl() { let _ = env_logger::try_init(); - for input in Input::files_in_dir("glsl", &["vert", "frag", "comp"]) { + for input in Input::files_in_dir("glsl", &["vert", "frag", "comp"], DIR_IN) { let input = Input { keep_input_extension: true, ..input @@ -927,7 +510,7 @@ fn convert_snapshots_glsl() { stage, defines: Default::default(), }, - &input.read_source(), + &input.read_source(DIR_IN, true), ) .unwrap(); diff --git a/naga/tests/out/analysis/spv-shadow.info.ron b/naga/tests/out/analysis/spv-shadow.info.ron index 6ddda61f5c6..b08a28438ed 100644 --- a/naga/tests/out/analysis/spv-shadow.info.ron +++ b/naga/tests/out/analysis/spv-shadow.info.ron @@ -18,7 +18,7 @@ functions: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(1), requirements: (""), @@ -413,10 +413,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(1), requirements: (""), @@ -1591,12 +1595,16 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(1), requirements: (""), @@ -1685,6 +1693,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-access.info.ron b/naga/tests/out/analysis/wgsl-access.info.ron index 319f62bdf13..d297b09a404 100644 --- a/naga/tests/out/analysis/wgsl-access.info.ron +++ b/naga/tests/out/analysis/wgsl-access.info.ron @@ -42,7 +42,7 @@ functions: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -1197,10 +1197,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2523,10 +2527,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -2563,10 +2571,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -2612,10 +2624,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2655,10 +2671,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2749,10 +2769,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2870,10 +2894,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -2922,10 +2950,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2977,10 +3009,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -3029,10 +3065,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -3084,10 +3124,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -3148,10 +3192,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(2), requirements: (""), @@ -3221,10 +3269,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(2), requirements: (""), @@ -3297,10 +3349,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -3397,10 +3453,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(1), requirements: (""), @@ -3593,12 +3653,16 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -4290,10 +4354,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -4742,10 +4810,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -4812,6 +4884,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-collatz.info.ron b/naga/tests/out/analysis/wgsl-collatz.info.ron index 7ec5799d758..2796f544510 100644 --- a/naga/tests/out/analysis/wgsl-collatz.info.ron +++ b/naga/tests/out/analysis/wgsl-collatz.info.ron @@ -8,7 +8,7 @@ functions: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(3), requirements: (""), @@ -275,12 +275,16 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(3), requirements: (""), @@ -430,6 +434,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [], diff --git a/naga/tests/out/analysis/wgsl-mesh-shader.info.ron b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron new file mode 100644 index 00000000000..208e0aac84e --- /dev/null +++ b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron @@ -0,0 +1,1211 @@ +( + type_flags: [ + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ], + functions: [], + entry_points: [ + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + ("READ | WRITE"), + ("WRITE"), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(1), + ty: Value(Pointer( + base: 0, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 2, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Bool, + width: 1, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(6), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), + ), + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + ("READ"), + ("WRITE"), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 0, + assignable_global: None, + ty: Handle(5), + ), + ( + uniformity: ( + non_uniform_result: Some(1), + requirements: (""), + ), + ref_count: 0, + assignable_global: None, + ty: Handle(6), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(1), + ty: Value(Pointer( + base: 0, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 9, + assignable_global: None, + ty: Value(Pointer( + base: 4, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(4), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(4), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(4), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 4, + assignable_global: None, + ty: Value(Pointer( + base: 7, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 6, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(6), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 2, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 2, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(2), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(2), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(7), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: Some((4, 24)), + primitive_type: Some((7, 79)), + ), + ), + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + (""), + (""), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(4), + ), + ( + uniformity: ( + non_uniform_result: Some(1), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(8), + ), + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(1), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), + ), + ], + const_expression_types: [], +) \ No newline at end of file diff --git a/naga/tests/out/analysis/wgsl-overrides.info.ron b/naga/tests/out/analysis/wgsl-overrides.info.ron index 0e0ae318042..a76c9c89c9b 100644 --- a/naga/tests/out/analysis/wgsl-overrides.info.ron +++ b/naga/tests/out/analysis/wgsl-overrides.info.ron @@ -8,7 +8,7 @@ entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -201,6 +201,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-storage-textures.info.ron b/naga/tests/out/analysis/wgsl-storage-textures.info.ron index fbbf7206c33..35b5a7e320c 100644 --- a/naga/tests/out/analysis/wgsl-storage-textures.info.ron +++ b/naga/tests/out/analysis/wgsl-storage-textures.info.ron @@ -11,7 +11,7 @@ entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -184,10 +184,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -396,6 +400,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [], diff --git a/naga/tests/out/ir/spv-fetch_depth.compact.ron b/naga/tests/out/ir/spv-fetch_depth.compact.ron index 1fbee2deb35..98f4426c3eb 100644 --- a/naga/tests/out/ir/spv-fetch_depth.compact.ron +++ b/naga/tests/out/ir/spv-fetch_depth.compact.ron @@ -196,6 +196,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-fetch_depth.ron b/naga/tests/out/ir/spv-fetch_depth.ron index 186f78354ad..104de852c17 100644 --- a/naga/tests/out/ir/spv-fetch_depth.ron +++ b/naga/tests/out/ir/spv-fetch_depth.ron @@ -266,6 +266,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-shadow.compact.ron b/naga/tests/out/ir/spv-shadow.compact.ron index b49cd9b55be..bed86a5334d 100644 --- a/naga/tests/out/ir/spv-shadow.compact.ron +++ b/naga/tests/out/ir/spv-shadow.compact.ron @@ -974,6 +974,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), ), ( @@ -984,6 +985,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), ), ], @@ -994,6 +996,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -1032,6 +1035,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-shadow.ron b/naga/tests/out/ir/spv-shadow.ron index e1f0f60b6bb..bdda1d18566 100644 --- a/naga/tests/out/ir/spv-shadow.ron +++ b/naga/tests/out/ir/spv-shadow.ron @@ -1252,6 +1252,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), ), ( @@ -1262,6 +1263,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), ), ], @@ -1272,6 +1274,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -1310,6 +1313,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-spec-constants.compact.ron b/naga/tests/out/ir/spv-spec-constants.compact.ron index 3fa6ffef4ff..67eb29c2475 100644 --- a/naga/tests/out/ir/spv-spec-constants.compact.ron +++ b/naga/tests/out/ir/spv-spec-constants.compact.ron @@ -151,6 +151,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), offset: 0, ), @@ -510,6 +511,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ( @@ -520,6 +522,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ( @@ -530,6 +533,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ], @@ -613,6 +617,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-spec-constants.ron b/naga/tests/out/ir/spv-spec-constants.ron index 94c90aa78f9..51686aa20eb 100644 --- a/naga/tests/out/ir/spv-spec-constants.ron +++ b/naga/tests/out/ir/spv-spec-constants.ron @@ -242,6 +242,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), offset: 0, ), @@ -616,6 +617,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ( @@ -626,6 +628,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ( @@ -636,6 +639,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ], @@ -719,6 +723,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-access.compact.ron b/naga/tests/out/ir/wgsl-access.compact.ron index 30e88984f3c..c3df0c8c500 100644 --- a/naga/tests/out/ir/wgsl-access.compact.ron +++ b/naga/tests/out/ir/wgsl-access.compact.ron @@ -2655,6 +2655,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "foo_frag", @@ -2672,6 +2674,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -2848,6 +2851,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "foo_compute", @@ -2907,6 +2912,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-access.ron b/naga/tests/out/ir/wgsl-access.ron index 30e88984f3c..c3df0c8c500 100644 --- a/naga/tests/out/ir/wgsl-access.ron +++ b/naga/tests/out/ir/wgsl-access.ron @@ -2655,6 +2655,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "foo_frag", @@ -2672,6 +2674,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -2848,6 +2851,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "foo_compute", @@ -2907,6 +2912,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-collatz.compact.ron b/naga/tests/out/ir/wgsl-collatz.compact.ron index f72c652d032..fc4daaa1296 100644 --- a/naga/tests/out/ir/wgsl-collatz.compact.ron +++ b/naga/tests/out/ir/wgsl-collatz.compact.ron @@ -334,6 +334,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-collatz.ron b/naga/tests/out/ir/wgsl-collatz.ron index f72c652d032..fc4daaa1296 100644 --- a/naga/tests/out/ir/wgsl-collatz.ron +++ b/naga/tests/out/ir/wgsl-collatz.ron @@ -334,6 +334,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-const_assert.compact.ron b/naga/tests/out/ir/wgsl-const_assert.compact.ron index 2816364f88b..648f4ff9bc9 100644 --- a/naga/tests/out/ir/wgsl-const_assert.compact.ron +++ b/naga/tests/out/ir/wgsl-const_assert.compact.ron @@ -34,6 +34,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-const_assert.ron b/naga/tests/out/ir/wgsl-const_assert.ron index 2816364f88b..648f4ff9bc9 100644 --- a/naga/tests/out/ir/wgsl-const_assert.ron +++ b/naga/tests/out/ir/wgsl-const_assert.ron @@ -34,6 +34,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-diagnostic-filter.compact.ron b/naga/tests/out/ir/wgsl-diagnostic-filter.compact.ron index c5746696d52..9a2bf193f30 100644 --- a/naga/tests/out/ir/wgsl-diagnostic-filter.compact.ron +++ b/naga/tests/out/ir/wgsl-diagnostic-filter.compact.ron @@ -73,6 +73,8 @@ ], diagnostic_filter_leaf: Some(0), ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [ diff --git a/naga/tests/out/ir/wgsl-diagnostic-filter.ron b/naga/tests/out/ir/wgsl-diagnostic-filter.ron index c5746696d52..9a2bf193f30 100644 --- a/naga/tests/out/ir/wgsl-diagnostic-filter.ron +++ b/naga/tests/out/ir/wgsl-diagnostic-filter.ron @@ -73,6 +73,8 @@ ], diagnostic_filter_leaf: Some(0), ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [ diff --git a/naga/tests/out/ir/wgsl-index-by-value.compact.ron b/naga/tests/out/ir/wgsl-index-by-value.compact.ron index a4f84a7a6b2..addd0e5871c 100644 --- a/naga/tests/out/ir/wgsl-index-by-value.compact.ron +++ b/naga/tests/out/ir/wgsl-index-by-value.compact.ron @@ -465,6 +465,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-index-by-value.ron b/naga/tests/out/ir/wgsl-index-by-value.ron index a4f84a7a6b2..addd0e5871c 100644 --- a/naga/tests/out/ir/wgsl-index-by-value.ron +++ b/naga/tests/out/ir/wgsl-index-by-value.ron @@ -465,6 +465,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-local-const.compact.ron b/naga/tests/out/ir/wgsl-local-const.compact.ron index 512972657ed..0e4e2e4d40e 100644 --- a/naga/tests/out/ir/wgsl-local-const.compact.ron +++ b/naga/tests/out/ir/wgsl-local-const.compact.ron @@ -100,6 +100,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-local-const.ron b/naga/tests/out/ir/wgsl-local-const.ron index 512972657ed..0e4e2e4d40e 100644 --- a/naga/tests/out/ir/wgsl-local-const.ron +++ b/naga/tests/out/ir/wgsl-local-const.ron @@ -100,6 +100,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-mesh-shader.compact.ron b/naga/tests/out/ir/wgsl-mesh-shader.compact.ron new file mode 100644 index 00000000000..38c79cba451 --- /dev/null +++ b/naga/tests/out/ir/wgsl-mesh-shader.compact.ron @@ -0,0 +1,846 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Quad, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: Some("TaskPayload"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: None, + offset: 0, + ), + ( + name: Some("visible"), + ty: 2, + binding: None, + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("VertexOutput"), + inner: Struct( + members: [ + ( + name: Some("position"), + ty: 1, + binding: Some(BuiltIn(Position( + invariant: false, + ))), + offset: 0, + ), + ( + name: Some("color"), + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Uint, + width: 4, + ), + ), + ), + ( + name: Some("PrimitiveOutput"), + inner: Struct( + members: [ + ( + name: Some("index"), + ty: 6, + binding: Some(BuiltIn(TriangleIndices)), + offset: 0, + ), + ( + name: Some("cull"), + ty: 2, + binding: Some(BuiltIn(CullPrimitive)), + offset: 12, + ), + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("PrimitiveInput"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 0, + ), + ], + span: 16, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("taskPayload"), + space: TaskPayload, + binding: None, + ty: 3, + init: None, + ), + ( + name: Some("workgroupData"), + space: WorkGroup, + binding: None, + ty: 0, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "ts_main", + stage: Task, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ts_main"), + arguments: [], + result: Some(( + ty: 6, + binding: Some(BuiltIn(MeshTaskSize)), + )), + local_variables: [], + expressions: [ + GlobalVariable(1), + Literal(F32(1.0)), + GlobalVariable(0), + AccessIndex( + base: 2, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 4, + 5, + 6, + 7, + ], + ), + GlobalVariable(0), + AccessIndex( + base: 9, + index: 1, + ), + Literal(Bool(true)), + Literal(U32(3)), + Literal(U32(1)), + Literal(U32(1)), + Compose( + ty: 6, + components: [ + 12, + 13, + 14, + ], + ), + ], + named_expressions: {}, + body: [ + Store( + pointer: 0, + value: 1, + ), + Emit(( + start: 3, + end: 4, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 8, + end: 9, + )), + Store( + pointer: 3, + value: 8, + ), + Emit(( + start: 10, + end: 11, + )), + Store( + pointer: 10, + value: 11, + ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 16, + )), + Return( + value: Some(15), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: Some(0), + ), + ( + name: "ms_main", + stage: Mesh, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ms_main"), + arguments: [ + ( + name: Some("index"), + ty: 5, + binding: Some(BuiltIn(LocalInvocationIndex)), + ), + ( + name: Some("id"), + ty: 6, + binding: Some(BuiltIn(GlobalInvocationId)), + ), + ], + result: None, + local_variables: [ + ( + name: Some("v"), + ty: 4, + init: None, + ), + ( + name: Some("p"), + ty: 7, + init: None, + ), + ], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + Literal(U32(3)), + Literal(U32(1)), + GlobalVariable(1), + Literal(F32(2.0)), + LocalVariable(0), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 8, + 9, + 10, + 11, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 14, + index: 0, + ), + Load( + pointer: 15, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 17, + 18, + 19, + 20, + ], + ), + Binary( + op: Multiply, + left: 21, + right: 16, + ), + Literal(U32(0)), + Load( + pointer: 6, + ), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(-1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 26, + 27, + 28, + 29, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 32, + index: 0, + ), + Load( + pointer: 33, + ), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 35, + 36, + 37, + 38, + ], + ), + Binary( + op: Multiply, + left: 39, + right: 34, + ), + Literal(U32(1)), + Load( + pointer: 6, + ), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 44, + 45, + 46, + 47, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 50, + index: 0, + ), + Load( + pointer: 51, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 53, + 54, + 55, + 56, + ], + ), + Binary( + op: Multiply, + left: 57, + right: 52, + ), + Literal(U32(2)), + Load( + pointer: 6, + ), + LocalVariable(1), + AccessIndex( + base: 61, + index: 0, + ), + Literal(U32(0)), + Literal(U32(1)), + Literal(U32(2)), + Compose( + ty: 6, + components: [ + 63, + 64, + 65, + ], + ), + AccessIndex( + base: 61, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 68, + index: 1, + ), + Load( + pointer: 69, + ), + Unary( + op: LogicalNot, + expr: 70, + ), + AccessIndex( + base: 61, + index: 2, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 73, + 74, + 75, + 76, + ], + ), + Literal(U32(0)), + Load( + pointer: 61, + ), + ], + named_expressions: { + 0: "index", + 1: "id", + }, + body: [ + MeshFunction(SetMeshOutputs( + vertex_count: 2, + primitive_count: 3, + )), + Store( + pointer: 4, + value: 5, + ), + Emit(( + start: 7, + end: 8, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 12, + end: 13, + )), + Store( + pointer: 7, + value: 12, + ), + Emit(( + start: 13, + end: 14, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 17, + )), + Emit(( + start: 21, + end: 23, + )), + Store( + pointer: 13, + value: 22, + ), + Emit(( + start: 24, + end: 25, + )), + MeshFunction(SetVertex( + index: 23, + value: 24, + )), + Emit(( + start: 25, + end: 26, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 30, + end: 31, + )), + Store( + pointer: 25, + value: 30, + ), + Emit(( + start: 31, + end: 32, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 33, + end: 35, + )), + Emit(( + start: 39, + end: 41, + )), + Store( + pointer: 31, + value: 40, + ), + Emit(( + start: 42, + end: 43, + )), + MeshFunction(SetVertex( + index: 41, + value: 42, + )), + Emit(( + start: 43, + end: 44, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 48, + end: 49, + )), + Store( + pointer: 43, + value: 48, + ), + Emit(( + start: 49, + end: 50, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 51, + end: 53, + )), + Emit(( + start: 57, + end: 59, + )), + Store( + pointer: 49, + value: 58, + ), + Emit(( + start: 60, + end: 61, + )), + MeshFunction(SetVertex( + index: 59, + value: 60, + )), + Emit(( + start: 62, + end: 63, + )), + Emit(( + start: 66, + end: 67, + )), + Store( + pointer: 62, + value: 66, + ), + Emit(( + start: 67, + end: 68, + )), + Emit(( + start: 69, + end: 72, + )), + Store( + pointer: 67, + value: 71, + ), + Emit(( + start: 72, + end: 73, + )), + Emit(( + start: 77, + end: 78, + )), + Store( + pointer: 72, + value: 77, + ), + Emit(( + start: 79, + end: 80, + )), + MeshFunction(SetPrimitive( + index: 78, + value: 79, + )), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: Some(( + topology: Triangles, + max_vertices: 3, + max_vertices_override: None, + max_primitives: 1, + max_primitives_override: None, + vertex_output_type: 4, + primitive_output_type: 7, + )), + task_payload: Some(0), + ), + ( + name: "fs_main", + stage: Fragment, + early_depth_test: None, + workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, + function: ( + name: Some("fs_main"), + arguments: [ + ( + name: Some("vertex"), + ty: 4, + binding: None, + ), + ( + name: Some("primitive"), + ty: 8, + binding: None, + ), + ], + result: Some(( + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + )), + local_variables: [], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + AccessIndex( + base: 0, + index: 1, + ), + AccessIndex( + base: 1, + index: 0, + ), + Binary( + op: Multiply, + left: 2, + right: 3, + ), + ], + named_expressions: { + 0: "vertex", + 1: "primitive", + }, + body: [ + Emit(( + start: 2, + end: 5, + )), + Return( + value: Some(4), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: None, + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-mesh-shader.ron b/naga/tests/out/ir/wgsl-mesh-shader.ron new file mode 100644 index 00000000000..38c79cba451 --- /dev/null +++ b/naga/tests/out/ir/wgsl-mesh-shader.ron @@ -0,0 +1,846 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Quad, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: Some("TaskPayload"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: None, + offset: 0, + ), + ( + name: Some("visible"), + ty: 2, + binding: None, + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("VertexOutput"), + inner: Struct( + members: [ + ( + name: Some("position"), + ty: 1, + binding: Some(BuiltIn(Position( + invariant: false, + ))), + offset: 0, + ), + ( + name: Some("color"), + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Uint, + width: 4, + ), + ), + ), + ( + name: Some("PrimitiveOutput"), + inner: Struct( + members: [ + ( + name: Some("index"), + ty: 6, + binding: Some(BuiltIn(TriangleIndices)), + offset: 0, + ), + ( + name: Some("cull"), + ty: 2, + binding: Some(BuiltIn(CullPrimitive)), + offset: 12, + ), + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("PrimitiveInput"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 0, + ), + ], + span: 16, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("taskPayload"), + space: TaskPayload, + binding: None, + ty: 3, + init: None, + ), + ( + name: Some("workgroupData"), + space: WorkGroup, + binding: None, + ty: 0, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "ts_main", + stage: Task, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ts_main"), + arguments: [], + result: Some(( + ty: 6, + binding: Some(BuiltIn(MeshTaskSize)), + )), + local_variables: [], + expressions: [ + GlobalVariable(1), + Literal(F32(1.0)), + GlobalVariable(0), + AccessIndex( + base: 2, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 4, + 5, + 6, + 7, + ], + ), + GlobalVariable(0), + AccessIndex( + base: 9, + index: 1, + ), + Literal(Bool(true)), + Literal(U32(3)), + Literal(U32(1)), + Literal(U32(1)), + Compose( + ty: 6, + components: [ + 12, + 13, + 14, + ], + ), + ], + named_expressions: {}, + body: [ + Store( + pointer: 0, + value: 1, + ), + Emit(( + start: 3, + end: 4, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 8, + end: 9, + )), + Store( + pointer: 3, + value: 8, + ), + Emit(( + start: 10, + end: 11, + )), + Store( + pointer: 10, + value: 11, + ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 16, + )), + Return( + value: Some(15), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: Some(0), + ), + ( + name: "ms_main", + stage: Mesh, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ms_main"), + arguments: [ + ( + name: Some("index"), + ty: 5, + binding: Some(BuiltIn(LocalInvocationIndex)), + ), + ( + name: Some("id"), + ty: 6, + binding: Some(BuiltIn(GlobalInvocationId)), + ), + ], + result: None, + local_variables: [ + ( + name: Some("v"), + ty: 4, + init: None, + ), + ( + name: Some("p"), + ty: 7, + init: None, + ), + ], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + Literal(U32(3)), + Literal(U32(1)), + GlobalVariable(1), + Literal(F32(2.0)), + LocalVariable(0), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 8, + 9, + 10, + 11, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 14, + index: 0, + ), + Load( + pointer: 15, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 17, + 18, + 19, + 20, + ], + ), + Binary( + op: Multiply, + left: 21, + right: 16, + ), + Literal(U32(0)), + Load( + pointer: 6, + ), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(-1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 26, + 27, + 28, + 29, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 32, + index: 0, + ), + Load( + pointer: 33, + ), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 35, + 36, + 37, + 38, + ], + ), + Binary( + op: Multiply, + left: 39, + right: 34, + ), + Literal(U32(1)), + Load( + pointer: 6, + ), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 44, + 45, + 46, + 47, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 50, + index: 0, + ), + Load( + pointer: 51, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 53, + 54, + 55, + 56, + ], + ), + Binary( + op: Multiply, + left: 57, + right: 52, + ), + Literal(U32(2)), + Load( + pointer: 6, + ), + LocalVariable(1), + AccessIndex( + base: 61, + index: 0, + ), + Literal(U32(0)), + Literal(U32(1)), + Literal(U32(2)), + Compose( + ty: 6, + components: [ + 63, + 64, + 65, + ], + ), + AccessIndex( + base: 61, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 68, + index: 1, + ), + Load( + pointer: 69, + ), + Unary( + op: LogicalNot, + expr: 70, + ), + AccessIndex( + base: 61, + index: 2, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 73, + 74, + 75, + 76, + ], + ), + Literal(U32(0)), + Load( + pointer: 61, + ), + ], + named_expressions: { + 0: "index", + 1: "id", + }, + body: [ + MeshFunction(SetMeshOutputs( + vertex_count: 2, + primitive_count: 3, + )), + Store( + pointer: 4, + value: 5, + ), + Emit(( + start: 7, + end: 8, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 12, + end: 13, + )), + Store( + pointer: 7, + value: 12, + ), + Emit(( + start: 13, + end: 14, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 17, + )), + Emit(( + start: 21, + end: 23, + )), + Store( + pointer: 13, + value: 22, + ), + Emit(( + start: 24, + end: 25, + )), + MeshFunction(SetVertex( + index: 23, + value: 24, + )), + Emit(( + start: 25, + end: 26, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 30, + end: 31, + )), + Store( + pointer: 25, + value: 30, + ), + Emit(( + start: 31, + end: 32, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 33, + end: 35, + )), + Emit(( + start: 39, + end: 41, + )), + Store( + pointer: 31, + value: 40, + ), + Emit(( + start: 42, + end: 43, + )), + MeshFunction(SetVertex( + index: 41, + value: 42, + )), + Emit(( + start: 43, + end: 44, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 48, + end: 49, + )), + Store( + pointer: 43, + value: 48, + ), + Emit(( + start: 49, + end: 50, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 51, + end: 53, + )), + Emit(( + start: 57, + end: 59, + )), + Store( + pointer: 49, + value: 58, + ), + Emit(( + start: 60, + end: 61, + )), + MeshFunction(SetVertex( + index: 59, + value: 60, + )), + Emit(( + start: 62, + end: 63, + )), + Emit(( + start: 66, + end: 67, + )), + Store( + pointer: 62, + value: 66, + ), + Emit(( + start: 67, + end: 68, + )), + Emit(( + start: 69, + end: 72, + )), + Store( + pointer: 67, + value: 71, + ), + Emit(( + start: 72, + end: 73, + )), + Emit(( + start: 77, + end: 78, + )), + Store( + pointer: 72, + value: 77, + ), + Emit(( + start: 79, + end: 80, + )), + MeshFunction(SetPrimitive( + index: 78, + value: 79, + )), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: Some(( + topology: Triangles, + max_vertices: 3, + max_vertices_override: None, + max_primitives: 1, + max_primitives_override: None, + vertex_output_type: 4, + primitive_output_type: 7, + )), + task_payload: Some(0), + ), + ( + name: "fs_main", + stage: Fragment, + early_depth_test: None, + workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, + function: ( + name: Some("fs_main"), + arguments: [ + ( + name: Some("vertex"), + ty: 4, + binding: None, + ), + ( + name: Some("primitive"), + ty: 8, + binding: None, + ), + ], + result: Some(( + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + )), + local_variables: [], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + AccessIndex( + base: 0, + index: 1, + ), + AccessIndex( + base: 1, + index: 0, + ), + Binary( + op: Multiply, + left: 2, + right: 3, + ), + ], + named_expressions: { + 0: "vertex", + 1: "primitive", + }, + body: [ + Emit(( + start: 2, + end: 5, + )), + Return( + value: Some(4), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: None, + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-must-use.compact.ron b/naga/tests/out/ir/wgsl-must-use.compact.ron index a701a6805da..16e925f2fb8 100644 --- a/naga/tests/out/ir/wgsl-must-use.compact.ron +++ b/naga/tests/out/ir/wgsl-must-use.compact.ron @@ -201,6 +201,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-must-use.ron b/naga/tests/out/ir/wgsl-must-use.ron index a701a6805da..16e925f2fb8 100644 --- a/naga/tests/out/ir/wgsl-must-use.ron +++ b/naga/tests/out/ir/wgsl-must-use.ron @@ -201,6 +201,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.compact.ron b/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.compact.ron index 640ee25ca49..28a824bb035 100644 --- a/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.compact.ron +++ b/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.compact.ron @@ -128,6 +128,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.ron b/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.ron index 640ee25ca49..28a824bb035 100644 --- a/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.ron +++ b/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.ron @@ -128,6 +128,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides-ray-query.compact.ron b/naga/tests/out/ir/wgsl-overrides-ray-query.compact.ron index f65e8f186db..152a45008c5 100644 --- a/naga/tests/out/ir/wgsl-overrides-ray-query.compact.ron +++ b/naga/tests/out/ir/wgsl-overrides-ray-query.compact.ron @@ -263,6 +263,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides-ray-query.ron b/naga/tests/out/ir/wgsl-overrides-ray-query.ron index f65e8f186db..152a45008c5 100644 --- a/naga/tests/out/ir/wgsl-overrides-ray-query.ron +++ b/naga/tests/out/ir/wgsl-overrides-ray-query.ron @@ -263,6 +263,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides.compact.ron b/naga/tests/out/ir/wgsl-overrides.compact.ron index 81221ff7941..fe136e71e4d 100644 --- a/naga/tests/out/ir/wgsl-overrides.compact.ron +++ b/naga/tests/out/ir/wgsl-overrides.compact.ron @@ -221,6 +221,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides.ron b/naga/tests/out/ir/wgsl-overrides.ron index 81221ff7941..fe136e71e4d 100644 --- a/naga/tests/out/ir/wgsl-overrides.ron +++ b/naga/tests/out/ir/wgsl-overrides.ron @@ -221,6 +221,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-storage-textures.compact.ron b/naga/tests/out/ir/wgsl-storage-textures.compact.ron index ec63fecac27..68c867a19e2 100644 --- a/naga/tests/out/ir/wgsl-storage-textures.compact.ron +++ b/naga/tests/out/ir/wgsl-storage-textures.compact.ron @@ -218,6 +218,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "csStore", @@ -315,6 +317,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-storage-textures.ron b/naga/tests/out/ir/wgsl-storage-textures.ron index ec63fecac27..68c867a19e2 100644 --- a/naga/tests/out/ir/wgsl-storage-textures.ron +++ b/naga/tests/out/ir/wgsl-storage-textures.ron @@ -218,6 +218,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "csStore", @@ -315,6 +317,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-template-list-trailing-comma.compact.ron b/naga/tests/out/ir/wgsl-template-list-trailing-comma.compact.ron index a8208c09b86..db619dff836 100644 --- a/naga/tests/out/ir/wgsl-template-list-trailing-comma.compact.ron +++ b/naga/tests/out/ir/wgsl-template-list-trailing-comma.compact.ron @@ -190,6 +190,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-template-list-trailing-comma.ron b/naga/tests/out/ir/wgsl-template-list-trailing-comma.ron index a8208c09b86..db619dff836 100644 --- a/naga/tests/out/ir/wgsl-template-list-trailing-comma.ron +++ b/naga/tests/out/ir/wgsl-template-list-trailing-comma.ron @@ -190,6 +190,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-texture-external.compact.ron b/naga/tests/out/ir/wgsl-texture-external.compact.ron index dbffbddcdc7..379e76566c5 100644 --- a/naga/tests/out/ir/wgsl-texture-external.compact.ron +++ b/naga/tests/out/ir/wgsl-texture-external.compact.ron @@ -360,6 +360,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -382,6 +383,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "vertex_main", @@ -418,6 +421,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "compute_main", @@ -449,6 +454,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-texture-external.ron b/naga/tests/out/ir/wgsl-texture-external.ron index dbffbddcdc7..379e76566c5 100644 --- a/naga/tests/out/ir/wgsl-texture-external.ron +++ b/naga/tests/out/ir/wgsl-texture-external.ron @@ -360,6 +360,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -382,6 +383,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "vertex_main", @@ -418,6 +421,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "compute_main", @@ -449,6 +454,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-types_with_comments.compact.ron b/naga/tests/out/ir/wgsl-types_with_comments.compact.ron index 7186209f00e..7c0d856946f 100644 --- a/naga/tests/out/ir/wgsl-types_with_comments.compact.ron +++ b/naga/tests/out/ir/wgsl-types_with_comments.compact.ron @@ -116,6 +116,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-types_with_comments.ron b/naga/tests/out/ir/wgsl-types_with_comments.ron index 480b0d2337f..34e44cb9653 100644 --- a/naga/tests/out/ir/wgsl-types_with_comments.ron +++ b/naga/tests/out/ir/wgsl-types_with_comments.ron @@ -172,6 +172,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/spv/wgsl-interface.fragment.spvasm b/naga/tests/out/spv/wgsl-interface.fragment.spvasm index e78d69121f3..ab1dcbd0101 100644 --- a/naga/tests/out/spv/wgsl-interface.fragment.spvasm +++ b/naga/tests/out/spv/wgsl-interface.fragment.spvasm @@ -17,8 +17,8 @@ OpMemberDecorate %7 2 Offset 8 OpDecorate %9 ArrayStride 4 OpMemberDecorate %12 0 Offset 0 OpMemberDecorate %13 0 Offset 0 -OpDecorate %16 Invariant OpDecorate %16 BuiltIn FragCoord +OpDecorate %16 Invariant OpDecorate %19 Location 1 OpDecorate %22 BuiltIn FrontFacing OpDecorate %22 Flat diff --git a/naga/tests/out/spv/wgsl-interface.vertex.spvasm b/naga/tests/out/spv/wgsl-interface.vertex.spvasm index fa11c5b89f7..b72f38e7938 100644 --- a/naga/tests/out/spv/wgsl-interface.vertex.spvasm +++ b/naga/tests/out/spv/wgsl-interface.vertex.spvasm @@ -17,8 +17,8 @@ OpMemberDecorate %13 0 Offset 0 OpDecorate %15 BuiltIn VertexIndex OpDecorate %18 BuiltIn InstanceIndex OpDecorate %20 Location 10 -OpDecorate %22 Invariant OpDecorate %22 BuiltIn Position +OpDecorate %22 Invariant OpDecorate %24 Location 1 OpDecorate %26 BuiltIn PointSize %2 = OpTypeVoid diff --git a/naga/tests/out/spv/wgsl-interface.vertex_two_structs.spvasm b/naga/tests/out/spv/wgsl-interface.vertex_two_structs.spvasm index f83a5a624b3..9706a904dbe 100644 --- a/naga/tests/out/spv/wgsl-interface.vertex_two_structs.spvasm +++ b/naga/tests/out/spv/wgsl-interface.vertex_two_structs.spvasm @@ -16,8 +16,8 @@ OpMemberDecorate %12 0 Offset 0 OpMemberDecorate %13 0 Offset 0 OpDecorate %16 BuiltIn VertexIndex OpDecorate %20 BuiltIn InstanceIndex -OpDecorate %22 Invariant OpDecorate %22 BuiltIn Position +OpDecorate %22 Invariant OpDecorate %24 BuiltIn PointSize %2 = OpTypeVoid %3 = OpTypeFloat 32 diff --git a/naga/tests/out/spv/wgsl-mesh-shader.spvasm b/naga/tests/out/spv/wgsl-mesh-shader.spvasm new file mode 100644 index 00000000000..28dfb64740c --- /dev/null +++ b/naga/tests/out/spv/wgsl-mesh-shader.spvasm @@ -0,0 +1,307 @@ +; SPIR-V +; Version: 1.4 +; Generator: rspirv +; Bound: 199 +OpCapability Shader +OpCapability MeshShadingEXT +OpExtension "SPV_EXT_mesh_shader" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint TaskEXT %17 "ts_main" %12 %14 %28 +OpEntryPoint MeshEXT %86 "ms_main" %49 %52 %12 %59 %63 %55 %56 %70 %74 %78 %81 %85 %14 %103 +OpEntryPoint Fragment %194 "fs_main" %185 %188 %191 %193 +OpExecutionMode %17 LocalSize 1 1 1 +OpExecutionMode %86 LocalSize 1 1 1 +OpExecutionMode %86 OutputTrianglesNV +OpExecutionMode %86 OutputVertices 3 +OpExecutionMode %86 OutputPrimitivesNV 1 +OpExecutionMode %194 OriginUpperLeft +OpMemberDecorate %67 0 BuiltIn Position +OpDecorate %67 Block +OpMemberDecorate %71 0 BuiltIn CullPrimitiveEXT +OpDecorate %71 Block +OpDecorate %74 PerPrimitiveNV +OpDecorate %75 Block +OpMemberDecorate %75 0 Location 0 +OpDecorate %81 PerPrimitiveNV +OpDecorate %81 BuiltIn PrimitiveTriangleIndicesEXT +OpDecorate %82 Block +OpMemberDecorate %82 0 Location 1 +OpDecorate %85 PerPrimitiveNV +OpMemberDecorate %6 0 Offset 0 +OpMemberDecorate %6 1 Offset 16 +OpMemberDecorate %7 0 Offset 0 +OpMemberDecorate %7 1 Offset 16 +OpMemberDecorate %10 0 Offset 0 +OpMemberDecorate %10 1 Offset 12 +OpMemberDecorate %10 2 Offset 16 +OpMemberDecorate %11 0 Offset 0 +OpDecorate %28 BuiltIn LocalInvocationId +OpDecorate %49 BuiltIn LocalInvocationIndex +OpDecorate %52 BuiltIn GlobalInvocationId +OpDecorate %103 BuiltIn LocalInvocationId +OpDecorate %185 BuiltIn FragCoord +OpDecorate %188 Location 0 +OpDecorate %191 Location 1 +OpDecorate %191 PerPrimitiveNV +OpDecorate %193 Location 0 +%2 = OpTypeVoid +%3 = OpTypeFloat 32 +%4 = OpTypeVector %3 4 +%5 = OpTypeBool +%6 = OpTypeStruct %4 %5 +%7 = OpTypeStruct %4 %4 +%8 = OpTypeInt 32 0 +%9 = OpTypeVector %8 3 +%10 = OpTypeStruct %9 %5 %4 +%11 = OpTypeStruct %4 +%13 = OpTypePointer TaskPayloadWorkgroupEXT %6 +%12 = OpVariable %13 TaskPayloadWorkgroupEXT +%15 = OpTypePointer Workgroup %3 +%14 = OpVariable %15 Workgroup +%18 = OpTypeFunction %2 +%19 = OpConstant %3 1 +%20 = OpConstant %3 0 +%21 = OpConstantComposite %4 %19 %19 %20 %19 +%22 = OpConstantTrue %5 +%23 = OpConstant %8 3 +%24 = OpConstant %8 1 +%25 = OpConstantComposite %9 %23 %24 %24 +%27 = OpConstantNull %3 +%29 = OpTypePointer Input %9 +%28 = OpVariable %29 Input +%31 = OpConstantNull %9 +%32 = OpTypeVector %5 3 +%37 = OpConstant %8 2 +%38 = OpConstant %8 264 +%40 = OpTypePointer TaskPayloadWorkgroupEXT %4 +%41 = OpConstant %8 0 +%43 = OpTypePointer TaskPayloadWorkgroupEXT %5 +%50 = OpTypePointer Input %8 +%49 = OpVariable %50 Input +%52 = OpVariable %29 Input +%54 = OpTypePointer Workgroup %8 +%55 = OpVariable %54 Workgroup +%56 = OpVariable %54 Workgroup +%57 = OpConstant %8 3 +%58 = OpTypeArray %7 %57 +%60 = OpTypePointer Workgroup %58 +%59 = OpVariable %60 Workgroup +%61 = OpConstant %8 1 +%62 = OpTypeArray %10 %61 +%64 = OpTypePointer Workgroup %62 +%63 = OpVariable %64 Workgroup +%66 = OpTypePointer Function %8 +%67 = OpTypeStruct %4 +%68 = OpTypeArray %67 %57 +%69 = OpTypePointer Output %68 +%70 = OpVariable %69 Output +%71 = OpTypeStruct %5 +%72 = OpTypeArray %71 %61 +%73 = OpTypePointer Output %72 +%74 = OpVariable %73 Output +%75 = OpTypeStruct %4 +%76 = OpTypeArray %75 %61 +%77 = OpTypePointer Output %76 +%78 = OpVariable %77 Output +%79 = OpTypeArray %9 %61 +%80 = OpTypePointer Output %79 +%81 = OpVariable %80 Output +%82 = OpTypeStruct %4 +%83 = OpTypeArray %82 %61 +%84 = OpTypePointer Output %83 +%85 = OpVariable %84 Output +%87 = OpConstant %3 2 +%88 = OpConstantComposite %4 %20 %19 %20 %19 +%89 = OpConstant %3 -1 +%90 = OpConstantComposite %4 %89 %89 %20 %19 +%91 = OpConstantComposite %4 %20 %20 %19 %19 +%92 = OpConstantComposite %4 %19 %89 %20 %19 +%93 = OpConstantComposite %4 %19 %20 %20 %19 +%94 = OpConstantComposite %9 %41 %24 %37 +%95 = OpConstantComposite %4 %19 %20 %19 %19 +%97 = OpTypePointer Function %7 +%98 = OpConstantNull %7 +%100 = OpTypePointer Function %10 +%101 = OpConstantNull %10 +%103 = OpVariable %29 Input +%110 = OpTypePointer Function %4 +%118 = OpTypePointer Workgroup %7 +%133 = OpTypePointer Function %9 +%135 = OpTypePointer Function %5 +%143 = OpTypePointer Workgroup %10 +%155 = OpTypePointer Output %4 +%163 = OpTypePointer Output %9 +%166 = OpTypePointer Output %5 +%186 = OpTypePointer Input %4 +%185 = OpVariable %186 Input +%188 = OpVariable %186 Input +%191 = OpVariable %186 Input +%193 = OpVariable %155 Output +%17 = OpFunction %2 None %18 +%16 = OpLabel +OpBranch %26 +%26 = OpLabel +%30 = OpLoad %9 %28 +%33 = OpIEqual %32 %30 %31 +%34 = OpAll %5 %33 +OpSelectionMerge %35 None +OpBranchConditional %34 %36 %35 +%36 = OpLabel +OpStore %14 %27 +OpBranch %35 +%35 = OpLabel +OpControlBarrier %37 %37 %38 +OpBranch %39 +%39 = OpLabel +OpStore %14 %19 +%42 = OpAccessChain %40 %12 %41 +OpStore %42 %21 +%44 = OpAccessChain %43 %12 %24 +OpStore %44 %22 +%45 = OpCompositeExtract %8 %25 0 +%46 = OpCompositeExtract %8 %25 1 +%47 = OpCompositeExtract %8 %25 2 +OpEmitMeshTasksEXT %45 %46 %47 %12 +OpFunctionEnd +%86 = OpFunction %2 None %18 +%48 = OpLabel +%96 = OpVariable %97 Function %98 +%99 = OpVariable %100 Function %101 +%65 = OpVariable %66 Function +%51 = OpLoad %8 %49 +%53 = OpLoad %9 %52 +OpBranch %102 +%102 = OpLabel +%104 = OpLoad %9 %103 +%105 = OpIEqual %32 %104 %31 +%106 = OpAll %5 %105 +OpSelectionMerge %107 None +OpBranchConditional %106 %108 %107 +%108 = OpLabel +OpStore %14 %27 +OpBranch %107 +%107 = OpLabel +OpControlBarrier %37 %37 %38 +OpBranch %109 +%109 = OpLabel +OpStore %55 %23 +OpStore %56 %24 +OpStore %14 %87 +%111 = OpAccessChain %110 %96 %41 +OpStore %111 %88 +%112 = OpAccessChain %40 %12 %41 +%113 = OpLoad %4 %112 +%114 = OpFMul %4 %88 %113 +%115 = OpAccessChain %110 %96 %24 +OpStore %115 %114 +%116 = OpLoad %7 %96 +%117 = OpAccessChain %118 %59 %41 +OpStore %117 %116 +%119 = OpAccessChain %110 %96 %41 +OpStore %119 %90 +%120 = OpAccessChain %40 %12 %41 +%121 = OpLoad %4 %120 +%122 = OpFMul %4 %91 %121 +%123 = OpAccessChain %110 %96 %24 +OpStore %123 %122 +%124 = OpLoad %7 %96 +%125 = OpAccessChain %118 %59 %24 +OpStore %125 %124 +%126 = OpAccessChain %110 %96 %41 +OpStore %126 %92 +%127 = OpAccessChain %40 %12 %41 +%128 = OpLoad %4 %127 +%129 = OpFMul %4 %93 %128 +%130 = OpAccessChain %110 %96 %24 +OpStore %130 %129 +%131 = OpLoad %7 %96 +%132 = OpAccessChain %118 %59 %37 +OpStore %132 %131 +%134 = OpAccessChain %133 %99 %41 +OpStore %134 %94 +%136 = OpAccessChain %43 %12 %24 +%137 = OpLoad %5 %136 +%138 = OpLogicalNot %5 %137 +%139 = OpAccessChain %135 %99 %24 +OpStore %139 %138 +%140 = OpAccessChain %110 %99 %37 +OpStore %140 %95 +%141 = OpLoad %10 %99 +%142 = OpAccessChain %143 %63 %41 +OpStore %142 %141 +%144 = OpLoad %8 %55 +%145 = OpLoad %8 %56 +OpSetMeshOutputsEXT %144 %145 +OpStore %65 %51 +OpBranch %146 +%146 = OpLabel +OpLoopMerge %148 %170 None +OpBranch %169 +%169 = OpLabel +%172 = OpLoad %8 %65 +%173 = OpULessThan %5 %172 %144 +OpBranchConditional %173 %171 %148 +%171 = OpLabel +%150 = OpLoad %8 %65 +%151 = OpAccessChain %118 %59 %150 +%152 = OpLoad %7 %151 +%153 = OpCompositeExtract %4 %152 0 +%154 = OpAccessChain %155 %70 %150 %41 +OpStore %154 %153 +%156 = OpCompositeExtract %4 %152 1 +%157 = OpAccessChain %155 %78 %150 %41 +OpStore %157 %156 +OpBranch %170 +%170 = OpLabel +%174 = OpLoad %8 %65 +%175 = OpIAdd %8 %174 %24 +OpStore %65 %175 +OpBranch %146 +%148 = OpLabel +OpStore %65 %51 +OpBranch %147 +%147 = OpLabel +OpLoopMerge %149 %177 None +OpBranch %176 +%176 = OpLabel +%179 = OpLoad %8 %65 +%180 = OpULessThan %5 %179 %145 +OpBranchConditional %180 %178 %149 +%178 = OpLabel +%158 = OpLoad %8 %65 +%159 = OpAccessChain %143 %63 %158 +%160 = OpLoad %10 %159 +%161 = OpCompositeExtract %9 %160 0 +%162 = OpAccessChain %163 %81 %158 +OpStore %162 %161 +%164 = OpCompositeExtract %5 %160 1 +%165 = OpAccessChain %166 %74 %158 %41 +OpStore %165 %164 +%167 = OpCompositeExtract %4 %160 2 +%168 = OpAccessChain %155 %85 %158 %41 +OpStore %168 %167 +OpBranch %177 +%177 = OpLabel +%181 = OpLoad %8 %65 +%182 = OpIAdd %8 %181 %24 +OpStore %65 %182 +OpBranch %147 +%149 = OpLabel +OpReturn +OpFunctionEnd +%194 = OpFunction %2 None %18 +%183 = OpLabel +%187 = OpLoad %4 %185 +%189 = OpLoad %4 %188 +%184 = OpCompositeConstruct %7 %187 %189 +%192 = OpLoad %4 %191 +%190 = OpCompositeConstruct %11 %192 +OpBranch %195 +%195 = OpLabel +%196 = OpCompositeExtract %4 %184 1 +%197 = OpCompositeExtract %4 %190 0 +%198 = OpFMul %4 %196 %197 +OpStore %193 %198 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 95301df9488..581af3e879c 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -33,11 +33,29 @@ webgl = ["wgpu/webgl"] # allows us to force the build to have profiling code enabled so we can test that configuration. test-build-with-profiling = ["profiling/type-check"] +# Naga forwarded features +glsl-in = ["naga/glsl-in"] +glsl-out = ["naga/glsl-out"] +spv-in = ["naga/spv-in"] +spv-out = ["naga/spv-out"] +wgsl-in = ["naga/wgsl-in"] +wgsl-out = ["naga/wgsl-out"] +msl-out = ["naga/msl-out"] +dot-out = ["naga/dot-out"] +hlsl-out = ["naga/hlsl-out"] + [dependencies] wgpu = { workspace = true, features = ["noop"] } wgpu-hal = { workspace = true, features = ["validation_canary"] } wgpu-macros.workspace = true +# Naga stuff that lives here due to sharing logic with benchmarks +naga = { workspace = true, features = ["serialize", "deserialize"] } +spirv = { workspace = true, features = ["deserialize"] } +rspirv.workspace = true +ron.workspace = true +toml.workspace = true + anyhow.workspace = true arrayvec.workspace = true approx.workspace = true diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 22afd7ecf77..775a3fdfc2b 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -7,6 +7,7 @@ mod expectations; pub mod image; mod init; mod isolation; +pub mod naga; pub mod native; mod params; mod poll; diff --git a/tests/src/naga.rs b/tests/src/naga.rs new file mode 100644 index 00000000000..075b26e46c0 --- /dev/null +++ b/tests/src/naga.rs @@ -0,0 +1,454 @@ +// A lot of the code can be unused based on configuration flags, +// the corresponding warnings aren't helpful. +#![allow(dead_code, unused_imports)] + +use core::fmt::Write; + +use std::{ + fs, + path::{Path, PathBuf}, +}; + +use naga::compact::KeepUnused; +use ron::de; + +bitflags::bitflags! { + #[derive(Clone, Copy, serde::Deserialize)] + #[serde(transparent)] + #[derive(Debug, Eq, PartialEq)] + pub struct Targets: u32 { + /// A serialization of the `naga::Module`, in RON format. + const IR = 1; + + /// A serialization of the `naga::valid::ModuleInfo`, in RON format. + const ANALYSIS = 1 << 1; + + const SPIRV = 1 << 2; + const METAL = 1 << 3; + const GLSL = 1 << 4; + const DOT = 1 << 5; + const HLSL = 1 << 6; + const WGSL = 1 << 7; + const NO_VALIDATION = 1 << 8; + } +} + +impl Targets { + /// Defaults for `spv` and `glsl` snapshots. + pub fn non_wgsl_default() -> Self { + Targets::WGSL + } + + /// Defaults for `wgsl` snapshots. + pub fn wgsl_default() -> Self { + Targets::HLSL | Targets::SPIRV | Targets::GLSL | Targets::METAL | Targets::WGSL + } +} + +#[derive(serde::Deserialize)] +pub struct SpvOutVersion(pub u8, pub u8); +impl Default for SpvOutVersion { + fn default() -> Self { + SpvOutVersion(1, 1) + } +} + +#[cfg(feature = "spv-out")] +#[derive(serde::Deserialize)] +pub struct BindingMapSerialization { + pub resource_binding: naga::ResourceBinding, + pub bind_target: naga::back::spv::BindingInfo, +} + +#[cfg(feature = "spv-out")] +pub fn deserialize_binding_map<'de, D>( + deserializer: D, +) -> Result +where + D: serde::Deserializer<'de>, +{ + use serde::Deserialize; + + let vec = Vec::::deserialize(deserializer)?; + let mut map = naga::back::spv::BindingMap::default(); + for item in vec { + map.insert(item.resource_binding, item.bind_target); + } + Ok(map) +} + +#[derive(Default, serde::Deserialize)] +#[serde(default)] +pub struct WgslInParameters { + pub parse_doc_comments: bool, +} +#[cfg(feature = "wgsl-in")] +impl From<&WgslInParameters> for naga::front::wgsl::Options { + fn from(value: &WgslInParameters) -> Self { + Self { + parse_doc_comments: value.parse_doc_comments, + } + } +} + +#[derive(Default, serde::Deserialize)] +#[serde(default)] +pub struct SpirvInParameters { + pub adjust_coordinate_space: bool, +} +#[cfg(feature = "spv-in")] +impl From<&SpirvInParameters> for naga::front::spv::Options { + fn from(value: &SpirvInParameters) -> Self { + Self { + adjust_coordinate_space: value.adjust_coordinate_space, + ..Default::default() + } + } +} + +#[derive(Default, serde::Deserialize)] +#[serde(default)] +pub struct SpirvOutParameters { + pub version: SpvOutVersion, + pub capabilities: naga::FastHashSet, + pub debug: bool, + pub adjust_coordinate_space: bool, + pub force_point_size: bool, + pub clamp_frag_depth: bool, + pub separate_entry_points: bool, + #[cfg(feature = "spv-out")] + #[serde(deserialize_with = "deserialize_binding_map")] + pub binding_map: naga::back::spv::BindingMap, +} +#[cfg(feature = "spv-out")] +impl SpirvOutParameters { + pub fn to_options<'a>( + &'a self, + bounds_check_policies: naga::proc::BoundsCheckPolicies, + debug_info: Option>, + ) -> naga::back::spv::Options<'a> { + use naga::back::spv; + let mut flags = spv::WriterFlags::LABEL_VARYINGS; + flags.set(spv::WriterFlags::DEBUG, self.debug); + flags.set( + spv::WriterFlags::ADJUST_COORDINATE_SPACE, + self.adjust_coordinate_space, + ); + flags.set(spv::WriterFlags::FORCE_POINT_SIZE, self.force_point_size); + flags.set(spv::WriterFlags::CLAMP_FRAG_DEPTH, self.clamp_frag_depth); + naga::back::spv::Options { + lang_version: (self.version.0, self.version.1), + flags, + capabilities: if self.capabilities.is_empty() { + None + } else { + Some(self.capabilities.clone()) + }, + bounds_check_policies, + binding_map: self.binding_map.clone(), + zero_initialize_workgroup_memory: spv::ZeroInitializeWorkgroupMemoryMode::Polyfill, + force_loop_bounding: true, + debug_info, + } + } +} + +#[derive(Default, serde::Deserialize)] +#[serde(default)] +pub struct WgslOutParameters { + pub explicit_types: bool, +} +#[cfg(feature = "wgsl-out")] +impl From<&WgslOutParameters> for naga::back::wgsl::WriterFlags { + fn from(value: &WgslOutParameters) -> Self { + let mut flags = Self::empty(); + flags.set(Self::EXPLICIT_TYPES, value.explicit_types); + flags + } +} + +#[derive(Default, serde::Deserialize)] +pub struct FragmentModule { + pub path: String, + pub entry_point: String, +} + +#[derive(Default, serde::Deserialize)] +#[serde(default)] +pub struct Parameters { + // -- GOD MODE -- + pub god_mode: bool, + + // -- wgsl-in options -- + #[serde(rename = "wgsl-in")] + pub wgsl_in: WgslInParameters, + + // -- spirv-in options -- + #[serde(rename = "spv-in")] + pub spv_in: SpirvInParameters, + + // -- SPIR-V options -- + pub spv: SpirvOutParameters, + + /// Defaults to [`Targets::non_wgsl_default()`] for `spv` and `glsl` snapshots, + /// and [`Targets::wgsl_default()`] for `wgsl` snapshots. + pub targets: Option, + + // -- MSL options -- + #[cfg(feature = "msl-out")] + pub msl: naga::back::msl::Options, + #[cfg(feature = "msl-out")] + #[serde(default)] + pub msl_pipeline: naga::back::msl::PipelineOptions, + + // -- GLSL options -- + #[cfg(feature = "glsl-out")] + pub glsl: naga::back::glsl::Options, + pub glsl_exclude_list: naga::FastHashSet, + #[cfg(feature = "glsl-out")] + pub glsl_multiview: Option, + + // -- HLSL options -- + #[cfg(feature = "hlsl-out")] + pub hlsl: naga::back::hlsl::Options, + + // -- WGSL options -- + pub wgsl: WgslOutParameters, + + // -- General options -- + + // Allow backends to be aware of the fragment module. + // Is the name of a WGSL file in the same directory as the test file. + pub fragment_module: Option, + + pub bounds_check_policies: naga::proc::BoundsCheckPolicies, + + #[cfg(any( + feature = "hlsl-out", + feature = "msl-out", + feature = "spv-out", + feature = "glsl-out" + ))] + pub pipeline_constants: naga::back::PipelineConstants, +} + +/// Information about a shader input file. +#[derive(Debug)] +pub struct Input { + /// The subdirectory of `tests/in` to which this input belongs, if any. + /// + /// If the subdirectory is omitted, we assume that the output goes + /// to "wgsl". + pub subdirectory: PathBuf, + + /// The input filename name, without a directory. + pub file_name: PathBuf, + + /// True if output filenames should add the output extension on top of + /// `file_name`'s existing extension, rather than replacing it. + /// + /// This is used by `convert_snapshots_glsl`, which wants to take input files + /// like `210-bevy-2d-shader.frag` and just add `.wgsl` to it, producing + /// `210-bevy-2d-shader.frag.wgsl`. + pub keep_input_extension: bool, +} + +impl Input { + /// Read an input file and its corresponding parameters file. + /// + /// Given `input`, the relative path of a shader input file, return + /// a `Source` value containing its path, code, and parameters. + /// + /// The `input` path is interpreted relative to the `BASE_DIR_IN` + /// subdirectory of the directory given by the `CARGO_MANIFEST_DIR` + /// environment variable. + pub fn new(subdirectory: &str, name: &str, extension: &str) -> Input { + Input { + subdirectory: PathBuf::from(subdirectory), + // Don't wipe out any extensions on `name`, as + // `with_extension` would do. + file_name: PathBuf::from(format!("{name}.{extension}")), + keep_input_extension: false, + } + } + + /// Return an iterator that produces an `Input` for each entry in `subdirectory`. + pub fn files_in_dir<'a>( + subdirectory: &'a str, + file_extensions: &'a [&'a str], + dir_in: &str, + ) -> impl Iterator + 'a { + let input_directory = Path::new(dir_in).join(subdirectory); + + let entries = match std::fs::read_dir(&input_directory) { + Ok(entries) => entries, + Err(err) => panic!( + "Error opening directory '{}': {}", + input_directory.display(), + err + ), + }; + + entries.filter_map(move |result| { + let entry = result.expect("error reading directory"); + if !entry.file_type().unwrap().is_file() { + return None; + } + + let file_name = PathBuf::from(entry.file_name()); + let extension = file_name + .extension() + .expect("all files in snapshot input directory should have extensions"); + + if !file_extensions.contains(&extension.to_str().unwrap()) { + return None; + } + + if let Ok(pat) = std::env::var("NAGA_SNAPSHOT") { + if !file_name.to_string_lossy().contains(&pat) { + return None; + } + } + + let input = Input::new( + subdirectory, + file_name.file_stem().unwrap().to_str().unwrap(), + extension.to_str().unwrap(), + ); + Some(input) + }) + } + + /// Return the path to the input directory. + pub fn input_directory(&self, dir_in: &str) -> PathBuf { + Path::new(dir_in).join(&self.subdirectory) + } + + /// Return the path to the output directory. + pub fn output_directory(subdirectory: &str, dir_out: &str) -> PathBuf { + Path::new(dir_out).join(subdirectory) + } + + /// Return the path to the input file. + pub fn input_path(&self, dir_in: &str) -> PathBuf { + let mut input = self.input_directory(dir_in); + input.push(&self.file_name); + input + } + + pub fn output_path(&self, subdirectory: &str, extension: &str, dir_out: &str) -> PathBuf { + let mut output = Self::output_directory(subdirectory, dir_out); + if self.keep_input_extension { + let file_name = format!( + "{}-{}.{}", + self.subdirectory.display(), + self.file_name.display(), + extension + ); + + output.push(&file_name); + } else { + let file_name = format!( + "{}-{}", + self.subdirectory.display(), + self.file_name.display() + ); + + output.push(&file_name); + output.set_extension(extension); + } + output + } + + /// Return the contents of the input file as a string. + pub fn read_source(&self, dir_in: &str, print: bool) -> String { + if print { + println!("Processing '{}'", self.file_name.display()); + } + let input_path = self.input_path(dir_in); + match fs::read_to_string(&input_path) { + Ok(source) => source, + Err(err) => { + panic!( + "Couldn't read shader input file `{}`: {}", + input_path.display(), + err + ); + } + } + } + + /// Return the contents of the input file as a vector of bytes. + pub fn read_bytes(&self, dir_in: &str, print: bool) -> Vec { + if print { + println!("Processing '{}'", self.file_name.display()); + } + let input_path = self.input_path(dir_in); + match fs::read(&input_path) { + Ok(bytes) => bytes, + Err(err) => { + panic!( + "Couldn't read shader input file `{}`: {}", + input_path.display(), + err + ); + } + } + } + + pub fn bytes(&self, dir_in: &str) -> u64 { + let input_path = self.input_path(dir_in); + std::fs::metadata(input_path).unwrap().len() + } + + /// Return this input's parameter file, parsed. + pub fn read_parameters(&self, dir_in: &str) -> Parameters { + let mut param_path = self.input_path(dir_in); + param_path.set_extension("toml"); + let mut params = match fs::read_to_string(¶m_path) { + Ok(string) => match toml::de::from_str(&string) { + Ok(params) => params, + Err(e) => panic!( + "Couldn't parse param file: {} due to: {e}", + param_path.display() + ), + }, + Err(_) => Parameters::default(), + }; + + if params.targets.is_none() { + match self + .input_path(dir_in) + .extension() + .unwrap() + .to_str() + .unwrap() + { + "wgsl" => params.targets = Some(Targets::wgsl_default()), + "spvasm" => params.targets = Some(Targets::non_wgsl_default()), + "vert" | "frag" | "comp" => params.targets = Some(Targets::non_wgsl_default()), + e => { + panic!("Unknown extension: {e}"); + } + } + } + + params + } + + /// Write `data` to a file corresponding to this input file in + /// `subdirectory`, with `extension`. + pub fn write_output_file( + &self, + subdirectory: &str, + extension: &str, + data: impl AsRef<[u8]>, + dir_out: &str, + ) { + let output_path = self.output_path(subdirectory, extension, dir_out); + fs::create_dir_all(output_path.parent().unwrap()).unwrap(); + if let Err(err) = fs::write(&output_path, data) { + panic!("Error writing {}: {}", output_path.display(), err); + } + } +} diff --git a/wgpu-core/src/validation.rs b/wgpu-core/src/validation.rs index 2c2f4b36c44..ae199f2c703 100644 --- a/wgpu-core/src/validation.rs +++ b/wgpu-core/src/validation.rs @@ -1085,6 +1085,8 @@ impl Interface { wgt::ShaderStages::VERTEX => naga::ShaderStage::Vertex, wgt::ShaderStages::FRAGMENT => naga::ShaderStage::Fragment, wgt::ShaderStages::COMPUTE => naga::ShaderStage::Compute, + wgt::ShaderStages::MESH => naga::ShaderStage::Mesh, + wgt::ShaderStages::TASK => naga::ShaderStage::Task, _ => unreachable!(), } } @@ -1229,7 +1231,7 @@ impl Interface { } // check workgroup size limits - if shader_stage == naga::ShaderStage::Compute { + if shader_stage.compute_like() { let max_workgroup_size_limits = [ self.limits.max_compute_workgroup_size_x, self.limits.max_compute_workgroup_size_y, diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index bb4e2a9d4ae..51381ce4f75 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -2099,6 +2099,9 @@ impl super::Adapter { if features.contains(wgt::Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN) { capabilities.push(spv::Capability::RayQueryPositionFetchKHR) } + if features.contains(wgt::Features::EXPERIMENTAL_MESH_SHADER) { + capabilities.push(spv::Capability::MeshShadingEXT); + } if self.private_caps.shader_integer_dot_product { // See . capabilities.extend(&[