diff --git a/src/webgpu/util/math.ts b/src/webgpu/util/math.ts index 380832e1b857..018d350a984d 100644 --- a/src/webgpu/util/math.ts +++ b/src/webgpu/util/math.ts @@ -3,6 +3,7 @@ import { assert } from '../../common/util/util.js'; import { Float16Array, getFloat16, + hfround, setFloat16, } from '../../external/petamoriken/float16/float16.js'; @@ -2021,13 +2022,9 @@ export function quantizeToF32(num: number): number { return quantizeToF32Data[0]; } -/** Statically allocate working data, so it doesn't need per-call creation */ -const quantizeToF16Data = new Float16Array(new ArrayBuffer(2)); - /** @returns the closest 16-bit floating point value to the input */ export function quantizeToF16(num: number): number { - quantizeToF16Data[0] = num; - return quantizeToF16Data[0]; + return hfround(num); } /** Statically allocate working data, so it doesn't need per-call creation */