Skip to content

Commit d9a43a8

Browse files
authored
Add Quantized GEMM kernel for Arm NEON on macOS ARM (#249)
1 parent 4e8805a commit d9a43a8

File tree

6 files changed

+465
-60
lines changed

6 files changed

+465
-60
lines changed

src/ArchAvailable.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ namespace kiwi
5959
static_cast<std::ptrdiff_t>(ArchType::sse4_1)
6060
#endif
6161
#if CPUINFO_ARCH_ARM64
62-
//static_cast<std::ptrdiff_t>(ArchType::neon)
62+
static_cast<std::ptrdiff_t>(ArchType::neon)
6363
#endif
6464
#else
6565
#ifdef KIWI_ARCH_X86_64
@@ -72,7 +72,7 @@ namespace kiwi
7272
static_cast<std::ptrdiff_t>(ArchType::sse4_1)
7373
#endif
7474
#ifdef KIWI_ARCH_ARM64
75-
//static_cast<std::ptrdiff_t>(ArchType::neon)
75+
static_cast<std::ptrdiff_t>(ArchType::neon)
7676
#endif
7777
#endif
7878
>;

src/CoNgramModel.cpp

Lines changed: 122 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
#include <iostream>
1+
#include <iostream>
22
#include <fstream>
3+
#include <cstring>
4+
#include <limits>
35
#include "PathEvaluator.hpp"
46
#include "Joiner.hpp"
57
#include "Kiwi.hpp"
@@ -626,7 +628,8 @@ namespace kiwi
626628
if constexpr (quantized)
627629
{
628630
float scale;
629-
eptr += requantizePackedInts<arch>(optr, scale, eptr, header.dim, header.qbit, header.qgroup, true);
631+
const bool toUint8 = arch != ArchType::neon;
632+
eptr += requantizePackedInts<arch>(optr, scale, eptr, header.dim, header.qbit, header.qgroup, toUint8);
630633
optr += header.dim;
631634
*reinterpret_cast<float*>(optr) = scale;
632635
optr += sizeof(float);
@@ -678,11 +681,22 @@ namespace kiwi
678681

679682
if constexpr (quantized)
680683
{
681-
qgemm::invNormU8<arch>(
682-
header.contextSize, header.dim,
683-
getContextQuantEmb(0), contextEmbStride(),
684-
const_cast<float*>(invNormContextPtr)
685-
);
684+
if constexpr (arch == ArchType::neon)
685+
{
686+
qgemm::invNormS8<arch>(
687+
header.contextSize, header.dim,
688+
getContextQuantEmbS8(0), contextEmbStride(),
689+
const_cast<float*>(invNormContextPtr)
690+
);
691+
}
692+
else
693+
{
694+
qgemm::invNormU8<arch>(
695+
header.contextSize, header.dim,
696+
getContextQuantEmb(0), contextEmbStride(),
697+
const_cast<float*>(invNormContextPtr)
698+
);
699+
}
686700
qgemm::invNormS8<arch>(
687701
header.vocabSize, header.dim,
688702
getOutputQuantEmb(0), outputEmbStride(),
@@ -711,7 +725,8 @@ namespace kiwi
711725
if (quantized)
712726
{
713727
float scale;
714-
eptr += requantizePackedInts<arch>(optr, scale, eptr, header.dim, header.qbit, header.qgroup, true);
728+
const bool toUint8 = arch != ArchType::neon;
729+
eptr += requantizePackedInts<arch>(optr, scale, eptr, header.dim, header.qbit, header.qgroup, toUint8);
715730
optr += header.dim;
716731
*reinterpret_cast<float*>(optr) = scale;
717732
optr += sizeof(float);
@@ -771,6 +786,7 @@ namespace kiwi
771786
eptr += sizeof(uint16_t);
772787
}
773788
}
789+
774790
}
775791

776792
template<ArchType arch, class KeyType, class VlKeyType, size_t windowSize, bool quantized>
@@ -856,13 +872,25 @@ namespace kiwi
856872
{
857873
const auto* contextPtr = getContextQuantEmb(unpackedContextId);
858874
const auto* outputPtr = getOutputQuantEmb(next);
859-
int32_t acc = qgemm::dotprod<arch>(contextPtr, outputPtr, header.dim);
860-
const float contextScale = *reinterpret_cast<const float*>(contextPtr + header.dim),
861-
outputScale = *reinterpret_cast<const float*>(outputPtr + header.dim),
875+
float contextBias;
876+
if constexpr (arch == ArchType::neon)
877+
{
878+
const auto* contextPtrS8 = getContextQuantEmbS8(unpackedContextId);
879+
const auto* contextRaw = reinterpret_cast<const uint8_t*>(contextPtrS8);
880+
const float score = qgemm::dotS8S8<arch>(header.dim, contextPtrS8, outputPtr);
881+
contextBias = *reinterpret_cast<const float*>(contextRaw + header.dim + sizeof(float));
882+
ll = score + contextBias;
883+
}
884+
else
885+
{
886+
int32_t acc = qgemm::dotprod<arch>(contextPtr, outputPtr, header.dim);
887+
const float contextScale = *reinterpret_cast<const float*>(contextPtr + header.dim);
888+
const float outputScale = *reinterpret_cast<const float*>(outputPtr + header.dim);
862889
contextBias = *reinterpret_cast<const float*>(contextPtr + header.dim + sizeof(float));
863-
const int32_t hsum = *reinterpret_cast<const int32_t*>(outputPtr + header.dim + sizeof(float));
864-
acc -= hsum;
865-
ll = acc * contextScale * outputScale + contextBias;
890+
const int32_t hsum = *reinterpret_cast<const int32_t*>(outputPtr + header.dim + sizeof(float));
891+
acc -= hsum;
892+
ll = acc * contextScale * outputScale + contextBias;
893+
}
866894
if (outputEmbBiasPtr) ll += outputEmbBiasPtr[next];
867895
}
868896
else
@@ -2474,11 +2502,24 @@ namespace kiwi
24742502

24752503
if constexpr (quantized)
24762504
{
2477-
qgemm::gemvU8U8<arch>(
2478-
header.contextSize, header.dim,
2479-
getContextQuantEmb(contextId),
2480-
getContextQuantEmb(0), contextEmbStride(),
2481-
scores);
2505+
if constexpr (arch == ArchType::neon)
2506+
{
2507+
qgemm::gemvS8S8<arch>(
2508+
header.contextSize, header.dim,
2509+
getContextQuantEmbS8(contextId),
2510+
getContextQuantEmbS8(0), contextEmbStride(),
2511+
scores
2512+
);
2513+
}
2514+
else
2515+
{
2516+
qgemm::gemvU8U8<arch>(
2517+
header.contextSize, header.dim,
2518+
getContextQuantEmb(contextId),
2519+
getContextQuantEmb(0), contextEmbStride(),
2520+
scores
2521+
);
2522+
}
24822523
}
24832524
else
24842525
{
@@ -2525,10 +2566,20 @@ namespace kiwi
25252566
float result = 0;
25262567
if constexpr (quantized)
25272568
{
2528-
result = qgemm::dotU8U8<arch>(
2529-
header.dim,
2530-
getContextQuantEmb(contextId1), getContextQuantEmb(contextId2)
2531-
);
2569+
if constexpr (arch == ArchType::neon)
2570+
{
2571+
result = qgemm::dotS8S8<arch>(
2572+
header.dim,
2573+
getContextQuantEmbS8(contextId1), getContextQuantEmbS8(contextId2)
2574+
);
2575+
}
2576+
else
2577+
{
2578+
result = qgemm::dotU8U8<arch>(
2579+
header.dim,
2580+
getContextQuantEmb(contextId1), getContextQuantEmb(contextId2)
2581+
);
2582+
}
25322583
}
25332584
else
25342585
{
@@ -2554,12 +2605,24 @@ namespace kiwi
25542605
float* scores = resultBuf.data() + header.vocabSize;
25552606
if constexpr (quantized)
25562607
{
2557-
qgemm::gemv<arch>(
2558-
header.vocabSize, header.dim,
2559-
getContextQuantEmb(contextId),
2560-
getOutputQuantEmb(0), outputEmbStride(),
2561-
scores
2562-
);
2608+
if constexpr (arch == ArchType::neon)
2609+
{
2610+
qgemm::gemvS8S8<arch>(
2611+
header.vocabSize, header.dim,
2612+
getContextQuantEmbS8(contextId),
2613+
getOutputQuantEmb(0), outputEmbStride(),
2614+
scores
2615+
);
2616+
}
2617+
else
2618+
{
2619+
qgemm::gemv<arch>(
2620+
header.vocabSize, header.dim,
2621+
getContextQuantEmb(contextId),
2622+
getOutputQuantEmb(0), outputEmbStride(),
2623+
scores
2624+
);
2625+
}
25632626
}
25642627
else
25652628
{
@@ -2606,18 +2669,36 @@ namespace kiwi
26062669
float* scores = resultBuf.data() + header.vocabSize;
26072670
if constexpr (quantized)
26082671
{
2609-
qgemm::gemv<arch>(
2610-
header.vocabSize, header.dim,
2611-
getContextQuantEmb(bgContextId),
2612-
getOutputQuantEmb(0), outputEmbStride(),
2613-
resultBuf.data()
2614-
);
2615-
qgemm::gemv<arch>(
2616-
header.vocabSize, header.dim,
2617-
getContextQuantEmb(contextId),
2618-
getOutputQuantEmb(0), outputEmbStride(),
2619-
scores
2620-
);
2672+
if constexpr (arch == ArchType::neon)
2673+
{
2674+
qgemm::gemvS8S8<arch>(
2675+
header.vocabSize, header.dim,
2676+
getContextQuantEmbS8(bgContextId),
2677+
getOutputQuantEmb(0), outputEmbStride(),
2678+
resultBuf.data()
2679+
);
2680+
qgemm::gemvS8S8<arch>(
2681+
header.vocabSize, header.dim,
2682+
getContextQuantEmbS8(contextId),
2683+
getOutputQuantEmb(0), outputEmbStride(),
2684+
scores
2685+
);
2686+
}
2687+
else
2688+
{
2689+
qgemm::gemv<arch>(
2690+
header.vocabSize, header.dim,
2691+
getContextQuantEmb(bgContextId),
2692+
getOutputQuantEmb(0), outputEmbStride(),
2693+
resultBuf.data()
2694+
);
2695+
qgemm::gemv<arch>(
2696+
header.vocabSize, header.dim,
2697+
getContextQuantEmb(contextId),
2698+
getOutputQuantEmb(0), outputEmbStride(),
2699+
scores
2700+
);
2701+
}
26212702
}
26222703
else
26232704
{

src/CoNgramModel.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ namespace kiwi
5151
const uint8_t* alignedKeyValueData = nullptr;
5252
std::unique_ptr<int32_t[]> allRootValueData;
5353
std::unique_ptr<uint8_t[]> allEmbs;
54-
const uint8_t* contextEmbPtr = nullptr; // [numContexts, (dim + scale? + bias + confid + vts)]
54+
const uint8_t* contextEmbPtr = nullptr; // [numContexts, (dim + scale? + bias + confid + vts)] (quantized NEON: dim stores S8 values)
5555
const uint8_t* outputEmbPtr = nullptr; // [numOutputs, (dim + scale? + sum?)]
5656
const uint8_t* distantEmbPtr = nullptr; // [numOutputs, (dim + scale? + bias + confid + pad?)]
5757
const float* positionConfidPtr = nullptr;
@@ -109,11 +109,16 @@ namespace kiwi
109109
return reinterpret_cast<const float*>(contextEmbPtr + idx * contextEmbStride());
110110
}
111111

112-
inline const uint8_t* getContextQuantEmb(uint32_t idx) const
112+
inline const uint8_t* getContextQuantEmb(size_t idx) const
113113
{
114114
return contextEmbPtr + idx * contextEmbStride();
115115
}
116116

117+
inline const int8_t* getContextQuantEmbS8(size_t idx) const
118+
{
119+
return reinterpret_cast<const int8_t*>(contextEmbPtr + idx * contextEmbStride());
120+
}
121+
117122
inline float getContextBias(uint32_t idx) const
118123
{
119124
const size_t offset = quantized ?

src/SIMD.hpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -896,11 +896,18 @@ namespace kiwi
896896
static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size)
897897
{
898898
int32x4_t sum = vdupq_n_s32(0);
899-
uint16x8_t pa;
900-
int8x16_t pb;
901899
for (size_t i = 0; i < size; i += 16)
902900
{
903-
//
901+
uint8x16_t pa = vld1q_u8(a + i);
902+
int8x16_t pb = vld1q_s8(b + i);
903+
// Extend a (uint8, 0-255) to int16 via zero-extend, b (int8) via sign-extend
904+
// Product range: 0*(-128) to 255*127 = [-32640, 32385], fits in int16
905+
int16x8_t pa_lo = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(pa)));
906+
int16x8_t pa_hi = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(pa)));
907+
int16x8_t pb_lo = vmovl_s8(vget_low_s8(pb));
908+
int16x8_t pb_hi = vmovl_s8(vget_high_s8(pb));
909+
sum = vpadalq_s16(sum, vmulq_s16(pa_lo, pb_lo));
910+
sum = vpadalq_s16(sum, vmulq_s16(pa_hi, pb_hi));
904911
}
905912
sum = vpaddq_s32(sum, sum);
906913
sum = vpaddq_s32(sum, sum);

0 commit comments

Comments
 (0)