Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 8ea6dd8

Browse files
authored
Compute min and max of numerical columns over a Dataset. (#714)
* Compute min and max of numerical columns over a Dataset. * Adds a few simple stream utilities, e.g. streamFromIncrementing, forEach, and resolveFully * Don't test with 'WEBGL_FLOAT_TEXTURE_ENABLED': false; numerical bug
1 parent 61bf615 commit 8ea6dd8

File tree

6 files changed

+275
-11
lines changed

6 files changed

+275
-11
lines changed

src/contrib/data/dataset.ts

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import * as seedrandom from 'seedrandom';
2020

2121
import {BatchDataset} from './batch_dataset';
22+
import {computeDatasetStatistics, DatasetStatistics} from './statistics';
2223
import {DataStream} from './streams/data_stream';
2324
import {streamFromConcatenated} from './streams/data_stream';
2425
import {streamFromFunction} from './streams/data_stream';
@@ -46,6 +47,43 @@ export abstract class Dataset {
4647
*/
4748
abstract async getStream(): Promise<DataStream<DatasetElement>>;
4849

50+
// TODO(soergel): Make Datasets report whether repeated getStream() calls
51+
// produce the same result (e.g., reading from a file) or different results
52+
// (e.g., from the webcam). Currently we don't make this distinction but it
53+
// could be important for the user to know.
54+
// abstract isDeterministic(): boolean;
55+
56+
// TODO(soergel): memoize computeStatistics()
57+
58+
/**
59+
* Gathers statistics from a Dataset (or optionally from a sample).
60+
*
61+
* This obtains a stream from the Dataset and, by default, does a full pass
62+
* to gather the statistics.
63+
*
64+
* Statistics may be computed over a sample. However: simply taking the first
65+
* n items from the stream may produce a poor estimate if the stream is
66+
* ordered in some way.
67+
*
68+
* A truly random shuffle of the stream would of course solve this
69+
* problem, but many streams do not allow for this, instead providing only a
70+
* sliding-window shuffle. A partially-randomized sample could be obtained by
71+
* shuffling over a window followed by taking the first n samples (where n is
72+
* smaller than the shuffle window size). However there is little point in
73+
* using that approach here, because the cost is likely dominated by obtaining
74+
* the data. Thus, once we have filled our shuffle buffer, we may as well use
75+
* all of that data instead of sampling from it.
76+
*
77+
* @param sampleSize The number of examples to take from the (possibly
78+
* shuffled) stream.
79+
* @param shuffleWindowSize The size of the shuffle window to use, if any.
80+
* (Not recommended, as described above).
81+
*/
82+
async computeStatistics(sampleSize?: number, shuffleWindowSize?: number):
83+
Promise<DatasetStatistics> {
84+
return computeDatasetStatistics(this, sampleSize, shuffleWindowSize);
85+
}
86+
4987
/**
5088
* Filters this dataset according to `predicate`.
5189
*
@@ -163,6 +201,8 @@ export abstract class Dataset {
163201
});
164202
}
165203

204+
// TODO(soergel): deep sharded shuffle, where supported
205+
166206
/**
167207
* Randomly shuffles the elements of this dataset.
168208
*

src/contrib/data/statistics.ts

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
* =============================================================================
17+
*/
18+
19+
import {Scalar, Tensor} from '../../tensor';
20+
21+
import {Dataset} from './dataset';
22+
import {ElementArray} from './types';
23+
24+
// TODO(soergel): Flesh out collected statistics.
25+
// For numeric columns we should provide mean, stddev, histogram, etc.
26+
// For string columns we should provide a vocabulary (at least, top-k), maybe a
27+
// length histogram, etc.
28+
// Collecting only numeric min and max is just the bare minimum for now.
29+
30+
export type NumericColumnStatistics = {
31+
min: number; max: number;
32+
};
33+
34+
export type DatasetStatistics = {
35+
[key: string]: NumericColumnStatistics
36+
};
37+
38+
/**
39+
* Provides a function that scales numeric values into the [0, 1] interval.
40+
*
41+
* @param min the lower bound of the inputs, which should be mapped to 0.
42+
* @param max the upper bound of the inputs, which should be mapped to 1,
43+
* @return A function that maps an input ElementArray to a scaled ElementArray.
44+
*/
45+
export function scaleTo01(min: number, max: number): (value: ElementArray) =>
46+
ElementArray {
47+
const range = max - min;
48+
const minTensor: Tensor = Scalar.new(min);
49+
const rangeTensor: Tensor = Scalar.new(range);
50+
return (value: ElementArray): ElementArray => {
51+
if (typeof (value) === 'string') {
52+
throw new Error('Can\'t scale a string.');
53+
} else {
54+
if (value instanceof Tensor) {
55+
const result = value.sub(minTensor).div(rangeTensor);
56+
return result;
57+
} else if (value instanceof Array) {
58+
return value.map(v => (v - min) / range);
59+
} else {
60+
return (value - min) / range;
61+
}
62+
}
63+
};
64+
}
65+
66+
export async function computeDatasetStatistics(
67+
dataset: Dataset, sampleSize?: number,
68+
shuffleWindowSize?: number): Promise<DatasetStatistics> {
69+
let stream = await dataset.getStream();
70+
// TODO(soergel): allow for deep shuffle where possible.
71+
if (shuffleWindowSize != null) {
72+
stream = stream.shuffle(shuffleWindowSize);
73+
}
74+
if (sampleSize != null) {
75+
stream = stream.take(sampleSize);
76+
}
77+
78+
// TODO(soergel): prepare the column objects based on a schema.q
79+
const result: DatasetStatistics = {};
80+
81+
await stream.forEach(e => {
82+
for (const key in e) {
83+
const value = e[key];
84+
if (typeof (value) === 'string') {
85+
} else {
86+
let recordMin: number;
87+
let recordMax: number;
88+
if (value instanceof Tensor) {
89+
recordMin = value.min().dataSync()[0];
90+
recordMax = value.max().dataSync()[0];
91+
} else if (value instanceof Array) {
92+
recordMin = value.reduce((a, b) => Math.min(a, b));
93+
recordMax = value.reduce((a, b) => Math.max(a, b));
94+
} else if (!isNaN(value) && isFinite(value)) {
95+
recordMin = value;
96+
recordMax = value;
97+
} else {
98+
// TODO(soergel): don't throw; instead record the stats as "unknown".
99+
throw new Error(`Cannot compute statistics: ${key} = ${value}`);
100+
}
101+
let columnStats: NumericColumnStatistics = result[key];
102+
if (columnStats == null) {
103+
columnStats = {
104+
min: Number.POSITIVE_INFINITY,
105+
max: Number.NEGATIVE_INFINITY
106+
};
107+
result[key] = columnStats;
108+
}
109+
columnStats.min = Math.min(columnStats.min, recordMin);
110+
columnStats.max = Math.max(columnStats.max, recordMax);
111+
}
112+
}
113+
// Returning undefined or null (i.e, type void) would indicate that the
114+
// stream is exhausted. So, we have to return *something* in order for
115+
// resolveFully() to operate.
116+
return {};
117+
});
118+
return result;
119+
}

src/contrib/data/statistics_test.ts

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {ALL_FLOAT_ENVS, describeWithFlags} from '../../test_util';
19+
20+
import {TestDataset} from './dataset_test';
21+
import {scaleTo01} from './statistics';
22+
23+
describeWithFlags('makeDatasetStatistics', ALL_FLOAT_ENVS, () => {
24+
it('computes numeric min and max over numbers, arrays, and Tensors', done => {
25+
const ds = new TestDataset().skip(55);
26+
ds.computeStatistics()
27+
.then(stats => {
28+
expect(stats['number'].min).toEqual(55);
29+
expect(stats['number'].max).toEqual(99);
30+
// The TestDataset includes cubes of the indices
31+
expect(stats['numberArray'].min).toEqual(55);
32+
expect(stats['numberArray'].max).toEqual(99 * 99 * 99);
33+
expect(stats['Tensor'].min).toEqual(55);
34+
expect(stats['Tensor'].max).toEqual(99 * 99 * 99);
35+
})
36+
.then(done)
37+
.catch(done.fail);
38+
});
39+
});
40+
41+
describeWithFlags('scaleTo01', ALL_FLOAT_ENVS, () => {
42+
it('scales numeric data to the [0, 1] interval', done => {
43+
const ds = new TestDataset().skip(55);
44+
const scaleFn = scaleTo01(55, 99 * 99 * 99);
45+
const scaledDataset = ds.map(x => ({'Tensor': scaleFn(x['Tensor'])}));
46+
47+
scaledDataset.computeStatistics()
48+
.then(stats => {
49+
expect(stats['Tensor'].min).toBeCloseTo(0);
50+
expect(stats['Tensor'].max).toBeCloseTo(1);
51+
})
52+
.then(done)
53+
.catch(done.fail);
54+
});
55+
});

src/contrib/data/streams/data_stream.ts

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ export function streamFromItems<T>(items: T[]): DataStream<T> {
3434
return new ArrayStream(items);
3535
}
3636

37+
/**
38+
* Create a `DataStream` of incrementing integers.
39+
*/
40+
export function streamFromIncrementing(start: number): DataStream<number> {
41+
let i = start;
42+
return streamFromFunction(() => i++);
43+
}
44+
3745
/**
3846
* Create a `DataStream` from a function.
3947
*/
@@ -95,13 +103,29 @@ export abstract class DataStream<T> {
95103
async collectRemaining(): Promise<T[]> {
96104
const result: T[] = [];
97105
let x = await this.next();
98-
while (x !== undefined) {
106+
while (x != null) {
99107
result.push(x);
100108
x = await this.next();
101109
}
102110
return result;
103111
}
104112

113+
/**
114+
* Draw items from the stream until it is exhausted.
115+
*
116+
* This can be useful when the stream has side effects but no output. In
117+
* that case, calling this function guarantees that the stream will be fully
118+
* processed.
119+
*/
120+
async resolveFully(): Promise<void> {
121+
let x = await this.next();
122+
while (x != null) {
123+
x = await this.next();
124+
}
125+
}
126+
127+
// TODO(soergel): Implement reduce() etc.
128+
105129
/**
106130
* Filters this stream according to `predicate`.
107131
*
@@ -126,6 +150,15 @@ export abstract class DataStream<T> {
126150
return new MapStream(this, transform);
127151
}
128152

153+
/**
154+
* Apply a function to every element of the stream.
155+
*
156+
* @param f A function to apply to each stream element.
157+
*/
158+
async forEach(f: (value: T) => {}|Promise<{}>): Promise<void> {
159+
return this.map(f).resolveFully();
160+
}
161+
129162
/**
130163
* Groups elements into batches.
131164
*
@@ -185,6 +218,8 @@ export abstract class DataStream<T> {
185218
return new PrefetchStream(this, bufferSize);
186219
}
187220

221+
// TODO(soergel): deep sharded shuffle, where supported
222+
188223
/**
189224
* Randomly shuffles the elements of this stream.
190225
*
@@ -193,8 +228,8 @@ export abstract class DataStream<T> {
193228
* @param seed: (Optional.) An integer specifying the random seed that will
194229
* be used to create the distribution.
195230
*/
196-
shuffle(bufferSize: number, seed?: string): DataStream<T> {
197-
return new ShuffleStream(this, bufferSize, seed);
231+
shuffle(windowSize: number, seed?: string): DataStream<T> {
232+
return new ShuffleStream(this, windowSize, seed);
198233
}
199234
}
200235

@@ -486,9 +521,9 @@ export class ShuffleStream<T> extends PrefetchStream<T> {
486521
private upstreamExhausted = false;
487522

488523
constructor(
489-
protected upstream: DataStream<T>, protected bufferSize: number,
524+
protected upstream: DataStream<T>, protected windowSize: number,
490525
seed?: string) {
491-
super(upstream, bufferSize);
526+
super(upstream, windowSize);
492527
this.random = seedrandom(seed);
493528
}
494529

src/contrib/data/streams/data_stream_test.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
* =============================================================================
1717
*/
1818

19-
import {DataStream} from './data_stream';
19+
import {DataStream, streamFromIncrementing} from './data_stream';
2020
import {streamFromConcatenated} from './data_stream';
2121
import {streamFromConcatenatedFunction} from './data_stream';
2222
import {streamFromFunction, streamFromItems} from './data_stream';
@@ -193,6 +193,16 @@ describe('DataStream', () => {
193193
.catch(done.fail);
194194
});
195195

196+
it('can be created with incrementing integers', done => {
197+
const readStream = streamFromIncrementing(0).take(7);
198+
readStream.collectRemaining()
199+
.then(result => {
200+
expect(result).toEqual([0, 1, 2, 3, 4, 5, 6]);
201+
})
202+
.then(done)
203+
.catch(done.fail);
204+
});
205+
196206
it('can be concatenated', done => {
197207
const a = streamFromItems([1, 2, 3]);
198208
const b = streamFromItems([4, 5, 6]);

src/test_util.ts

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,21 @@ import {Tensor} from './tensor';
2222
import {DataType, TypedArray} from './types';
2323
import * as util from './util';
2424

25-
export const WEBGL_ENVS: Features[] = [
25+
const WEBGL_FLOAT_ENVS: Features[] = [
2626
{'BACKEND': 'webgl', 'WEBGL_FLOAT_TEXTURE_ENABLED': true, 'WEBGL_VERSION': 1},
27-
{'BACKEND': 'webgl', 'WEBGL_FLOAT_TEXTURE_ENABLED': true, 'WEBGL_VERSION': 2},
2827
{
2928
'BACKEND': 'webgl',
30-
'WEBGL_FLOAT_TEXTURE_ENABLED': false,
31-
'WEBGL_VERSION': 1
32-
},
29+
'WEBGL_FLOAT_TEXTURE_ENABLED': true,
30+
'WEBGL_VERSION': 2
31+
}
3332
];
33+
export const WEBGL_ENVS = WEBGL_FLOAT_ENVS.concat([{
34+
'BACKEND': 'webgl',
35+
'WEBGL_FLOAT_TEXTURE_ENABLED': false,
36+
'WEBGL_VERSION': 1
37+
}]);
3438
export const CPU_ENVS: Features[] = [{'BACKEND': 'cpu'}];
39+
export const ALL_FLOAT_ENVS = WEBGL_FLOAT_ENVS.concat(CPU_ENVS);
3540
export const ALL_ENVS = WEBGL_ENVS.concat(CPU_ENVS);
3641

3742
/** Accuracy for tests. */

0 commit comments

Comments
 (0)