Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace JSON case cache serialization with binary files #3094

Merged
merged 1 commit into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 98 additions & 21 deletions src/common/framework/data_cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,64 @@
* expensive to build using a two-level cache (in-memory, pre-computed file).
*/

import { assert } from '../util/util.js';

interface DataStore {
load(path: string): Promise<string>;
load(path: string): Promise<Uint8Array>;
}

/** Logger is a basic debug logger function */
export type Logger = (s: string) => void;

/** DataCache is an interface to a data store used to hold cached data */
/**
* DataCacheNode represents a single cache entry in the LRU DataCache.
* DataCacheNode is a doubly linked list, so that least-recently-used entries can be removed, and
* cache hits can move the node to the front of the list.
*/
class DataCacheNode {
public constructor(path: string, data: unknown) {
this.path = path;
this.data = data;
}

/** insertAfter() re-inserts this node in the doubly-linked list after @p prev */
public insertAfter(prev: DataCacheNode) {
this.unlink();
this.next = prev.next;
this.prev = prev;
prev.next = this;
if (this.next) {
this.next.prev = this;
}
}

/** unlink() removes this node from the doubly-linked list */
public unlink() {
const prev = this.prev;
const next = this.next;
if (prev) {
prev.next = next;
}
if (next) {
next.prev = prev;
}
this.prev = null;
this.next = null;
}

public readonly path: string; // The file path this node represents
public readonly data: unknown; // The deserialized data for this node
public prev: DataCacheNode | null = null; // The previous node in the doubly-linked list
public next: DataCacheNode | null = null; // The next node in the doubly-linked list
}

/** DataCache is an interface to a LRU-cached data store used to hold data cached by path */
export class DataCache {
public constructor() {
this.lruHeadNode.next = this.lruTailNode;
this.lruTailNode.prev = this.lruHeadNode;
}

/** setDataStore() sets the backing data store used by the data cache */
public setStore(dataStore: DataStore) {
this.dataStore = dataStore;
Expand All @@ -28,17 +77,20 @@ export class DataCache {
* building the data and storing it in the cache.
*/
public async fetch<Data>(cacheable: Cacheable<Data>): Promise<Data> {
// First check the in-memory cache
let data = this.cache.get(cacheable.path);
if (data !== undefined) {
this.log('in-memory cache hit');
return Promise.resolve(data as Data);
{
// First check the in-memory cache
const node = this.cache.get(cacheable.path);
if (node !== undefined) {
this.log('in-memory cache hit');
node.insertAfter(this.lruHeadNode);
return Promise.resolve(node.data as Data);
}
}
this.log('in-memory cache miss');
// In in-memory cache miss.
// Next, try the data store.
if (this.dataStore !== null && !this.unavailableFiles.has(cacheable.path)) {
let serialized: string | undefined;
let serialized: Uint8Array | undefined;
try {
serialized = await this.dataStore.load(cacheable.path);
this.log('loaded serialized');
Expand All @@ -49,16 +101,37 @@ export class DataCache {
}
if (serialized !== undefined) {
this.log(`deserializing`);
data = cacheable.deserialize(serialized);
this.cache.set(cacheable.path, data);
return data as Data;
const data = cacheable.deserialize(serialized);
this.addToCache(cacheable.path, data);
return data;
}
}
// Not found anywhere. Build the data, and cache for future lookup.
this.log(`cache: building (${cacheable.path})`);
data = await cacheable.build();
this.cache.set(cacheable.path, data);
return data as Data;
const data = await cacheable.build();
this.addToCache(cacheable.path, data);
return data;
}

/**
* addToCache() creates a new node for @p path and @p data, inserting the new node at the front of
* the doubly-linked list. If the number of entries in the cache exceeds this.maxCount, then the
* least recently used entry is evicted
* @param path the file path for the data
* @param data the deserialized data
*/
private addToCache(path: string, data: unknown) {
if (this.cache.size >= this.maxCount) {
const toEvict = this.lruTailNode.prev;
assert(toEvict !== null);
toEvict.unlink();
this.cache.delete(toEvict.path);
this.log(`evicting ${toEvict.path}`);
}
const node = new DataCacheNode(path, data);
node.insertAfter(this.lruHeadNode);
this.cache.set(path, node);
this.log(`added ${path}. new count: ${this.cache.size}`);
}

private log(msg: string) {
Expand All @@ -67,7 +140,12 @@ export class DataCache {
}
}

private cache = new Map<string, unknown>();
// Max number of entries in the cache before LRU entries are evicted.
private readonly maxCount = 4;

private cache = new Map<string, DataCacheNode>();
private lruHeadNode = new DataCacheNode('', null); // placeholder node (no path or data)
private lruTailNode = new DataCacheNode('', null); // placeholder node (no path or data)
private unavailableFiles = new Set<string>();
private dataStore: DataStore | null = null;
private debugLogger: Logger | null = null;
Expand Down Expand Up @@ -107,14 +185,13 @@ export interface Cacheable<Data> {
build(): Promise<Data>;

/**
* serialize() transforms `data` to a string (usually JSON encoded) so that it
* can be stored in a text cache file.
* serialize() encodes `data` to a binary representation so that it can be stored in a cache file.
*/
serialize(data: Data): string;
serialize(data: Data): Uint8Array;

/**
* deserialize() is the inverse of serialize(), transforming the string back
* to the Data object.
* deserialize() is the inverse of serialize(), decoding the binary representation back to a Data
* object.
*/
deserialize(serialized: string): Data;
deserialize(binary: Uint8Array): Data;
}
4 changes: 2 additions & 2 deletions src/common/runtime/cmdline.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ Did you remember to build with code coverage instrumentation enabled?`
if (dataPath !== undefined) {
dataCache.setStore({
load: (path: string) => {
return new Promise<string>((resolve, reject) => {
fs.readFile(`${dataPath}/${path}`, 'utf8', (err, data) => {
return new Promise<Uint8Array>((resolve, reject) => {
fs.readFile(`${dataPath}/${path}`, (err, data) => {
if (err !== null) {
reject(err.message);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/common/runtime/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ Did you remember to build with code coverage instrumentation enabled?`
if (dataPath !== undefined) {
dataCache.setStore({
load: (path: string) => {
return new Promise<string>((resolve, reject) => {
fs.readFile(`${dataPath}/${path}`, 'utf8', (err, data) => {
return new Promise<Uint8Array>((resolve, reject) => {
fs.readFile(`${dataPath}/${path}`, (err, data) => {
if (err !== null) {
reject(err.message);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/common/runtime/standalone.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ dataCache.setStore({
if (!response.ok) {
return Promise.reject(response.statusText);
}
return await response.text();
return new Uint8Array(await response.arrayBuffer());
ben-clayton marked this conversation as resolved.
Show resolved Hide resolved
},
});

Expand Down
6 changes: 3 additions & 3 deletions src/common/tools/gen_cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ const outRootDir = nonFlagsArgs[2];

dataCache.setStore({
load: (path: string) => {
return new Promise<string>((resolve, reject) => {
fs.readFile(`data/${path}`, 'utf8', (err, data) => {
return new Promise<Uint8Array>((resolve, reject) => {
fs.readFile(`data/${path}`, (err, data) => {
if (err !== null) {
reject(err.message);
} else {
Expand Down Expand Up @@ -180,7 +180,7 @@ and
const data = await cacheable.build();
const serialized = cacheable.serialize(data);
fs.mkdirSync(path.dirname(outPath), { recursive: true });
fs.writeFileSync(outPath, serialized);
fs.writeFileSync(outPath, serialized, 'binary');
break;
}
case 'list': {
Expand Down
46 changes: 31 additions & 15 deletions src/unittests/serialization.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
deserializeExpectation,
serializeExpectation,
} from '../webgpu/shader/execution/expression/case_cache.js';
import BinaryStream from '../webgpu/util/binary_stream.js';
import {
anyOf,
deserializeComparator,
Expand Down Expand Up @@ -206,11 +207,14 @@ g.test('value').fn(t => {
f32
),
]) {
const serialized = serializeValue(value);
const deserialized = deserializeValue(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeValue(s, value);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeValue(d);
t.expect(
objectEquals(value, deserialized),
`value ${value} -> serialize -> deserialize -> ${deserialized}`
`${value.type} ${value} -> serialize -> deserialize -> ${deserialized}
buffer: ${s.buffer()}`
);
}
});
Expand Down Expand Up @@ -240,8 +244,10 @@ g.test('fpinterval_f32').fn(t => {
FP.f32.toInterval([kValue.f32.negative.subnormal.min, kValue.f32.negative.subnormal.max]),
FP.f32.toInterval([kValue.f32.negative.infinity, kValue.f32.positive.infinity]),
]) {
const serialized = serializeFPInterval(interval);
const deserialized = deserializeFPInterval(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeFPInterval(s, interval);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeFPInterval(d);
t.expect(
objectEquals(interval, deserialized),
`interval ${interval} -> serialize -> deserialize -> ${deserialized}`
Expand Down Expand Up @@ -274,8 +280,10 @@ g.test('fpinterval_f16').fn(t => {
FP.f16.toInterval([kValue.f16.negative.subnormal.min, kValue.f16.negative.subnormal.max]),
FP.f16.toInterval([kValue.f16.negative.infinity, kValue.f16.positive.infinity]),
]) {
const serialized = serializeFPInterval(interval);
const deserialized = deserializeFPInterval(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeFPInterval(s, interval);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeFPInterval(d);
t.expect(
objectEquals(interval, deserialized),
`interval ${interval} -> serialize -> deserialize -> ${deserialized}`
Expand Down Expand Up @@ -308,8 +316,10 @@ g.test('fpinterval_abstract').fn(t => {
FP.abstract.toInterval([kValue.f64.negative.subnormal.min, kValue.f64.negative.subnormal.max]),
FP.abstract.toInterval([kValue.f64.negative.infinity, kValue.f64.positive.infinity]),
]) {
const serialized = serializeFPInterval(interval);
const deserialized = deserializeFPInterval(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeFPInterval(s, interval);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeFPInterval(d);
t.expect(
objectEquals(interval, deserialized),
`interval ${interval} -> serialize -> deserialize -> ${deserialized}`
Expand All @@ -328,8 +338,10 @@ g.test('expression_expectation').fn(t => {
// Intervals
[FP.f32.toInterval([-8.0, 0.5]), FP.f32.toInterval([2.0, 4.0])],
]) {
const serialized = serializeExpectation(expectation);
const deserialized = deserializeExpectation(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeExpectation(s, expectation);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeExpectation(d);
t.expect(
objectEquals(expectation, deserialized),
`expectation ${expectation} -> serialize -> deserialize -> ${deserialized}`
Expand All @@ -356,8 +368,10 @@ g.test('anyOf').fn(t => {
testCases: [f32(0), f32(10), f32(122), f32(123), f32(124), f32(200)],
},
]) {
const serialized = serializeComparator(c.comparator);
const deserialized = deserializeComparator(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeComparator(s, c.comparator);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeComparator(d);
for (const val of c.testCases) {
const got = deserialized.compare(val);
const expect = c.comparator.compare(val);
Expand All @@ -382,8 +396,10 @@ g.test('skipUndefined').fn(t => {
testCases: [f32(0), f32(10), f32(122), f32(123), f32(124), f32(200)],
},
]) {
const serialized = serializeComparator(c.comparator);
const deserialized = deserializeComparator(serialized);
const s = new BinaryStream(new Uint8Array(1024));
serializeComparator(s, c.comparator);
const d = new BinaryStream(s.buffer());
const deserialized = deserializeComparator(d);
for (const val of c.testCases) {
const got = deserialized.compare(val);
const expect = c.comparator.compare(val);
Expand Down
Loading