-
Notifications
You must be signed in to change notification settings - Fork 387
Use LoadN for short BF16 dot product inputs instead of scalar conversion
#2703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Use LoadN for short BF16 dot product inputs instead of scalar conversion
#2703
Conversation
…rsion This changes the prologue fallback logic for inputs with fewer elements than that of a full vector's lane count to use partial vector loads with implicit zeroing (via `LoadN`) instead of scalar loads and conversions. Rationale: Falling back to scalar conversion of `__bf16` to `float` elements will on GCC 13+ end up silently generating a call to the `__extendbfsf2` library function per conversion, rather than just zero-extending to 32 bits and shifting left by 16. This both bloats the generated code and has a substantial runtime cost; on x64 it takes take more time to process 8 BF16 elements with the scalar fallback path than it takes to process 1024 BF16 elements in the vector path..! Partial vector loads are introduced for the following dot product function overloads: * `bfloat16_t` vs. `bfloat16_t` * `float` vs `bfloat16_t`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thank you for updating this!
|
Re-importing so we can see what went wrong with internal CI.. |
|
Looks like a compiler crash, sigh. I raised an LLVM issue, linked above. |
|
Changing V sum1 = Zero(df32);
V sum0 = ReorderWidenMulAccumulate(df32, a, b, Zero(df32), sum1);
return ReduceSum(df32, Add(sum0, sum1));to const auto ab_lo = Mul(PromoteLowerTo(df32, a), PromoteLowerTo(df32, b));
const auto ab_hi = Mul(PromoteUpperTo(df32, a), PromoteUpperTo(df32, b));
return ReduceSum(df32, Add(ab_lo, ab_hi));seems to avoid triggering the Clang crash (Godbolt), and should be semantically identical. Performance-wise it also looks very similar to me (after re-running the short input BF16xBF16 benchmarks from the original comment). If this sounds good (please let me know if there are preferred alternate approaches), I can update this branch with these changes. |
|
I agree that's semantically equivalent, and could be a viable workaround, but this should really be a two-instruction sequence : BFDOT+ADDV. Perhaps we can put your workaround inside an |
|
Hmm it looks like this crashes under The slight irony of the situation is that this is all due to GCC BF16 scalar codegen, so the scalar fallback code under Clang is probably perfectly fine as it is... 🙂 |
|
;) This is BTW the third compiler issue we are seeing within a week. It seems LLVM's testing coverage is not yet there. I'd be fine to land your current change (with ReorderWidenMulAccumulate) |
Clang does not need the `LoadN` fallback and can happily use the scalar code. Coincidentally, this also avoids triggering a Clang crash when `ReorderWidenMulAccumulate` is used alongside capped vectors.
Changes have been made. Looks like some build failures are happening in the realm of FWIW we created an upstream bug report to GCC a few weeks ago about this issue (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=121853). If this behavior remains unchanged, scalar conversion will continue to be a bit of a performance mine field. A sledge hammer option (that I'm also pretty sure would remove the need for this particular PR entirely) is to change HWY_API HWY_BF16_CONSTEXPR float F32FromBF16(bfloat16_t bf) {
-#if HWY_HAVE_SCALAR_BF16_OPERATORS
+#if HWY_HAVE_SCALAR_BF16_OPERATORS && !HWY_COMPILER_GCC_ACTUAL
return static_cast<float>(bf);
#else
return BitCastScalar<float>(static_cast<uint32_t>( |
|
Thanks for making the change! Re-imported for internal review, hopefully landing soon.
ACK, triggered by clang warning fixes, thanks for the heads up.
Even if this is fixed tomorrow, we'd still have to live with current compilers for several years. I am open to this suggestion. Want to send a separate PR to change the default value of |
|
I have mixed feelings. If it weren't for the clang bug, it would still be better to have vector code for the remainders. As is, it's indeed regrettable to have it enabled only for GCC. Maybe we can keep this open and un-merged until that's fixed, WDYT? |
Sounds good to me 👍 |
Resolves issue #2699.
This changes the fallback logic for inputs with fewer elements than that of a full vector's lane count to use partial vector loads with implicit zeroing (via
LoadN) instead of scalar loads and conversions.I think these should be fairly conservative changes, as the main vector loops remain untouched. An alternative approach would be to generalize the final boundary handling code for <
Nremaining elements to useLoadNinstead ofFirstN+"adjusted"LoadU, which would remove the need for such an initial special case in the first place. Since I don't have access to enough exotic testing hardware to know how/if this would affect performance on non-x64/arm64 systems, I skipped on doing that. 🙂It's possible that the use of
ReorderWidenMulAccumulateshould be replaced with explicitPromote(Upper|Lower)Toand multiplication since the input sums will always be zero, but I am not sure how much it matters in practice. This just mirrors what's in the main loops.Rationale/background:
Falling back to scalar conversion of
__bf16tofloatelements will on GCC 13+ end up silently generating a call to the__extendbfsf2library function per conversion, rather than just zero-extending to 32 bits and shifting left by 16. This both bloats the generated code and has a substantial runtime cost; on x64 it takes take more time to process 8 BF16 elements with the scalar fallback path than it takes to process 1024 BF16 elements in the vector path..!Partial vector loads are introduced for the following dot product function overloads:
bfloat16_tvs.bfloat16_t(usedF32FromBF16)floatvs.bfloat16_t(usedConvertScalarTo<float>which transitively callsF32FromBF16)Some before/after graphs from a Sapphire Rapids system:
bfloat16_tvs.bfloat16_t:Before:

After:

floatvs.bfloat16_t:Before:

After:
