Skip to content

Commit

Permalink
Support num_workgroups builtin
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed Aug 18, 2021
1 parent 73be8c7 commit 79d899f
Show file tree
Hide file tree
Showing 20 changed files with 206 additions and 111 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## TBD
- API:
- atomic types and functions
- `num_workgroups` built-in
- WGSL `select()` order of true/false is swapped

## v0.5 (2021-06-18)
Expand Down
1 change: 1 addition & 0 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2643,6 +2643,7 @@ fn glsl_built_in(built_in: crate::BuiltIn, output: bool) -> &'static str {
Bi::LocalInvocationIndex => "gl_LocalInvocationIndex",
Bi::WorkGroupId => "gl_WorkGroupID",
Bi::WorkGroupSize => "gl_WorkGroupSize",
Bi::NumWorkGroups => "gl_NumWorkGroups",
}
}

Expand Down
8 changes: 7 additions & 1 deletion src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,13 @@ impl crate::BuiltIn {
Self::LocalInvocationId => "SV_GroupThreadID",
Self::LocalInvocationIndex => "SV_GroupIndex",
Self::WorkGroupId => "SV_GroupID",
_ => return Err(Error::Unimplemented(format!("builtin {:?}", self))),
// The specific semantic we use here doesn't matter, because references
// to this field will get replaced with references to `SPECIAL_CBUF_VAR`
// in `Writer::write_expr`.
Self::NumWorkGroups => "SV_GroupID",
Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
return Err(Error::Unimplemented(format!("builtin {:?}", self)))
}
})
}
}
Expand Down
20 changes: 19 additions & 1 deletion src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
const SPECIAL_BASE_VERTEX: &str = "base_vertex";
const SPECIAL_BASE_INSTANCE: &str = "base_instance";
const SPECIAL_OTHER: &str = "other";

/// Structure contains information required for generating
/// wrapped structure of all entry points arguments
Expand Down Expand Up @@ -105,6 +106,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, "struct {} {{", SPECIAL_CBUF_TYPE)?;
writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_BASE_VERTEX)?;
writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_BASE_INSTANCE)?;
writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
writeln!(self.out, "}};")?;
write!(
self.out,
Expand Down Expand Up @@ -1234,10 +1236,26 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(
self.out,
"({}.{} + ",
SPECIAL_CBUF_VAR, SPECIAL_BASE_INSTANCE
SPECIAL_CBUF_VAR, SPECIAL_BASE_INSTANCE,
)?;
")"
}
Some(crate::BuiltIn::NumWorkGroups) => {
//Note: despite their names (`BASE_VERTEX` and `BASE_INSTANCE`),
// in compute shaders the special constants contain the number
// of workgroups, which we are using here.
write!(
self.out,
"uint3({}.{}, {}.{}, {}.{})",
SPECIAL_CBUF_VAR,
SPECIAL_BASE_VERTEX,
SPECIAL_CBUF_VAR,
SPECIAL_BASE_INSTANCE,
SPECIAL_CBUF_VAR,
SPECIAL_OTHER,
)?;
return Ok(());
}
_ => "",
};

Expand Down
3 changes: 2 additions & 1 deletion src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,8 @@ impl ResolvedBinding {
Bi::LocalInvocationIndex => "thread_index_in_threadgroup",
Bi::WorkGroupId => "threadgroup_position_in_grid",
Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
_ => return Err(Error::UnsupportedBuiltIn(built_in)),
Bi::NumWorkGroups => "threadgroups_per_grid",
Bi::CullDistance => return Err(Error::UnsupportedBuiltIn(built_in)),
};
write!(out, "{}", name)?;
}
Expand Down
1 change: 1 addition & 0 deletions src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,7 @@ impl Writer {
Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
Bi::WorkGroupId => BuiltIn::WorkgroupId,
Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
};

self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);
Expand Down
1 change: 1 addition & 0 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1596,6 +1596,7 @@ fn builtin_str(built_in: crate::BuiltIn) -> Option<&'static str> {
Bi::GlobalInvocationId => Some("global_invocation_id"),
Bi::WorkGroupId => Some("workgroup_id"),
Bi::WorkGroupSize => Some("workgroup_size"),
Bi::NumWorkGroups => Some("num_workgroups"),
Bi::SampleIndex => Some("sample_index"),
Bi::SampleMask => Some("sample_mask"),
Bi::PrimitiveIndex => Some("primitive_index"),
Expand Down
10 changes: 10 additions & 0 deletions src/front/glsl/variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,16 @@ impl Parser {
false,
StorageQualifier::Input,
),
"gl_NumWorkGroups" => add_builtin(
TypeInner::Vector {
size: VectorSize::Tri,
kind: ScalarKind::Uint,
width: 4,
},
BuiltIn::NumWorkGroups,
false,
StorageQualifier::Input,
),
"gl_FrontFacing" => add_builtin(
TypeInner::Scalar {
kind: ScalarKind::Bool,
Expand Down
1 change: 1 addition & 0 deletions src/front/spv/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ pub(super) fn map_builtin(word: spirv::Word) -> Result<crate::BuiltIn, Error> {
Some(Bi::LocalInvocationIndex) => crate::BuiltIn::LocalInvocationIndex,
Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId,
Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize,
Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups,
_ => return Err(Error::UnsupportedBuiltIn(word)),
})
}
Expand Down
1 change: 1 addition & 0 deletions src/front/wgsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>>
"local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
"workgroup_id" => crate::BuiltIn::WorkGroupId,
"workgroup_size" => crate::BuiltIn::WorkGroupSize,
"num_workgroups" => crate::BuiltIn::NumWorkGroups,
_ => return Err(Error::UnknownBuiltin(span)),
})
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ pub enum BuiltIn {
LocalInvocationIndex,
WorkGroupId,
WorkGroupSize,
NumWorkGroups,
}

/// Number of bytes per scalar.
Expand Down
3 changes: 2 additions & 1 deletion src/valid/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ impl FunctionInfo {
crate::BuiltIn::FrontFacing
// per-work-group built-ins are uniform
| crate::BuiltIn::WorkGroupId
| crate::BuiltIn::WorkGroupSize => true,
| crate::BuiltIn::WorkGroupSize
| crate::BuiltIn::NumWorkGroups => true,
_ => false,
},
// only flat inputs are uniform
Expand Down
3 changes: 2 additions & 1 deletion src/valid/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ impl VaryingContext<'_> {
Bi::GlobalInvocationId
| Bi::LocalInvocationId
| Bi::WorkGroupId
| Bi::WorkGroupSize => (
| Bi::WorkGroupSize
| Bi::NumWorkGroups => (
self.stage == St::Compute && !self.output,
*ty_inner
== Ti::Vector {
Expand Down
7 changes: 7 additions & 0 deletions tests/in/interface.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,11 @@
spv_version: (1, 0),
spv_capabilities: [ Shader, SampleRateShading ],
spv_adjust_coordinate_space: false,
hlsl_custom: true,
hlsl: (
shader_model: V5_1,
binding_map: {},
fake_missing_bindings: false,
special_constants_binding: Some((space: 1, register: 0)),
),
)
6 changes: 4 additions & 2 deletions tests/in/interface.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ fn fragment(
return FragmentOutput(in.varying, mask, color);
}

var<workgroup> output: array<u32, 1>;

[[stage(compute), workgroup_size(1)]]
fn compute(
[[builtin(global_invocation_id)]] global_id: vec3<u32>,
[[builtin(local_invocation_id)]] local_id: vec3<u32>,
[[builtin(local_invocation_index)]] local_index: u32,
[[builtin(workgroup_id)]] wg_id: vec3<u32>,
//TODO: https://github.com/gpuweb/gpuweb/issues/1590
//[[builtin(workgroup_size)]] wg_size: vec3<u32>,
[[builtin(num_workgroups)]] num_wgs: vec3<u32>,
) {
output[0] = global_id.x + local_id.x + local_index + wg_id.x + num_wgs.x;
}
12 changes: 11 additions & 1 deletion tests/out/hlsl/interface.hlsl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
struct NagaConstants {
int base_vertex;
int base_instance;
uint other;
};
ConstantBuffer<NagaConstants> _NagaConstants: register(b0, space1);

struct VertexOutput {
float4 position : SV_Position;
Expand All @@ -10,6 +16,8 @@ struct FragmentOutput {
float color : SV_Target0;
};

groupshared uint output[1];

struct VertexInput_vertex {
uint color1 : LOC10;
uint instance_index1 : SV_InstanceID;
Expand All @@ -28,11 +36,12 @@ struct ComputeInput_compute {
uint3 local_id1 : SV_GroupThreadID;
uint local_index1 : SV_GroupIndex;
uint3 wg_id1 : SV_GroupID;
uint3 num_wgs1 : SV_GroupID;
};

VertexOutput vertex(VertexInput_vertex vertexinput_vertex)
{
uint tmp = ((vertexinput_vertex.vertex_index1 + vertexinput_vertex.instance_index1) + vertexinput_vertex.color1);
uint tmp = (((_NagaConstants.base_vertex + vertexinput_vertex.vertex_index1) + (_NagaConstants.base_instance + vertexinput_vertex.instance_index1)) + vertexinput_vertex.color1);
const VertexOutput vertexoutput1 = { float4(1.0.xxxx), float(tmp) };
return vertexoutput1;
}
Expand All @@ -48,5 +57,6 @@ FragmentOutput fragment(FragmentInput_fragment fragmentinput_fragment)
[numthreads(1, 1, 1)]
void compute(ComputeInput_compute computeinput_compute)
{
output[0] = ((((computeinput_compute.global_id1.x + computeinput_compute.local_id1.x) + computeinput_compute.local_index1) + computeinput_compute.wg_id1.x) + uint3(_NagaConstants.base_vertex, _NagaConstants.base_instance, _NagaConstants.other).x);
return;
}
1 change: 1 addition & 0 deletions tests/out/hlsl/skybox.hlsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
struct NagaConstants {
int base_vertex;
int base_instance;
uint other;
};
ConstantBuffer<NagaConstants> _NagaConstants: register(b1);

Expand Down
6 changes: 6 additions & 0 deletions tests/out/msl/interface.msl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ struct FragmentOutput {
metal::uint sample_mask;
float color;
};
struct type4 {
metal::uint inner[1];
};

struct vertex1Input {
metal::uint color [[attribute(10)]];
Expand Down Expand Up @@ -61,6 +64,9 @@ kernel void compute1(
, metal::uint3 local_id [[thread_position_in_threadgroup]]
, metal::uint local_index [[thread_index_in_threadgroup]]
, metal::uint3 wg_id [[threadgroup_position_in_grid]]
, metal::uint3 num_wgs [[threadgroups_per_grid]]
, threadgroup type4& output
) {
output.inner[0] = (((global_id.x + local_id.x) + local_index) + wg_id.x) + num_wgs.x;
return;
}
Loading

0 comments on commit 79d899f

Please sign in to comment.