diff --git a/hwy/contrib/dot/dot-inl.h b/hwy/contrib/dot/dot-inl.h index e48796aaa7..10fd897918 100644 --- a/hwy/contrib/dot/dot-inl.h +++ b/hwy/contrib/dot/dot-inl.h @@ -174,9 +174,17 @@ struct Dot { (kAssumptions & kMultipleOfVector) != 0; constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; - // Won't be able to do a full vector load without padding => scalar loop. + // Won't be able to do a full vector load without padding. Use a scalar + // loop under Clang. GCC has very suboptimal codegen for scalar BF16->float + // conversions, so use vector ops with LoadN instead. + // TODO: https://github.com/google/highway/pull/2703 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && HWY_UNLIKELY(num_elements < NF)) { +#if HWY_COMPILER_GCC_ACTUAL + const VF a = LoadN(df, pa, num_elements); + const VF b = PromoteTo(df, LoadN(dbfh, pb, num_elements)); + return ReduceSum(df, Mul(a, b)); +#else // Only 2x unroll to avoid excessive code size. float sum0 = 0.0f; float sum1 = 0.0f; @@ -189,6 +197,7 @@ struct Dot { sum1 += pa[i] * ConvertScalarTo(pb[i]); } return sum0 + sum1; +#endif } // Compiler doesn't make independent sum* accumulators, so unroll manually. @@ -279,9 +288,19 @@ struct Dot { (kAssumptions & kMultipleOfVector) != 0; constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; - // Won't be able to do a full vector load without padding => scalar loop. + // Won't be able to do a full vector load without padding. Use a scalar + // loop under Clang. GCC has very suboptimal codegen for scalar BF16->float + // conversions, so use vector ops with LoadN instead. + // TODO: https://github.com/google/highway/pull/2703 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && HWY_UNLIKELY(num_elements < N)) { +#if HWY_COMPILER_GCC_ACTUAL + const auto a = LoadN(d, pa, num_elements); + const auto b = LoadN(d, pb, num_elements); + V sum1 = Zero(df32); + V sum0 = ReorderWidenMulAccumulate(df32, a, b, Zero(df32), sum1); + return ReduceSum(df32, Add(sum0, sum1)); +#else float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for.. float sum1 = 0.0f; // this unlikely(?) case. for (; i + 2 <= num_elements; i += 2) { @@ -292,6 +311,7 @@ struct Dot { sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]); } return sum0 + sum1; +#endif } // See comment in the other Compute() overload. Unroll 2x, but we need