Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wgsl: Cleanup cruft related to quantizeToF16 #3082

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/unittests/floating_point.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3423,7 +3423,7 @@ g.test('negationInterval')
);
});

g.test('quantizeToF16Interval_f32')
g.test('quantizeToF16Interval')
.paramsSubcasesOnly<ScalarToIntervalCase>(
// prettier-ignore
[
Expand Down
41 changes: 15 additions & 26 deletions src/webgpu/util/floating_point.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4608,6 +4608,21 @@ class F32Traits extends FPTraits {

// Framework - API

private readonly QuantizeToF16IntervalOp: ScalarToIntervalOp = {
impl: (n: number): FPInterval => {
const rounded = correctlyRoundedF16(n);
const flushed = addFlushedIfNeededF16(rounded);
return this.spanIntervals(...flushed.map(f => this.toInterval(f)));
},
};

protected quantizeToF16IntervalImpl(n: number): FPInterval {
return this.runScalarToIntervalOp(this.toInterval(n), this.QuantizeToF16IntervalOp);
}

/** Calculate an acceptance interval of quantizeToF16(x) */
public readonly quantizeToF16Interval = this.quantizeToF16IntervalImpl.bind(this);

/**
* Once-allocated ArrayBuffer/views to avoid overhead of allocation when
* converting between numeric formats
Expand Down Expand Up @@ -4720,21 +4735,6 @@ class F32Traits extends FPTraits {

/** Calculate an acceptance interval vector for unpack4x8unorm(x) */
public readonly unpack4x8unormInterval = this.unpack4x8unormIntervalImpl.bind(this);

private readonly QuantizeToF16IntervalOp: ScalarToIntervalOp = {
impl: (n: number): FPInterval => {
const rounded = correctlyRoundedF16(n);
const flushed = addFlushedIfNeededF16(rounded);
return this.spanIntervals(...flushed.map(f => this.toInterval(f)));
},
};

protected quantizeToF16IntervalImpl(n: number): FPInterval {
return this.runScalarToIntervalOp(this.toInterval(n), this.QuantizeToF16IntervalOp);
}

/** Calculate an acceptance interval of quantizeToF16(x) */
public readonly quantizeToF16Interval = this.quantizeToF16IntervalImpl.bind(this);
}

// Need to separately allocate f32 traits, so they can be referenced by
Expand Down Expand Up @@ -5055,10 +5055,6 @@ class FPAbstractTraits extends FPTraits {
'normalizeInterval'
);
public readonly powInterval = this.unimplementedScalarPairToInterval.bind(this, 'powInterval');
public readonly quantizeToF16Interval = this.unimplementedScalarToInterval.bind(
this,
'quantizeToF16Interval'
);
public readonly radiansInterval = this.radiansIntervalImpl.bind(this);
public readonly reflectInterval = this.unimplementedVectorPairToVector.bind(
this,
Expand Down Expand Up @@ -5353,7 +5349,6 @@ class F16Traits extends FPTraits {
public readonly negationInterval = this.negationIntervalImpl.bind(this);
public readonly normalizeInterval = this.normalizeIntervalImpl.bind(this);
public readonly powInterval = this.powIntervalImpl.bind(this);
public readonly quantizeToF16Interval = this.quantizeToF16IntervalNotAvailable.bind(this);
public readonly radiansInterval = this.radiansIntervalImpl.bind(this);
public readonly reflectInterval = this.reflectIntervalImpl.bind(this);
public readonly refractInterval = this.refractIntervalImpl.bind(this);
Expand All @@ -5374,12 +5369,6 @@ class F16Traits extends FPTraits {
public readonly tanhInterval = this.tanhIntervalImpl.bind(this);
public readonly transposeInterval = this.transposeIntervalImpl.bind(this);
public readonly truncInterval = this.truncIntervalImpl.bind(this);

/** quantizeToF16 has no f16 overload. */
private quantizeToF16IntervalNotAvailable(n: number): FPInterval {
unreachable("quantizeToF16 don't have f16 overload.");
return kF16UnboundedInterval;
}
}

export const FP = {
Expand Down
Loading