Skip to content

Commit

Permalink
Rework toInterval to do repacking
Browse files Browse the repository at this point in the history
This allows for removing wrapper utilities and directly use arrow
functions instead.
  • Loading branch information
zoddicus committed Oct 19, 2023
1 parent f58eab5 commit bf48a54
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 37 deletions.
62 changes: 25 additions & 37 deletions src/webgpu/util/floating_point.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import {
quantizeToF32,
quantizeToF16,
unflatten2DArray,
every2DArray,
} from './math.js';

/** Indicate the kind of WGSL floating point numbers being operated on */
Expand Down Expand Up @@ -630,12 +631,19 @@ export abstract class FPTraits {
public abstract constants(): FPConstants;

// Utilities - Implemented

/** @returns an interval containing the point or the original interval */
public toInterval(n: number | IntervalBounds | FPInterval): FPInterval {
if (n instanceof FPInterval) {
if (n.kind === this.kind) {
return n;
}

// Preserve if the original interval was unbounded or bounded
if (!n.isFinite()) {
return this.constants().unboundedInterval;
}

return new FPInterval(this.kind, ...n.bounds());
}

Expand Down Expand Up @@ -699,7 +707,7 @@ export abstract class FPTraits {

/** @returns an FPVector representation of an array of values if possible */
public toVector(v: (number | IntervalBounds | FPInterval)[]): FPVector {
if (this.isVector(v)) {
if (this.isVector(v) && v.every(e => e.kind === this.kind)) {
return v;
}

Expand Down Expand Up @@ -763,7 +771,12 @@ export abstract class FPTraits {

/** @returns an FPMatrix representation of an array of an array of values if possible */
public toMatrix(m: Array2D<number | IntervalBounds | FPInterval> | FPVector[]): FPMatrix {
if (this.isMatrix(m)) {
if (
this.isMatrix(m) &&
every2DArray(m, (e: FPInterval) => {
return e.kind === this.kind;
})
) {
return m;
}

Expand Down Expand Up @@ -4921,36 +4934,6 @@ class FPAbstractTraits extends FPTraits {
return FPAbstractTraits._constants;
}

// Utilities - Proxies
// Wrappers for forwarding ULP and absolute error interval calls to f32.
// AbstractFloat accuracies are technically unbounded for ULP and absolute
// error interval, but testing that implementations are at least as good as
// f32.

/** Forwarder for ULPInterval */
protected forwardUlpInterval(n: number, numULP: number): FPInterval {
const result = FP['f32'].ulpInterval(n, numULP);
if (!result.isFinite()) {
return this.constants().unboundedInterval;
}

return this.toInterval(result.bounds());
}

/** Forwarder for scalar pair to interval generator */
protected forwardScalarPairToInterval(
func: (x: number | FPInterval, y: number | FPInterval) => FPInterval,
x: number | FPInterval,
y: number | FPInterval
): FPInterval {
const result = func(x, y);
if (!result.isFinite()) {
return this.constants().unboundedInterval;
}

return this.toInterval(result.bounds());
}

// Utilities - Overrides
// number is represented as a f64 internally, so all number values are already
// quantized to f64
Expand All @@ -4970,7 +4953,9 @@ class FPAbstractTraits extends FPTraits {
public readonly absoluteErrorInterval = this.unboundedAbsoluteErrorInterval.bind(this);
public readonly correctlyRoundedInterval = this.correctlyRoundedIntervalImpl.bind(this);
public readonly correctlyRoundedMatrix = this.correctlyRoundedMatrixImpl.bind(this);
public readonly ulpInterval = this.forwardUlpInterval.bind(this);
public readonly ulpInterval = (n: number, numULP: number): FPInterval => {
return this.toInterval(kF32Traits.ulpInterval(n, numULP));
};

// Framework - API - Overrides
public readonly absInterval = this.absIntervalImpl.bind(this);
Expand Down Expand Up @@ -5007,10 +4992,13 @@ class FPAbstractTraits extends FPTraits {
'determinantInterval'
);
public readonly distanceInterval = this.unimplementedDistance.bind(this);
public readonly divisionInterval = this.forwardScalarPairToInterval.bind(
this,
kF32Traits.divisionInterval
);
public readonly divisionInterval = (
x: number | FPInterval,
y: number | FPInterval
): FPInterval => {
return this.toInterval(kF32Traits.divisionInterval(x, y));
};

public readonly dotInterval = this.unimplementedVectorPairToInterval.bind(this, 'dotInterval');
public readonly expInterval = this.unimplementedScalarToInterval.bind(this, 'expInterval');
public readonly exp2Interval = this.unimplementedScalarToInterval.bind(this, 'exp2Interval');
Expand Down
24 changes: 24 additions & 0 deletions src/webgpu/util/math.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2241,3 +2241,27 @@ export function map2DArray<T, S>(m: T[][], op: (input: T) => S): S[][] {
}
return result;
}

/**
* Performs a .every over a matrix and return the result
*
* @param m input matrix of type T
* @param op operation that performs a test on an element
* @returns a boolean indicating if the test passed for every element
*/
export function every2DArray<T>(m: T[][], op: (input: T) => boolean): boolean {
const c = m.length;
const r = m[0].length;
assert(
m.every(c => c.length === r),
`Unexpectedly received jagged array to map`
);
for (let i = 0; i < c; i++) {
for (let j = 0; j < r; j++) {
if (!op(m[i][j])) {
return false;
}
}
}
return true;
}

0 comments on commit bf48a54

Please sign in to comment.