-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathcase_cache.ts
201 lines (187 loc) · 6 KB
/
case_cache.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import { Cacheable, dataCache } from '../../../../common/framework/data_cache.js';
import { unreachable } from '../../../../common/util/util.js';
import BinaryStream from '../../../util/binary_stream.js';
import { deserializeComparator, serializeComparator } from '../../../util/compare.js';
import {
MatrixValue,
Value,
VectorValue,
deserializeValue,
isScalarValue,
serializeValue,
} from '../../../util/conversion.js';
import {
FPInterval,
deserializeFPInterval,
serializeFPInterval,
} from '../../../util/floating_point.js';
import { flatten2DArray, unflatten2DArray } from '../../../util/math.js';
import { Case } from './case.js';
import { Expectation, isComparator } from './expectation.js';
enum SerializedExpectationKind {
Value,
Interval,
Interval1DArray,
Interval2DArray,
Array,
Comparator,
}
/** serializeExpectation() serializes an Expectation to a BinaryStream */
export function serializeExpectation(s: BinaryStream, e: Expectation) {
if (isScalarValue(e) || e instanceof VectorValue || e instanceof MatrixValue) {
s.writeU8(SerializedExpectationKind.Value);
serializeValue(s, e);
return;
}
if (e instanceof FPInterval) {
s.writeU8(SerializedExpectationKind.Interval);
serializeFPInterval(s, e);
return;
}
if (e instanceof Array) {
if (e[0] instanceof Array) {
e = e as FPInterval[][];
const cols = e.length;
const rows = e[0].length;
s.writeU8(SerializedExpectationKind.Interval2DArray);
s.writeU16(cols);
s.writeU16(rows);
s.writeArray(flatten2DArray(e), serializeFPInterval);
} else {
e = e as FPInterval[];
s.writeU8(SerializedExpectationKind.Interval1DArray);
s.writeArray(e, serializeFPInterval);
}
return;
}
if (isComparator(e)) {
s.writeU8(SerializedExpectationKind.Comparator);
serializeComparator(s, e);
return;
}
unreachable(`cannot serialize Expectation ${e}`);
}
/** deserializeExpectation() deserializes an Expectation from a BinaryStream */
export function deserializeExpectation(s: BinaryStream): Expectation {
const kind = s.readU8();
switch (kind) {
case SerializedExpectationKind.Value: {
return deserializeValue(s);
}
case SerializedExpectationKind.Interval: {
return deserializeFPInterval(s);
}
case SerializedExpectationKind.Interval1DArray: {
return s.readArray(deserializeFPInterval);
}
case SerializedExpectationKind.Interval2DArray: {
const cols = s.readU16();
const rows = s.readU16();
return unflatten2DArray(s.readArray(deserializeFPInterval), cols, rows);
}
case SerializedExpectationKind.Comparator: {
return deserializeComparator(s);
}
default: {
unreachable(`invalid serialized expectation kind: ${kind}`);
}
}
}
/** serializeCase() serializes a Case to a BinaryStream */
export function serializeCase(s: BinaryStream, c: Case) {
s.writeCond(c.input instanceof Array, {
if_true: () => {
// c.input is array
s.writeArray(c.input as Value[], serializeValue);
},
if_false: () => {
// c.input is not array
serializeValue(s, c.input as Value);
},
});
serializeExpectation(s, c.expected);
}
/** deserializeCase() deserializes a Case from a BinaryStream */
export function deserializeCase(s: BinaryStream): Case {
const input = s.readCond({
if_true: () => {
// c.input is array
return s.readArray(deserializeValue);
},
if_false: () => {
// c.input is not array
return deserializeValue(s);
},
});
const expected = deserializeExpectation(s);
return { input, expected };
}
/** CaseListBuilder is a function that builds a list of cases, Case[] */
export type CaseListBuilder = () => Case[];
/**
* CaseCache is a cache of Case[].
* CaseCache implements the Cacheable interface, so the cases can be pre-built
* and stored in the data cache, reducing computation costs at CTS runtime.
*/
export class CaseCache implements Cacheable<Record<string, Case[]>> {
/**
* Constructor
* @param name the name of the cache. This must be globally unique.
* @param builders a Record of case-list name to case-list builder.
*/
constructor(name: string, builders: Record<string, CaseListBuilder>) {
this.path = `webgpu/shader/execution/${name}.bin`;
this.builders = builders;
}
/** get() returns the list of cases with the given name */
public async get(name: string): Promise<Case[]> {
const data = await dataCache.fetch(this);
return data[name];
}
/**
* build() implements the Cacheable.build interface.
* @returns the data.
*/
build(): Promise<Record<string, Case[]>> {
const built: Record<string, Case[]> = {};
for (const name in this.builders) {
const cases = this.builders[name]();
built[name] = cases;
}
return Promise.resolve(built);
}
/**
* serialize() implements the Cacheable.serialize interface.
* @returns the serialized data.
*/
serialize(data: Record<string, Case[]>): Uint8Array {
const maxSize = 32 << 20; // 32MB - max size for a file
const stream = new BinaryStream(new ArrayBuffer(maxSize));
stream.writeU32(Object.keys(data).length);
for (const name in data) {
stream.writeString(name);
stream.writeArray(data[name], serializeCase);
}
return stream.buffer();
}
/**
* deserialize() implements the Cacheable.deserialize interface.
* @returns the deserialize data.
*/
deserialize(array: Uint8Array): Record<string, Case[]> {
const s = new BinaryStream(array.buffer);
const casesByName: Record<string, Case[]> = {};
const numRecords = s.readU32();
for (let i = 0; i < numRecords; i++) {
const name = s.readString();
const cases = s.readArray(deserializeCase);
casesByName[name] = cases;
}
return casesByName;
}
public readonly path: string;
private readonly builders: Record<string, CaseListBuilder>;
}
export function makeCaseCache(name: string, builders: Record<string, CaseListBuilder>): CaseCache {
return new CaseCache(name, builders);
}