Skip to content

Commit 0478c13

Browse files
committed
Use LoadN for short BF16 dot product inputs instead of scalar conversion
This changes the prologue fallback logic for inputs with fewer elements than that of a full vector's lane count to use partial vector loads with implicit zeroing (via `LoadN`) instead of scalar loads and conversions. Rationale: Falling back to scalar conversion of `__bf16` to `float` elements will on GCC 13+ end up silently generating a call to the `__extendbfsf2` library function per conversion, rather than just zero-extending to 32 bits and shifting left by 16. This both bloats the generated code and has a substantial runtime cost; on x64 it takes take more time to process 8 BF16 elements with the scalar fallback path than it takes to process 1024 BF16 elements in the vector path..! Partial vector loads are introduced for the following dot product function overloads: * `bfloat16_t` vs. `bfloat16_t` * `float` vs `bfloat16_t`
1 parent 7f59ca4 commit 0478c13

File tree

1 file changed

+10
-24
lines changed

1 file changed

+10
-24
lines changed

hwy/contrib/dot/dot-inl.h

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -174,21 +174,12 @@ struct Dot {
174174
(kAssumptions & kMultipleOfVector) != 0;
175175
constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
176176

177-
// Won't be able to do a full vector load without padding => scalar loop.
177+
// Won't be able to do a full vector load without padding => partial load.
178178
if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
179179
HWY_UNLIKELY(num_elements < NF)) {
180-
// Only 2x unroll to avoid excessive code size.
181-
float sum0 = 0.0f;
182-
float sum1 = 0.0f;
183-
size_t i = 0;
184-
for (; i + 2 <= num_elements; i += 2) {
185-
sum0 += pa[i + 0] * ConvertScalarTo<float>(pb[i + 0]);
186-
sum1 += pa[i + 1] * ConvertScalarTo<float>(pb[i + 1]);
187-
}
188-
for (; i < num_elements; ++i) {
189-
sum1 += pa[i] * ConvertScalarTo<float>(pb[i]);
190-
}
191-
return sum0 + sum1;
180+
const VF a = LoadN(df, pa, num_elements);
181+
const VF b = PromoteTo(df, LoadN(dbfh, pb, num_elements));
182+
return ReduceSum(df, Mul(a, b));
192183
}
193184

194185
// Compiler doesn't make independent sum* accumulators, so unroll manually.
@@ -279,19 +270,14 @@ struct Dot {
279270
(kAssumptions & kMultipleOfVector) != 0;
280271
constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
281272

282-
// Won't be able to do a full vector load without padding => scalar loop.
273+
// Won't be able to do a full vector load without padding => partial load.
283274
if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
284275
HWY_UNLIKELY(num_elements < N)) {
285-
float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for..
286-
float sum1 = 0.0f; // this unlikely(?) case.
287-
for (; i + 2 <= num_elements; i += 2) {
288-
sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]);
289-
sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]);
290-
}
291-
if (i < num_elements) {
292-
sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]);
293-
}
294-
return sum0 + sum1;
276+
const auto a = LoadN(d, pa, num_elements);
277+
const auto b = LoadN(d, pb, num_elements);
278+
V sum1 = Zero(df32);
279+
V sum0 = ReorderWidenMulAccumulate(df32, a, b, Zero(df32), sum1);
280+
return ReduceSum(df32, Add(sum0, sum1));
295281
}
296282

297283
// See comment in the other Compute() overload. Unroll 2x, but we need

0 commit comments

Comments
 (0)