-
To demonstrate how to participate in a Kaggle competition with Flax/ JAX, I have made a Kaggle notebook where I apply transfer learning with a pre-trained ResNet: Currently I am running the notebook in GPU because TPU is always short of supply. I found that the training speed in GPU is surprisingly speed (i.e. ~2 sec/ iteration). I have been debugging that for a while but didn't get any luck. Therefore, I would like to share my work here and see if any experienced developers could help identify the issues in my notebook. There is a lack of Flax use cases in Kaggle community, so I believe, if the issue gets fixed, this notebook could serve as a great reference for Flax users who wants to take part in Kaggle. Any help would be appreciated! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Hi @riven314, great you are experimenting with this! Some questions:
Thanks in advance! |
Beta Was this translation helpful? Give feedback.
-
Hi @marcvanzee |
Beta Was this translation helpful? Give feedback.
Hi @marcvanzee
Thanks for your response! After some debugging, I have found the main bottleneck to be resizing on-the-fly during dataloading.
Replacing it by a pre-resized datasets significantly improved the training speed.
As the bottleneck is not related to Flax, I will close this discussion for now.