diff --git a/src/webgpu/util/conversion.ts b/src/webgpu/util/conversion.ts index e1aa31566e38..5ee3e30c36ac 100644 --- a/src/webgpu/util/conversion.ts +++ b/src/webgpu/util/conversion.ts @@ -66,6 +66,24 @@ export function numbersApproximatelyEqual(a: number, b: number, maxDiff: number ); } +/** + * Once-allocated ArrayBuffer/views to avoid overhead of allocation when converting between numeric formats + * + * workingData* is shared between multiple functions in this file, so to avoid re-entrancy problems, make sure in + * functions that use it that they don't call themselves or other functions that use workingData*. + */ +const workingData = new ArrayBuffer(8); +const workingDataU32 = new Uint32Array(workingData); +const workingDataU16 = new Uint16Array(workingData); +const workingDataU8 = new Uint8Array(workingData); +const workingDataF32 = new Float32Array(workingData); +const workingDataF16 = new Float16Array(workingData); +const workingDataI16 = new Int16Array(workingData); +const workingDataI32 = new Int32Array(workingData); +const workingDataI8 = new Int8Array(workingData); +const workingDataF64 = new Float64Array(workingData); +const workingDataView = new DataView(workingData); + /** * Encodes a JS `number` into an IEEE754 floating point number with the specified number of * sign, exponent, mantissa bits, and exponent bias. @@ -91,9 +109,8 @@ export function float32ToFloatBits( return (((1 << exponentBits) - 1) << mantissaBits) | ((1 << mantissaBits) - 1); } - const buf = new DataView(new ArrayBuffer(Float32Array.BYTES_PER_ELEMENT)); - buf.setFloat32(0, n, true); - const bits = buf.getUint32(0, true); + workingDataView.setFloat32(0, n, true); + const bits = workingDataView.getUint32(0, true); // bits (32): seeeeeeeefffffffffffffffffffffff // 0 or 1 @@ -165,21 +182,6 @@ export const kFloat16Format = { signed: 1, exponentBits: 5, mantissaBits: 10, bi /** FloatFormat for 9 bit mantissa, 5 bit exponent unsigned float */ export const kUFloat9e5Format = { signed: 0, exponentBits: 5, mantissaBits: 9, bias: 15 } as const; -/** - * Once-allocated ArrayBuffer/views to avoid overhead of allocation when converting between numeric formats - * - * workingData* is shared between multiple functions in this file, so to avoid re-entrancy problems, make sure in - * functions that use it that they don't call themselves or other functions that use workingData*. - */ -const workingData = new ArrayBuffer(4); -const workingDataU32 = new Uint32Array(workingData); -const workingDataU16 = new Uint16Array(workingData); -const workingDataU8 = new Uint8Array(workingData); -const workingDataF32 = new Float32Array(workingData); -const workingDataF16 = new Float16Array(workingData); -const workingDataI16 = new Int16Array(workingData); -const workingDataI8 = new Int8Array(workingData); - /** Bitcast u32 (represented as integer Number) to f32 (represented as floating-point Number). */ export function float32BitsToNumber(bits: number): number { workingDataU32[0] = bits; @@ -531,58 +533,44 @@ export function gammaDecompress(n: number): number { /** Converts a 32-bit float value to a 32-bit unsigned integer value */ export function float32ToUint32(f32: number): number { - const f32Arr = new Float32Array(1); - f32Arr[0] = f32; - const u32Arr = new Uint32Array(f32Arr.buffer); - return u32Arr[0]; + workingDataF32[0] = f32; + return workingDataU32[0]; } /** Converts a 32-bit unsigned integer value to a 32-bit float value */ export function uint32ToFloat32(u32: number): number { - const u32Arr = new Uint32Array(1); - u32Arr[0] = u32; - const f32Arr = new Float32Array(u32Arr.buffer); - return f32Arr[0]; + workingDataU32[0] = u32; + return workingDataF32[0]; } /** Converts a 32-bit float value to a 32-bit signed integer value */ export function float32ToInt32(f32: number): number { - const f32Arr = new Float32Array(1); - f32Arr[0] = f32; - const i32Arr = new Int32Array(f32Arr.buffer); - return i32Arr[0]; + workingDataF32[0] = f32; + return workingDataI32[0]; } /** Converts a 32-bit unsigned integer value to a 32-bit signed integer value */ export function uint32ToInt32(u32: number): number { - const u32Arr = new Uint32Array(1); - u32Arr[0] = u32; - const i32Arr = new Int32Array(u32Arr.buffer); - return i32Arr[0]; + workingDataU32[0] = u32; + return workingDataI32[0]; } /** Converts a 16-bit float value to a 16-bit unsigned integer value */ export function float16ToUint16(f16: number): number { - const f16Arr = new Float16Array(1); - f16Arr[0] = f16; - const u16Arr = new Uint16Array(f16Arr.buffer); - return u16Arr[0]; + workingDataF16[0] = f16; + return workingDataU16[0]; } /** Converts a 16-bit unsigned integer value to a 16-bit float value */ export function uint16ToFloat16(u16: number): number { - const u16Arr = new Uint16Array(1); - u16Arr[0] = u16; - const f16Arr = new Float16Array(u16Arr.buffer); - return f16Arr[0]; + workingDataU16[0] = u16; + return workingDataF16[0]; } /** Converts a 16-bit float value to a 16-bit signed integer value */ export function float16ToInt16(f16: number): number { - const f16Arr = new Float16Array(1); - f16Arr[0] = f16; - const i16Arr = new Int16Array(f16Arr.buffer); - return i16Arr[0]; + workingDataF16[0] = f16; + return workingDataI16[0]; } /** A type of number representable by Scalar. */ @@ -764,40 +752,48 @@ export function TypeMat(cols: number, rows: number, elementType: ScalarType): Ma /** Type is a ScalarType, VectorType, or MatrixType. */ export type Type = ScalarType | VectorType | MatrixType; +/** Copy bytes from `buf` at `offset` into the working data, then read it out using `workingDataOut` */ +function valueFromBytes(workingDataOut: TypedArrayBufferView, buf: Uint8Array, offset: number) { + for (let i = 0; i < workingDataOut.BYTES_PER_ELEMENT; ++i) { + workingDataU8[i] = buf[offset + i]; + } + return workingDataOut[0]; +} + export const TypeI32 = new ScalarType('i32', 4, (buf: Uint8Array, offset: number) => - i32(new Int32Array(buf.buffer, offset)[0]) + i32(valueFromBytes(workingDataI32, buf, offset)) ); export const TypeU32 = new ScalarType('u32', 4, (buf: Uint8Array, offset: number) => - u32(new Uint32Array(buf.buffer, offset)[0]) + u32(valueFromBytes(workingDataU32, buf, offset)) ); export const TypeAbstractFloat = new ScalarType( 'abstract-float', 8, - (buf: Uint8Array, offset: number) => abstractFloat(new Float64Array(buf.buffer, offset)[0]) + (buf: Uint8Array, offset: number) => abstractFloat(valueFromBytes(workingDataF64, buf, offset)) ); export const TypeF64 = new ScalarType('f64', 8, (buf: Uint8Array, offset: number) => - f64(new Float64Array(buf.buffer, offset)[0]) + f64(valueFromBytes(workingDataF64, buf, offset)) ); export const TypeF32 = new ScalarType('f32', 4, (buf: Uint8Array, offset: number) => - f32(new Float32Array(buf.buffer, offset)[0]) + f32(valueFromBytes(workingDataF32, buf, offset)) ); export const TypeI16 = new ScalarType('i16', 2, (buf: Uint8Array, offset: number) => - i16(new Int16Array(buf.buffer, offset)[0]) + i16(valueFromBytes(workingDataI16, buf, offset)) ); export const TypeU16 = new ScalarType('u16', 2, (buf: Uint8Array, offset: number) => - u16(new Uint16Array(buf.buffer, offset)[0]) + u16(valueFromBytes(workingDataU16, buf, offset)) ); export const TypeF16 = new ScalarType('f16', 2, (buf: Uint8Array, offset: number) => - f16Bits(new Uint16Array(buf.buffer, offset)[0]) + f16Bits(valueFromBytes(workingDataU16, buf, offset)) ); export const TypeI8 = new ScalarType('i8', 1, (buf: Uint8Array, offset: number) => - i8(new Int8Array(buf.buffer, offset)[0]) + i8(valueFromBytes(workingDataI8, buf, offset)) ); export const TypeU8 = new ScalarType('u8', 1, (buf: Uint8Array, offset: number) => - u8(new Uint8Array(buf.buffer, offset)[0]) + u8(valueFromBytes(workingDataU8, buf, offset)) ); export const TypeBool = new ScalarType('bool', 4, (buf: Uint8Array, offset: number) => - bool(new Uint32Array(buf.buffer, offset)[0] !== 0) + bool(valueFromBytes(workingDataU32, buf, offset) !== 0) ); /** @returns the ScalarType from the ScalarKind */ @@ -877,12 +873,17 @@ type ScalarValue = boolean | number; export class Scalar { readonly value: ScalarValue; // The scalar value readonly type: ScalarType; // The type of the scalar - readonly bits: Uint8Array; // The scalar value packed in a Uint8Array - public constructor(type: ScalarType, value: ScalarValue, bits: TypedArrayBufferView) { + // The scalar value, packed in one or two 32-bit unsigned integers. + // Whether or not the bits1 is used depends on `this.type.size`. + readonly bits1: number; + readonly bits0: number; + + public constructor(type: ScalarType, value: ScalarValue, bits1: number, bits0: number) { this.value = value; this.type = type; - this.bits = new Uint8Array(bits.buffer); + this.bits1 = bits1; + this.bits0 = bits0; } /** @@ -892,8 +893,10 @@ export class Scalar { */ public copyTo(buffer: Uint8Array, offset: number) { assert(this.type.kind !== 'f64', `Copying f64 values to/from buffers is not defined`); - for (let i = 0; i < this.bits.length; i++) { - buffer[offset + i] = this.bits[i]; + workingDataU32[1] = this.bits1; + workingDataU32[0] = this.bits0; + for (let i = 0; i < this.type.size; i++) { + buffer[offset + i] = workingDataU8[i]; } } @@ -937,11 +940,12 @@ export class Scalar { case -Infinity: return Colors.bold(this.value.toString()); default: { - // Uint8Array.map returns a Uint8Array, so cannot use .map directly - const hex = Array.from(this.bits) - .reverse() - .map(x => x.toString(16).padStart(2, '0')) - .join(''); + workingDataU32[1] = this.bits1; + workingDataU32[0] = this.bits0; + let hex = ''; + for (let i = 0; i < this.type.size; ++i) { + hex = workingDataU8[i].toString(16).padStart(2, '0') + hex; + } const n = this.value as Number; if (n !== null && isFloatValue(this)) { let str = this.value.toString(); @@ -979,108 +983,109 @@ export interface ScalarBuilder { (value: number): Scalar; } -/** Create an AbstractFloat from a numeric value, a JS `number`. */ -export function abstractFloat(value: number): Scalar { - const arr = new Float64Array([value]); - return new Scalar(TypeAbstractFloat, arr[0], arr); +/** Create a Scalar of `type` by storing `value` as an element of `workingDataArray` and retrieving it. + * The working data array *must* be an alias of `workingData`. + */ +function scalarFromValue( + type: ScalarType, + workingDataArray: TypedArrayBufferView, + value: number +): Scalar { + // Clear all bits of the working data since `value` may be smaller; the upper bits should be 0. + workingDataU32[1] = 0; + workingDataU32[0] = 0; + workingDataArray[0] = value; + return new Scalar(type, workingDataArray[0], workingDataU32[1], workingDataU32[0]); +} + +/** Create a Scalar of `type` by storing `value` as an element of `workingDataStoreArray` and + * reinterpreting it as an element of `workingDataLoadArray`. + * Both working data arrays *must* be aliases of `workingData`. + */ +function scalarFromBits( + type: ScalarType, + workingDataStoreArray: TypedArrayBufferView, + workingDataLoadArray: TypedArrayBufferView, + bits: number +): Scalar { + // Clear all bits of the working data since `value` may be smaller; the upper bits should be 0. + workingDataU32[1] = 0; + workingDataU32[0] = 0; + workingDataStoreArray[0] = bits; + return new Scalar(type, workingDataLoadArray[0], workingDataU32[1], workingDataU32[0]); } + +/** Create an AbstractFloat from a numeric value, a JS `number`. */ +export const abstractFloat = (value: number): Scalar => + scalarFromValue(TypeAbstractFloat, workingDataF64, value); + /** Create an f64 from a numeric value, a JS `number`. */ -export function f64(value: number): Scalar { - const arr = new Float64Array([value]); - return new Scalar(TypeF64, arr[0], arr); -} +export const f64 = (value: number): Scalar => scalarFromValue(TypeF64, workingDataF64, value); + /** Create an f32 from a numeric value, a JS `number`. */ -export function f32(value: number): Scalar { - const arr = new Float32Array([value]); - return new Scalar(TypeF32, arr[0], arr); -} +export const f32 = (value: number): Scalar => scalarFromValue(TypeF32, workingDataF32, value); + /** Create an f16 from a numeric value, a JS `number`. */ -export function f16(value: number): Scalar { - const arr = new Float16Array([value]); - return new Scalar(TypeF16, arr[0], arr); -} +export const f16 = (value: number): Scalar => scalarFromValue(TypeF16, workingDataF16, value); + /** Create an f32 from a bit representation, a uint32 represented as a JS `number`. */ -export function f32Bits(bits: number): Scalar { - const arr = new Uint32Array([bits]); - return new Scalar(TypeF32, new Float32Array(arr.buffer)[0], arr); -} +export const f32Bits = (bits: number): Scalar => + scalarFromBits(TypeF32, workingDataU32, workingDataF32, bits); + /** Create an f16 from a bit representation, a uint16 represented as a JS `number`. */ -export function f16Bits(bits: number): Scalar { - const arr = new Uint16Array([bits]); - return new Scalar(TypeF16, new Float16Array(arr.buffer)[0], arr); -} +export const f16Bits = (bits: number): Scalar => + scalarFromBits(TypeF16, workingDataU16, workingDataF16, bits); /** Create an i32 from a numeric value, a JS `number`. */ -export function i32(value: number): Scalar { - const arr = new Int32Array([value]); - return new Scalar(TypeI32, arr[0], arr); -} +export const i32 = (value: number): Scalar => scalarFromValue(TypeI32, workingDataI32, value); + /** Create an i16 from a numeric value, a JS `number`. */ -export function i16(value: number): Scalar { - const arr = new Int16Array([value]); - return new Scalar(TypeI16, arr[0], arr); -} +export const i16 = (value: number): Scalar => scalarFromValue(TypeI16, workingDataI16, value); + /** Create an i8 from a numeric value, a JS `number`. */ -export function i8(value: number): Scalar { - const arr = new Int8Array([value]); - return new Scalar(TypeI8, arr[0], arr); -} +export const i8 = (value: number): Scalar => scalarFromValue(TypeI8, workingDataI8, value); /** Create an i32 from a bit representation, a uint32 represented as a JS `number`. */ -export function i32Bits(bits: number): Scalar { - const arr = new Uint32Array([bits]); - return new Scalar(TypeI32, new Int32Array(arr.buffer)[0], arr); -} +export const i32Bits = (bits: number): Scalar => + scalarFromBits(TypeI32, workingDataU32, workingDataI32, bits); + /** Create an i16 from a bit representation, a uint16 represented as a JS `number`. */ -export function i16Bits(bits: number): Scalar { - const arr = new Uint16Array([bits]); - return new Scalar(TypeI16, new Int16Array(arr.buffer)[0], arr); -} +export const i16Bits = (bits: number): Scalar => + scalarFromBits(TypeI16, workingDataU16, workingDataI16, bits); + /** Create an i8 from a bit representation, a uint8 represented as a JS `number`. */ -export function i8Bits(bits: number): Scalar { - const arr = new Uint8Array([bits]); - return new Scalar(TypeI8, new Int8Array(arr.buffer)[0], arr); -} +export const i8Bits = (bits: number): Scalar => + scalarFromBits(TypeI8, workingDataU8, workingDataI8, bits); /** Create a u32 from a numeric value, a JS `number`. */ -export function u32(value: number): Scalar { - const arr = new Uint32Array([value]); - return new Scalar(TypeU32, arr[0], arr); -} +export const u32 = (value: number): Scalar => scalarFromValue(TypeU32, workingDataU32, value); + /** Create a u16 from a numeric value, a JS `number`. */ -export function u16(value: number): Scalar { - const arr = new Uint16Array([value]); - return new Scalar(TypeU16, arr[0], arr); -} +export const u16 = (value: number): Scalar => scalarFromValue(TypeU16, workingDataU16, value); + /** Create a u8 from a numeric value, a JS `number`. */ -export function u8(value: number): Scalar { - const arr = new Uint8Array([value]); - return new Scalar(TypeU8, arr[0], arr); -} +export const u8 = (value: number): Scalar => scalarFromValue(TypeU8, workingDataU8, value); /** Create an u32 from a bit representation, a uint32 represented as a JS `number`. */ -export function u32Bits(bits: number): Scalar { - const arr = new Uint32Array([bits]); - return new Scalar(TypeU32, bits, arr); -} +export const u32Bits = (bits: number): Scalar => + scalarFromBits(TypeU32, workingDataU32, workingDataU32, bits); + /** Create an u16 from a bit representation, a uint16 represented as a JS `number`. */ -export function u16Bits(bits: number): Scalar { - const arr = new Uint16Array([bits]); - return new Scalar(TypeU16, bits, arr); -} +export const u16Bits = (bits: number): Scalar => + scalarFromBits(TypeU16, workingDataU16, workingDataU16, bits); + /** Create an u8 from a bit representation, a uint8 represented as a JS `number`. */ -export function u8Bits(bits: number): Scalar { - const arr = new Uint8Array([bits]); - return new Scalar(TypeU8, bits, arr); -} +export const u8Bits = (bits: number): Scalar => + scalarFromBits(TypeU8, workingDataU8, workingDataU8, bits); /** Create a boolean value. */ export function bool(value: boolean): Scalar { // WGSL does not support using 'bool' types directly in storage / uniform // buffers, so instead we pack booleans in a u32, where 'false' is zero and // 'true' is any non-zero value. - const arr = new Uint32Array([value ? 1 : 0]); - return new Scalar(TypeBool, value, arr); + workingDataU32[0] = value ? 1 : 0; + workingDataU32[1] = 0; + return new Scalar(TypeBool, value, workingDataU32[1], workingDataU32[0]); } /** A 'true' literal value */ @@ -1091,18 +1096,15 @@ export const False = bool(false); // Encoding to u32s, instead of BigInt, for serialization export function reinterpretF64AsU32s(f64: number): [number, number] { - const array = new Float64Array(1); - array[0] = f64; - const u32s = new Uint32Array(array.buffer); - return [u32s[0], u32s[1]]; + workingDataF64[0] = f64; + return [workingDataU32[0], workingDataU32[1]]; } // De-encoding from u32s, instead of BigInt, for serialization export function reinterpretU32sAsF64(u32s: [number, number]): number { - const array = new Uint32Array(2); - array[0] = u32s[0]; - array[1] = u32s[1]; - return new Float64Array(array.buffer)[0]; + workingDataU32[0] = u32s[0]; + workingDataU32[1] = u32s[1]; + return workingDataF64[0]; } /** @@ -1110,9 +1112,8 @@ export function reinterpretU32sAsF64(u32s: [number, number]): number { * of the bits of a number assumed to be an f32 value. */ export function reinterpretF32AsU32(f32: number): number { - const array = new Float32Array(1); - array[0] = f32; - return new Uint32Array(array.buffer)[0]; + workingDataF32[0] = f32; + return workingDataU32[0]; } /** @@ -1120,9 +1121,8 @@ export function reinterpretF32AsU32(f32: number): number { * of the bits of a number assumed to be an f32 value. */ export function reinterpretF32AsI32(f32: number): number { - const array = new Float32Array(1); - array[0] = f32; - return new Int32Array(array.buffer)[0]; + workingDataF32[0] = f32; + return workingDataI32[0]; } /** @@ -1130,9 +1130,8 @@ export function reinterpretF32AsI32(f32: number): number { * of the bits of a number assumed to be an u32 value. */ export function reinterpretU32AsF32(u32: number): number { - const array = new Uint32Array(1); - array[0] = u32; - return new Float32Array(array.buffer)[0]; + workingDataU32[0] = u32; + return workingDataF32[0]; } /** @@ -1140,9 +1139,8 @@ export function reinterpretU32AsF32(u32: number): number { * of the bits of a number assumed to be an u32 value. */ export function reinterpretU32AsI32(u32: number): number { - const array = new Uint32Array(1); - array[0] = u32; - return new Int32Array(array.buffer)[0]; + workingDataU32[0] = u32; + return workingDataI32[0]; } /** @@ -1150,9 +1148,8 @@ export function reinterpretU32AsI32(u32: number): number { * of the bits of a number assumed to be an i32 value. */ export function reinterpretI32AsU32(i32: number): number { - const array = new Int32Array(1); - array[0] = i32; - return new Uint32Array(array.buffer)[0]; + workingDataI32[0] = i32; + return workingDataU32[0]; } /** @@ -1160,9 +1157,8 @@ export function reinterpretI32AsU32(i32: number): number { * of the bits of a number assumed to be an i32 value. */ export function reinterpretI32AsF32(i32: number): number { - const array = new Int32Array(1); - array[0] = i32; - return new Float32Array(array.buffer)[0]; + workingDataI32[0] = i32; + return workingDataF32[0]; } /** @@ -1170,9 +1166,8 @@ export function reinterpretI32AsF32(i32: number): number { * of the bits of a number assumed to be an f16 value. */ export function reinterpretF16AsU16(f16: number): number { - const array = new Float16Array(1); - array[0] = f16; - return new Uint16Array(array.buffer)[0]; + workingDataF16[0] = f16; + return workingDataU16[0]; } /** @@ -1180,9 +1175,8 @@ export function reinterpretF16AsU16(f16: number): number { * of the bits of a number assumed to be an u16 value. */ export function reinterpretU16AsF16(u16: number): number { - const array = new Uint16Array(1); - array[0] = u16; - return new Float16Array(array.buffer)[0]; + workingDataU16[0] = u16; + return workingDataF16[0]; } /** @@ -1397,9 +1391,9 @@ export function serializeValue(v: Value): SerializedValue { const value = (kind: ScalarKind, s: Scalar) => { switch (kind) { case 'f32': - return new Uint32Array(s.bits.buffer)[0]; + return s.bits0; case 'f16': - return new Uint16Array(s.bits.buffer)[0]; + return s.bits0; default: return s.value; }