Skip to content

Commit

Permalink
Add tests for subgroupShuffle variants (#4061)
Browse files Browse the repository at this point in the history
* Refactor some duplicated code into subgroup_util.ts
  * typed input generation
  * framebuffer parameter generation
* Add tests for subgroup shuffle variants
  * All variants: subgroupShuffle, subgroupShuffleUp,
    subgroupShuffleDown, subgroupShuffleXor
  * Fragment tests
  * Data type tests
  * specific variant tests
  * randomized compute tests
  * predicated compute tests
  • Loading branch information
alan-baker authored Nov 29, 2024
1 parent 68633ee commit a0713ec
Show file tree
Hide file tree
Showing 8 changed files with 1,089 additions and 262 deletions.
7 changes: 7 additions & 0 deletions src/webgpu/listing_meta.json
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,13 @@
"webgpu:shader,execution,expression,call,builtin,subgroupMul:data_types:*": { "subcaseMS": 11861.865 },
"webgpu:shader,execution,expression,call,builtin,subgroupMul:fp_accuracy:*": { "subcaseMS": 35606.717 },
"webgpu:shader,execution,expression,call,builtin,subgroupMul:fragment:*": { "subcaseMS": 0.263 },
"webgpu:shader,execution,expression,call,builtin,subgroupShuffle:compute,all_active:*": { "subcaseMS": 39.191 },
"webgpu:shader,execution,expression,call,builtin,subgroupShuffle:compute,split:*": { "subcaseMS": 3074.451 },
"webgpu:shader,execution,expression,call,builtin,subgroupShuffle:data_types:*": { "subcaseMS": 5767.334 },
"webgpu:shader,execution,expression,call,builtin,subgroupShuffle:fragment:*": { "subcaseMS": 49.537 },
"webgpu:shader,execution,expression,call,builtin,subgroupShuffle:shuffle,id:*": { "subcaseMS": 924.078 },
"webgpu:shader,execution,expression,call,builtin,subgroupShuffle:shuffleUpDown,delta:*": { "subcaseMS": 81.870 },
"webgpu:shader,execution,expression,call,builtin,subgroupShuffle:shuffleXor,mask:*": { "subcaseMS": 62.127 },
"webgpu:shader,execution,expression,call,builtin,tan:abstract_float:*": { "subcaseMS": 17043.428 },
"webgpu:shader,execution,expression,call,builtin,tan:f16:*": { "subcaseMS": 116.157 },
"webgpu:shader,execution,expression,call,builtin,tan:f32:*": { "subcaseMS": 13.532 },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@ local_invocation_index. Tests should avoid assuming there is.

import { makeTestGroup } from '../../../../../../common/framework/test_group.js';
import { keysOf, objectsToRecord } from '../../../../../../common/util/data_tables.js';
import { assert, unreachable } from '../../../../../../common/util/util.js';
import { kTextureFormatInfo } from '../../../../../format_info.js';
import { kBit } from '../../../../../util/constants.js';
import { assert } from '../../../../../../common/util/util.js';
import {
kConcreteNumericScalarsAndVectors,
Type,
VectorType,
scalarTypeOf,
} from '../../../../../util/conversion.js';
import { align } from '../../../../../util/math.js';

import {
kWGSizes,
Expand All @@ -28,120 +24,14 @@ import {
SubgroupTest,
kFramebufferSizes,
runFragmentTest,
generateTypedInputs,
getUintsPerFramebuffer,
} from './subgroup_util.js';

export const g = makeTestGroup(SubgroupTest);

const kTypes = objectsToRecord(kConcreteNumericScalarsAndVectors);

/**
* Generates scalar values for type
*
* Generates 4 32-bit values whose bit patterns represent
* interesting values of the data type.
* @param type The data type
*/
function generateScalarValues(type: Type): number[] {
const scalarTy = scalarTypeOf(type);
switch (scalarTy) {
case Type.u32:
return [kBit.u32.min, kBit.u32.max, 1111, 2222];
case Type.i32:
return [
kBit.i32.positive.min,
kBit.i32.positive.max,
kBit.i32.negative.min,
0xffffffff, // -1
];
case Type.f32:
return [
kBit.f32.positive.zero,
kBit.f32.positive.nearest_max,
kBit.f32.negative.nearest_min,
0xbf800000, // -1
];
case Type.f16:
return [
kBit.f16.positive.zero,
kBit.f16.positive.nearest_max,
kBit.f16.negative.nearest_min,
0xbc00, // -1
];
default:
unreachable(`Unsupported type: ${type.toString()}`);
}
return [0, 0, 0, 0];
}

/**
* Generates input bit patterns for the input type
*
* Generates 4 values of type in a Uint32Array.
* 16-bit types are appropriately packed.
* @param type The data type
*/
function generateTypedInputs(type: Type): Uint32Array {
const scalarValues = generateScalarValues(type);
let elements = 1;
if (type instanceof VectorType) {
elements = type.width;
}
if (type.requiresF16()) {
switch (elements) {
case 1:
return new Uint32Array([
scalarValues[0] | (scalarValues[1] << 16),
scalarValues[2] | (scalarValues[3] << 16),
]);
case 2:
return new Uint32Array([
scalarValues[0] | (scalarValues[0] << 16),
scalarValues[1] | (scalarValues[1] << 16),
scalarValues[2] | (scalarValues[2] << 16),
scalarValues[3] | (scalarValues[3] << 16),
]);
case 3:
return new Uint32Array([
scalarValues[0] | (scalarValues[0] << 16),
scalarValues[0] | (kDataSentinel << 16),
scalarValues[1] | (scalarValues[1] << 16),
scalarValues[1] | (kDataSentinel << 16),
scalarValues[2] | (scalarValues[2] << 16),
scalarValues[2] | (kDataSentinel << 16),
scalarValues[3] | (scalarValues[3] << 16),
scalarValues[3] | (kDataSentinel << 16),
]);
case 4:
return new Uint32Array([
scalarValues[0] | (scalarValues[0] << 16),
scalarValues[0] | (scalarValues[0] << 16),
scalarValues[1] | (scalarValues[1] << 16),
scalarValues[1] | (scalarValues[1] << 16),
scalarValues[2] | (scalarValues[2] << 16),
scalarValues[2] | (scalarValues[2] << 16),
scalarValues[3] | (scalarValues[3] << 16),
scalarValues[3] | (scalarValues[3] << 16),
]);
default:
unreachable(`Unsupported type: ${type.toString()}`);
}
return new Uint32Array([0]);
} else {
const bound = elements === 3 ? 4 : elements;
const values: number[] = [];
for (let i = 0; i < 4; i++) {
for (let j = 0; j < bound; j++) {
if (j < elements) {
values.push(scalarValues[i]);
} else {
values.push(kDataSentinel);
}
}
}
return new Uint32Array(values);
}
}

/**
* Checks results from data types test
*
Expand Down Expand Up @@ -528,12 +418,7 @@ function checkFragment(
);
}

const { blockWidth, blockHeight, bytesPerBlock } = kTextureFormatInfo[format];
const blocksPerRow = width / blockWidth;
// 256 minimum comes from image copy requirements.
const bytesPerRow = align(blocksPerRow * (bytesPerBlock ?? 1), 256);
const uintsPerRow = bytesPerRow / 4;
const uintsPerTexel = (bytesPerBlock ?? 1) / blockWidth / blockHeight / 4;
const { uintsPerRow, uintsPerTexel } = getUintsPerFramebuffer(format, width, height);

const coordToIndex = (row: number, col: number) => {
return uintsPerRow * row + col * uintsPerTexel;
Expand Down
121 changes: 3 additions & 118 deletions src/webgpu/shader/execution/expression/call/builtin/quadSwap.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,11 @@ local_invocation_index. Tests should avoid assuming there is.
import { makeTestGroup } from '../../../../../../common/framework/test_group.js';
import { keysOf, objectsToRecord } from '../../../../../../common/util/data_tables.js';
import { assert, unreachable } from '../../../../../../common/util/util.js';
import { kTextureFormatInfo } from '../../../../../format_info.js';
import { kBit } from '../../../../../util/constants.js';
import {
kConcreteNumericScalarsAndVectors,
Type,
VectorType,
scalarTypeOf,
} from '../../../../../util/conversion.js';
import { align } from '../../../../../util/math.js';

import {
kWGSizes,
Expand All @@ -28,6 +24,8 @@ import {
SubgroupTest,
kFramebufferSizes,
runFragmentTest,
generateTypedInputs,
getUintsPerFramebuffer,
} from './subgroup_util.js';

export const g = makeTestGroup(SubgroupTest);
Expand All @@ -38,114 +36,6 @@ type SwapOp = 'quadSwapX' | 'quadSwapY' | 'quadSwapDiagonal';

const kOps: SwapOp[] = ['quadSwapX', 'quadSwapY', 'quadSwapDiagonal'];

/**
* Generates scalar values for type
*
* Generates 4 32-bit values whose bit patterns represent
* interesting values of the data type.
* @param type The data type
*/
function generateScalarValues(type: Type): number[] {
const scalarTy = scalarTypeOf(type);
switch (scalarTy) {
case Type.u32:
return [kBit.u32.min, kBit.u32.max, 1111, 2222];
case Type.i32:
return [
kBit.i32.positive.min,
kBit.i32.positive.max,
kBit.i32.negative.min,
0xffffffff, // -1
];
case Type.f32:
return [
kBit.f32.positive.zero,
kBit.f32.positive.nearest_max,
kBit.f32.negative.nearest_min,
0xbf800000, // -1
];
case Type.f16:
return [
kBit.f16.positive.zero,
kBit.f16.positive.nearest_max,
kBit.f16.negative.nearest_min,
0xbc00, // -1
];
default:
unreachable(`Unsupported type: ${type.toString()}`);
}
return [0, 0, 0, 0];
}

/**
* Generates input bit patterns for the input type
*
* Generates 4 values of type in a Uint32Array.
* 16-bit types are appropriately packed.
* @param type The data type
*/
function generateTypedInputs(type: Type): Uint32Array {
const scalarValues = generateScalarValues(type);
let elements = 1;
if (type instanceof VectorType) {
elements = type.width;
}
if (type.requiresF16()) {
switch (elements) {
case 1:
return new Uint32Array([
scalarValues[0] | (scalarValues[1] << 16),
scalarValues[2] | (scalarValues[3] << 16),
]);
case 2:
return new Uint32Array([
scalarValues[0] | (scalarValues[0] << 16),
scalarValues[1] | (scalarValues[1] << 16),
scalarValues[2] | (scalarValues[2] << 16),
scalarValues[3] | (scalarValues[3] << 16),
]);
case 3:
return new Uint32Array([
scalarValues[0] | (scalarValues[0] << 16),
scalarValues[0] | (kDataSentinel << 16),
scalarValues[1] | (scalarValues[1] << 16),
scalarValues[1] | (kDataSentinel << 16),
scalarValues[2] | (scalarValues[2] << 16),
scalarValues[2] | (kDataSentinel << 16),
scalarValues[3] | (scalarValues[3] << 16),
scalarValues[3] | (kDataSentinel << 16),
]);
case 4:
return new Uint32Array([
scalarValues[0] | (scalarValues[0] << 16),
scalarValues[0] | (scalarValues[0] << 16),
scalarValues[1] | (scalarValues[1] << 16),
scalarValues[1] | (scalarValues[1] << 16),
scalarValues[2] | (scalarValues[2] << 16),
scalarValues[2] | (scalarValues[2] << 16),
scalarValues[3] | (scalarValues[3] << 16),
scalarValues[3] | (scalarValues[3] << 16),
]);
default:
unreachable(`Unsupported type: ${type.toString()}`);
}
return new Uint32Array([0]);
} else {
const bound = elements === 3 ? 4 : elements;
const values: number[] = [];
for (let i = 0; i < 4; i++) {
for (let j = 0; j < bound; j++) {
if (j < elements) {
values.push(scalarValues[i]);
} else {
values.push(kDataSentinel);
}
}
}
return new Uint32Array(values);
}
}

/**
* Returns the swapped quad invocation id for the given op
*
Expand Down Expand Up @@ -544,12 +434,7 @@ function checkFragment(
);
}

const { blockWidth, blockHeight, bytesPerBlock } = kTextureFormatInfo[format];
const blocksPerRow = width / blockWidth;
// 256 minimum comes from image copy requirements.
const bytesPerRow = align(blocksPerRow * (bytesPerBlock ?? 1), 256);
const uintsPerRow = bytesPerRow / 4;
const uintsPerTexel = (bytesPerBlock ?? 1) / blockWidth / blockHeight / 4;
const { uintsPerRow, uintsPerTexel } = getUintsPerFramebuffer(format, width, height);

const coordToIndex = (row: number, col: number) => {
return uintsPerRow * row + col * uintsPerTexel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ local_invocation_index. Tests should avoid assuming there is.
import { makeTestGroup } from '../../../../../../common/framework/test_group.js';
import { keysOf } from '../../../../../../common/util/data_tables.js';
import { iterRange } from '../../../../../../common/util/util.js';
import { kTextureFormatInfo } from '../../../../../format_info.js';
import { align } from '../../../../../util/math.js';
import { PRNG } from '../../../../../util/prng.js';

import {
Expand All @@ -22,6 +20,7 @@ import {
kFramebufferSizes,
runComputeTest,
runFragmentTest,
getUintsPerFramebuffer,
} from './subgroup_util.js';

export const g = makeTestGroup(SubgroupTest);
Expand Down Expand Up @@ -279,12 +278,7 @@ function checkFragmentAll(
width: number,
height: number
): Error | undefined {
const { blockWidth, blockHeight, bytesPerBlock } = kTextureFormatInfo[format];
const blocksPerRow = width / blockWidth;
// 256 minimum comes from image copy requirements.
const bytesPerRow = align(blocksPerRow * (bytesPerBlock ?? 1), 256);
const uintsPerRow = bytesPerRow / 4;
const uintsPerTexel = (bytesPerBlock ?? 1) / blockWidth / blockHeight / 4;
const { uintsPerRow, uintsPerTexel } = getUintsPerFramebuffer(format, width, height);

// Iteration skips last row and column to avoid helper invocations because it is not
// guaranteed whether or not they participate in the subgroup operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ local_invocation_index. Tests should avoid assuming there is.
import { makeTestGroup } from '../../../../../../common/framework/test_group.js';
import { keysOf } from '../../../../../../common/util/data_tables.js';
import { iterRange } from '../../../../../../common/util/util.js';
import { kTextureFormatInfo } from '../../../../../format_info.js';
import { align } from '../../../../../util/math.js';
import { PRNG } from '../../../../../util/prng.js';

import {
Expand All @@ -22,6 +20,7 @@ import {
runComputeTest,
runFragmentTest,
kFramebufferSizes,
getUintsPerFramebuffer,
} from './subgroup_util.js';

export const g = makeTestGroup(SubgroupTest);
Expand Down Expand Up @@ -279,12 +278,7 @@ function checkFragmentAny(
width: number,
height: number
): Error | undefined {
const { blockWidth, blockHeight, bytesPerBlock } = kTextureFormatInfo[format];
const blocksPerRow = width / blockWidth;
// 256 minimum comes from image copy requirements.
const bytesPerRow = align(blocksPerRow * (bytesPerBlock ?? 1), 256);
const uintsPerRow = bytesPerRow / 4;
const uintsPerTexel = (bytesPerBlock ?? 1) / blockWidth / blockHeight / 4;
const { uintsPerRow, uintsPerTexel } = getUintsPerFramebuffer(format, width, height);

// Iteration skips last row and column to avoid helper invocations because it is not
// guaranteed whether or not they participate in the subgroup operation.
Expand Down
Loading

0 comments on commit a0713ec

Please sign in to comment.