diff --git a/src/webgpu/listing_meta.json b/src/webgpu/listing_meta.json index 6e30934e818c..99efc128226d 100644 --- a/src/webgpu/listing_meta.json +++ b/src/webgpu/listing_meta.json @@ -1818,10 +1818,10 @@ "webgpu:shader,validation,expression,access,vector:vector:*": { "subcaseMS": 1.407 }, "webgpu:shader,validation,expression,binary,and_or_xor:invalid_types:*": { "subcaseMS": 24.069 }, "webgpu:shader,validation,expression,binary,and_or_xor:scalar_vector:*": { "subcaseMS": 666.807 }, + "webgpu:shader,validation,expression,binary,bitwise_shift:invalid_types:*": { "subcaseMS": 22.058 }, + "webgpu:shader,validation,expression,binary,bitwise_shift:scalar_vector:*": { "subcaseMS": 525.052 }, "webgpu:shader,validation,expression,binary,bitwise_shift:shift_left_concrete:*": { "subcaseMS": 1.216 }, - "webgpu:shader,validation,expression,binary,bitwise_shift:shift_left_vec_size_mismatch:*": { "subcaseMS": 1.367 }, "webgpu:shader,validation,expression,binary,bitwise_shift:shift_right_concrete:*": { "subcaseMS": 1.237 }, - "webgpu:shader,validation,expression,binary,bitwise_shift:shift_right_vec_size_mismatch:*": { "subcaseMS": 1.334 }, "webgpu:shader,validation,expression,binary,comparison:invalid_types:*": { "subcaseMS": 39.526 }, "webgpu:shader,validation,expression,binary,comparison:scalar_vector:*": { "subcaseMS": 1598.064 }, "webgpu:shader,validation,expression,call,builtin,abs:parameters:*": { "subcaseMS": 10.133 }, diff --git a/src/webgpu/shader/validation/expression/binary/bitwise_shift.spec.ts b/src/webgpu/shader/validation/expression/binary/bitwise_shift.spec.ts index 5f7b995ded8e..6a654471b924 100644 --- a/src/webgpu/shader/validation/expression/binary/bitwise_shift.spec.ts +++ b/src/webgpu/shader/validation/expression/binary/bitwise_shift.spec.ts @@ -3,6 +3,13 @@ Validation tests for the bitwise shift binary expression operations `; import { makeTestGroup } from '../../../../../common/framework/test_group.js'; +import { keysOf, objectsToRecord } from '../../../../../common/util/data_tables.js'; +import { + Type, + kAllScalarsAndVectors, + numElementsOf, + scalarTypeOf, +} from '../../../../util/conversion.js'; import { ShaderValidationTest } from '../../shader_validation_test.js'; export const g = makeTestGroup(ShaderValidationTest); @@ -21,6 +28,137 @@ function vectorize(v: string, size: number | undefined): string { return v; } +// A list of scalar and vector types. +const kScalarAndVectorTypes = objectsToRecord(kAllScalarsAndVectors); + +g.test('scalar_vector') + .desc( + ` + Validates that scalar and vector expressions are only accepted when the LHS is an integer and the RHS is abstract or unsigned. + ` + ) + .params(u => + u + .combine('op', ['<<', '>>']) + .combine('lhs', keysOf(kScalarAndVectorTypes)) + .combine( + 'rhs', + // Skip vec3 and vec4 on the RHS to keep the number of subcases down. + keysOf(kScalarAndVectorTypes).filter( + value => !(value.startsWith('vec3') || value.startsWith('vec4')) + ) + ) + .beginSubcases() + ) + .beforeAllSubcases(t => { + if ( + scalarTypeOf(kScalarAndVectorTypes[t.params.lhs]) === Type.f16 || + scalarTypeOf(kScalarAndVectorTypes[t.params.rhs]) === Type.f16 + ) { + t.selectDeviceOrSkipTestCase('shader-f16'); + } + }) + .fn(t => { + const lhs = kScalarAndVectorTypes[t.params.lhs]; + const rhs = kScalarAndVectorTypes[t.params.rhs]; + const lhsElement = scalarTypeOf(lhs); + const rhsElement = scalarTypeOf(rhs); + const hasF16 = lhsElement === Type.f16 || rhsElement === Type.f16; + const code = ` +${hasF16 ? 'enable f16;' : ''} +const lhs = ${lhs.create(0).wgsl()}; +const rhs = ${rhs.create(0).wgsl()}; +const foo = lhs ${t.params.op} rhs; +`; + + // The LHS must be an integer, and the RHS must be an abstract/unsigned integer. + // The vector widths must also match. + const lhs_valid = [Type.abstractInt, Type.i32, Type.u32].includes(lhsElement); + const rhs_valid = [Type.abstractInt, Type.u32].includes(rhsElement); + const valid = lhs_valid && rhs_valid && numElementsOf(lhs) === numElementsOf(rhs); + t.expectCompileResult(valid, code); + }); + +interface InvalidTypeConfig { + // An expression that produces a value of the target type. + expr: string; + // A function that converts an expression of the target type into a valid u32 operand. + control: (x: string) => string; +} +const kInvalidTypes: Record = { + mat2x2f: { + expr: 'm', + control: e => `u32(${e}[0][0])`, + }, + + array: { + expr: 'arr', + control: e => `${e}[0]`, + }, + + ptr: { + expr: '(&u)', + control: e => `*${e}`, + }, + + atomic: { + expr: 'a', + control: e => `atomicLoad(&${e})`, + }, + + texture: { + expr: 't', + control: e => `u32(textureLoad(${e}, vec2(), 0).x)`, + }, + + sampler: { + expr: 's', + control: e => `u32(textureSampleLevel(t, ${e}, vec2(), 0).x)`, + }, + + struct: { + expr: 'str', + control: e => `${e}.u`, + }, +}; + +g.test('invalid_types') + .desc( + ` + Validates that expressions are never accepted for non-scalar and non-vector types. + ` + ) + .params(u => + u + .combine('op', ['<<', '>>']) + .combine('type', keysOf(kInvalidTypes)) + .combine('control', [true, false]) + .beginSubcases() + ) + .fn(t => { + const type = kInvalidTypes[t.params.type]; + const expr = t.params.control ? type.control(type.expr) : type.expr; + const code = ` +@group(0) @binding(0) var t : texture_2d; +@group(0) @binding(1) var s : sampler; +@group(0) @binding(2) var a : atomic; + +struct S { u : u32 } + +var u : u32; +var m : mat2x2f; +var arr : array; +var str : S; + +@compute @workgroup_size(1) +fn main() { + let foo = ${expr} ${t.params.op} ${expr}; +} +`; + + t.expectCompileResult(t.params.control, code); + }); + const kLeftShiftCases = [ // rhs >= bitwidth fails { lhs: `0u`, rhs: `31u`, pass: true }, @@ -80,28 +218,6 @@ fn main() { t.expectCompileResult(t.params.case.pass, code); }); -g.test('shift_left_vec_size_mismatch') - .desc('Tests validation of binary left shift of vectors with mismatched sizes') - .params(u => - u - .combine('vectorize_lhs', [2, 3, 4]) // - .combine('vectorize_rhs', [2, 3, 4]) - ) - .fn(t => { - const lhs = `1`; - const rhs = `1`; - const lhs_vec_size = t.params.vectorize_lhs; - const rhs_vec_size = t.params.vectorize_rhs; - const code = ` -@compute @workgroup_size(1) -fn main() { - const r = ${vectorize(lhs, lhs_vec_size)} << ${vectorize(rhs, rhs_vec_size)}; -} - `; - const pass = lhs_vec_size === rhs_vec_size; - t.expectCompileResult(pass, code); - }); - const kRightShiftCases = [ // rhs >= bitwidth fails { lhs: `0u`, rhs: `31u`, pass: true }, @@ -142,25 +258,3 @@ fn main() { `; t.expectCompileResult(t.params.case.pass, code); }); - -g.test('shift_right_vec_size_mismatch') - .desc('Tests validation of binary right shift of vectors with mismatched sizes') - .params(u => - u - .combine('vectorize_lhs', [2, 3, 4]) // - .combine('vectorize_rhs', [2, 3, 4]) - ) - .fn(t => { - const lhs = `1`; - const rhs = `1`; - const lhs_vec_size = t.params.vectorize_lhs; - const rhs_vec_size = t.params.vectorize_rhs; - const code = ` -@compute @workgroup_size(1) -fn main() { - const r = ${vectorize(lhs, lhs_vec_size)} >> ${vectorize(rhs, rhs_vec_size)}; -} - `; - const pass = lhs_vec_size === rhs_vec_size; - t.expectCompileResult(pass, code); - });