Skip to content

Commit

Permalink
[glsl-in] Improve infix handling
Browse files Browse the repository at this point in the history
Add support for float, vector and matrices targets.

Fix prefix and postfix being inverted (one was returning the value
of the other).

Remove an unneeded local indirection for prefix handling.

Add tests.
  • Loading branch information
JCapucho authored and kvark committed Sep 17, 2021
1 parent 1cb4447 commit d7ca7d4
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 106 deletions.
8 changes: 6 additions & 2 deletions src/front/glsl/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,13 @@ pub enum HirExprKind {
tgt: Handle<HirExpr>,
value: Handle<HirExpr>,
},
IncDec {
increment: bool,
/// A prefix/postfix operator like `++`
PrePostfix {
/// The operation to be performed
op: BinaryOperator,
/// Whether this is a postfix or a prefix
postfix: bool,
/// The target expression
expr: Handle<HirExpr>,
},
}
Expand Down
148 changes: 77 additions & 71 deletions src/front/glsl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -750,102 +750,108 @@ impl Context {

value
}
HirExprKind::IncDec {
increment,
postfix,
expr,
} => {
let op = match increment {
true => BinaryOperator::Add,
false => BinaryOperator::Subtract,
};

HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => {
let pointer = self
.lower_expect_inner(stmt, parser, expr, ExprPos::Lhs, body)?
.0;
let left = self.add_expression(Expression::Load { pointer }, meta, body);

let uint = match parser.resolve_type(self, left, meta)?.scalar_kind() {
Some(ScalarKind::Sint) => false,
Some(ScalarKind::Uint) => true,
_ => {
let make_constant_inner = |kind, width| {
let value = match kind {
ScalarKind::Sint => crate::ScalarValue::Sint(1),
ScalarKind::Uint => crate::ScalarValue::Uint(1),
ScalarKind::Float => crate::ScalarValue::Float(1.0),
ScalarKind::Bool => return None,
};

Some(crate::ConstantInner::Scalar { width, value })
};
let res = match *parser.resolve_type(self, left, meta)? {
TypeInner::Scalar { kind, width } => {
let ty = TypeInner::Scalar { kind, width };
make_constant_inner(kind, width).map(|i| (ty, i, None, None))
}
TypeInner::Vector { size, kind, width } => {
let ty = TypeInner::Vector { size, kind, width };
make_constant_inner(kind, width).map(|i| (ty, i, Some(size), None))
}
TypeInner::Matrix {
columns,
rows,
width,
} => {
let ty = TypeInner::Matrix {
columns,
rows,
width,
};
make_constant_inner(ScalarKind::Float, width)
.map(|i| (ty, i, Some(rows), Some(columns)))
}
_ => None,
};
let (ty_inner, inner, rows, columns) = match res {
Some(res) => res,
None => {
parser.errors.push(Error {
kind: ErrorKind::SemanticError(
"Increment/decrement operations must operate in integers".into(),
"Increment/decrement only works on scalar/vector/matrix".into(),
),
meta,
});
true
return Ok((Some(left), meta));
}
};

let one = parser.module.constants.append(
let constant_1 = parser.module.constants.append(
Constant {
name: None,
specialization: None,
inner: crate::ConstantInner::Scalar {
width: 4,
value: match uint {
true => crate::ScalarValue::Uint(1),
false => crate::ScalarValue::Sint(1),
},
},
inner,
},
Default::default(),
);
let right = self.add_expression(Expression::Constant(one), meta, body);

let value = self.add_expression(Expression::Binary { op, left, right }, meta, body);

if postfix {
let local = self.locals.append(
LocalVariable {
name: None,
ty: parser.module.types.fetch_or_append(
Type {
name: None,
inner: TypeInner::Scalar {
kind: match uint {
true => ScalarKind::Uint,
false => ScalarKind::Sint,
},
width: 4,
},
},
meta.as_span(),
),
init: None,
},
meta.as_span(),
);

let expr = self.add_expression(Expression::LocalVariable(local), meta, body);
let load = self.add_expression(Expression::Load { pointer: expr }, meta, body);

self.emit_flush(body);
self.emit_start();

body.push(
Statement::Store {
pointer: expr,
value: left,
},
meta.as_span(),
);
let mut right = self.add_expression(Expression::Constant(constant_1), meta, body);

// Glsl allows pre/postfixes operations on vectors and matrices, so if the
// target is either of them change the right side of the addition to be splatted
// to the same size as the target, furthermore if the target is a matrix
// use a composed matrix using the splatted value.
if let Some(size) = rows {
right =
self.add_expression(Expression::Splat { size, value: right }, meta, body);

if let Some(cols) = columns {
let ty = parser.module.types.fetch_or_append(
Type {
name: None,
inner: ty_inner,
},
meta.as_span(),
);

self.emit_flush(body);
self.emit_start();
right = self.add_expression(
Expression::Compose {
ty,
components: std::iter::repeat(right).take(cols as usize).collect(),
},
meta,
body,
);
}
}

body.push(Statement::Store { pointer, value }, meta.as_span());
let value = self.add_expression(Expression::Binary { op, left, right }, meta, body);

load
} else {
self.emit_flush(body);
self.emit_start();
self.emit_flush(body);
self.emit_start();

body.push(Statement::Store { pointer, value }, meta.as_span());
body.push(Statement::Store { pointer, value }, meta.as_span());

if postfix {
left
} else {
value
}
}
_ => {
Expand Down
31 changes: 10 additions & 21 deletions src/front/glsl/parser/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,24 +257,14 @@ impl<'source> ParsingContext<'source> {
Default::default(),
)
}
TokenValue::Increment => {
TokenValue::Increment | TokenValue::Decrement => {
base = stmt.hir_exprs.append(
HirExpr {
kind: HirExprKind::IncDec {
increment: true,
postfix: true,
expr: base,
},
meta,
},
Default::default(),
)
}
TokenValue::Decrement => {
base = stmt.hir_exprs.append(
HirExpr {
kind: HirExprKind::IncDec {
increment: false,
kind: HirExprKind::PrePostfix {
op: match value {
TokenValue::Increment => crate::BinaryOperator::Add,
_ => crate::BinaryOperator::Subtract,
},
postfix: true,
expr: base,
},
Expand Down Expand Up @@ -331,11 +321,10 @@ impl<'source> ParsingContext<'source> {

stmt.hir_exprs.append(
HirExpr {
kind: HirExprKind::IncDec {
increment: match value {
TokenValue::Increment => true,
TokenValue::Decrement => false,
_ => unreachable!(),
kind: HirExprKind::PrePostfix {
op: match value {
TokenValue::Increment => crate::BinaryOperator::Add,
_ => crate::BinaryOperator::Subtract,
},
postfix: false,
expr,
Expand Down
18 changes: 18 additions & 0 deletions tests/in/glsl/prepostfix.frag
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#version 450 core

void main() {
int scalar_target;
int scalar = 1;
scalar_target = scalar++;
scalar_target = --scalar;

uvec2 vec_target;
uvec2 vec = uvec2(1);
vec_target = vec--;
vec_target = ++vec;

mat4x3 mat_target;
mat4x3 mat = mat4x3(1);
mat_target = mat++;
mat_target = --mat;
}
6 changes: 2 additions & 4 deletions tests/out/wgsl/246-collatz-comp.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ var<private> gl_GlobalInvocationID: vec3<u32>;
fn collatz_iterations(n: u32) -> u32 {
var n1: u32;
var i: u32 = 0u;
var local: u32;

n1 = n;
loop {
Expand All @@ -32,12 +31,11 @@ fn collatz_iterations(n: u32) -> u32 {
}
}
let _e33: u32 = i;
local = _e33;
i = (_e33 + 1u);
}
}
let _e38: u32 = i;
return _e38;
let _e36: u32 = i;
return _e36;
}

fn main1() {
Expand Down
14 changes: 6 additions & 8 deletions tests/out/wgsl/constant-array-size-vert.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,25 @@ var<uniform> global: Data;
fn function() -> vec4<f32> {
var sum: vec4<f32> = vec4<f32>(0.0, 0.0, 0.0, 0.0);
var i: i32 = 0;
var local: i32;

loop {
let _e9: i32 = i;
if (!((_e9 < 42))) {
break;
}
{
let _e17: vec4<f32> = sum;
let _e18: i32 = i;
let _e20: vec4<f32> = global.vecs[_e18];
sum = (_e17 + _e20);
let _e15: vec4<f32> = sum;
let _e16: i32 = i;
let _e18: vec4<f32> = global.vecs[_e16];
sum = (_e15 + _e18);
}
continuing {
let _e12: i32 = i;
local = _e12;
i = (_e12 + 1);
}
}
let _e22: vec4<f32> = sum;
return _e22;
let _e20: vec4<f32> = sum;
return _e20;
}

fn main1() {
Expand Down
40 changes: 40 additions & 0 deletions tests/out/wgsl/prepostfix-frag.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
fn main1() {
var scalar_target: i32;
var scalar: i32 = 1;
var vec_target: vec2<u32>;
var vec: vec2<u32> = vec2<u32>(1u, 1u);
var mat_target: mat4x3<f32>;
var mat: mat4x3<f32> = mat4x3<f32>(vec3<f32>(1.0, 0.0, 0.0), vec3<f32>(0.0, 1.0, 0.0), vec3<f32>(0.0, 0.0, 1.0), vec3<f32>(0.0, 0.0, 0.0));

let _e3: i32 = scalar;
scalar = (_e3 + 1);
scalar_target = _e3;
let _e6: i32 = scalar;
let _e8: i32 = (_e6 - 1);
scalar = _e8;
scalar_target = _e8;
let _e14: vec2<u32> = vec;
vec = (_e14 - vec2<u32>(1u));
vec_target = _e14;
let _e18: vec2<u32> = vec;
let _e21: vec2<u32> = (_e18 + vec2<u32>(1u));
vec = _e21;
vec_target = _e21;
let _e24: f32 = f32(1);
let _e32: mat4x3<f32> = mat;
let _e34: vec3<f32> = vec3<f32>(1.0);
mat = (_e32 + mat4x3<f32>(_e34, _e34, _e34, _e34));
mat_target = _e32;
let _e37: mat4x3<f32> = mat;
let _e39: vec3<f32> = vec3<f32>(1.0);
let _e41: mat4x3<f32> = (_e37 - mat4x3<f32>(_e39, _e39, _e39, _e39));
mat = _e41;
mat_target = _e41;
return;
}

[[stage(fragment)]]
fn main() {
main1();
return;
}

0 comments on commit d7ca7d4

Please sign in to comment.