Replies: 1 comment 2 replies
-
Hey @alekhka, I don't why this would happen in principle from Flax's perspective. |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi!
I adapted the imagenet example code and I am trying to train and Conv-RNN with it. ResNet training works fine. My model looks something like this -
I am applying the RNN cell only once during compilation to avoid trigger compilations for every application of the cell. When I train this with say 6 timesteps, the loss goes to NaN quickly. But when I train with fewer timesteps, it seems to do fine. Peculiarly, if I train with
JAX_DEBUG_NANS=True
it trains fine (but lot slower) with even large timesteps with no NaNs. How can I go about debuggin what is causing these NaNs?Beta Was this translation helpful? Give feedback.
All reactions