diff --git a/src/unittests/serialization.spec.ts b/src/unittests/serialization.spec.ts index 76ac1f715545..7f5384ad9cb5 100644 --- a/src/unittests/serialization.spec.ts +++ b/src/unittests/serialization.spec.ts @@ -207,7 +207,7 @@ g.test('value').fn(t => { f32 ), ]) { - const s = new BinaryStream(new Uint8Array(1024)); + const s = new BinaryStream(new Uint8Array(1024).buffer); serializeValue(s, value); const d = new BinaryStream(s.buffer()); const deserialized = deserializeValue(d); @@ -244,7 +244,7 @@ g.test('fpinterval_f32').fn(t => { FP.f32.toInterval([kValue.f32.negative.subnormal.min, kValue.f32.negative.subnormal.max]), FP.f32.toInterval([kValue.f32.negative.infinity, kValue.f32.positive.infinity]), ]) { - const s = new BinaryStream(new Uint8Array(1024)); + const s = new BinaryStream(new Uint8Array(1024).buffer); serializeFPInterval(s, interval); const d = new BinaryStream(s.buffer()); const deserialized = deserializeFPInterval(d); @@ -280,7 +280,7 @@ g.test('fpinterval_f16').fn(t => { FP.f16.toInterval([kValue.f16.negative.subnormal.min, kValue.f16.negative.subnormal.max]), FP.f16.toInterval([kValue.f16.negative.infinity, kValue.f16.positive.infinity]), ]) { - const s = new BinaryStream(new Uint8Array(1024)); + const s = new BinaryStream(new Uint8Array(1024).buffer); serializeFPInterval(s, interval); const d = new BinaryStream(s.buffer()); const deserialized = deserializeFPInterval(d); @@ -316,7 +316,7 @@ g.test('fpinterval_abstract').fn(t => { FP.abstract.toInterval([kValue.f64.negative.subnormal.min, kValue.f64.negative.subnormal.max]), FP.abstract.toInterval([kValue.f64.negative.infinity, kValue.f64.positive.infinity]), ]) { - const s = new BinaryStream(new Uint8Array(1024)); + const s = new BinaryStream(new Uint8Array(1024).buffer); serializeFPInterval(s, interval); const d = new BinaryStream(s.buffer()); const deserialized = deserializeFPInterval(d); @@ -338,7 +338,7 @@ g.test('expression_expectation').fn(t => { // Intervals [FP.f32.toInterval([-8.0, 0.5]), FP.f32.toInterval([2.0, 4.0])], ]) { - const s = new BinaryStream(new Uint8Array(1024)); + const s = new BinaryStream(new Uint8Array(1024).buffer); serializeExpectation(s, expectation); const d = new BinaryStream(s.buffer()); const deserialized = deserializeExpectation(d); @@ -368,7 +368,7 @@ g.test('anyOf').fn(t => { testCases: [f32(0), f32(10), f32(122), f32(123), f32(124), f32(200)], }, ]) { - const s = new BinaryStream(new Uint8Array(1024)); + const s = new BinaryStream(new Uint8Array(1024).buffer); serializeComparator(s, c.comparator); const d = new BinaryStream(s.buffer()); const deserialized = deserializeComparator(d); @@ -396,7 +396,7 @@ g.test('skipUndefined').fn(t => { testCases: [f32(0), f32(10), f32(122), f32(123), f32(124), f32(200)], }, ]) { - const s = new BinaryStream(new Uint8Array(1024)); + const s = new BinaryStream(new Uint8Array(1024).buffer); serializeComparator(s, c.comparator); const d = new BinaryStream(s.buffer()); const deserialized = deserializeComparator(d); diff --git a/src/webgpu/shader/execution/expression/case_cache.ts b/src/webgpu/shader/execution/expression/case_cache.ts index 88f4a48df4c8..daee31993161 100644 --- a/src/webgpu/shader/execution/expression/case_cache.ts +++ b/src/webgpu/shader/execution/expression/case_cache.ts @@ -166,21 +166,21 @@ export class CaseCache implements Cacheable> { */ serialize(data: Record): Uint8Array { const maxSize = 32 << 20; // 32MB - max size for a file - const s = new BinaryStream(new Uint8Array(maxSize)); + const s = new BinaryStream(new Uint8Array(maxSize).buffer); s.writeU32(Object.keys(data).length); for (const name in data) { s.writeString(name); s.writeArray(data[name], serializeCase); } - return s.buffer(); + return new Uint8Array(s.buffer()); } /** * deserialize() implements the Cacheable.deserialize interface. * @returns the deserialize data. */ - deserialize(buffer: Uint8Array): Record { - const s = new BinaryStream(buffer); + deserialize(array: Uint8Array): Record { + const s = new BinaryStream(array.buffer); const casesByName: Record = {}; const numRecords = s.readU32(); for (let i = 0; i < numRecords; i++) { diff --git a/src/webgpu/util/binary_stream.ts b/src/webgpu/util/binary_stream.ts index 575973afbe0b..4941b9a4dab9 100644 --- a/src/webgpu/util/binary_stream.ts +++ b/src/webgpu/util/binary_stream.ts @@ -1,6 +1,6 @@ import { assert } from '../../common/util/util.js'; -import { Float16Array } from '../../external/petamoriken/float16/float16.js'; +import { float16ToUint16, uint16ToFloat16 } from './conversion.js'; import { align } from './math.js'; /** @@ -13,178 +13,116 @@ export default class BinaryStream { * Constructor * @param buffer the buffer to read from / write to. Array length must be a multiple of 8 bytes. */ - constructor(buffer: Uint8Array) { + constructor(buffer: ArrayBufferLike) { this.offset = 0; - this.u8 = buffer; - this.u16 = new Uint16Array(this.u8.buffer); - this.u32 = new Uint32Array(this.u8.buffer); - this.i8 = new Int8Array(this.u8.buffer); - this.i16 = new Int16Array(this.u8.buffer); - this.i32 = new Int32Array(this.u8.buffer); - this.f16 = new Float16Array(this.u8.buffer); - this.f32 = new Float32Array(this.u8.buffer); - this.f64 = new Float64Array(this.u8.buffer); + this.view = new DataView(buffer); } /** buffer() returns the stream's buffer sliced to the 8-byte rounded read or write offset */ - buffer(): Uint8Array { - return this.u8.slice(0, align(this.offset, 8)); + buffer(): ArrayBufferLike { + return new Uint8Array(this.view.buffer, align(this.offset, 8)).buffer; } /** writeBool() writes a boolean as 255 or 0 to the buffer at the next byte offset */ writeBool(value: boolean) { - this.u8[this.offset++] = value ? 255 : 0; + this.view.setUint8(this.offset++, value ? 255 : 0); } /** readBool() reads a boolean from the buffer at the next byte offset */ readBool(): boolean { - const val = this.u8[this.offset++]; + const val = this.view.getUint8(this.offset++); assert(val === 0 || val === 255); return val !== 0; } /** writeU8() writes a uint8 to the buffer at the next byte offset */ writeU8(value: number) { - this.u8[this.offset++] = value; + this.view.setUint8(this.offset++, value); } /** readU8() reads a uint8 from the buffer at the next byte offset */ readU8(): number { - return this.u8[this.offset++]; - } - - /** u8View() returns a Uint8Array view of the uint8 at the next byte offset */ - u8View(): Uint8Array { - const at = this.offset++; - return new Uint8Array(this.u8.buffer, at, 1); + return this.view.getUint8(this.offset++); } /** writeU16() writes a uint16 to the buffer at the next 16-bit aligned offset */ writeU16(value: number) { - this.u16[this.bumpWord(2)] = value; + this.view.setUint16(this.alignedOffset(2), value, /* littleEndian */ true); } /** readU16() reads a uint16 from the buffer at the next 16-bit aligned offset */ readU16(): number { - return this.u16[this.bumpWord(2)]; - } - - /** u16View() returns a Uint16Array view of the uint16 at the next 16-bit aligned offset */ - u16View(): Uint16Array { - const at = this.bumpWord(2); - return new Uint16Array(this.u16.buffer, at * 2, 1); + return this.view.getUint16(this.alignedOffset(2), /* littleEndian */ true); } /** writeU32() writes a uint32 to the buffer at the next 32-bit aligned offset */ writeU32(value: number) { - this.u32[this.bumpWord(4)] = value; + this.view.setUint32(this.alignedOffset(4), value, /* littleEndian */ true); } /** readU32() reads a uint32 from the buffer at the next 32-bit aligned offset */ readU32(): number { - return this.u32[this.bumpWord(4)]; - } - - /** u32View() returns a Uint32Array view of the uint32 at the next 32-bit aligned offset */ - u32View(): Uint32Array { - const at = this.bumpWord(4); - return new Uint32Array(this.u32.buffer, at * 4, 1); + return this.view.getUint32(this.alignedOffset(4), /* littleEndian */ true); } /** writeI8() writes a int8 to the buffer at the next byte offset */ writeI8(value: number) { - this.i8[this.offset++] = value; + this.view.setInt8(this.offset++, value); } /** readI8() reads a int8 from the buffer at the next byte offset */ readI8(): number { - return this.i8[this.offset++]; - } - - /** i8View() returns a Uint8Array view of the uint8 at the next byte offset */ - i8View(): Int8Array { - const at = this.offset++; - return new Int8Array(this.i8.buffer, at, 1); + return this.view.getInt8(this.offset++); } /** writeI16() writes a int16 to the buffer at the next 16-bit aligned offset */ writeI16(value: number) { - this.i16[this.bumpWord(2)] = value; + this.view.setInt16(this.alignedOffset(2), value, /* littleEndian */ true); } /** readI16() reads a int16 from the buffer at the next 16-bit aligned offset */ readI16(): number { - return this.i16[this.bumpWord(2)]; - } - - /** i16View() returns a Int16Array view of the uint16 at the next 16-bit aligned offset */ - i16View(): Int16Array { - const at = this.bumpWord(2); - return new Int16Array(this.i16.buffer, at * 2, 1); + return this.view.getInt16(this.alignedOffset(2), /* littleEndian */ true); } /** writeI32() writes a int32 to the buffer at the next 32-bit aligned offset */ writeI32(value: number) { - this.i32[this.bumpWord(4)] = value; + this.view.setInt32(this.alignedOffset(4), value, /* littleEndian */ true); } /** readI32() reads a int32 from the buffer at the next 32-bit aligned offset */ readI32(): number { - return this.i32[this.bumpWord(4)]; - } - - /** i32View() returns a Int32Array view of the uint32 at the next 32-bit aligned offset */ - i32View(): Int32Array { - const at = this.bumpWord(4); - return new Int32Array(this.i32.buffer, at * 4, 1); + return this.view.getInt32(this.alignedOffset(4), /* littleEndian */ true); } /** writeF16() writes a float16 to the buffer at the next 16-bit aligned offset */ writeF16(value: number) { - this.f16[this.bumpWord(2)] = value; + this.writeU16(float16ToUint16(value)); } /** readF16() reads a float16 from the buffer at the next 16-bit aligned offset */ readF16(): number { - return this.f16[this.bumpWord(2)]; - } - - /** f16View() returns a Float16Array view of the uint16 at the next 16-bit aligned offset */ - f16View(): Float16Array { - const at = this.bumpWord(2); - return new Float16Array(this.f16.buffer, at * 2, 1); + return uint16ToFloat16(this.readU16()); } /** writeF32() writes a float32 to the buffer at the next 32-bit aligned offset */ writeF32(value: number) { - this.f32[this.bumpWord(4)] = value; + this.view.setFloat32(this.alignedOffset(4), value, /* littleEndian */ true); } /** readF32() reads a float32 from the buffer at the next 32-bit aligned offset */ readF32(): number { - return this.f32[this.bumpWord(4)]; - } - - /** f32View() returns a Float32Array view of the uint32 at the next 32-bit aligned offset */ - f32View(): Float32Array { - const at = this.bumpWord(4); - return new Float32Array(this.f32.buffer, at * 4, 1); + return this.view.getFloat32(this.alignedOffset(4), /* littleEndian */ true); } /** writeF64() writes a float64 to the buffer at the next 64-bit aligned offset */ writeF64(value: number) { - this.f64[this.bumpWord(8)] = value; + this.view.setFloat64(this.alignedOffset(8), value); } /** readF64() reads a float64 from the buffer at the next 64-bit aligned offset */ readF64(): number { - return this.f64[this.bumpWord(8)]; - } - - /** f64View() returns a Float64Array view of the uint64 at the next 64-bit aligned offset */ - f64View(): Float64Array { - const at = this.bumpWord(8); - return new Float64Array(this.f64.buffer, at * 8, 1); + return this.view.getFloat64(this.alignedOffset(8)); } /** @@ -261,23 +199,15 @@ export default class BinaryStream { } /** - * bumpWord() increments this.offset by `bytes`, after first aligning this.offset to `bytes`. - * @returns the old offset aligned to the next multiple of `bytes`, divided by `bytes`. + * alignedOffset() aligns this.offset to `bytes`, then increments this.offset by `bytes`. + * @returns the old offset aligned to the next multiple of `bytes`. */ - private bumpWord(bytes: number) { - const multiple = Math.floor((this.offset + bytes - 1) / bytes); - this.offset = (multiple + 1) * bytes; - return multiple; + private alignedOffset(bytes: number) { + const aligned = align(this.offset, bytes); + this.offset = aligned + bytes; + return aligned; } private offset: number; - private u8: Uint8Array; - private u16: Uint16Array; - private u32: Uint32Array; - private i8: Int8Array; - private i16: Int16Array; - private i32: Int32Array; - private f16: Float16Array; - private f32: Float32Array; - private f64: Float64Array; + private view: DataView; } diff --git a/src/webgpu/util/floating_point.ts b/src/webgpu/util/floating_point.ts index f9b9d2ca44ce..82d6403ce3ae 100644 --- a/src/webgpu/util/floating_point.ts +++ b/src/webgpu/util/floating_point.ts @@ -107,7 +107,8 @@ export class FPInterval { public constructor(kind: FPKind, ...bounds: IntervalBounds) { this.kind = kind; - const [begin, end] = bounds.length === 2 ? bounds : [bounds[0], bounds[0]]; + const begin = bounds[0]; + const end = bounds.length === 2 ? bounds[1] : bounds[0]; assert(!Number.isNaN(begin) && !Number.isNaN(end), `bounds need to be non-NaN`); assert(begin <= end, `bounds[0] (${begin}) must be less than or equal to bounds[1] (${end})`); @@ -206,11 +207,11 @@ export function deserializeFPInterval(s: BinaryStream): FPInterval { // Bounded switch (kind) { case 'abstract': - return traits.toInterval([s.readF64(), s.readF64()]); + return new FPInterval(traits.kind, s.readF64(), s.readF64()); case 'f32': - return traits.toInterval([s.readF32(), s.readF32()]); + return new FPInterval(traits.kind, s.readF32(), s.readF32()); case 'f16': - return traits.toInterval([s.readF16(), s.readF16()]); + return new FPInterval(traits.kind, s.readF16(), s.readF16()); } unreachable(`Unable to deserialize FPInterval with kind ${kind}`); },