Example for flax.jax_utils.prefetch_to_device with TFDS dataset #3869
Replies: 2 comments
-
Have you tried calling |
Beta Was this translation helpful? Give feedback.
0 replies
-
It took me a while to understand how the batch dimensions were changed. I found it did not help with the training performance, at least not for a single GPU case. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I want to use
flax.jax_utils.prefetch_to_device
to preload data to my GPU. But I could not figure out how to work with a TFDS dataset.The dataset is simple.
The above code works without using
prefetch_to_device
. But I cannot simply callprefetch_to_device(train_ds.as_numpy_iterator())
becauseprefetch_to_device
requires the first dimension of the iterator to be the number of devices. I only have one GPU, so it expects theas_numpy_iterator
to return an iterator in the shape of(1, batch_size, ...)
.I cannot find a way to make
as_numpy_iterator
to return one more dimension. Maybe I need a new way to constract the training dataset? Anyway, it'd be very appreciated if I could get some code snippet or examples.Thank you.
Beta Was this translation helpful? Give feedback.
All reactions