Skip to content

Commit

Permalink
More cache deserialization micro-optimisations
Browse files Browse the repository at this point in the history
* Use DataView instead of a bunch of separate typed arrays.
* Avoid small allocations where it's trivial to do so.

Speeds up deserialization around ~10% based on profiling in Chrome.
  • Loading branch information
ben-clayton committed Oct 26, 2023
1 parent 250e583 commit ba9e5d6
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 117 deletions.
14 changes: 7 additions & 7 deletions src/unittests/serialization.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ g.test('value').fn(t => {
f32
),
]) {
const s = new BinaryStream(new Uint8Array(1024));
const s = new BinaryStream(new Uint8Array(1024).buffer);
serializeValue(s, value);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeValue(d);
Expand Down Expand Up @@ -244,7 +244,7 @@ g.test('fpinterval_f32').fn(t => {
FP.f32.toInterval([kValue.f32.negative.subnormal.min, kValue.f32.negative.subnormal.max]),
FP.f32.toInterval([kValue.f32.negative.infinity, kValue.f32.positive.infinity]),
]) {
const s = new BinaryStream(new Uint8Array(1024));
const s = new BinaryStream(new Uint8Array(1024).buffer);
serializeFPInterval(s, interval);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeFPInterval(d);
Expand Down Expand Up @@ -280,7 +280,7 @@ g.test('fpinterval_f16').fn(t => {
FP.f16.toInterval([kValue.f16.negative.subnormal.min, kValue.f16.negative.subnormal.max]),
FP.f16.toInterval([kValue.f16.negative.infinity, kValue.f16.positive.infinity]),
]) {
const s = new BinaryStream(new Uint8Array(1024));
const s = new BinaryStream(new Uint8Array(1024).buffer);
serializeFPInterval(s, interval);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeFPInterval(d);
Expand Down Expand Up @@ -316,7 +316,7 @@ g.test('fpinterval_abstract').fn(t => {
FP.abstract.toInterval([kValue.f64.negative.subnormal.min, kValue.f64.negative.subnormal.max]),
FP.abstract.toInterval([kValue.f64.negative.infinity, kValue.f64.positive.infinity]),
]) {
const s = new BinaryStream(new Uint8Array(1024));
const s = new BinaryStream(new Uint8Array(1024).buffer);
serializeFPInterval(s, interval);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeFPInterval(d);
Expand All @@ -338,7 +338,7 @@ g.test('expression_expectation').fn(t => {
// Intervals
[FP.f32.toInterval([-8.0, 0.5]), FP.f32.toInterval([2.0, 4.0])],
]) {
const s = new BinaryStream(new Uint8Array(1024));
const s = new BinaryStream(new Uint8Array(1024).buffer);
serializeExpectation(s, expectation);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeExpectation(d);
Expand Down Expand Up @@ -368,7 +368,7 @@ g.test('anyOf').fn(t => {
testCases: [f32(0), f32(10), f32(122), f32(123), f32(124), f32(200)],
},
]) {
const s = new BinaryStream(new Uint8Array(1024));
const s = new BinaryStream(new Uint8Array(1024).buffer);
serializeComparator(s, c.comparator);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeComparator(d);
Expand Down Expand Up @@ -396,7 +396,7 @@ g.test('skipUndefined').fn(t => {
testCases: [f32(0), f32(10), f32(122), f32(123), f32(124), f32(200)],
},
]) {
const s = new BinaryStream(new Uint8Array(1024));
const s = new BinaryStream(new Uint8Array(1024).buffer);
serializeComparator(s, c.comparator);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeComparator(d);
Expand Down
8 changes: 4 additions & 4 deletions src/webgpu/shader/execution/expression/case_cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,21 @@ export class CaseCache implements Cacheable<Record<string, CaseList>> {
*/
serialize(data: Record<string, CaseList>): Uint8Array {
const maxSize = 32 << 20; // 32MB - max size for a file
const s = new BinaryStream(new Uint8Array(maxSize));
const s = new BinaryStream(new Uint8Array(maxSize).buffer);
s.writeU32(Object.keys(data).length);
for (const name in data) {
s.writeString(name);
s.writeArray(data[name], serializeCase);
}
return s.buffer();
return new Uint8Array(s.buffer());
}

/**
* deserialize() implements the Cacheable.deserialize interface.
* @returns the deserialize data.
*/
deserialize(buffer: Uint8Array): Record<string, CaseList> {
const s = new BinaryStream(buffer);
deserialize(array: Uint8Array): Record<string, CaseList> {
const s = new BinaryStream(array.buffer);
const casesByName: Record<string, CaseList> = {};
const numRecords = s.readU32();
for (let i = 0; i < numRecords; i++) {
Expand Down
134 changes: 32 additions & 102 deletions src/webgpu/util/binary_stream.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { assert } from '../../common/util/util.js';
import { Float16Array } from '../../external/petamoriken/float16/float16.js';

import { float16ToUint16, uint16ToFloat16 } from './conversion.js';
import { align } from './math.js';

/**
Expand All @@ -13,178 +13,116 @@ export default class BinaryStream {
* Constructor
* @param buffer the buffer to read from / write to. Array length must be a multiple of 8 bytes.
*/
constructor(buffer: Uint8Array) {
constructor(buffer: ArrayBufferLike) {
this.offset = 0;
this.u8 = buffer;
this.u16 = new Uint16Array(this.u8.buffer);
this.u32 = new Uint32Array(this.u8.buffer);
this.i8 = new Int8Array(this.u8.buffer);
this.i16 = new Int16Array(this.u8.buffer);
this.i32 = new Int32Array(this.u8.buffer);
this.f16 = new Float16Array(this.u8.buffer);
this.f32 = new Float32Array(this.u8.buffer);
this.f64 = new Float64Array(this.u8.buffer);
this.view = new DataView(buffer);
}

/** buffer() returns the stream's buffer sliced to the 8-byte rounded read or write offset */
buffer(): Uint8Array {
return this.u8.slice(0, align(this.offset, 8));
buffer(): ArrayBufferLike {
return new Uint8Array(this.view.buffer, align(this.offset, 8)).buffer;
}

/** writeBool() writes a boolean as 255 or 0 to the buffer at the next byte offset */
writeBool(value: boolean) {
this.u8[this.offset++] = value ? 255 : 0;
this.view.setUint8(this.offset++, value ? 255 : 0);
}

/** readBool() reads a boolean from the buffer at the next byte offset */
readBool(): boolean {
const val = this.u8[this.offset++];
const val = this.view.getUint8(this.offset++);
assert(val === 0 || val === 255);
return val !== 0;
}

/** writeU8() writes a uint8 to the buffer at the next byte offset */
writeU8(value: number) {
this.u8[this.offset++] = value;
this.view.setUint8(this.offset++, value);
}

/** readU8() reads a uint8 from the buffer at the next byte offset */
readU8(): number {
return this.u8[this.offset++];
}

/** u8View() returns a Uint8Array view of the uint8 at the next byte offset */
u8View(): Uint8Array {
const at = this.offset++;
return new Uint8Array(this.u8.buffer, at, 1);
return this.view.getUint8(this.offset++);
}

/** writeU16() writes a uint16 to the buffer at the next 16-bit aligned offset */
writeU16(value: number) {
this.u16[this.bumpWord(2)] = value;
this.view.setUint16(this.alignedOffset(2), value, /* littleEndian */ true);
}

/** readU16() reads a uint16 from the buffer at the next 16-bit aligned offset */
readU16(): number {
return this.u16[this.bumpWord(2)];
}

/** u16View() returns a Uint16Array view of the uint16 at the next 16-bit aligned offset */
u16View(): Uint16Array {
const at = this.bumpWord(2);
return new Uint16Array(this.u16.buffer, at * 2, 1);
return this.view.getUint16(this.alignedOffset(2), /* littleEndian */ true);
}

/** writeU32() writes a uint32 to the buffer at the next 32-bit aligned offset */
writeU32(value: number) {
this.u32[this.bumpWord(4)] = value;
this.view.setUint32(this.alignedOffset(4), value, /* littleEndian */ true);
}

/** readU32() reads a uint32 from the buffer at the next 32-bit aligned offset */
readU32(): number {
return this.u32[this.bumpWord(4)];
}

/** u32View() returns a Uint32Array view of the uint32 at the next 32-bit aligned offset */
u32View(): Uint32Array {
const at = this.bumpWord(4);
return new Uint32Array(this.u32.buffer, at * 4, 1);
return this.view.getUint32(this.alignedOffset(4), /* littleEndian */ true);
}

/** writeI8() writes a int8 to the buffer at the next byte offset */
writeI8(value: number) {
this.i8[this.offset++] = value;
this.view.setInt8(this.offset++, value);
}

/** readI8() reads a int8 from the buffer at the next byte offset */
readI8(): number {
return this.i8[this.offset++];
}

/** i8View() returns a Uint8Array view of the uint8 at the next byte offset */
i8View(): Int8Array {
const at = this.offset++;
return new Int8Array(this.i8.buffer, at, 1);
return this.view.getInt8(this.offset++);
}

/** writeI16() writes a int16 to the buffer at the next 16-bit aligned offset */
writeI16(value: number) {
this.i16[this.bumpWord(2)] = value;
this.view.setInt16(this.alignedOffset(2), value, /* littleEndian */ true);
}

/** readI16() reads a int16 from the buffer at the next 16-bit aligned offset */
readI16(): number {
return this.i16[this.bumpWord(2)];
}

/** i16View() returns a Int16Array view of the uint16 at the next 16-bit aligned offset */
i16View(): Int16Array {
const at = this.bumpWord(2);
return new Int16Array(this.i16.buffer, at * 2, 1);
return this.view.getInt16(this.alignedOffset(2), /* littleEndian */ true);
}

/** writeI32() writes a int32 to the buffer at the next 32-bit aligned offset */
writeI32(value: number) {
this.i32[this.bumpWord(4)] = value;
this.view.setInt32(this.alignedOffset(4), value, /* littleEndian */ true);
}

/** readI32() reads a int32 from the buffer at the next 32-bit aligned offset */
readI32(): number {
return this.i32[this.bumpWord(4)];
}

/** i32View() returns a Int32Array view of the uint32 at the next 32-bit aligned offset */
i32View(): Int32Array {
const at = this.bumpWord(4);
return new Int32Array(this.i32.buffer, at * 4, 1);
return this.view.getInt32(this.alignedOffset(4), /* littleEndian */ true);
}

/** writeF16() writes a float16 to the buffer at the next 16-bit aligned offset */
writeF16(value: number) {
this.f16[this.bumpWord(2)] = value;
this.writeU16(float16ToUint16(value));
}

/** readF16() reads a float16 from the buffer at the next 16-bit aligned offset */
readF16(): number {
return this.f16[this.bumpWord(2)];
}

/** f16View() returns a Float16Array view of the uint16 at the next 16-bit aligned offset */
f16View(): Float16Array {
const at = this.bumpWord(2);
return new Float16Array(this.f16.buffer, at * 2, 1);
return uint16ToFloat16(this.readU16());
}

/** writeF32() writes a float32 to the buffer at the next 32-bit aligned offset */
writeF32(value: number) {
this.f32[this.bumpWord(4)] = value;
this.view.setFloat32(this.alignedOffset(4), value, /* littleEndian */ true);
}

/** readF32() reads a float32 from the buffer at the next 32-bit aligned offset */
readF32(): number {
return this.f32[this.bumpWord(4)];
}

/** f32View() returns a Float32Array view of the uint32 at the next 32-bit aligned offset */
f32View(): Float32Array {
const at = this.bumpWord(4);
return new Float32Array(this.f32.buffer, at * 4, 1);
return this.view.getFloat32(this.alignedOffset(4), /* littleEndian */ true);
}

/** writeF64() writes a float64 to the buffer at the next 64-bit aligned offset */
writeF64(value: number) {
this.f64[this.bumpWord(8)] = value;
this.view.setFloat64(this.alignedOffset(8), value);
}

/** readF64() reads a float64 from the buffer at the next 64-bit aligned offset */
readF64(): number {
return this.f64[this.bumpWord(8)];
}

/** f64View() returns a Float64Array view of the uint64 at the next 64-bit aligned offset */
f64View(): Float64Array {
const at = this.bumpWord(8);
return new Float64Array(this.f64.buffer, at * 8, 1);
return this.view.getFloat64(this.alignedOffset(8));
}

/**
Expand Down Expand Up @@ -261,23 +199,15 @@ export default class BinaryStream {
}

/**
* bumpWord() increments this.offset by `bytes`, after first aligning this.offset to `bytes`.
* @returns the old offset aligned to the next multiple of `bytes`, divided by `bytes`.
* alignedOffset() aligns this.offset to `bytes`, then increments this.offset by `bytes`.
* @returns the old offset aligned to the next multiple of `bytes`.
*/
private bumpWord(bytes: number) {
const multiple = Math.floor((this.offset + bytes - 1) / bytes);
this.offset = (multiple + 1) * bytes;
return multiple;
private alignedOffset(bytes: number) {
const aligned = align(this.offset, bytes);
this.offset = aligned + bytes;
return aligned;
}

private offset: number;
private u8: Uint8Array;
private u16: Uint16Array;
private u32: Uint32Array;
private i8: Int8Array;
private i16: Int16Array;
private i32: Int32Array;
private f16: Float16Array;
private f32: Float32Array;
private f64: Float64Array;
private view: DataView;
}
9 changes: 5 additions & 4 deletions src/webgpu/util/floating_point.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ export class FPInterval {
public constructor(kind: FPKind, ...bounds: IntervalBounds) {
this.kind = kind;

const [begin, end] = bounds.length === 2 ? bounds : [bounds[0], bounds[0]];
const begin = bounds[0];
const end = bounds.length === 2 ? bounds[1] : bounds[0];
assert(!Number.isNaN(begin) && !Number.isNaN(end), `bounds need to be non-NaN`);
assert(begin <= end, `bounds[0] (${begin}) must be less than or equal to bounds[1] (${end})`);

Expand Down Expand Up @@ -206,11 +207,11 @@ export function deserializeFPInterval(s: BinaryStream): FPInterval {
// Bounded
switch (kind) {
case 'abstract':
return traits.toInterval([s.readF64(), s.readF64()]);
return new FPInterval(traits.kind, s.readF64(), s.readF64());
case 'f32':
return traits.toInterval([s.readF32(), s.readF32()]);
return new FPInterval(traits.kind, s.readF32(), s.readF32());
case 'f16':
return traits.toInterval([s.readF16(), s.readF16()]);
return new FPInterval(traits.kind, s.readF16(), s.readF16());
}
unreachable(`Unable to deserialize FPInterval with kind ${kind}`);
},
Expand Down

0 comments on commit ba9e5d6

Please sign in to comment.