Skip to content

Commit

Permalink
wgsl: Add more validation tests for bitwise shift (#3494)
Browse files Browse the repository at this point in the history
Add tests that cover various invalid types.
  • Loading branch information
jrprice authored Mar 14, 2024
1 parent 3a56acd commit 45f88b5
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 46 deletions.
4 changes: 2 additions & 2 deletions src/webgpu/listing_meta.json
Original file line number Diff line number Diff line change
Expand Up @@ -1818,10 +1818,10 @@
"webgpu:shader,validation,expression,access,vector:vector:*": { "subcaseMS": 1.407 },
"webgpu:shader,validation,expression,binary,and_or_xor:invalid_types:*": { "subcaseMS": 24.069 },
"webgpu:shader,validation,expression,binary,and_or_xor:scalar_vector:*": { "subcaseMS": 666.807 },
"webgpu:shader,validation,expression,binary,bitwise_shift:invalid_types:*": { "subcaseMS": 22.058 },
"webgpu:shader,validation,expression,binary,bitwise_shift:scalar_vector:*": { "subcaseMS": 525.052 },
"webgpu:shader,validation,expression,binary,bitwise_shift:shift_left_concrete:*": { "subcaseMS": 1.216 },
"webgpu:shader,validation,expression,binary,bitwise_shift:shift_left_vec_size_mismatch:*": { "subcaseMS": 1.367 },
"webgpu:shader,validation,expression,binary,bitwise_shift:shift_right_concrete:*": { "subcaseMS": 1.237 },
"webgpu:shader,validation,expression,binary,bitwise_shift:shift_right_vec_size_mismatch:*": { "subcaseMS": 1.334 },
"webgpu:shader,validation,expression,binary,comparison:invalid_types:*": { "subcaseMS": 39.526 },
"webgpu:shader,validation,expression,binary,comparison:scalar_vector:*": { "subcaseMS": 1598.064 },
"webgpu:shader,validation,expression,call,builtin,abs:parameters:*": { "subcaseMS": 10.133 },
Expand Down
182 changes: 138 additions & 44 deletions src/webgpu/shader/validation/expression/binary/bitwise_shift.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@ Validation tests for the bitwise shift binary expression operations
`;

import { makeTestGroup } from '../../../../../common/framework/test_group.js';
import { keysOf, objectsToRecord } from '../../../../../common/util/data_tables.js';
import {
Type,
kAllScalarsAndVectors,
numElementsOf,
scalarTypeOf,
} from '../../../../util/conversion.js';
import { ShaderValidationTest } from '../../shader_validation_test.js';

export const g = makeTestGroup(ShaderValidationTest);
Expand All @@ -21,6 +28,137 @@ function vectorize(v: string, size: number | undefined): string {
return v;
}

// A list of scalar and vector types.
const kScalarAndVectorTypes = objectsToRecord(kAllScalarsAndVectors);

g.test('scalar_vector')
.desc(
`
Validates that scalar and vector expressions are only accepted when the LHS is an integer and the RHS is abstract or unsigned.
`
)
.params(u =>
u
.combine('op', ['<<', '>>'])
.combine('lhs', keysOf(kScalarAndVectorTypes))
.combine(
'rhs',
// Skip vec3 and vec4 on the RHS to keep the number of subcases down.
keysOf(kScalarAndVectorTypes).filter(
value => !(value.startsWith('vec3') || value.startsWith('vec4'))
)
)
.beginSubcases()
)
.beforeAllSubcases(t => {
if (
scalarTypeOf(kScalarAndVectorTypes[t.params.lhs]) === Type.f16 ||
scalarTypeOf(kScalarAndVectorTypes[t.params.rhs]) === Type.f16
) {
t.selectDeviceOrSkipTestCase('shader-f16');
}
})
.fn(t => {
const lhs = kScalarAndVectorTypes[t.params.lhs];
const rhs = kScalarAndVectorTypes[t.params.rhs];
const lhsElement = scalarTypeOf(lhs);
const rhsElement = scalarTypeOf(rhs);
const hasF16 = lhsElement === Type.f16 || rhsElement === Type.f16;
const code = `
${hasF16 ? 'enable f16;' : ''}
const lhs = ${lhs.create(0).wgsl()};
const rhs = ${rhs.create(0).wgsl()};
const foo = lhs ${t.params.op} rhs;
`;

// The LHS must be an integer, and the RHS must be an abstract/unsigned integer.
// The vector widths must also match.
const lhs_valid = [Type.abstractInt, Type.i32, Type.u32].includes(lhsElement);
const rhs_valid = [Type.abstractInt, Type.u32].includes(rhsElement);
const valid = lhs_valid && rhs_valid && numElementsOf(lhs) === numElementsOf(rhs);
t.expectCompileResult(valid, code);
});

interface InvalidTypeConfig {
// An expression that produces a value of the target type.
expr: string;
// A function that converts an expression of the target type into a valid u32 operand.
control: (x: string) => string;
}
const kInvalidTypes: Record<string, InvalidTypeConfig> = {
mat2x2f: {
expr: 'm',
control: e => `u32(${e}[0][0])`,
},

array: {
expr: 'arr',
control: e => `${e}[0]`,
},

ptr: {
expr: '(&u)',
control: e => `*${e}`,
},

atomic: {
expr: 'a',
control: e => `atomicLoad(&${e})`,
},

texture: {
expr: 't',
control: e => `u32(textureLoad(${e}, vec2(), 0).x)`,
},

sampler: {
expr: 's',
control: e => `u32(textureSampleLevel(t, ${e}, vec2(), 0).x)`,
},

struct: {
expr: 'str',
control: e => `${e}.u`,
},
};

g.test('invalid_types')
.desc(
`
Validates that expressions are never accepted for non-scalar and non-vector types.
`
)
.params(u =>
u
.combine('op', ['<<', '>>'])
.combine('type', keysOf(kInvalidTypes))
.combine('control', [true, false])
.beginSubcases()
)
.fn(t => {
const type = kInvalidTypes[t.params.type];
const expr = t.params.control ? type.control(type.expr) : type.expr;
const code = `
@group(0) @binding(0) var t : texture_2d<f32>;
@group(0) @binding(1) var s : sampler;
@group(0) @binding(2) var<storage, read_write> a : atomic<u32>;
struct S { u : u32 }
var<private> u : u32;
var<private> m : mat2x2f;
var<private> arr : array<u32, 4>;
var<private> str : S;
@compute @workgroup_size(1)
fn main() {
let foo = ${expr} ${t.params.op} ${expr};
}
`;

t.expectCompileResult(t.params.control, code);
});

const kLeftShiftCases = [
// rhs >= bitwidth fails
{ lhs: `0u`, rhs: `31u`, pass: true },
Expand Down Expand Up @@ -80,28 +218,6 @@ fn main() {
t.expectCompileResult(t.params.case.pass, code);
});

g.test('shift_left_vec_size_mismatch')
.desc('Tests validation of binary left shift of vectors with mismatched sizes')
.params(u =>
u
.combine('vectorize_lhs', [2, 3, 4]) //
.combine('vectorize_rhs', [2, 3, 4])
)
.fn(t => {
const lhs = `1`;
const rhs = `1`;
const lhs_vec_size = t.params.vectorize_lhs;
const rhs_vec_size = t.params.vectorize_rhs;
const code = `
@compute @workgroup_size(1)
fn main() {
const r = ${vectorize(lhs, lhs_vec_size)} << ${vectorize(rhs, rhs_vec_size)};
}
`;
const pass = lhs_vec_size === rhs_vec_size;
t.expectCompileResult(pass, code);
});

const kRightShiftCases = [
// rhs >= bitwidth fails
{ lhs: `0u`, rhs: `31u`, pass: true },
Expand Down Expand Up @@ -142,25 +258,3 @@ fn main() {
`;
t.expectCompileResult(t.params.case.pass, code);
});

g.test('shift_right_vec_size_mismatch')
.desc('Tests validation of binary right shift of vectors with mismatched sizes')
.params(u =>
u
.combine('vectorize_lhs', [2, 3, 4]) //
.combine('vectorize_rhs', [2, 3, 4])
)
.fn(t => {
const lhs = `1`;
const rhs = `1`;
const lhs_vec_size = t.params.vectorize_lhs;
const rhs_vec_size = t.params.vectorize_rhs;
const code = `
@compute @workgroup_size(1)
fn main() {
const r = ${vectorize(lhs, lhs_vec_size)} >> ${vectorize(rhs, rhs_vec_size)};
}
`;
const pass = lhs_vec_size === rhs_vec_size;
t.expectCompileResult(pass, code);
});

0 comments on commit 45f88b5

Please sign in to comment.