From 55e144b978345f7e03ec7c897571b3265404949d Mon Sep 17 00:00:00 2001 From: Ryan Harrison Date: Mon, 23 Oct 2023 16:03:41 -0400 Subject: [PATCH] wgsl: Cleanup cruft related to quantizeToF16 This is only defined for f32, so doesn't really need to be defined in the common super class. This allows for removing the various stub references to it, that will never be implemented. --- src/unittests/floating_point.spec.ts | 2 +- src/webgpu/util/floating_point.ts | 41 ++++++++++------------------ 2 files changed, 16 insertions(+), 27 deletions(-) diff --git a/src/unittests/floating_point.spec.ts b/src/unittests/floating_point.spec.ts index 4c32e11459af..44b87f42e0ad 100644 --- a/src/unittests/floating_point.spec.ts +++ b/src/unittests/floating_point.spec.ts @@ -3423,7 +3423,7 @@ g.test('negationInterval') ); }); -g.test('quantizeToF16Interval_f32') +g.test('quantizeToF16Interval') .paramsSubcasesOnly( // prettier-ignore [ diff --git a/src/webgpu/util/floating_point.ts b/src/webgpu/util/floating_point.ts index e9f6271c323b..9ccfa3b3a660 100644 --- a/src/webgpu/util/floating_point.ts +++ b/src/webgpu/util/floating_point.ts @@ -4608,6 +4608,21 @@ class F32Traits extends FPTraits { // Framework - API + private readonly QuantizeToF16IntervalOp: ScalarToIntervalOp = { + impl: (n: number): FPInterval => { + const rounded = correctlyRoundedF16(n); + const flushed = addFlushedIfNeededF16(rounded); + return this.spanIntervals(...flushed.map(f => this.toInterval(f))); + }, + }; + + protected quantizeToF16IntervalImpl(n: number): FPInterval { + return this.runScalarToIntervalOp(this.toInterval(n), this.QuantizeToF16IntervalOp); + } + + /** Calculate an acceptance interval of quantizeToF16(x) */ + public readonly quantizeToF16Interval = this.quantizeToF16IntervalImpl.bind(this); + /** * Once-allocated ArrayBuffer/views to avoid overhead of allocation when * converting between numeric formats @@ -4720,21 +4735,6 @@ class F32Traits extends FPTraits { /** Calculate an acceptance interval vector for unpack4x8unorm(x) */ public readonly unpack4x8unormInterval = this.unpack4x8unormIntervalImpl.bind(this); - - private readonly QuantizeToF16IntervalOp: ScalarToIntervalOp = { - impl: (n: number): FPInterval => { - const rounded = correctlyRoundedF16(n); - const flushed = addFlushedIfNeededF16(rounded); - return this.spanIntervals(...flushed.map(f => this.toInterval(f))); - }, - }; - - protected quantizeToF16IntervalImpl(n: number): FPInterval { - return this.runScalarToIntervalOp(this.toInterval(n), this.QuantizeToF16IntervalOp); - } - - /** Calculate an acceptance interval of quantizeToF16(x) */ - public readonly quantizeToF16Interval = this.quantizeToF16IntervalImpl.bind(this); } // Need to separately allocate f32 traits, so they can be referenced by @@ -5055,10 +5055,6 @@ class FPAbstractTraits extends FPTraits { 'normalizeInterval' ); public readonly powInterval = this.unimplementedScalarPairToInterval.bind(this, 'powInterval'); - public readonly quantizeToF16Interval = this.unimplementedScalarToInterval.bind( - this, - 'quantizeToF16Interval' - ); public readonly radiansInterval = this.radiansIntervalImpl.bind(this); public readonly reflectInterval = this.unimplementedVectorPairToVector.bind( this, @@ -5353,7 +5349,6 @@ class F16Traits extends FPTraits { public readonly negationInterval = this.negationIntervalImpl.bind(this); public readonly normalizeInterval = this.normalizeIntervalImpl.bind(this); public readonly powInterval = this.powIntervalImpl.bind(this); - public readonly quantizeToF16Interval = this.quantizeToF16IntervalNotAvailable.bind(this); public readonly radiansInterval = this.radiansIntervalImpl.bind(this); public readonly reflectInterval = this.reflectIntervalImpl.bind(this); public readonly refractInterval = this.refractIntervalImpl.bind(this); @@ -5374,12 +5369,6 @@ class F16Traits extends FPTraits { public readonly tanhInterval = this.tanhIntervalImpl.bind(this); public readonly transposeInterval = this.transposeIntervalImpl.bind(this); public readonly truncInterval = this.truncIntervalImpl.bind(this); - - /** quantizeToF16 has no f16 overload. */ - private quantizeToF16IntervalNotAvailable(n: number): FPInterval { - unreachable("quantizeToF16 don't have f16 overload."); - return kF16UnboundedInterval; - } } export const FP = {