From d7ca7d43b987c17f553db2823d833646ff1d48e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Capucho?= Date: Fri, 17 Sep 2021 17:01:55 +0100 Subject: [PATCH] [glsl-in] Improve infix handling Add support for float, vector and matrices targets. Fix prefix and postfix being inverted (one was returning the value of the other). Remove an unneeded local indirection for prefix handling. Add tests. --- src/front/glsl/ast.rs | 8 +- src/front/glsl/context.rs | 148 ++++++++++--------- src/front/glsl/parser/expressions.rs | 31 ++-- tests/in/glsl/prepostfix.frag | 18 +++ tests/out/wgsl/246-collatz-comp.wgsl | 6 +- tests/out/wgsl/constant-array-size-vert.wgsl | 14 +- tests/out/wgsl/prepostfix-frag.wgsl | 40 +++++ 7 files changed, 159 insertions(+), 106 deletions(-) create mode 100644 tests/in/glsl/prepostfix.frag create mode 100644 tests/out/wgsl/prepostfix-frag.wgsl diff --git a/src/front/glsl/ast.rs b/src/front/glsl/ast.rs index f03a89a9da..779d172238 100644 --- a/src/front/glsl/ast.rs +++ b/src/front/glsl/ast.rs @@ -123,9 +123,13 @@ pub enum HirExprKind { tgt: Handle, value: Handle, }, - IncDec { - increment: bool, + /// A prefix/postfix operator like `++` + PrePostfix { + /// The operation to be performed + op: BinaryOperator, + /// Whether this is a postfix or a prefix postfix: bool, + /// The target expression expr: Handle, }, } diff --git a/src/front/glsl/context.rs b/src/front/glsl/context.rs index ef28821875..9d35bf1ce4 100644 --- a/src/front/glsl/context.rs +++ b/src/front/glsl/context.rs @@ -750,102 +750,108 @@ impl Context { value } - HirExprKind::IncDec { - increment, - postfix, - expr, - } => { - let op = match increment { - true => BinaryOperator::Add, - false => BinaryOperator::Subtract, - }; - + HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => { let pointer = self .lower_expect_inner(stmt, parser, expr, ExprPos::Lhs, body)? .0; let left = self.add_expression(Expression::Load { pointer }, meta, body); - let uint = match parser.resolve_type(self, left, meta)?.scalar_kind() { - Some(ScalarKind::Sint) => false, - Some(ScalarKind::Uint) => true, - _ => { + let make_constant_inner = |kind, width| { + let value = match kind { + ScalarKind::Sint => crate::ScalarValue::Sint(1), + ScalarKind::Uint => crate::ScalarValue::Uint(1), + ScalarKind::Float => crate::ScalarValue::Float(1.0), + ScalarKind::Bool => return None, + }; + + Some(crate::ConstantInner::Scalar { width, value }) + }; + let res = match *parser.resolve_type(self, left, meta)? { + TypeInner::Scalar { kind, width } => { + let ty = TypeInner::Scalar { kind, width }; + make_constant_inner(kind, width).map(|i| (ty, i, None, None)) + } + TypeInner::Vector { size, kind, width } => { + let ty = TypeInner::Vector { size, kind, width }; + make_constant_inner(kind, width).map(|i| (ty, i, Some(size), None)) + } + TypeInner::Matrix { + columns, + rows, + width, + } => { + let ty = TypeInner::Matrix { + columns, + rows, + width, + }; + make_constant_inner(ScalarKind::Float, width) + .map(|i| (ty, i, Some(rows), Some(columns))) + } + _ => None, + }; + let (ty_inner, inner, rows, columns) = match res { + Some(res) => res, + None => { parser.errors.push(Error { kind: ErrorKind::SemanticError( - "Increment/decrement operations must operate in integers".into(), + "Increment/decrement only works on scalar/vector/matrix".into(), ), meta, }); - true + return Ok((Some(left), meta)); } }; - let one = parser.module.constants.append( + let constant_1 = parser.module.constants.append( Constant { name: None, specialization: None, - inner: crate::ConstantInner::Scalar { - width: 4, - value: match uint { - true => crate::ScalarValue::Uint(1), - false => crate::ScalarValue::Sint(1), - }, - }, + inner, }, Default::default(), ); - let right = self.add_expression(Expression::Constant(one), meta, body); - - let value = self.add_expression(Expression::Binary { op, left, right }, meta, body); - - if postfix { - let local = self.locals.append( - LocalVariable { - name: None, - ty: parser.module.types.fetch_or_append( - Type { - name: None, - inner: TypeInner::Scalar { - kind: match uint { - true => ScalarKind::Uint, - false => ScalarKind::Sint, - }, - width: 4, - }, - }, - meta.as_span(), - ), - init: None, - }, - meta.as_span(), - ); - - let expr = self.add_expression(Expression::LocalVariable(local), meta, body); - let load = self.add_expression(Expression::Load { pointer: expr }, meta, body); - - self.emit_flush(body); - self.emit_start(); - - body.push( - Statement::Store { - pointer: expr, - value: left, - }, - meta.as_span(), - ); + let mut right = self.add_expression(Expression::Constant(constant_1), meta, body); + + // Glsl allows pre/postfixes operations on vectors and matrices, so if the + // target is either of them change the right side of the addition to be splatted + // to the same size as the target, furthermore if the target is a matrix + // use a composed matrix using the splatted value. + if let Some(size) = rows { + right = + self.add_expression(Expression::Splat { size, value: right }, meta, body); + + if let Some(cols) = columns { + let ty = parser.module.types.fetch_or_append( + Type { + name: None, + inner: ty_inner, + }, + meta.as_span(), + ); - self.emit_flush(body); - self.emit_start(); + right = self.add_expression( + Expression::Compose { + ty, + components: std::iter::repeat(right).take(cols as usize).collect(), + }, + meta, + body, + ); + } + } - body.push(Statement::Store { pointer, value }, meta.as_span()); + let value = self.add_expression(Expression::Binary { op, left, right }, meta, body); - load - } else { - self.emit_flush(body); - self.emit_start(); + self.emit_flush(body); + self.emit_start(); - body.push(Statement::Store { pointer, value }, meta.as_span()); + body.push(Statement::Store { pointer, value }, meta.as_span()); + if postfix { left + } else { + value } } _ => { diff --git a/src/front/glsl/parser/expressions.rs b/src/front/glsl/parser/expressions.rs index 175c451dce..2eb77d572c 100644 --- a/src/front/glsl/parser/expressions.rs +++ b/src/front/glsl/parser/expressions.rs @@ -257,24 +257,14 @@ impl<'source> ParsingContext<'source> { Default::default(), ) } - TokenValue::Increment => { + TokenValue::Increment | TokenValue::Decrement => { base = stmt.hir_exprs.append( HirExpr { - kind: HirExprKind::IncDec { - increment: true, - postfix: true, - expr: base, - }, - meta, - }, - Default::default(), - ) - } - TokenValue::Decrement => { - base = stmt.hir_exprs.append( - HirExpr { - kind: HirExprKind::IncDec { - increment: false, + kind: HirExprKind::PrePostfix { + op: match value { + TokenValue::Increment => crate::BinaryOperator::Add, + _ => crate::BinaryOperator::Subtract, + }, postfix: true, expr: base, }, @@ -331,11 +321,10 @@ impl<'source> ParsingContext<'source> { stmt.hir_exprs.append( HirExpr { - kind: HirExprKind::IncDec { - increment: match value { - TokenValue::Increment => true, - TokenValue::Decrement => false, - _ => unreachable!(), + kind: HirExprKind::PrePostfix { + op: match value { + TokenValue::Increment => crate::BinaryOperator::Add, + _ => crate::BinaryOperator::Subtract, }, postfix: false, expr, diff --git a/tests/in/glsl/prepostfix.frag b/tests/in/glsl/prepostfix.frag new file mode 100644 index 0000000000..1b59428927 --- /dev/null +++ b/tests/in/glsl/prepostfix.frag @@ -0,0 +1,18 @@ +#version 450 core + +void main() { + int scalar_target; + int scalar = 1; + scalar_target = scalar++; + scalar_target = --scalar; + + uvec2 vec_target; + uvec2 vec = uvec2(1); + vec_target = vec--; + vec_target = ++vec; + + mat4x3 mat_target; + mat4x3 mat = mat4x3(1); + mat_target = mat++; + mat_target = --mat; +} diff --git a/tests/out/wgsl/246-collatz-comp.wgsl b/tests/out/wgsl/246-collatz-comp.wgsl index e052f4b486..fe75c08fb9 100644 --- a/tests/out/wgsl/246-collatz-comp.wgsl +++ b/tests/out/wgsl/246-collatz-comp.wgsl @@ -10,7 +10,6 @@ var gl_GlobalInvocationID: vec3; fn collatz_iterations(n: u32) -> u32 { var n1: u32; var i: u32 = 0u; - var local: u32; n1 = n; loop { @@ -32,12 +31,11 @@ fn collatz_iterations(n: u32) -> u32 { } } let _e33: u32 = i; - local = _e33; i = (_e33 + 1u); } } - let _e38: u32 = i; - return _e38; + let _e36: u32 = i; + return _e36; } fn main1() { diff --git a/tests/out/wgsl/constant-array-size-vert.wgsl b/tests/out/wgsl/constant-array-size-vert.wgsl index 2eaec30862..ac18b25398 100644 --- a/tests/out/wgsl/constant-array-size-vert.wgsl +++ b/tests/out/wgsl/constant-array-size-vert.wgsl @@ -9,7 +9,6 @@ var global: Data; fn function() -> vec4 { var sum: vec4 = vec4(0.0, 0.0, 0.0, 0.0); var i: i32 = 0; - var local: i32; loop { let _e9: i32 = i; @@ -17,19 +16,18 @@ fn function() -> vec4 { break; } { - let _e17: vec4 = sum; - let _e18: i32 = i; - let _e20: vec4 = global.vecs[_e18]; - sum = (_e17 + _e20); + let _e15: vec4 = sum; + let _e16: i32 = i; + let _e18: vec4 = global.vecs[_e16]; + sum = (_e15 + _e18); } continuing { let _e12: i32 = i; - local = _e12; i = (_e12 + 1); } } - let _e22: vec4 = sum; - return _e22; + let _e20: vec4 = sum; + return _e20; } fn main1() { diff --git a/tests/out/wgsl/prepostfix-frag.wgsl b/tests/out/wgsl/prepostfix-frag.wgsl new file mode 100644 index 0000000000..df723ccf75 --- /dev/null +++ b/tests/out/wgsl/prepostfix-frag.wgsl @@ -0,0 +1,40 @@ +fn main1() { + var scalar_target: i32; + var scalar: i32 = 1; + var vec_target: vec2; + var vec: vec2 = vec2(1u, 1u); + var mat_target: mat4x3; + var mat: mat4x3 = mat4x3(vec3(1.0, 0.0, 0.0), vec3(0.0, 1.0, 0.0), vec3(0.0, 0.0, 1.0), vec3(0.0, 0.0, 0.0)); + + let _e3: i32 = scalar; + scalar = (_e3 + 1); + scalar_target = _e3; + let _e6: i32 = scalar; + let _e8: i32 = (_e6 - 1); + scalar = _e8; + scalar_target = _e8; + let _e14: vec2 = vec; + vec = (_e14 - vec2(1u)); + vec_target = _e14; + let _e18: vec2 = vec; + let _e21: vec2 = (_e18 + vec2(1u)); + vec = _e21; + vec_target = _e21; + let _e24: f32 = f32(1); + let _e32: mat4x3 = mat; + let _e34: vec3 = vec3(1.0); + mat = (_e32 + mat4x3(_e34, _e34, _e34, _e34)); + mat_target = _e32; + let _e37: mat4x3 = mat; + let _e39: vec3 = vec3(1.0); + let _e41: mat4x3 = (_e37 - mat4x3(_e39, _e39, _e39, _e39)); + mat = _e41; + mat_target = _e41; + return; +} + +[[stage(fragment)]] +fn main() { + main1(); + return; +}