Skip to content

Commit 636c616

Browse files
authored
[tfjs-data] support async generator (#8408)
1 parent 3daf152 commit 636c616

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

tfjs-data/src/readers.ts

+9-10
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,12 @@ export function func<T extends TensorContainer>(
140140

141141
/**
142142
* Create a `Dataset` that produces each element from provided JavaScript
143-
* generator, which is a function*
144-
* (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions),
145-
* or a function that returns an
146-
* iterator
147-
* (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions).
143+
* generator, which is a function that returns a (potentially async) iterator.
148144
*
149-
* The returned iterator should have `.next()` function that returns element in
150-
* format of `{value: TensorContainer, done:boolean}`.
145+
* For more information on iterators and generators, see
146+
* https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators .
147+
* For the iterator protocol, see
148+
* https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Iteration_protocols .
151149
*
152150
* Example of creating a dataset from an iterator factory:
153151
* ```js
@@ -188,8 +186,8 @@ export function func<T extends TensorContainer>(
188186
* await ds.forEachAsync(e => console.log(e));
189187
* ```
190188
*
191-
* @param generator A JavaScript generator function that returns a JavaScript
192-
* iterator.
189+
* @param generator A JavaScript function that returns
190+
* a (potentially async) JavaScript iterator.
193191
*
194192
* @doc {
195193
* heading: 'Data',
@@ -199,7 +197,8 @@ export function func<T extends TensorContainer>(
199197
* }
200198
*/
201199
export function generator<T extends TensorContainer>(
202-
generator: () => Iterator<T>| Promise<Iterator<T>>): Dataset<T> {
200+
generator: () => Iterator<T> | Promise<Iterator<T>> | AsyncIterator<T>,
201+
): Dataset<T> {
203202
return datasetFromIteratorFn(async () => {
204203
const gen = await generator();
205204
return iteratorFromFunction(() => gen.next());

tfjs-data/src/readers_test.ts

+15
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,21 @@ describeAllEnvs('readers', () => {
4545
expect(result).toEqual([0, 1, 2, 3, 4]);
4646
});
4747

48+
it('generate dataset from async generator', async () => {
49+
async function* dataGenerator() {
50+
const numElements = 5;
51+
let index = 0;
52+
while (index < numElements) {
53+
const x = index;
54+
index++;
55+
yield x;
56+
}
57+
}
58+
const ds = tfd.generator(dataGenerator);
59+
const result = await ds.toArrayForTest();
60+
expect(result).toEqual([0, 1, 2, 3, 4]);
61+
});
62+
4863
it('generate multiple datasets from JavaScript generator', async () => {
4964
function* dataGenerator() {
5065
const numElements = 5;

0 commit comments

Comments
 (0)