diff --git a/src/webgpu/listing_meta.json b/src/webgpu/listing_meta.json index 669d171f34a4..1ad427ac7668 100644 --- a/src/webgpu/listing_meta.json +++ b/src/webgpu/listing_meta.json @@ -2015,6 +2015,11 @@ "webgpu:shader,validation,expression,binary,bitwise_shift:shift_right_concrete:*": { "subcaseMS": 1.237 }, "webgpu:shader,validation,expression,binary,comparison:invalid_types:*": { "subcaseMS": 39.526 }, "webgpu:shader,validation,expression,binary,comparison:scalar_vector:*": { "subcaseMS": 1598.064 }, + "webgpu:shader,validation,expression,binary,short_circuiting_and_or:invalid_array_count_on_rhs:*": { "subcaseMS": 4.309 }, + "webgpu:shader,validation,expression,binary,short_circuiting_and_or:invalid_rhs_const:*": { "subcaseMS": 4.341 }, + "webgpu:shader,validation,expression,binary,short_circuiting_and_or:invalid_rhs_override:*": { "subcaseMS": 27.490 }, + "webgpu:shader,validation,expression,binary,short_circuiting_and_or:invalid_types:*": { "subcaseMS": 13.409 }, + "webgpu:shader,validation,expression,binary,short_circuiting_and_or:scalar_vector:*": { "subcaseMS": 397.769 }, "webgpu:shader,validation,expression,binary,div_rem:invalid_type_with_itself:*": { "subcaseMS": 38.059 }, "webgpu:shader,validation,expression,binary,div_rem:scalar_vector:*": { "subcaseMS": 743.721 }, "webgpu:shader,validation,expression,binary,div_rem:scalar_vector_out_of_range:*": { "subcaseMS": 650.727 }, diff --git a/src/webgpu/shader/validation/expression/binary/short_circuiting_and_or.spec.ts b/src/webgpu/shader/validation/expression/binary/short_circuiting_and_or.spec.ts new file mode 100644 index 000000000000..30f521e54944 --- /dev/null +++ b/src/webgpu/shader/validation/expression/binary/short_circuiting_and_or.spec.ts @@ -0,0 +1,264 @@ +export const description = ` +Validation tests for short-circuiting && and || expressions. +`; + +import { makeTestGroup } from '../../../../../common/framework/test_group.js'; +import { keysOf, objectsToRecord } from '../../../../../common/util/data_tables.js'; +import { + kAllScalarsAndVectors, + ScalarType, + scalarTypeOf, + Type, +} from '../../../../util/conversion.js'; +import { ShaderValidationTest } from '../../shader_validation_test.js'; + +export const g = makeTestGroup(ShaderValidationTest); + +// A list of scalar and vector types. +const kScalarAndVectorTypes = objectsToRecord(kAllScalarsAndVectors); + +g.test('scalar_vector') + .desc( + ` + Validates that scalar and vector short-circuiting operators are only accepted for scalar booleans. + ` + ) + .params(u => + u + .combine('op', ['&&', '||']) + .combine('lhs', keysOf(kScalarAndVectorTypes)) + .combine( + 'rhs', + // Skip vec3 and vec4 on the RHS to keep the number of subcases down. + keysOf(kScalarAndVectorTypes).filter( + value => !(value.startsWith('vec3') || value.startsWith('vec4')) + ) + ) + .beginSubcases() + ) + .beforeAllSubcases(t => { + if ( + scalarTypeOf(kScalarAndVectorTypes[t.params.lhs]) === Type.f16 || + scalarTypeOf(kScalarAndVectorTypes[t.params.rhs]) === Type.f16 + ) { + t.selectDeviceOrSkipTestCase('shader-f16'); + } + }) + .fn(t => { + const lhs = kScalarAndVectorTypes[t.params.lhs]; + const rhs = kScalarAndVectorTypes[t.params.rhs]; + const lhsElement = scalarTypeOf(lhs); + const rhsElement = scalarTypeOf(rhs); + const hasF16 = lhsElement === Type.f16 || rhsElement === Type.f16; + const code = ` +${hasF16 ? 'enable f16;' : ''} +const lhs = ${lhs.create(0).wgsl()}; +const rhs = ${rhs.create(0).wgsl()}; +const foo = lhs ${t.params.op} rhs; +`; + + // Determine if the types are compatible. + let valid = false; + if (lhs instanceof ScalarType && rhs instanceof ScalarType) { + valid = lhsElement === Type.bool && rhsElement === Type.bool; + } + + t.expectCompileResult(valid, code); + }); + +interface InvalidTypeConfig { + // An expression that produces a value of the target type. + expr: string; + // A function that converts an expression of the target type into a valid boolean operand. + control: (x: string) => string; +} +const kInvalidTypes: Record = { + mat2x2f: { + expr: 'm', + control: e => `bool(${e}[0][0])`, + }, + + array: { + expr: 'arr', + control: e => `${e}[0]`, + }, + + ptr: { + expr: '(&b)', + control: e => `*${e}`, + }, + + atomic: { + expr: 'a', + control: e => `bool(atomicLoad(&${e}))`, + }, + + texture: { + expr: 't', + control: e => `bool(textureLoad(${e}, vec2(), 0).x)`, + }, + + sampler: { + expr: 's', + control: e => `bool(textureSampleLevel(t, ${e}, vec2(), 0).x)`, + }, + + struct: { + expr: 'str', + control: e => `${e}.b`, + }, +}; + +g.test('invalid_types') + .desc( + ` + Validates that short-circuiting expressions are never accepted for non-scalar and non-vector types. + ` + ) + .params(u => + u + .combine('op', ['&&', '||']) + .combine('type', keysOf(kInvalidTypes)) + .combine('control', [true, false]) + .beginSubcases() + ) + .fn(t => { + const type = kInvalidTypes[t.params.type]; + const expr = t.params.control ? type.control(type.expr) : type.expr; + const code = ` +@group(0) @binding(0) var t : texture_2d; +@group(0) @binding(1) var s : sampler; +@group(0) @binding(2) var a : atomic; + +struct S { b : bool } + +var b : bool; +var m : mat2x2f; +var arr : array; +var str : S; + +@compute @workgroup_size(1) +fn main() { + let foo = ${expr} ${t.params.op} ${expr}; +} +`; + + t.expectCompileResult(t.params.control, code); + }); + +// A map from operator to the value of the LHS that will cause short-circuiting. +const kLhsForShortCircuit: Record = { + '&&': false, + '||': true, +}; + +// A list of expressions that are invalid unless guarded by a short-circuiting expression. +const kInvalidRhsExpressions: Record = { + overflow: 'i32(1< + u + .combine('op', ['&&', '||']) + .combine('rhs', keysOf(kInvalidRhsExpressions)) + .combine('short_circuit', [true, false]) + .beginSubcases() + ) + .fn(t => { + let lhs = kLhsForShortCircuit[t.params.op]; + if (!t.params.short_circuit) { + lhs = !lhs; + } + const code = ` +const thirty_one = 31u; +const zero_i32 = 0i; +const one_f32 = 1.0f; + +@compute @workgroup_size(1) +fn main() { + let foo = ${lhs} ${t.params.op} ${kInvalidRhsExpressions[t.params.rhs]}; +} +`; + + t.expectCompileResult(t.params.short_circuit, code); + }); + +g.test('invalid_rhs_override') + .desc( + ` + Validates that a short-circuiting expression with an override-expression LHS guards the evaluation of its RHS expression. + ` + ) + .params(u => + u + .combine('op', ['&&', '||']) + .combine('rhs', keysOf(kInvalidRhsExpressions)) + .combine('short_circuit', [true, false]) + .beginSubcases() + ) + .fn(t => { + let lhs = kLhsForShortCircuit[t.params.op]; + if (!t.params.short_circuit) { + lhs = !lhs; + } + const code = ` +override cond : bool; +override zero_i32 = 0i; +override one_f32 = 1.0f; +override thirty_one = 31u; +override foo = cond ${t.params.op} ${kInvalidRhsExpressions[t.params.rhs]}; +`; + + const constants: Record = {}; + constants['cond'] = lhs ? 1 : 0; + t.expectPipelineResult({ + expectedResult: t.params.short_circuit, + code, + constants, + reference: ['foo'], + }); + }); + +// A list of expressions that are invalid unless guarded by a short-circuiting expression. +// The control case will use `value = 10`, the failure case will use `value = 1`. +const kInvalidArrayCounts: Record = { + negative: 'value - 2', + sqrt_neg1: 'u32(sqrt(value - 2))', + nested: '10 + array()[0]', +}; + +g.test('invalid_array_count_on_rhs') + .desc( + ` + Validates that an invalid array count expression is not guarded by a short-circuiting expression. + ` + ) + .params(u => + u + .combine('op', ['&&', '||']) + .combine('rhs', keysOf(kInvalidArrayCounts)) + .combine('control', [true, false]) + .beginSubcases() + ) + .fn(t => { + const lhs = t.params.op === '&&' ? 'false' : 'true'; + const code = ` +const value = ${t.params.control ? '10' : '1'}; + +@compute @workgroup_size(1) +fn main() { + let foo = ${lhs} ${t.params.op} array()[0]; +} +`; + + t.expectCompileResult(t.params.control, code); + });