Skip to content

Commit

Permalink
wgsl: Add AF remainder (%) execution tests
Browse files Browse the repository at this point in the history
Issue #1626
  • Loading branch information
zoddicus committed Oct 19, 2023
1 parent 424e345 commit f9dcf2b
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 10 deletions.
14 changes: 9 additions & 5 deletions src/unittests/floating_point.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4844,14 +4844,18 @@ const kRemainderCases = {
g.test('remainderInterval')
.params(u =>
u
.combine('trait', ['f32', 'f16'] as const)
.combine('trait', ['abstract', 'f32', 'f16'] as const)
.beginSubcases()
.expandWithParams<ScalarPairToIntervalCase>(p => {
const trait = FP[p.trait];
const constants = trait.constants();
// This is a ULP based interval, so abstract should behave like f32, so
// swizzling the trait as needed.
const trait = p.trait === 'abstract' ? 'f32' : p.trait;
const fp = FP[trait];
const constants = fp.constants();

// prettier-ignore
return [
...kRemainderCases[p.trait],
...kRemainderCases[trait],
// Normals
{ input: [0, 1], expected: 0 },
{ input: [0, -1], expected: 0 },
Expand Down Expand Up @@ -4892,7 +4896,7 @@ g.test('remainderInterval')
const got = trait.remainderInterval(x, y);
t.expect(
objectEquals(expected, got),
`FP.${t.params.trait}.remainderInterval(${x}, ${y}) returned ${got}. Expected ${expected}`
`${t.params.trait}.remainderInterval(${x}, ${y}) returned ${got}. Expected ${expected}`
);
});

Expand Down
4 changes: 4 additions & 0 deletions src/webgpu/listing_meta.json
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,10 @@
"webgpu:shader,execution,expression,binary,af_multiplication:scalar_vector:*": { "subcaseMS": 2025.534 },
"webgpu:shader,execution,expression,binary,af_multiplication:vector:*": { "subcaseMS": 710.667 },
"webgpu:shader,execution,expression,binary,af_multiplication:vector_scalar:*": { "subcaseMS": 2085.300 },
"webgpu:shader,execution,expression,binary,af_remainder:scalar:*": { "subcaseMS": 1103.701 },
"webgpu:shader,execution,expression,binary,af_remainder:scalar_vector:*": { "subcaseMS": 756.800 },
"webgpu:shader,execution,expression,binary,af_remainder:vector:*": { "subcaseMS": 299.701 },
"webgpu:shader,execution,expression,binary,af_remainder:vector_scalar:*": { "subcaseMS": 777.701 },
"webgpu:shader,execution,expression,binary,af_subtraction:scalar:*": { "subcaseMS": 854.100 },
"webgpu:shader,execution,expression,binary,af_subtraction:scalar_vector:*": { "subcaseMS": 2336.534 },
"webgpu:shader,execution,expression,binary,af_subtraction:vector:*": { "subcaseMS": 764.201 },
Expand Down
157 changes: 157 additions & 0 deletions src/webgpu/shader/execution/expression/binary/af_remainder.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
export const description = `
Execution Tests for non-matrix abstract float remainder expression
`;

import { makeTestGroup } from '../../../../../common/framework/test_group.js';
import { GPUTest } from '../../../../gpu_test.js';
import { TypeAbstractFloat, TypeVec } from '../../../../util/conversion.js';
import { FP, FPVector } from '../../../../util/floating_point.js';
import {

Check failure on line 9 in src/webgpu/shader/execution/expression/binary/af_remainder.spec.ts

View workflow job for this annotation

GitHub Actions / build

Replace `⏎··sparseF64Range,⏎··sparseVectorF64Range,⏎` with `·sparseF64Range,·sparseVectorF64Range·`
sparseF64Range,
sparseVectorF64Range,
} from '../../../../util/math.js';
import { makeCaseCache } from '../case_cache.js';
import { onlyConstInputSource, run } from '../expression.js';

import { abstractBinary } from './binary.js';

const remainderVectorScalarInterval = (v: number[], s: number): FPVector => {
return FP.abstract.toVector(v.map(e => FP.abstract.remainderInterval(e, s)));
};

const remainderScalarVectorInterval = (s: number, v: number[]): FPVector => {
return FP.abstract.toVector(v.map(e => FP.abstract.remainderInterval(s, e)));
};

export const g = makeTestGroup(GPUTest);

const scalar_cases = {
['scalar']: () => {
return FP.abstract.generateScalarPairToIntervalCases(
sparseF64Range(),
sparseF64Range(),
'finite',
FP.abstract.remainderInterval
);
},
};

const vector_scalar_cases = ([2, 3, 4] as const)
.map(dim => ({
[`vec${dim}_scalar`]: () => {
return FP.abstract.generateVectorScalarToVectorCases(
sparseVectorF64Range(dim),
sparseF64Range(),
'finite',
remainderVectorScalarInterval
);
},
}))
.reduce((a, b) => ({ ...a, ...b }), {});

const scalar_vector_cases = ([2, 3, 4] as const)
.map(dim => ({
[`scalar_vec${dim}`]: () => {
return FP.abstract.generateScalarVectorToVectorCases(
sparseF64Range(),
sparseVectorF64Range(dim),
'finite',
remainderScalarVectorInterval
);
},
}))
.reduce((a, b) => ({ ...a, ...b }), {});

export const d = makeCaseCache('binary/af_remainder', {
...scalar_cases,
...vector_scalar_cases,
...scalar_vector_cases,
});

g.test('scalar')
.specURL('https://www.w3.org/TR/WGSL/#floating-point-evaluation')
.desc(
`
Expression: x % y, where x and y are scalars
Accuracy: Derived from x - y * trunc(x/y)
`
)
.params(u => u.combine('inputSource', onlyConstInputSource))
.fn(async t => {
const cases = await d.get('scalar');
await run(
t,
abstractBinary('%'),
[TypeAbstractFloat, TypeAbstractFloat],
TypeAbstractFloat,
t.params,
cases
);
});

g.test('vector')
.specURL('https://www.w3.org/TR/WGSL/#floating-point-evaluation')
.desc(
`
Expression: x % y, where x and y are vectors
Accuracy: Derived from x - y * trunc(x/y)
`
)
.params(u =>
u.combine('inputSource', onlyConstInputSource).combine('vectorize', [2, 3, 4] as const)
)
.fn(async t => {
const cases = await d.get('scalar'); // Using vectorize to generate vector cases based on scalar cases
await run(
t,
abstractBinary('%'),
[TypeAbstractFloat, TypeAbstractFloat],
TypeAbstractFloat,
t.params,
cases
);
});

g.test('vector_scalar')
.specURL('https://www.w3.org/TR/WGSL/#floating-point-evaluation')
.desc(
`
Expression: x % y, where x is a vector and y is a scalar
Accuracy: Correctly rounded
`
)
.params(u => u.combine('inputSource', onlyConstInputSource).combine('dim', [2, 3, 4] as const))
.fn(async t => {
const dim = t.params.dim;
const cases = await d.get(`vec${dim}_scalar`);
await run(
t,
abstractBinary('%'),
[TypeVec(dim, TypeAbstractFloat), TypeAbstractFloat],
TypeVec(dim, TypeAbstractFloat),
t.params,
cases
);
});

g.test('scalar_vector')
.specURL('https://www.w3.org/TR/WGSL/#floating-point-evaluation')
.desc(
`
Expression: x % y, where x is a scalar and y is a vector
Accuracy: Correctly rounded
`
)
.params(u => u.combine('inputSource', onlyConstInputSource).combine('dim', [2, 3, 4] as const))
.fn(async t => {
const dim = t.params.dim;
const cases = await d.get(`scalar_vec${dim}`);
await run(
t,
abstractBinary('%'),
[TypeAbstractFloat, TypeVec(dim, TypeAbstractFloat)],
TypeVec(dim, TypeAbstractFloat),
t.params,
cases
);
});
8 changes: 3 additions & 5 deletions src/webgpu/util/floating_point.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4998,7 +4998,6 @@ class FPAbstractTraits extends FPTraits {
): FPInterval => {
return this.toInterval(kF32Traits.divisionInterval(x, y));
};

public readonly dotInterval = this.unimplementedVectorPairToInterval.bind(this, 'dotInterval');
public readonly expInterval = this.unimplementedScalarToInterval.bind(this, 'expInterval');
public readonly exp2Interval = this.unimplementedScalarToInterval.bind(this, 'exp2Interval');
Expand Down Expand Up @@ -5066,10 +5065,9 @@ class FPAbstractTraits extends FPTraits {
'reflectInterval'
);
public readonly refractInterval = this.unimplementedRefract.bind(this);
public readonly remainderInterval = this.unimplementedScalarPairToInterval.bind(
this,
'remainderInterval'
);
public readonly remainderInterval = (x: number, y: number): FPInterval => {
return this.toInterval(kF32Traits.remainderInterval(x, y));
};
public readonly roundInterval = this.unimplementedScalarToInterval.bind(this, 'roundInterval');
public readonly saturateInterval = this.saturateIntervalImpl.bind(this);
public readonly signInterval = this.unimplementedScalarToInterval.bind(this, 'signInterval');
Expand Down

0 comments on commit f9dcf2b

Please sign in to comment.