Discrepancy of LayerNorm between frameworks #1797
-
I have been trying to port the swin transformer torch weights to flax and I have spent over 3 days just to get the outputs of each layer right. I checked every single module of mine and in the end the only error I could find was that LayerNorms work slightly different in torch and flax. Even though the differences seem small in the beginning, in the larger picture, these small differences blew up considerably for me when layers containing this Norm were stacked. https://colab.research.google.com/drive/1d4O-3mvlXrqqoXEOH_mWe0r0unv8MyuO?usp=sharing It turns out that all three give different outputs for the same input vector after the 2nd or 3rd decimal place which is very odd. A notable difference is that TF and torch outputs are more similar than TF/flax or torch/flax. If anyone knows what is causing this or how I can possibly fix this then please let me know. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Very cool that you are porting torch models! 😄 . We actually had a recent Github Discussion on exactly this issue, did you use the search function? #1709 Like I say there, we are using an approximation for efficiency, but I think we should consider adding a flag so the exact value can be computed. |
Beta Was this translation helpful? Give feedback.
Very cool that you are porting torch models! 😄 .
We actually had a recent Github Discussion on exactly this issue, did you use the search function? #1709
Like I say there, we are using an approximation for efficiency, but I think we should consider adding a flag so the exact value can be computed.