-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
wgsl: Add and/or/xor validation tests (#3476)
Test that the operators are only accepted for booleans or compatible integer types.
- Loading branch information
Showing
2 changed files
with
183 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
181 changes: 181 additions & 0 deletions
181
src/webgpu/shader/validation/expression/binary/and_or_xor.spec.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
export const description = ` | ||
Validation tests for logical and bitwise and/or/xor expressions. | ||
`; | ||
|
||
import { makeTestGroup } from '../../../../../common/framework/test_group.js'; | ||
import { keysOf, objectsToRecord } from '../../../../../common/util/data_tables.js'; | ||
import { | ||
kAllScalarsAndVectors, | ||
ScalarType, | ||
scalarTypeOf, | ||
Type, | ||
VectorType, | ||
} from '../../../../util/conversion.js'; | ||
import { ShaderValidationTest } from '../../shader_validation_test.js'; | ||
|
||
export const g = makeTestGroup(ShaderValidationTest); | ||
|
||
// A list of operators and a flag for whether they support boolean values or not. | ||
const kOperators = { | ||
and: { op: '&', supportsBool: true }, | ||
or: { op: '|', supportsBool: true }, | ||
xor: { op: '^', supportsBool: false }, | ||
}; | ||
|
||
// A list of scalar and vector types. | ||
const kScalarAndVectorTypes = objectsToRecord(kAllScalarsAndVectors); | ||
|
||
g.test('scalar_vector') | ||
.desc( | ||
` | ||
Validates that scalar and vector expressions are only accepted for bool or compatible integer types. | ||
` | ||
) | ||
.params(u => | ||
u | ||
.combine('op', keysOf(kOperators)) | ||
.combine('lhs', keysOf(kScalarAndVectorTypes)) | ||
.combine( | ||
'rhs', | ||
// Skip vec3 and vec4 on the RHS to keep the number of subcases down. | ||
keysOf(kScalarAndVectorTypes).filter( | ||
value => !(value.startsWith('vec3') || value.startsWith('vec4')) | ||
) | ||
) | ||
.beginSubcases() | ||
) | ||
.beforeAllSubcases(t => { | ||
if ( | ||
scalarTypeOf(kScalarAndVectorTypes[t.params.lhs]) === Type.f16 || | ||
scalarTypeOf(kScalarAndVectorTypes[t.params.rhs]) === Type.f16 | ||
) { | ||
t.selectDeviceOrSkipTestCase('shader-f16'); | ||
} | ||
}) | ||
.fn(t => { | ||
const op = kOperators[t.params.op]; | ||
const lhs = kScalarAndVectorTypes[t.params.lhs]; | ||
const rhs = kScalarAndVectorTypes[t.params.rhs]; | ||
const lhsElement = scalarTypeOf(lhs); | ||
const rhsElement = scalarTypeOf(rhs); | ||
const hasF16 = lhsElement === Type.f16 || rhsElement === Type.f16; | ||
const code = ` | ||
${hasF16 ? 'enable f16;' : ''} | ||
const lhs = ${lhs.create(lhsElement.kind === 'abstract-int' ? 0n : 0).wgsl()}; | ||
const rhs = ${rhs.create(rhsElement.kind === 'abstract-int' ? 0n : 0).wgsl()}; | ||
const foo = lhs ${op.op} rhs; | ||
`; | ||
|
||
// Determine if the element types are compatible. | ||
const kIntegerTypes = [Type.abstractInt, Type.i32, Type.u32]; | ||
let elementIsCompatible = false; | ||
if (lhsElement === Type.abstractInt) { | ||
// Abstract integers are compatible with any other integer type. | ||
elementIsCompatible = kIntegerTypes.includes(rhsElement); | ||
} else if (rhsElement === Type.abstractInt) { | ||
// Abstract integers are compatible with any other numeric type. | ||
elementIsCompatible = kIntegerTypes.includes(lhsElement); | ||
} else if (kIntegerTypes.includes(lhsElement)) { | ||
// Concrete integers are only compatible with values with the exact same type. | ||
elementIsCompatible = lhsElement === rhsElement; | ||
} else if (lhsElement === Type.bool) { | ||
// Booleans are only compatible with other booleans. | ||
elementIsCompatible = rhsElement === Type.bool; | ||
} | ||
|
||
// Determine if the full type is compatible. | ||
let valid = false; | ||
if (lhs instanceof ScalarType && rhs instanceof ScalarType) { | ||
valid = elementIsCompatible; | ||
} else if (lhs instanceof VectorType && rhs instanceof VectorType) { | ||
// Vectors are only compatible with if the vector widths match. | ||
valid = lhs.width === rhs.width && elementIsCompatible; | ||
} | ||
|
||
if (lhsElement.kind === 'bool') { | ||
valid &&= op.supportsBool; | ||
} | ||
|
||
t.expectCompileResult(valid, code); | ||
}); | ||
|
||
interface InvalidTypeConfig { | ||
// An expression that produces a value of the target type. | ||
expr: string; | ||
// A function that converts an expression of the target type into a valid integer operand. | ||
control: (x: string) => string; | ||
} | ||
const kInvalidTypes: Record<string, InvalidTypeConfig> = { | ||
mat2x2f: { | ||
expr: 'm', | ||
control: e => `i32(${e}[0][0])`, | ||
}, | ||
|
||
array: { | ||
expr: 'arr', | ||
control: e => `${e}[0]`, | ||
}, | ||
|
||
ptr: { | ||
expr: '(&u)', | ||
control: e => `*${e}`, | ||
}, | ||
|
||
atomic: { | ||
expr: 'a', | ||
control: e => `atomicLoad(&${e})`, | ||
}, | ||
|
||
texture: { | ||
expr: 't', | ||
control: e => `i32(textureLoad(${e}, vec2(), 0).x)`, | ||
}, | ||
|
||
sampler: { | ||
expr: 's', | ||
control: e => `i32(textureSampleLevel(t, ${e}, vec2(), 0).x)`, | ||
}, | ||
|
||
struct: { | ||
expr: 'str', | ||
control: e => `${e}.u`, | ||
}, | ||
}; | ||
|
||
g.test('invalid_types') | ||
.desc( | ||
` | ||
Validates that expressions are never accepted for non-scalar and non-vector types. | ||
` | ||
) | ||
.params(u => | ||
u | ||
.combine('op', keysOf(kOperators)) | ||
.combine('type', keysOf(kInvalidTypes)) | ||
.combine('control', [true, false]) | ||
.beginSubcases() | ||
) | ||
.fn(t => { | ||
const op = kOperators[t.params.op]; | ||
const type = kInvalidTypes[t.params.type]; | ||
const expr = t.params.control ? type.control(type.expr) : type.expr; | ||
const code = ` | ||
@group(0) @binding(0) var t : texture_2d<f32>; | ||
@group(0) @binding(1) var s : sampler; | ||
@group(0) @binding(2) var<storage, read_write> a : atomic<i32>; | ||
struct S { u : u32 } | ||
var<private> u : u32; | ||
var<private> m : mat2x2f; | ||
var<private> arr : array<i32, 4>; | ||
var<private> str : S; | ||
@compute @workgroup_size(1) | ||
fn main() { | ||
let foo = ${expr} ${op.op} ${expr}; | ||
} | ||
`; | ||
|
||
t.expectCompileResult(t.params.control, code); | ||
}); |