Skip to content

Commit

Permalink
Update WGSL grammar for pointer access. (gfx-rs#1312)
Browse files Browse the repository at this point in the history
* Update WGSL grammar for pointer access.

Comes with a small test, which revealed a number of issues in the backends.

* Validate pointer arguments to functions to only have function/private/workgroup classes.

Comes with a small test. Also, "pointer-access.spv" test is temporarily disabled.
  • Loading branch information
kvark authored Sep 27, 2021
1 parent 38d74a7 commit 21324b8
Show file tree
Hide file tree
Showing 19 changed files with 312 additions and 197 deletions.
15 changes: 12 additions & 3 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,9 @@ impl<'a, W: Write> Writer<'a, W> {
TypeInner::Pointer { .. }
| TypeInner::Struct { .. }
| TypeInner::Image { .. }
| TypeInner::Sampler { .. } => unreachable!(),
| TypeInner::Sampler { .. } => {
return Err(Error::Custom(format!("Unable to write type {:?}", inner)))
}
}

Ok(())
Expand Down Expand Up @@ -1332,15 +1334,22 @@ impl<'a, W: Write> Writer<'a, W> {
// This is where we can generate intermediate constants for some expression types.
Statement::Emit(ref range) => {
for handle in range.clone() {
let expr_name = if let Some(name) = ctx.named_expressions.get(&handle) {
let info = &ctx.info[handle];
let ptr_class = info.ty.inner_with(&self.module.types).pointer_class();
let expr_name = if ptr_class.is_some() {
// GLSL can't save a pointer-valued expression in a variable,
// but we shouldn't ever need to: they should never be named expressions,
// and none of the expression types flagged by bake_ref_count can be pointer-valued.
None
} else if let Some(name) = ctx.named_expressions.get(&handle) {
// Front end provides names for all variables at the start of writing.
// But we write them to step by step. We need to recache them
// Otherwise, we could accidentally write variable name instead of full expression.
// Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
Some(self.namer.call_unique(name))
} else {
let min_ref_count = ctx.expressions[handle].bake_ref_count();
if min_ref_count <= ctx.info[handle].ref_count {
if min_ref_count <= info.ref_count {
Some(format!("{}{}", super::BAKE_PREFIX, handle.index()))
} else {
None
Expand Down
11 changes: 9 additions & 2 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1057,15 +1057,22 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
match *stmt {
Statement::Emit(ref range) => {
for handle in range.clone() {
let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
let info = &func_ctx.info[handle];
let ptr_class = info.ty.inner_with(&module.types).pointer_class();
let expr_name = if ptr_class.is_some() {
// HLSL can't save a pointer-valued expression in a variable,
// but we shouldn't ever need to: they should never be named expressions,
// and none of the expression types flagged by bake_ref_count can be pointer-valued.
None
} else if let Some(name) = func_ctx.named_expressions.get(&handle) {
// Front end provides names for all variables at the start of writing.
// But we write them to step by step. We need to recache them
// Otherwise, we could accidentally write variable name instead of full expression.
// Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
Some(self.namer.call_unique(name))
} else {
let min_ref_count = func_ctx.expressions[handle].bake_ref_count();
if min_ref_count <= func_ctx.info[handle].ref_count {
if min_ref_count <= info.ref_count {
Some(format!("_expr{}", handle.index()))
} else {
None
Expand Down
13 changes: 10 additions & 3 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,7 @@ impl<W: Write> Writer<W> {
)?;
}
TypeResolution::Value(ref other) => {
log::error!("Type {:?} isn't a known local", other);
log::warn!("Type {:?} isn't a known local", other); //TEMP!
return Err(Error::FeatureNotImplemented("weird local type".to_string()));
}
}
Expand Down Expand Up @@ -1383,7 +1383,14 @@ impl<W: Write> Writer<W> {
match *statement {
crate::Statement::Emit(ref range) => {
for handle in range.clone() {
let expr_name = if let Some(name) =
let info = &context.expression.info[handle];
let ptr_class = info
.ty
.inner_with(&context.expression.module.types)
.pointer_class();
let expr_name = if ptr_class.is_some() {
None // don't bake pointer expressions (just yet)
} else if let Some(name) =
context.expression.function.named_expressions.get(&handle)
{
// Front end provides names for all variables at the start of writing.
Expand All @@ -1394,7 +1401,7 @@ impl<W: Write> Writer<W> {
} else {
let min_ref_count =
context.expression.function.expressions[handle].bake_ref_count();
if min_ref_count <= context.expression.info[handle].ref_count {
if min_ref_count <= info.ref_count {
Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
} else {
None
Expand Down
20 changes: 11 additions & 9 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,10 +525,13 @@ impl<W: Write> Writer<W> {
"storage_",
"",
storage_format_str(format),
if access.contains(crate::StorageAccess::STORE) {
",write"
if access.contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
{
",read_write"
} else if access.contains(crate::StorageAccess::LOAD) {
",read"
} else {
""
",write"
},
),
};
Expand Down Expand Up @@ -639,6 +642,7 @@ impl<W: Write> Writer<W> {
inner
)));
}
write!(self.out, ">")?;
}
_ => {
return Err(Error::Unimplemented(format!(
Expand Down Expand Up @@ -666,6 +670,7 @@ impl<W: Write> Writer<W> {
match *stmt {
Statement::Emit(ref range) => {
for handle in range.clone() {
let info = &func_ctx.info[handle];
let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
// Front end provides names for all variables at the start of writing.
// But we write them to step by step. We need to recache them
Expand All @@ -682,8 +687,7 @@ impl<W: Write> Writer<W> {
| Expression::ImageSample { .. } => true,
_ => false,
};
if min_ref_count <= func_ctx.info[handle].ref_count || required_baking_expr
{
if min_ref_count <= info.ref_count || required_baking_expr {
// If expression contains unsupported builtin we should skip it
if let Expression::Load { pointer } = func_ctx.expressions[handle] {
if let Expression::AccessIndex { base, index } =
Expand Down Expand Up @@ -809,8 +813,8 @@ impl<W: Write> Writer<W> {
}
let func_name = &self.names[&NameKey::Function(function)];
write!(self.out, "{}(", func_name)?;
for (index, argument) in arguments.iter().enumerate() {
self.write_expr(module, *argument, func_ctx)?;
for (index, &argument) in arguments.iter().enumerate() {
self.write_expr(module, argument, func_ctx)?;
// Only write a comma if isn't the last element
if index != arguments.len().saturating_sub(1) {
// The leading space is for readability only
Expand Down Expand Up @@ -1199,14 +1203,12 @@ impl<W: Write> Writer<W> {
self.write_expr(module, right, func_ctx)?;
write!(self.out, ")")?;
}
// TODO: copy-paste from glsl-out
Expression::Access { base, index } => {
self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
write!(self.out, "[")?;
self.write_expr(module, index, func_ctx)?;
write!(self.out, "]")?
}
// TODO: copy-paste from glsl-out
Expression::AccessIndex { base, index } => {
let base_ty_res = &func_ctx.info[base].ty;
let mut resolved = base_ty_res.inner_with(&module.types);
Expand Down
24 changes: 12 additions & 12 deletions src/front/wgsl/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,24 +558,24 @@ impl<'a> Lexer<'a> {
Ok(pair)
}

// TODO relocate storage texture specifics
pub(super) fn next_storage_access(&mut self) -> Result<crate::StorageAccess, Error<'a>> {
let (ident, span) = self.next_ident_with_span()?;
match ident {
"read" => Ok(crate::StorageAccess::LOAD),
"write" => Ok(crate::StorageAccess::STORE),
"read_write" => Ok(crate::StorageAccess::LOAD | crate::StorageAccess::STORE),
_ => Err(Error::UnknownAccess(span)),
}
}

pub(super) fn next_format_generic(
&mut self,
) -> Result<(crate::StorageFormat, crate::StorageAccess), Error<'a>> {
self.expect(Token::Paren('<'))?;
let (ident, ident_span) = self.next_ident_with_span()?;
let format = conv::map_storage_format(ident, ident_span)?;
let access = if self.skip(Token::Separator(',')) {
let (raw, span) = self.next_ident_with_span()?;
match raw {
"read" => crate::StorageAccess::LOAD,
"write" => crate::StorageAccess::STORE,
"read_write" => crate::StorageAccess::all(),
_ => return Err(Error::UnknownAccess(span)),
}
} else {
crate::StorageAccess::LOAD
};
self.expect(Token::Separator(','))?;
let access = self.next_storage_access()?;
self.expect(Token::Paren('>'))?;
Ok((format, access))
}
Expand Down
17 changes: 9 additions & 8 deletions src/front/wgsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2587,13 +2587,7 @@ impl Parser {
class = Some(match class_str {
"storage" => {
let access = if lexer.skip(Token::Separator(',')) {
let (ident, span) = lexer.next_ident_with_span()?;
match ident {
"read" => crate::StorageAccess::LOAD,
"write" => crate::StorageAccess::STORE,
"read_write" => crate::StorageAccess::all(),
_ => return Err(Error::UnknownAccess(span)),
}
lexer.next_storage_access()?
} else {
// defaulting to `read`
crate::StorageAccess::LOAD
Expand Down Expand Up @@ -2836,9 +2830,16 @@ impl Parser {
"ptr" => {
lexer.expect_generic_paren('<')?;
let (ident, span) = lexer.next_ident_with_span()?;
let class = conv::map_storage_class(ident, span)?;
let mut class = conv::map_storage_class(ident, span)?;
lexer.expect(Token::Separator(','))?;
let (base, _access) = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
if let crate::StorageClass::Storage { ref mut access } = class {
*access = if lexer.skip(Token::Separator(',')) {
lexer.next_storage_access()?
} else {
crate::StorageAccess::LOAD
};
}
lexer.expect_generic_paren('>')?;
crate::TypeInner::Pointer { base, class }
}
Expand Down
4 changes: 2 additions & 2 deletions src/front/wgsl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ fn parse_types() {
parse_str("var t: texture_cube_array<i32>;").unwrap();
parse_str("var t: texture_multisampled_2d<u32>;").unwrap();
parse_str("var t: texture_storage_1d<rgba8uint,write>;").unwrap();
parse_str("var t: texture_storage_3d<r32float>;").unwrap();
parse_str("var t: texture_storage_3d<r32float,read>;").unwrap();
}

#[test]
Expand Down Expand Up @@ -305,7 +305,7 @@ fn parse_texture_load() {
.unwrap();
parse_str(
"
var t: texture_storage_1d_array<r32float>;
var t: texture_storage_1d_array<r32float,read>;
fn foo() {
let r: vec4<f32> = textureLoad(t, 10, 2);
}
Expand Down
19 changes: 19 additions & 0 deletions src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ pub enum FunctionError {
},
#[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")]
InvalidArgumentType { index: usize, name: String },
#[error("Argument '{name}' at index {index} is a pointer of class {class:?}, which can't be passed into functions.")]
InvalidArgumentPointerClass {
index: usize,
name: String,
class: crate::StorageClass,
},
#[error("There are instructions after `return`/`break`/`continue`")]
InstructionsAfterReturn,
#[error("The `break` is used outside of a `loop` or `switch` context")]
Expand Down Expand Up @@ -696,6 +702,19 @@ impl super::Validator {
name: argument.name.clone().unwrap_or_default(),
});
}
match module.types[argument.ty].inner.pointer_class() {
Some(crate::StorageClass::Private)
| Some(crate::StorageClass::Function)
| Some(crate::StorageClass::WorkGroup)
| None => {}
Some(other) => {
return Err(FunctionError::InvalidArgumentPointerClass {
index,
name: argument.name.clone().unwrap_or_default(),
class: other,
})
}
}
}

self.valid_expression_set.clear();
Expand Down
8 changes: 8 additions & 0 deletions tests/in/access.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ struct Bar {
[[group(0), binding(0)]]
var<storage,read_write> bar: Bar;

fn read_from_private(foo: ptr<function, f32>) -> f32 {
return *foo;
}

[[stage(vertex)]]
fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4<f32> {
var foo: f32 = 0.0;
Expand All @@ -25,6 +29,10 @@ fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4<f32> {
let b = bar.matrix[index].x;
let a = bar.data[arrayLength(&bar.data) - 2u];

// test pointer types
let pointer: ptr<storage, i32, read_write> = &bar.data[0];
let foo_value = read_from_private(&foo);

// test storage stores
bar.matrix[1].z = 1.0;
bar.matrix = mat4x4<f32>(vec4<f32>(0.0), vec4<f32>(1.0), vec4<f32>(2.0), vec4<f32>(3.0));
Expand Down
2 changes: 1 addition & 1 deletion tests/in/image.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ var image_multisampled_src: texture_multisampled_2d<u32>;
[[group(0), binding(4)]]
var image_depth_multisampled_src: texture_depth_multisampled_2d;
[[group(0), binding(1)]]
var image_storage_src: texture_storage_2d<rgba8uint>;
var image_storage_src: texture_storage_2d<rgba8uint, read>;
[[group(0), binding(5)]]
var image_array_src: texture_2d_array<u32>;
[[group(0), binding(6)]]
Expand Down
5 changes: 5 additions & 0 deletions tests/out/glsl/access.atomics.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ buffer Bar_block_0Cs {
} _group_0_binding_0;


float read_from_private(inout float foo2) {
float _e2 = foo2;
return _e2;
}

void main() {
int tmp = 0;
int value = _group_0_binding_0.atom;
Expand Down
6 changes: 6 additions & 0 deletions tests/out/glsl/access.foo.Vertex.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ buffer Bar_block_0Vs {
} _group_0_binding_0;


float read_from_private(inout float foo2) {
float _e2 = foo2;
return _e2;
}

void main() {
uint vi = uint(gl_VertexID);
float foo1 = 0.0;
Expand All @@ -21,6 +26,7 @@ void main() {
uvec2 arr[2] = _group_0_binding_0.arr;
float b = _group_0_binding_0.matrix[3][0];
int a = _group_0_binding_0.data[(uint(_group_0_binding_0.data.length()) - 2u)];
float _e25 = read_from_private(foo1);
_group_0_binding_0.matrix[1][2] = 1.0;
_group_0_binding_0.matrix = mat4x4(vec4(0.0), vec4(1.0), vec4(2.0), vec4(3.0));
_group_0_binding_0.arr = uvec2[2](uvec2(0u), uvec2(1u));
Expand Down
7 changes: 7 additions & 0 deletions tests/out/hlsl/access.hlsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@

RWByteAddressBuffer bar : register(u0);

float read_from_private(inout float foo2)
{
float _expr2 = foo2;
return _expr2;
}

uint NagaBufferLengthRW(RWByteAddressBuffer buffer)
{
uint ret;
Expand All @@ -19,6 +25,7 @@ float4 foo(uint vi : SV_VertexID) : SV_Position
uint2 arr[2] = {asuint(bar.Load2(72+0)), asuint(bar.Load2(72+8))};
float b = asfloat(bar.Load(0+48+0));
int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 88) / 4) - 2u)*4+88));
const float _e25 = read_from_private(foo1);
bar.Store(8+16+0, asuint(1.0));
{
float4x4 _value2 = float4x4(float4(0.0.xxxx), float4(1.0.xxxx), float4(2.0.xxxx), float4(3.0.xxxx));
Expand Down
Loading

0 comments on commit 21324b8

Please sign in to comment.