From bf48a5498844ac47820c764e0275df26d65924cf Mon Sep 17 00:00:00 2001
From: Ryan Harrison <rharrison@google.com>
Date: Thu, 19 Oct 2023 14:09:57 -0400
Subject: [PATCH] Rework toInterval to do repacking

This allows for removing wrapper utilities and directly use arrow
functions instead.
---
 src/webgpu/util/floating_point.ts | 62 +++++++++++++------------------
 src/webgpu/util/math.ts           | 24 ++++++++++++
 2 files changed, 49 insertions(+), 37 deletions(-)

diff --git a/src/webgpu/util/floating_point.ts b/src/webgpu/util/floating_point.ts
index 06f496c26457..efe4119fc4ba 100644
--- a/src/webgpu/util/floating_point.ts
+++ b/src/webgpu/util/floating_point.ts
@@ -43,6 +43,7 @@ import {
   quantizeToF32,
   quantizeToF16,
   unflatten2DArray,
+  every2DArray,
 } from './math.js';
 
 /** Indicate the kind of WGSL floating point numbers being operated on */
@@ -630,12 +631,19 @@ export abstract class FPTraits {
   public abstract constants(): FPConstants;
 
   // Utilities - Implemented
+
   /** @returns an interval containing the point or the original interval */
   public toInterval(n: number | IntervalBounds | FPInterval): FPInterval {
     if (n instanceof FPInterval) {
       if (n.kind === this.kind) {
         return n;
       }
+
+      // Preserve if the original interval was unbounded or bounded
+      if (!n.isFinite()) {
+        return this.constants().unboundedInterval;
+      }
+
       return new FPInterval(this.kind, ...n.bounds());
     }
 
@@ -699,7 +707,7 @@ export abstract class FPTraits {
 
   /** @returns an FPVector representation of an array of values if possible */
   public toVector(v: (number | IntervalBounds | FPInterval)[]): FPVector {
-    if (this.isVector(v)) {
+    if (this.isVector(v) && v.every(e => e.kind === this.kind)) {
       return v;
     }
 
@@ -763,7 +771,12 @@ export abstract class FPTraits {
 
   /** @returns an FPMatrix representation of an array of an array of values if possible */
   public toMatrix(m: Array2D<number | IntervalBounds | FPInterval> | FPVector[]): FPMatrix {
-    if (this.isMatrix(m)) {
+    if (
+      this.isMatrix(m) &&
+      every2DArray(m, (e: FPInterval) => {
+        return e.kind === this.kind;
+      })
+    ) {
       return m;
     }
 
@@ -4921,36 +4934,6 @@ class FPAbstractTraits extends FPTraits {
     return FPAbstractTraits._constants;
   }
 
-  // Utilities - Proxies
-  // Wrappers for forwarding ULP and absolute error interval calls to f32.
-  // AbstractFloat accuracies are technically unbounded for ULP and absolute
-  // error interval, but testing that implementations are at least as good as
-  // f32.
-
-  /** Forwarder for ULPInterval */
-  protected forwardUlpInterval(n: number, numULP: number): FPInterval {
-    const result = FP['f32'].ulpInterval(n, numULP);
-    if (!result.isFinite()) {
-      return this.constants().unboundedInterval;
-    }
-
-    return this.toInterval(result.bounds());
-  }
-
-  /** Forwarder for scalar pair to interval generator */
-  protected forwardScalarPairToInterval(
-    func: (x: number | FPInterval, y: number | FPInterval) => FPInterval,
-    x: number | FPInterval,
-    y: number | FPInterval
-  ): FPInterval {
-    const result = func(x, y);
-    if (!result.isFinite()) {
-      return this.constants().unboundedInterval;
-    }
-
-    return this.toInterval(result.bounds());
-  }
-
   // Utilities - Overrides
   // number is represented as a f64 internally, so all number values are already
   // quantized to f64
@@ -4970,7 +4953,9 @@ class FPAbstractTraits extends FPTraits {
   public readonly absoluteErrorInterval = this.unboundedAbsoluteErrorInterval.bind(this);
   public readonly correctlyRoundedInterval = this.correctlyRoundedIntervalImpl.bind(this);
   public readonly correctlyRoundedMatrix = this.correctlyRoundedMatrixImpl.bind(this);
-  public readonly ulpInterval = this.forwardUlpInterval.bind(this);
+  public readonly ulpInterval = (n: number, numULP: number): FPInterval => {
+    return this.toInterval(kF32Traits.ulpInterval(n, numULP));
+  };
 
   // Framework - API - Overrides
   public readonly absInterval = this.absIntervalImpl.bind(this);
@@ -5007,10 +4992,13 @@ class FPAbstractTraits extends FPTraits {
     'determinantInterval'
   );
   public readonly distanceInterval = this.unimplementedDistance.bind(this);
-  public readonly divisionInterval = this.forwardScalarPairToInterval.bind(
-    this,
-    kF32Traits.divisionInterval
-  );
+  public readonly divisionInterval = (
+    x: number | FPInterval,
+    y: number | FPInterval
+  ): FPInterval => {
+    return this.toInterval(kF32Traits.divisionInterval(x, y));
+  };
+
   public readonly dotInterval = this.unimplementedVectorPairToInterval.bind(this, 'dotInterval');
   public readonly expInterval = this.unimplementedScalarToInterval.bind(this, 'expInterval');
   public readonly exp2Interval = this.unimplementedScalarToInterval.bind(this, 'exp2Interval');
diff --git a/src/webgpu/util/math.ts b/src/webgpu/util/math.ts
index cc7b5e44a99f..8f76124ac951 100644
--- a/src/webgpu/util/math.ts
+++ b/src/webgpu/util/math.ts
@@ -2241,3 +2241,27 @@ export function map2DArray<T, S>(m: T[][], op: (input: T) => S): S[][] {
   }
   return result;
 }
+
+/**
+ * Performs a .every over a matrix and return the result
+ *
+ * @param m input matrix of type T
+ * @param op operation that performs a test on an element
+ * @returns a boolean indicating if the test passed for every element
+ */
+export function every2DArray<T>(m: T[][], op: (input: T) => boolean): boolean {
+  const c = m.length;
+  const r = m[0].length;
+  assert(
+    m.every(c => c.length === r),
+    `Unexpectedly received jagged array to map`
+  );
+  for (let i = 0; i < c; i++) {
+    for (let j = 0; j < r; j++) {
+      if (!op(m[i][j])) {
+        return false;
+      }
+    }
+  }
+  return true;
+}