diff --git a/src/webgpu/listing_meta.json b/src/webgpu/listing_meta.json index 327ba644120d..7c4982ca1853 100644 --- a/src/webgpu/listing_meta.json +++ b/src/webgpu/listing_meta.json @@ -1813,6 +1813,8 @@ "webgpu:shader,validation,expression,binary,bitwise_shift:shift_left_vec_size_mismatch:*": { "subcaseMS": 1.367 }, "webgpu:shader,validation,expression,binary,bitwise_shift:shift_right_concrete:*": { "subcaseMS": 1.237 }, "webgpu:shader,validation,expression,binary,bitwise_shift:shift_right_vec_size_mismatch:*": { "subcaseMS": 1.334 }, + "webgpu:shader,validation,expression,binary,comparison:invalid_types:*": { "subcaseMS": 39.526 }, + "webgpu:shader,validation,expression,binary,comparison:scalar_vector:*": { "subcaseMS": 1598.064 }, "webgpu:shader,validation,expression,call,builtin,abs:parameters:*": { "subcaseMS": 10.133 }, "webgpu:shader,validation,expression,call,builtin,abs:values:*": { "subcaseMS": 0.391 }, "webgpu:shader,validation,expression,call,builtin,acos:integer_argument:*": { "subcaseMS": 1.512 }, diff --git a/src/webgpu/shader/validation/expression/binary/comparison.spec.ts b/src/webgpu/shader/validation/expression/binary/comparison.spec.ts new file mode 100644 index 000000000000..bfba7adaa6b4 --- /dev/null +++ b/src/webgpu/shader/validation/expression/binary/comparison.spec.ts @@ -0,0 +1,186 @@ +export const description = ` +Validation tests for comparison expressions. +`; + +import { makeTestGroup } from '../../../../../common/framework/test_group.js'; +import { keysOf, objectsToRecord } from '../../../../../common/util/data_tables.js'; +import { + isFloatType, + kAllScalarsAndVectors, + ScalarType, + scalarTypeOf, + Type, + VectorType, +} from '../../../../util/conversion.js'; +import { ShaderValidationTest } from '../../shader_validation_test.js'; + +export const g = makeTestGroup(ShaderValidationTest); + +// A list of scalar and vector types. +const kScalarAndVectorTypes = objectsToRecord(kAllScalarsAndVectors); + +// A list of comparison operators and a flag for whether they support boolean values or not. +const kComparisonOperators = { + eq: { op: '==', supportsBool: true }, + ne: { op: '!=', supportsBool: true }, + gt: { op: '>', supportsBool: false }, + ge: { op: '>=', supportsBool: false }, + lt: { op: '<', supportsBool: false }, + le: { op: '<=', supportsBool: false }, +}; + +g.test('scalar_vector') + .desc( + ` + Validates that scalar and vector comparison expressions are only accepted for compatible types. + ` + ) + .params(u => + u + .combine('op', keysOf(kComparisonOperators)) + .combine('lhs', keysOf(kScalarAndVectorTypes)) + .combine( + 'rhs', + // Skip vec3 and vec4 on the RHS to keep the number of subcases down. + keysOf(kScalarAndVectorTypes).filter( + value => !(value.startsWith('vec3') || value.startsWith('vec4')) + ) + ) + .beginSubcases() + ) + .beforeAllSubcases(t => { + if ( + scalarTypeOf(kScalarAndVectorTypes[t.params.lhs]) === Type.f16 || + scalarTypeOf(kScalarAndVectorTypes[t.params.rhs]) === Type.f16 + ) { + t.selectDeviceOrSkipTestCase('shader-f16'); + } + }) + .fn(t => { + const lhs = kScalarAndVectorTypes[t.params.lhs]; + const rhs = kScalarAndVectorTypes[t.params.rhs]; + const lhsElement = scalarTypeOf(lhs); + const rhsElement = scalarTypeOf(rhs); + const hasF16 = lhsElement === Type.f16 || rhsElement === Type.f16; + const code = ` +${hasF16 ? 'enable f16;' : ''} +const lhs = ${lhs.create(lhsElement.kind === 'abstract-int' ? 0n : 0).wgsl()}; +const rhs = ${rhs.create(rhsElement.kind === 'abstract-int' ? 0n : 0).wgsl()}; +const foo = lhs ${kComparisonOperators[t.params.op].op} rhs; +`; + + let valid = false; + + // Determine if the element types are comparable. + let elementIsCompatible = false; + if (lhsElement.kind === 'abstract-int') { + // Abstract integers are comparable to any other numeric type. + elementIsCompatible = rhsElement.kind !== 'bool'; + } else if (rhsElement.kind === 'abstract-int') { + // Abstract integers are comparable to any other numeric type. + elementIsCompatible = lhsElement.kind !== 'bool'; + } else if (lhsElement.kind === 'abstract-float') { + // Abstract floats are comparable to any other float type. + elementIsCompatible = isFloatType(rhsElement); + } else if (rhsElement.kind === 'abstract-float') { + // Abstract floats are comparable to any other float type. + elementIsCompatible = isFloatType(lhsElement); + } else { + // Non-abstract types are only comparable to values with the exact same type. + elementIsCompatible = lhsElement === rhsElement; + } + + // Determine if the full type is comparable. + if (lhs instanceof ScalarType && rhs instanceof ScalarType) { + valid = elementIsCompatible; + } else if (lhs instanceof VectorType && rhs instanceof VectorType) { + // Vectors are only comparable if the vector widths match. + valid = lhs.width === rhs.width && elementIsCompatible; + } + + if (lhsElement.kind === 'bool') { + valid &&= kComparisonOperators[t.params.op].supportsBool; + } + + t.expectCompileResult(valid, code); + }); + +interface InvalidTypeConfig { + // An expression that produces a value of the target type. + expr: string; + // A function that converts an expression of the target type into a valid comparison operand. + control: (x: string) => string; +} +const kInvalidTypes: Record = { + mat2x2f: { + expr: 'm', + control: e => `${e}[0]`, + }, + + array: { + expr: 'arr', + control: e => `${e}[0]`, + }, + + ptr: { + expr: '(&u)', + control: e => `*${e}`, + }, + + atomic: { + expr: 'a', + control: e => `atomicLoad(&${e})`, + }, + + texture: { + expr: 't', + control: e => `textureLoad(${e}, vec2(), 0)`, + }, + + sampler: { + expr: 's', + control: e => `textureSampleLevel(t, ${e}, vec2(), 0)`, + }, + + struct: { + expr: 'str', + control: e => `${e}.u`, + }, +}; + +g.test('invalid_types') + .desc( + ` + Validates that comparison expressions are never accepted for non-scalar and non-vector types. + ` + ) + .params(u => + u + .combine('op', keysOf(kComparisonOperators)) + .combine('type', keysOf(kInvalidTypes)) + .combine('control', [true, false]) + .beginSubcases() + ) + .fn(t => { + const type = kInvalidTypes[t.params.type]; + const expr = t.params.control ? type.control(type.expr) : type.expr; + const code = ` +@group(0) @binding(0) var t : texture_2d; +@group(0) @binding(1) var s : sampler; +@group(0) @binding(2) var a : atomic; + +struct S { u : u32 } + +var u : u32; +var m : mat2x2f; +var arr : array; +var str : S; + +@compute @workgroup_size(1) +fn main() { + let foo = ${expr} ${kComparisonOperators[t.params.op].op} ${expr}; +} +`; + + t.expectCompileResult(t.params.control, code); + });