diff --git a/src/unittests/floating_point.spec.ts b/src/unittests/floating_point.spec.ts index 3d73d4254577..da9110bc2d70 100644 --- a/src/unittests/floating_point.spec.ts +++ b/src/unittests/floating_point.spec.ts @@ -3431,7 +3431,7 @@ g.test('negationInterval') ); }); -g.test('quantizeToF16Interval_f32') +g.test('quantizeToF16Interval') .paramsSubcasesOnly( // prettier-ignore [ diff --git a/src/webgpu/util/floating_point.ts b/src/webgpu/util/floating_point.ts index 20f74483df31..fc76f85260fa 100644 --- a/src/webgpu/util/floating_point.ts +++ b/src/webgpu/util/floating_point.ts @@ -4608,6 +4608,21 @@ class F32Traits extends FPTraits { // Framework - API + private readonly QuantizeToF16IntervalOp: ScalarToIntervalOp = { + impl: (n: number): FPInterval => { + const rounded = correctlyRoundedF16(n); + const flushed = addFlushedIfNeededF16(rounded); + return this.spanIntervals(...flushed.map(f => this.toInterval(f))); + }, + }; + + protected quantizeToF16IntervalImpl(n: number): FPInterval { + return this.runScalarToIntervalOp(this.toInterval(n), this.QuantizeToF16IntervalOp); + } + + /** Calculate an acceptance interval of quantizeToF16(x) */ + public readonly quantizeToF16Interval = this.quantizeToF16IntervalImpl.bind(this); + /** * Once-allocated ArrayBuffer/views to avoid overhead of allocation when * converting between numeric formats @@ -4720,21 +4735,6 @@ class F32Traits extends FPTraits { /** Calculate an acceptance interval vector for unpack4x8unorm(x) */ public readonly unpack4x8unormInterval = this.unpack4x8unormIntervalImpl.bind(this); - - private readonly QuantizeToF16IntervalOp: ScalarToIntervalOp = { - impl: (n: number): FPInterval => { - const rounded = correctlyRoundedF16(n); - const flushed = addFlushedIfNeededF16(rounded); - return this.spanIntervals(...flushed.map(f => this.toInterval(f))); - }, - }; - - protected quantizeToF16IntervalImpl(n: number): FPInterval { - return this.runScalarToIntervalOp(this.toInterval(n), this.QuantizeToF16IntervalOp); - } - - /** Calculate an acceptance interval of quantizeToF16(x) */ - public readonly quantizeToF16Interval = this.quantizeToF16IntervalImpl.bind(this); } // Need to separately allocate f32 traits, so they can be referenced by @@ -5055,10 +5055,6 @@ class FPAbstractTraits extends FPTraits { 'normalizeInterval' ); public readonly powInterval = this.unimplementedScalarPairToInterval.bind(this, 'powInterval'); - public readonly quantizeToF16Interval = this.unimplementedScalarToInterval.bind( - this, - 'quantizeToF16Interval' - ); public readonly radiansInterval = this.radiansIntervalImpl.bind(this); public readonly reflectInterval = this.unimplementedVectorPairToVector.bind( this, @@ -5353,7 +5349,6 @@ class F16Traits extends FPTraits { public readonly negationInterval = this.negationIntervalImpl.bind(this); public readonly normalizeInterval = this.normalizeIntervalImpl.bind(this); public readonly powInterval = this.powIntervalImpl.bind(this); - public readonly quantizeToF16Interval = this.quantizeToF16IntervalNotAvailable.bind(this); public readonly radiansInterval = this.radiansIntervalImpl.bind(this); public readonly reflectInterval = this.reflectIntervalImpl.bind(this); public readonly refractInterval = this.refractIntervalImpl.bind(this); @@ -5374,12 +5369,6 @@ class F16Traits extends FPTraits { public readonly tanhInterval = this.tanhIntervalImpl.bind(this); public readonly transposeInterval = this.transposeIntervalImpl.bind(this); public readonly truncInterval = this.truncIntervalImpl.bind(this); - - /** quantizeToF16 has no f16 overload. */ - private quantizeToF16IntervalNotAvailable(n: number): FPInterval { - unreachable("quantizeToF16 don't have f16 overload."); - return kF16UnboundedInterval; - } } export const FP = {