diff --git a/src/front/glsl/ast.rs b/src/front/glsl/ast.rs index f03a89a9da..779d172238 100644 --- a/src/front/glsl/ast.rs +++ b/src/front/glsl/ast.rs @@ -123,9 +123,13 @@ pub enum HirExprKind { tgt: Handle, value: Handle, }, - IncDec { - increment: bool, + /// A prefix/postfix operator like `++` + PrePostfix { + /// The operation to be performed + op: BinaryOperator, + /// Whether this is a postfix or a prefix postfix: bool, + /// The target expression expr: Handle, }, } diff --git a/src/front/glsl/context.rs b/src/front/glsl/context.rs index ef28821875..9d35bf1ce4 100644 --- a/src/front/glsl/context.rs +++ b/src/front/glsl/context.rs @@ -750,102 +750,108 @@ impl Context { value } - HirExprKind::IncDec { - increment, - postfix, - expr, - } => { - let op = match increment { - true => BinaryOperator::Add, - false => BinaryOperator::Subtract, - }; - + HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => { let pointer = self .lower_expect_inner(stmt, parser, expr, ExprPos::Lhs, body)? .0; let left = self.add_expression(Expression::Load { pointer }, meta, body); - let uint = match parser.resolve_type(self, left, meta)?.scalar_kind() { - Some(ScalarKind::Sint) => false, - Some(ScalarKind::Uint) => true, - _ => { + let make_constant_inner = |kind, width| { + let value = match kind { + ScalarKind::Sint => crate::ScalarValue::Sint(1), + ScalarKind::Uint => crate::ScalarValue::Uint(1), + ScalarKind::Float => crate::ScalarValue::Float(1.0), + ScalarKind::Bool => return None, + }; + + Some(crate::ConstantInner::Scalar { width, value }) + }; + let res = match *parser.resolve_type(self, left, meta)? { + TypeInner::Scalar { kind, width } => { + let ty = TypeInner::Scalar { kind, width }; + make_constant_inner(kind, width).map(|i| (ty, i, None, None)) + } + TypeInner::Vector { size, kind, width } => { + let ty = TypeInner::Vector { size, kind, width }; + make_constant_inner(kind, width).map(|i| (ty, i, Some(size), None)) + } + TypeInner::Matrix { + columns, + rows, + width, + } => { + let ty = TypeInner::Matrix { + columns, + rows, + width, + }; + make_constant_inner(ScalarKind::Float, width) + .map(|i| (ty, i, Some(rows), Some(columns))) + } + _ => None, + }; + let (ty_inner, inner, rows, columns) = match res { + Some(res) => res, + None => { parser.errors.push(Error { kind: ErrorKind::SemanticError( - "Increment/decrement operations must operate in integers".into(), + "Increment/decrement only works on scalar/vector/matrix".into(), ), meta, }); - true + return Ok((Some(left), meta)); } }; - let one = parser.module.constants.append( + let constant_1 = parser.module.constants.append( Constant { name: None, specialization: None, - inner: crate::ConstantInner::Scalar { - width: 4, - value: match uint { - true => crate::ScalarValue::Uint(1), - false => crate::ScalarValue::Sint(1), - }, - }, + inner, }, Default::default(), ); - let right = self.add_expression(Expression::Constant(one), meta, body); - - let value = self.add_expression(Expression::Binary { op, left, right }, meta, body); - - if postfix { - let local = self.locals.append( - LocalVariable { - name: None, - ty: parser.module.types.fetch_or_append( - Type { - name: None, - inner: TypeInner::Scalar { - kind: match uint { - true => ScalarKind::Uint, - false => ScalarKind::Sint, - }, - width: 4, - }, - }, - meta.as_span(), - ), - init: None, - }, - meta.as_span(), - ); - - let expr = self.add_expression(Expression::LocalVariable(local), meta, body); - let load = self.add_expression(Expression::Load { pointer: expr }, meta, body); - - self.emit_flush(body); - self.emit_start(); - - body.push( - Statement::Store { - pointer: expr, - value: left, - }, - meta.as_span(), - ); + let mut right = self.add_expression(Expression::Constant(constant_1), meta, body); + + // Glsl allows pre/postfixes operations on vectors and matrices, so if the + // target is either of them change the right side of the addition to be splatted + // to the same size as the target, furthermore if the target is a matrix + // use a composed matrix using the splatted value. + if let Some(size) = rows { + right = + self.add_expression(Expression::Splat { size, value: right }, meta, body); + + if let Some(cols) = columns { + let ty = parser.module.types.fetch_or_append( + Type { + name: None, + inner: ty_inner, + }, + meta.as_span(), + ); - self.emit_flush(body); - self.emit_start(); + right = self.add_expression( + Expression::Compose { + ty, + components: std::iter::repeat(right).take(cols as usize).collect(), + }, + meta, + body, + ); + } + } - body.push(Statement::Store { pointer, value }, meta.as_span()); + let value = self.add_expression(Expression::Binary { op, left, right }, meta, body); - load - } else { - self.emit_flush(body); - self.emit_start(); + self.emit_flush(body); + self.emit_start(); - body.push(Statement::Store { pointer, value }, meta.as_span()); + body.push(Statement::Store { pointer, value }, meta.as_span()); + if postfix { left + } else { + value } } _ => { diff --git a/src/front/glsl/parser/expressions.rs b/src/front/glsl/parser/expressions.rs index 175c451dce..2eb77d572c 100644 --- a/src/front/glsl/parser/expressions.rs +++ b/src/front/glsl/parser/expressions.rs @@ -257,24 +257,14 @@ impl<'source> ParsingContext<'source> { Default::default(), ) } - TokenValue::Increment => { + TokenValue::Increment | TokenValue::Decrement => { base = stmt.hir_exprs.append( HirExpr { - kind: HirExprKind::IncDec { - increment: true, - postfix: true, - expr: base, - }, - meta, - }, - Default::default(), - ) - } - TokenValue::Decrement => { - base = stmt.hir_exprs.append( - HirExpr { - kind: HirExprKind::IncDec { - increment: false, + kind: HirExprKind::PrePostfix { + op: match value { + TokenValue::Increment => crate::BinaryOperator::Add, + _ => crate::BinaryOperator::Subtract, + }, postfix: true, expr: base, }, @@ -331,11 +321,10 @@ impl<'source> ParsingContext<'source> { stmt.hir_exprs.append( HirExpr { - kind: HirExprKind::IncDec { - increment: match value { - TokenValue::Increment => true, - TokenValue::Decrement => false, - _ => unreachable!(), + kind: HirExprKind::PrePostfix { + op: match value { + TokenValue::Increment => crate::BinaryOperator::Add, + _ => crate::BinaryOperator::Subtract, }, postfix: false, expr, diff --git a/tests/in/glsl/prepostfix.frag b/tests/in/glsl/prepostfix.frag new file mode 100644 index 0000000000..1b59428927 --- /dev/null +++ b/tests/in/glsl/prepostfix.frag @@ -0,0 +1,18 @@ +#version 450 core + +void main() { + int scalar_target; + int scalar = 1; + scalar_target = scalar++; + scalar_target = --scalar; + + uvec2 vec_target; + uvec2 vec = uvec2(1); + vec_target = vec--; + vec_target = ++vec; + + mat4x3 mat_target; + mat4x3 mat = mat4x3(1); + mat_target = mat++; + mat_target = --mat; +} diff --git a/tests/out/wgsl/246-collatz-comp.wgsl b/tests/out/wgsl/246-collatz-comp.wgsl index e052f4b486..fe75c08fb9 100644 --- a/tests/out/wgsl/246-collatz-comp.wgsl +++ b/tests/out/wgsl/246-collatz-comp.wgsl @@ -10,7 +10,6 @@ var gl_GlobalInvocationID: vec3; fn collatz_iterations(n: u32) -> u32 { var n1: u32; var i: u32 = 0u; - var local: u32; n1 = n; loop { @@ -32,12 +31,11 @@ fn collatz_iterations(n: u32) -> u32 { } } let _e33: u32 = i; - local = _e33; i = (_e33 + 1u); } } - let _e38: u32 = i; - return _e38; + let _e36: u32 = i; + return _e36; } fn main1() { diff --git a/tests/out/wgsl/constant-array-size-vert.wgsl b/tests/out/wgsl/constant-array-size-vert.wgsl index 2eaec30862..ac18b25398 100644 --- a/tests/out/wgsl/constant-array-size-vert.wgsl +++ b/tests/out/wgsl/constant-array-size-vert.wgsl @@ -9,7 +9,6 @@ var global: Data; fn function() -> vec4 { var sum: vec4 = vec4(0.0, 0.0, 0.0, 0.0); var i: i32 = 0; - var local: i32; loop { let _e9: i32 = i; @@ -17,19 +16,18 @@ fn function() -> vec4 { break; } { - let _e17: vec4 = sum; - let _e18: i32 = i; - let _e20: vec4 = global.vecs[_e18]; - sum = (_e17 + _e20); + let _e15: vec4 = sum; + let _e16: i32 = i; + let _e18: vec4 = global.vecs[_e16]; + sum = (_e15 + _e18); } continuing { let _e12: i32 = i; - local = _e12; i = (_e12 + 1); } } - let _e22: vec4 = sum; - return _e22; + let _e20: vec4 = sum; + return _e20; } fn main1() { diff --git a/tests/out/wgsl/prepostfix-frag.wgsl b/tests/out/wgsl/prepostfix-frag.wgsl new file mode 100644 index 0000000000..df723ccf75 --- /dev/null +++ b/tests/out/wgsl/prepostfix-frag.wgsl @@ -0,0 +1,40 @@ +fn main1() { + var scalar_target: i32; + var scalar: i32 = 1; + var vec_target: vec2; + var vec: vec2 = vec2(1u, 1u); + var mat_target: mat4x3; + var mat: mat4x3 = mat4x3(vec3(1.0, 0.0, 0.0), vec3(0.0, 1.0, 0.0), vec3(0.0, 0.0, 1.0), vec3(0.0, 0.0, 0.0)); + + let _e3: i32 = scalar; + scalar = (_e3 + 1); + scalar_target = _e3; + let _e6: i32 = scalar; + let _e8: i32 = (_e6 - 1); + scalar = _e8; + scalar_target = _e8; + let _e14: vec2 = vec; + vec = (_e14 - vec2(1u)); + vec_target = _e14; + let _e18: vec2 = vec; + let _e21: vec2 = (_e18 + vec2(1u)); + vec = _e21; + vec_target = _e21; + let _e24: f32 = f32(1); + let _e32: mat4x3 = mat; + let _e34: vec3 = vec3(1.0); + mat = (_e32 + mat4x3(_e34, _e34, _e34, _e34)); + mat_target = _e32; + let _e37: mat4x3 = mat; + let _e39: vec3 = vec3(1.0); + let _e41: mat4x3 = (_e37 - mat4x3(_e39, _e39, _e39, _e39)); + mat = _e41; + mat_target = _e41; + return; +} + +[[stage(fragment)]] +fn main() { + main1(); + return; +}