Skip to content

Commit

Permalink
Replace JSON case cache serialization with binary files
Browse files Browse the repository at this point in the history
This removes a the need to create bunch of temporary JSON objects,
reducing the amount of garbage collection we need to do.

This change also changes the DataCache to be unbounded to a 4-element
LRU cache, capping the amount of memory used.
  • Loading branch information
ben-clayton committed Oct 26, 2023
1 parent f3196f8 commit 6e21caa
Show file tree
Hide file tree
Showing 11 changed files with 865 additions and 370 deletions.
119 changes: 98 additions & 21 deletions src/common/framework/data_cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,64 @@
* expensive to build using a two-level cache (in-memory, pre-computed file).
*/

import { assert } from '../util/util.js';

interface DataStore {
load(path: string): Promise<string>;
load(path: string): Promise<Uint8Array>;
}

/** Logger is a basic debug logger function */
export type Logger = (s: string) => void;

/** DataCache is an interface to a data store used to hold cached data */
/**
* DataCacheNode represents a single cache entry in the LRU DataCache.
* DataCacheNode is a doubly linked list, so that least-recently-used entries can be removed, and
* cache hits can move the node to the front of the list.
*/
class DataCacheNode {
public constructor(path: string, data: unknown) {
this.path = path;
this.data = data;
}

/** insertAfter() re-inserts this node in the doubly-linked list after @p prev */
public insertAfter(prev: DataCacheNode) {
this.unlink();
this.next = prev.next;
this.prev = prev;
prev.next = this;
if (this.next) {
this.next.prev = this;
}
}

/** unlink() removes this node from the doubly-linked list */
public unlink() {
const prev = this.prev;
const next = this.next;
if (prev) {
prev.next = next;
}
if (next) {
next.prev = prev;
}
this.prev = null;
this.next = null;
}

public readonly path: string; // The file path this node represents
public readonly data: unknown; // The deserialized data for this node
public prev: DataCacheNode | null = null; // The previous node in the doubly-linked list
public next: DataCacheNode | null = null; // The next node in the doubly-linked list
}

/** DataCache is an interface to a LRU-cached data store used to hold data cached by path */
export class DataCache {
public constructor() {
this.lruHeadNode.next = this.lruTailNode;
this.lruTailNode.prev = this.lruHeadNode;
}

/** setDataStore() sets the backing data store used by the data cache */
public setStore(dataStore: DataStore) {
this.dataStore = dataStore;
Expand All @@ -28,17 +77,20 @@ export class DataCache {
* building the data and storing it in the cache.
*/
public async fetch<Data>(cacheable: Cacheable<Data>): Promise<Data> {
// First check the in-memory cache
let data = this.cache.get(cacheable.path);
if (data !== undefined) {
this.log('in-memory cache hit');
return Promise.resolve(data as Data);
{
// First check the in-memory cache
const node = this.cache.get(cacheable.path);
if (node !== undefined) {
this.log('in-memory cache hit');
node.insertAfter(this.lruHeadNode);
return Promise.resolve(node.data as Data);
}
}
this.log('in-memory cache miss');
// In in-memory cache miss.
// Next, try the data store.
if (this.dataStore !== null && !this.unavailableFiles.has(cacheable.path)) {
let serialized: string | undefined;
let serialized: Uint8Array | undefined;
try {
serialized = await this.dataStore.load(cacheable.path);
this.log('loaded serialized');
Expand All @@ -49,16 +101,37 @@ export class DataCache {
}
if (serialized !== undefined) {
this.log(`deserializing`);
data = cacheable.deserialize(serialized);
this.cache.set(cacheable.path, data);
return data as Data;
const data = cacheable.deserialize(serialized);
this.addToCache(cacheable.path, data);
return data;
}
}
// Not found anywhere. Build the data, and cache for future lookup.
this.log(`cache: building (${cacheable.path})`);
data = await cacheable.build();
this.cache.set(cacheable.path, data);
return data as Data;
const data = await cacheable.build();
this.addToCache(cacheable.path, data);
return data;
}

/**
* addToCache() creates a new node for @p path and @p data, inserting the new node at the front of
* the doubly-linked list. If the number of entries in the cache exceeds this.maxCount, then the
* least recently used entry is evicted
* @param path the file path for the data
* @param data the deserialized data
*/
private addToCache(path: string, data: unknown) {
if (this.cache.size >= this.maxCount) {
const toEvict = this.lruTailNode.prev;
assert(toEvict !== null);
toEvict.unlink();
this.cache.delete(toEvict.path);
this.log(`evicting ${toEvict.path}`);
}
const node = new DataCacheNode(path, data);
node.insertAfter(this.lruHeadNode);
this.cache.set(path, node);
this.log(`added ${path}. new count: ${this.cache.size}`);
}

private log(msg: string) {
Expand All @@ -67,7 +140,12 @@ export class DataCache {
}
}

private cache = new Map<string, unknown>();
// Max number of entries in the cache before LRU entries are evicted.
private readonly maxCount = 4;

private cache = new Map<string, DataCacheNode>();
private lruHeadNode = new DataCacheNode('', null); // placeholder node (no path or data)
private lruTailNode = new DataCacheNode('', null); // placeholder node (no path or data)
private unavailableFiles = new Set<string>();
private dataStore: DataStore | null = null;
private debugLogger: Logger | null = null;
Expand Down Expand Up @@ -107,14 +185,13 @@ export interface Cacheable<Data> {
build(): Promise<Data>;

/**
* serialize() transforms `data` to a string (usually JSON encoded) so that it
* can be stored in a text cache file.
* serialize() encodes `data` to a binary representation so that it can be stored in a cache file.
*/
serialize(data: Data): string;
serialize(data: Data): Uint8Array;

/**
* deserialize() is the inverse of serialize(), transforming the string back
* to the Data object.
* deserialize() is the inverse of serialize(), decoding the binary representation back to a Data
* object.
*/
deserialize(serialized: string): Data;
deserialize(binary: Uint8Array): Data;
}
4 changes: 2 additions & 2 deletions src/common/runtime/cmdline.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ Did you remember to build with code coverage instrumentation enabled?`
if (dataPath !== undefined) {
dataCache.setStore({
load: (path: string) => {
return new Promise<string>((resolve, reject) => {
fs.readFile(`${dataPath}/${path}`, 'utf8', (err, data) => {
return new Promise<Uint8Array>((resolve, reject) => {
fs.readFile(`${dataPath}/${path}`, (err, data) => {
if (err !== null) {
reject(err.message);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/common/runtime/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ Did you remember to build with code coverage instrumentation enabled?`
if (dataPath !== undefined) {
dataCache.setStore({
load: (path: string) => {
return new Promise<string>((resolve, reject) => {
fs.readFile(`${dataPath}/${path}`, 'utf8', (err, data) => {
return new Promise<Uint8Array>((resolve, reject) => {
fs.readFile(`${dataPath}/${path}`, (err, data) => {
if (err !== null) {
reject(err.message);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/common/runtime/standalone.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ dataCache.setStore({
if (!response.ok) {
return Promise.reject(response.statusText);
}
return await response.text();
return new Uint8Array(await response.arrayBuffer());
},
});

Expand Down
6 changes: 3 additions & 3 deletions src/common/tools/gen_cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ const outRootDir = nonFlagsArgs[2];

dataCache.setStore({
load: (path: string) => {
return new Promise<string>((resolve, reject) => {
fs.readFile(`data/${path}`, 'utf8', (err, data) => {
return new Promise<Uint8Array>((resolve, reject) => {
fs.readFile(`data/${path}`, (err, data) => {
if (err !== null) {
reject(err.message);
} else {
Expand Down Expand Up @@ -180,7 +180,7 @@ and
const data = await cacheable.build();
const serialized = cacheable.serialize(data);
fs.mkdirSync(path.dirname(outPath), { recursive: true });
fs.writeFileSync(outPath, serialized);
fs.writeFileSync(outPath, serialized, 'binary');
break;
}
case 'list': {
Expand Down
46 changes: 31 additions & 15 deletions src/unittests/serialization.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
deserializeExpectation,
serializeExpectation,
} from '../webgpu/shader/execution/expression/case_cache.js';
import BinaryStream from '../webgpu/util/binary_stream.js';
import {
anyOf,
deserializeComparator,
Expand Down Expand Up @@ -206,11 +207,14 @@ g.test('value').fn(t => {
f32
),
]) {
const serialized = serializeValue(value);
const deserialized = deserializeValue(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeValue(s, value);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeValue(d);
t.expect(
objectEquals(value, deserialized),
`value ${value} -> serialize -> deserialize -> ${deserialized}`
`${value.type} ${value} -> serialize -> deserialize -> ${deserialized}
buffer: ${s.buffer()}`
);
}
});
Expand Down Expand Up @@ -240,8 +244,10 @@ g.test('fpinterval_f32').fn(t => {
FP.f32.toInterval([kValue.f32.negative.subnormal.min, kValue.f32.negative.subnormal.max]),
FP.f32.toInterval([kValue.f32.negative.infinity, kValue.f32.positive.infinity]),
]) {
const serialized = serializeFPInterval(interval);
const deserialized = deserializeFPInterval(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeFPInterval(s, interval);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeFPInterval(d);
t.expect(
objectEquals(interval, deserialized),
`interval ${interval} -> serialize -> deserialize -> ${deserialized}`
Expand Down Expand Up @@ -274,8 +280,10 @@ g.test('fpinterval_f16').fn(t => {
FP.f16.toInterval([kValue.f16.negative.subnormal.min, kValue.f16.negative.subnormal.max]),
FP.f16.toInterval([kValue.f16.negative.infinity, kValue.f16.positive.infinity]),
]) {
const serialized = serializeFPInterval(interval);
const deserialized = deserializeFPInterval(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeFPInterval(s, interval);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeFPInterval(d);
t.expect(
objectEquals(interval, deserialized),
`interval ${interval} -> serialize -> deserialize -> ${deserialized}`
Expand Down Expand Up @@ -308,8 +316,10 @@ g.test('fpinterval_abstract').fn(t => {
FP.abstract.toInterval([kValue.f64.negative.subnormal.min, kValue.f64.negative.subnormal.max]),
FP.abstract.toInterval([kValue.f64.negative.infinity, kValue.f64.positive.infinity]),
]) {
const serialized = serializeFPInterval(interval);
const deserialized = deserializeFPInterval(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeFPInterval(s, interval);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeFPInterval(d);
t.expect(
objectEquals(interval, deserialized),
`interval ${interval} -> serialize -> deserialize -> ${deserialized}`
Expand All @@ -328,8 +338,10 @@ g.test('expression_expectation').fn(t => {
// Intervals
[FP.f32.toInterval([-8.0, 0.5]), FP.f32.toInterval([2.0, 4.0])],
]) {
const serialized = serializeExpectation(expectation);
const deserialized = deserializeExpectation(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeExpectation(s, expectation);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeExpectation(d);
t.expect(
objectEquals(expectation, deserialized),
`expectation ${expectation} -> serialize -> deserialize -> ${deserialized}`
Expand All @@ -356,8 +368,10 @@ g.test('anyOf').fn(t => {
testCases: [f32(0), f32(10), f32(122), f32(123), f32(124), f32(200)],
},
]) {
const serialized = serializeComparator(c.comparator);
const deserialized = deserializeComparator(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeComparator(s, c.comparator);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeComparator(d);
for (const val of c.testCases) {
const got = deserialized.compare(val);
const expect = c.comparator.compare(val);
Expand All @@ -382,8 +396,10 @@ g.test('skipUndefined').fn(t => {
testCases: [f32(0), f32(10), f32(122), f32(123), f32(124), f32(200)],
},
]) {
const serialized = serializeComparator(c.comparator);
const deserialized = deserializeComparator(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeComparator(s, c.comparator);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeComparator(d);
for (const val of c.testCases) {
const got = deserialized.compare(val);
const expect = c.comparator.compare(val);
Expand Down
Loading

0 comments on commit 6e21caa

Please sign in to comment.