Skip to content

Commit 704bde6

Browse files
committed
add f16 to webgpu/shader/types.ts utilities
Udpate affected tests; - webgpu:shader,execution,zero_init: skip f16 cases - webgpu:shader,execution,robust_access: generalize to cover most f16 cases Bug: #3405
1 parent 7ad3a97 commit 704bde6

File tree

4 files changed

+118
-33
lines changed

4 files changed

+118
-33
lines changed

src/webgpu/shader/execution/robust_access.spec.ts

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ TODO: add tests to check that textureLoad operations stay in-bounds.
77

88
import { makeTestGroup } from '../../../common/framework/test_group.js';
99
import { assert } from '../../../common/util/util.js';
10+
import { Float16Array } from '../../../external/petamoriken/float16/float16.js';
1011
import { GPUTest } from '../../gpu_test.js';
1112
import { align } from '../../util/math.js';
1213
import { generateTypes, supportedScalarTypes, supportsAtomics } from '../types.js';
@@ -25,6 +26,7 @@ const kMinI32 = -0x8000_0000;
2526
*/
2627
async function runShaderTest(
2728
t: GPUTest,
29+
enables: string,
2830
stage: GPUShaderStageFlags,
2931
testSource: string,
3032
layout: GPUPipelineLayout,
@@ -41,7 +43,7 @@ async function runShaderTest(
4143
usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.STORAGE,
4244
});
4345

44-
const source = `
46+
const source = `${enables}
4547
struct Constants {
4648
zero: u32
4749
};
@@ -96,10 +98,12 @@ fn main() {
9698
/** Fill an ArrayBuffer with sentinel values, except clear a region to zero. */
9799
function testFillArrayBuffer(
98100
array: ArrayBuffer,
99-
type: 'u32' | 'i32' | 'f32',
101+
type: 'u32' | 'i32' | 'f16' | 'f32',
100102
{ zeroByteStart, zeroByteCount }: { zeroByteStart: number; zeroByteCount: number }
101103
) {
102-
const constructor = { u32: Uint32Array, i32: Int32Array, f32: Float32Array }[type];
104+
const constructor = { u32: Uint32Array, i32: Int32Array, f16: Float16Array, f32: Float32Array }[
105+
type
106+
];
103107
assert(zeroByteCount % constructor.BYTES_PER_ELEMENT === 0);
104108
new constructor(array).fill(42);
105109
new constructor(array, zeroByteStart, zeroByteCount / constructor.BYTES_PER_ELEMENT).fill(0);
@@ -168,10 +172,15 @@ g.test('linear_memory')
168172
{ shadowingMode: 'function-scope' },
169173
])
170174
.expand('isAtomic', p => (supportsAtomics(p) ? [false, true] : [false]))
171-
.beginSubcases()
172175
.expand('baseType', supportedScalarTypes)
176+
.beginSubcases()
173177
.expandWithParams(generateTypes)
174178
)
179+
.beforeAllSubcases(t => {
180+
if (t.params.baseType === 'f16') {
181+
t.selectDeviceOrSkipTestCase('shader-f16');
182+
}
183+
})
175184
.fn(async t => {
176185
const {
177186
addressSpace,
@@ -189,6 +198,13 @@ g.test('linear_memory')
189198
assert(_kTypeInfo !== undefined, 'not an indexable type');
190199
assert('arrayLength' in _kTypeInfo);
191200

201+
if (baseType === 'f16' && addressSpace === 'uniform' && containerType === 'array') {
202+
// Array elements must be aligned to 16 bytes, but the logic in generateTypes
203+
// creates an array of vec4 of the baseType. But for f16 that's only 8 bytes.
204+
// We would need to write more complex logic for that.
205+
t.skip('TODO: Test logic does not handle array of f16 in the uniform address space');
206+
}
207+
192208
let usesCanary = false;
193209
let globalSource = '';
194210
let testFunctionSource = '';
@@ -429,6 +445,8 @@ fn runTest() -> u32 {
429445
],
430446
});
431447

448+
const enables = t.params.baseType === 'f16' ? 'enable f16;' : '';
449+
432450
// Run it.
433451
if (bufferBindingSize !== undefined && baseType !== 'bool') {
434452
const expectedData = new ArrayBuffer(testBufferSize);
@@ -450,6 +468,7 @@ fn runTest() -> u32 {
450468
// Run the shader, accessing the buffer.
451469
await runShaderTest(
452470
t,
471+
enables,
453472
GPUShaderStage.COMPUTE,
454473
testSource,
455474
layout,
@@ -475,6 +494,6 @@ fn runTest() -> u32 {
475494
bufferBindingEnd
476495
);
477496
} else {
478-
await runShaderTest(t, GPUShaderStage.COMPUTE, testSource, layout, []);
497+
await runShaderTest(t, enables, GPUShaderStage.COMPUTE, testSource, layout, []);
479498
}
480499
});

src/webgpu/shader/execution/zero_init.spec.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ g.test('compute,zero_init')
107107
? [true, false]
108108
: [false]) {
109109
for (const scalarType of supportedScalarTypes({ isAtomic, ...p })) {
110+
// Fewer subcases: supportedScalarTypes was expanded to include f16
111+
// but that may take too much time. It would require more complex code.
112+
if (scalarType === 'f16') continue;
113+
110114
// Fewer subcases: For nested types, skip atomic u32 and non-atomic i32.
111115
if (p._containerDepth > 0) {
112116
if (scalarType === 'u32' && isAtomic) continue;

src/webgpu/shader/types.ts

Lines changed: 86 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,20 @@ import { keysOf } from '../../common/util/data_tables.js';
22
import { assert } from '../../common/util/util.js';
33
import { align } from '../util/math.js';
44

5-
const kArrayLength = 3;
5+
const kDefaultArrayLength = 3;
66

77
export type Requirement = 'never' | 'may' | 'must'; // never is the same as "must not"
88
export type ContainerType = 'scalar' | 'vector' | 'matrix' | 'atomic' | 'array';
9-
export type ScalarType = 'i32' | 'u32' | 'f32' | 'bool';
9+
export type ScalarType = 'i32' | 'u32' | 'f16' | 'f32' | 'bool';
1010

11-
export const HostSharableTypes = ['i32', 'u32', 'f32'] as const;
11+
export const HostSharableTypes = ['i32', 'u32', 'f16', 'f32'] as const;
1212

1313
/** Info for each plain scalar type. */
1414
export const kScalarTypeInfo =
1515
/* prettier-ignore */ {
1616
'i32': { layout: { alignment: 4, size: 4 }, supportsAtomics: true, arrayLength: 1, innerLength: 0 },
1717
'u32': { layout: { alignment: 4, size: 4 }, supportsAtomics: true, arrayLength: 1, innerLength: 0 },
18+
'f16': { layout: { alignment: 2, size: 2 }, supportsAtomics: false, arrayLength: 1, innerLength: 0, feature: 'shader-f16' },
1819
'f32': { layout: { alignment: 4, size: 4 }, supportsAtomics: false, arrayLength: 1, innerLength: 0 },
1920
'bool': { layout: undefined, supportsAtomics: false, arrayLength: 1, innerLength: 0 },
2021
} as const;
@@ -24,29 +25,71 @@ export const kScalarTypes = keysOf(kScalarTypeInfo);
2425
/** Info for each vecN<> container type. */
2526
export const kVectorContainerTypeInfo =
2627
/* prettier-ignore */ {
27-
'vec2': { layout: { alignment: 8, size: 8 }, arrayLength: 2 , innerLength: 0 },
28-
'vec3': { layout: { alignment: 16, size: 12 }, arrayLength: 3 , innerLength: 0 },
29-
'vec4': { layout: { alignment: 16, size: 16 }, arrayLength: 4 , innerLength: 0 },
28+
'vec2': { arrayLength: 2 , innerLength: 0 },
29+
'vec3': { arrayLength: 3 , innerLength: 0 },
30+
'vec4': { arrayLength: 4 , innerLength: 0 },
3031
} as const;
3132
/** List of all vecN<> container types. */
3233
export const kVectorContainerTypes = keysOf(kVectorContainerTypeInfo);
3334

35+
/** Returns the vector layout for a given vector container and base type, or undefined if that base type has no layout */
36+
function vectorLayout(
37+
vectorContainer: 'vec2' | 'vec3' | 'vec4',
38+
baseType: ScalarType
39+
): { alignment: number; size: number } | undefined {
40+
const n = kVectorContainerTypeInfo[vectorContainer].arrayLength;
41+
const scalarLayout = kScalarTypeInfo[baseType].layout;
42+
if (scalarLayout === undefined) {
43+
return undefined;
44+
}
45+
if (n === 3) {
46+
return { alignment: scalarLayout.alignment * 4, size: scalarLayout.size * 3 };
47+
}
48+
return { alignment: scalarLayout.alignment * n, size: scalarLayout.size * n };
49+
}
50+
3451
/** Info for each matNxN<> container type. */
3552
export const kMatrixContainerTypeInfo =
3653
/* prettier-ignore */ {
37-
'mat2x2': { layout: { alignment: 8, size: 16 }, arrayLength: 2, innerLength: 2 },
38-
'mat3x2': { layout: { alignment: 8, size: 24 }, arrayLength: 3, innerLength: 2 },
39-
'mat4x2': { layout: { alignment: 8, size: 32 }, arrayLength: 4, innerLength: 2 },
40-
'mat2x3': { layout: { alignment: 16, size: 32 }, arrayLength: 2, innerLength: 3 },
41-
'mat3x3': { layout: { alignment: 16, size: 48 }, arrayLength: 3, innerLength: 3 },
42-
'mat4x3': { layout: { alignment: 16, size: 64 }, arrayLength: 4, innerLength: 3 },
43-
'mat2x4': { layout: { alignment: 16, size: 32 }, arrayLength: 2, innerLength: 4 },
44-
'mat3x4': { layout: { alignment: 16, size: 48 }, arrayLength: 3, innerLength: 4 },
45-
'mat4x4': { layout: { alignment: 16, size: 64 }, arrayLength: 4, innerLength: 4 },
54+
'mat2x2': { arrayLength: 2, innerLength: 2 },
55+
'mat3x2': { arrayLength: 3, innerLength: 2 },
56+
'mat4x2': { arrayLength: 4, innerLength: 2 },
57+
'mat2x3': { arrayLength: 2, innerLength: 3 },
58+
'mat3x3': { arrayLength: 3, innerLength: 3 },
59+
'mat4x3': { arrayLength: 4, innerLength: 3 },
60+
'mat2x4': { arrayLength: 2, innerLength: 4 },
61+
'mat3x4': { arrayLength: 3, innerLength: 4 },
62+
'mat4x4': { arrayLength: 4, innerLength: 4 },
4663
} as const;
4764
/** List of all matNxN<> container types. */
4865
export const kMatrixContainerTypes = keysOf(kMatrixContainerTypeInfo);
4966

67+
export const kMatrixContainerTypeLayoutInfo =
68+
/* prettier-ignore */ {
69+
'f16': {
70+
'mat2x2': { layout: { alignment: 4, size: 8 } },
71+
'mat3x2': { layout: { alignment: 4, size: 12 } },
72+
'mat4x2': { layout: { alignment: 4, size: 16 } },
73+
'mat2x3': { layout: { alignment: 8, size: 16 } },
74+
'mat3x3': { layout: { alignment: 8, size: 24 } },
75+
'mat4x3': { layout: { alignment: 8, size: 32 } },
76+
'mat2x4': { layout: { alignment: 8, size: 16 } },
77+
'mat3x4': { layout: { alignment: 8, size: 24 } },
78+
'mat4x4': { layout: { alignment: 8, size: 32 } },
79+
},
80+
'f32': {
81+
'mat2x2': { layout: { alignment: 8, size: 16 } },
82+
'mat3x2': { layout: { alignment: 8, size: 24 } },
83+
'mat4x2': { layout: { alignment: 8, size: 32 } },
84+
'mat2x3': { layout: { alignment: 16, size: 32 } },
85+
'mat3x3': { layout: { alignment: 16, size: 48 } },
86+
'mat4x3': { layout: { alignment: 16, size: 64 } },
87+
'mat2x4': { layout: { alignment: 16, size: 32 } },
88+
'mat3x4': { layout: { alignment: 16, size: 48 } },
89+
'mat4x4': { layout: { alignment: 16, size: 64 } },
90+
}
91+
} as const;
92+
5093
export type AddressSpace = 'storage' | 'uniform' | 'private' | 'function' | 'workgroup' | 'handle';
5194
export type AccessMode = 'read' | 'write' | 'read_write';
5295
export type Scope = 'module' | 'function';
@@ -189,21 +232,27 @@ export function* generateTypes({
189232
for (const vectorType of kVectorContainerTypes) {
190233
yield {
191234
type: `${vectorType}<${scalarType}>`,
192-
_kTypeInfo: { elementBaseType: baseType, ...kVectorContainerTypeInfo[vectorType] },
235+
_kTypeInfo: {
236+
elementBaseType: baseType,
237+
...kVectorContainerTypeInfo[vectorType],
238+
layout: vectorLayout(vectorType, scalarType as ScalarType),
239+
},
193240
};
194241
}
195242
}
196243

197244
if (containerType === 'matrix') {
198-
// Matrices can only be f32.
199-
if (baseType === 'f32') {
245+
// Matrices can only be f16 or f32.
246+
if (baseType === 'f16' || baseType === 'f32') {
200247
for (const matrixType of kMatrixContainerTypes) {
201-
const matrixInfo = kMatrixContainerTypeInfo[matrixType];
248+
const matrixDimInfo = kMatrixContainerTypeInfo[matrixType];
249+
const matrixLayoutInfo = kMatrixContainerTypeLayoutInfo[baseType][matrixType];
202250
yield {
203251
type: `${matrixType}<${scalarType}>`,
204252
_kTypeInfo: {
205-
elementBaseType: `vec${matrixInfo.innerLength}<${scalarType}>`,
206-
...matrixInfo,
253+
elementBaseType: `vec${matrixDimInfo.innerLength}<${scalarType}>`,
254+
...matrixDimInfo,
255+
...matrixLayoutInfo,
207256
},
208257
};
209258
}
@@ -212,33 +261,43 @@ export function* generateTypes({
212261

213262
// Array types
214263
if (containerType === 'array') {
264+
// Buffer affective binding size must be a multiple of 4. Adjust array length as needed.
265+
let arrayLength = kDefaultArrayLength;
266+
if (
267+
addressSpace === 'storage' &&
268+
scalarInfo.layout !== undefined &&
269+
scalarInfo.layout.alignment % 4 !== 0
270+
) {
271+
arrayLength = align(arrayLength, 4);
272+
}
273+
215274
const arrayTypeInfo = {
216275
elementBaseType: `${baseType}`,
217-
arrayLength: kArrayLength,
276+
arrayLength,
218277
layout: scalarInfo.layout
219278
? {
220279
alignment: scalarInfo.layout.alignment,
221280
size:
222281
addressSpace === 'uniform'
223282
? // Uniform storage class must have array elements aligned to 16.
224-
kArrayLength *
283+
arrayLength *
225284
arrayStride({
226285
...scalarInfo.layout,
227286
alignment: 16,
228287
})
229-
: kArrayLength * arrayStride(scalarInfo.layout),
288+
: arrayLength * arrayStride(scalarInfo.layout),
230289
}
231290
: undefined,
232291
};
233292

234293
// Sized
235294
if (addressSpace === 'uniform') {
236295
yield {
237-
type: `array<vec4<${scalarType}>,${kArrayLength}>`,
296+
type: `array<vec4<${scalarType}>,${arrayLength}>`,
238297
_kTypeInfo: arrayTypeInfo,
239298
};
240299
} else {
241-
yield { type: `array<${scalarType},${kArrayLength}>`, _kTypeInfo: arrayTypeInfo };
300+
yield { type: `array<${scalarType},${arrayLength}>`, _kTypeInfo: arrayTypeInfo };
242301
}
243302
// Unsized
244303
if (addressSpace === 'storage') {
@@ -272,7 +331,7 @@ export function supportsAtomics(p: {
272331
);
273332
}
274333

275-
/** Generates an iterator of supported base types (i32/u32/f32/bool) */
334+
/** Generates an iterator of supported base types (i32/u32/f16/f32/bool) */
276335
export function* supportedScalarTypes(p: { isAtomic: boolean; addressSpace: string }) {
277336
for (const scalarType of kScalarTypes) {
278337
const info = kScalarTypeInfo[scalarType];

src/webgpu/util/check_contents.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ export function checkElementsEqual(
4343
expected: TypedArrayBufferView
4444
): ErrorWithExtra | undefined {
4545
assert(actual.constructor === expected.constructor, 'TypedArray type mismatch');
46-
assert(actual.length === expected.length, 'size mismatch');
46+
assert(
47+
actual.length === expected.length,
48+
`length mismatch: expected ${expected.length} got ${actual.length}`
49+
);
4750

4851
let failedElementsFirstMaybe: number | undefined = undefined;
4952
/** Sparse array with `true` for elements that failed. */

0 commit comments

Comments
 (0)