diff --git a/src/webgpu/util/floating_point.ts b/src/webgpu/util/floating_point.ts index 06f496c26457..efe4119fc4ba 100644 --- a/src/webgpu/util/floating_point.ts +++ b/src/webgpu/util/floating_point.ts @@ -43,6 +43,7 @@ import { quantizeToF32, quantizeToF16, unflatten2DArray, + every2DArray, } from './math.js'; /** Indicate the kind of WGSL floating point numbers being operated on */ @@ -630,12 +631,19 @@ export abstract class FPTraits { public abstract constants(): FPConstants; // Utilities - Implemented + /** @returns an interval containing the point or the original interval */ public toInterval(n: number | IntervalBounds | FPInterval): FPInterval { if (n instanceof FPInterval) { if (n.kind === this.kind) { return n; } + + // Preserve if the original interval was unbounded or bounded + if (!n.isFinite()) { + return this.constants().unboundedInterval; + } + return new FPInterval(this.kind, ...n.bounds()); } @@ -699,7 +707,7 @@ export abstract class FPTraits { /** @returns an FPVector representation of an array of values if possible */ public toVector(v: (number | IntervalBounds | FPInterval)[]): FPVector { - if (this.isVector(v)) { + if (this.isVector(v) && v.every(e => e.kind === this.kind)) { return v; } @@ -763,7 +771,12 @@ export abstract class FPTraits { /** @returns an FPMatrix representation of an array of an array of values if possible */ public toMatrix(m: Array2D | FPVector[]): FPMatrix { - if (this.isMatrix(m)) { + if ( + this.isMatrix(m) && + every2DArray(m, (e: FPInterval) => { + return e.kind === this.kind; + }) + ) { return m; } @@ -4921,36 +4934,6 @@ class FPAbstractTraits extends FPTraits { return FPAbstractTraits._constants; } - // Utilities - Proxies - // Wrappers for forwarding ULP and absolute error interval calls to f32. - // AbstractFloat accuracies are technically unbounded for ULP and absolute - // error interval, but testing that implementations are at least as good as - // f32. - - /** Forwarder for ULPInterval */ - protected forwardUlpInterval(n: number, numULP: number): FPInterval { - const result = FP['f32'].ulpInterval(n, numULP); - if (!result.isFinite()) { - return this.constants().unboundedInterval; - } - - return this.toInterval(result.bounds()); - } - - /** Forwarder for scalar pair to interval generator */ - protected forwardScalarPairToInterval( - func: (x: number | FPInterval, y: number | FPInterval) => FPInterval, - x: number | FPInterval, - y: number | FPInterval - ): FPInterval { - const result = func(x, y); - if (!result.isFinite()) { - return this.constants().unboundedInterval; - } - - return this.toInterval(result.bounds()); - } - // Utilities - Overrides // number is represented as a f64 internally, so all number values are already // quantized to f64 @@ -4970,7 +4953,9 @@ class FPAbstractTraits extends FPTraits { public readonly absoluteErrorInterval = this.unboundedAbsoluteErrorInterval.bind(this); public readonly correctlyRoundedInterval = this.correctlyRoundedIntervalImpl.bind(this); public readonly correctlyRoundedMatrix = this.correctlyRoundedMatrixImpl.bind(this); - public readonly ulpInterval = this.forwardUlpInterval.bind(this); + public readonly ulpInterval = (n: number, numULP: number): FPInterval => { + return this.toInterval(kF32Traits.ulpInterval(n, numULP)); + }; // Framework - API - Overrides public readonly absInterval = this.absIntervalImpl.bind(this); @@ -5007,10 +4992,13 @@ class FPAbstractTraits extends FPTraits { 'determinantInterval' ); public readonly distanceInterval = this.unimplementedDistance.bind(this); - public readonly divisionInterval = this.forwardScalarPairToInterval.bind( - this, - kF32Traits.divisionInterval - ); + public readonly divisionInterval = ( + x: number | FPInterval, + y: number | FPInterval + ): FPInterval => { + return this.toInterval(kF32Traits.divisionInterval(x, y)); + }; + public readonly dotInterval = this.unimplementedVectorPairToInterval.bind(this, 'dotInterval'); public readonly expInterval = this.unimplementedScalarToInterval.bind(this, 'expInterval'); public readonly exp2Interval = this.unimplementedScalarToInterval.bind(this, 'exp2Interval'); diff --git a/src/webgpu/util/math.ts b/src/webgpu/util/math.ts index cc7b5e44a99f..8f76124ac951 100644 --- a/src/webgpu/util/math.ts +++ b/src/webgpu/util/math.ts @@ -2241,3 +2241,27 @@ export function map2DArray(m: T[][], op: (input: T) => S): S[][] { } return result; } + +/** + * Performs a .every over a matrix and return the result + * + * @param m input matrix of type T + * @param op operation that performs a test on an element + * @returns a boolean indicating if the test passed for every element + */ +export function every2DArray(m: T[][], op: (input: T) => boolean): boolean { + const c = m.length; + const r = m[0].length; + assert( + m.every(c => c.length === r), + `Unexpectedly received jagged array to map` + ); + for (let i = 0; i < c; i++) { + for (let j = 0; j < r; j++) { + if (!op(m[i][j])) { + return false; + } + } + } + return true; +}