Understanding TPU efficiency in the examples #1408
-
Hello, This is more a question about TPUs than about Flax, but this seemed like place to ask it. I'm trying to understand why TPUs seem to be extremely fast in some cases and not in others. In particular, I'm looking at the For ImageNet, it looks like 8x TPU v3 are much faster than 8x V100 (and about the same compared to the GPU running with mixed precision). For the PixelCNN, it looks like 8x TPU v3 are 3x slower than 8x V100! Is this correct, and if so, why is there such a large difference? Thanks, |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
The runtime comparisons for ImageNet is what we usually see when comparing GPUs and TPUS. However, since TPU uses Please note the slowness in PixelCNN++ is a known issue (#458). Copying @j-towns's latest response on that issue: "Based on some other generative modelling work which I've been doing on TPU lately, it seems the precision parameter to layers like Conv makes a small but noticable difference to training stability and to test performance. It might be worth finding out whether this affects PixelCNN++ and perhaps adding a command line argument to enable higher precisions." We should probably add a link to the issue in the PixelCNN++ readme! |
Beta Was this translation helpful? Give feedback.
-
Great, thank you for the very comprehensive response! |
Beta Was this translation helpful? Give feedback.
The runtime comparisons for ImageNet is what we usually see when comparing GPUs and TPUS. However, since TPU uses
bfloat16
for matrix multiplications, this can in some extreme cases affect training stability, which is what probably happens in PixelCNN++. In that example we expect the test loss to be below 2.92, which requires a very precise setup, and using TPU here actually slows down training.Please note the slowness in PixelCNN++ is a known issue (#458). Copying @j-towns's latest response on that issue:
"Based on some other generative modelling work which I've been doing on TPU lately, it seems the precision parameter to layers like Conv makes a small but noticable difference to train…