1- #include < iostream>
1+ #include < iostream>
22#include < fstream>
3+ #include < cstring>
4+ #include < limits>
35#include " PathEvaluator.hpp"
46#include " Joiner.hpp"
57#include " Kiwi.hpp"
@@ -626,7 +628,8 @@ namespace kiwi
626628 if constexpr (quantized)
627629 {
628630 float scale;
629- eptr += requantizePackedInts<arch>(optr, scale, eptr, header.dim , header.qbit , header.qgroup , true );
631+ const bool toUint8 = arch != ArchType::neon;
632+ eptr += requantizePackedInts<arch>(optr, scale, eptr, header.dim , header.qbit , header.qgroup , toUint8);
630633 optr += header.dim ;
631634 *reinterpret_cast <float *>(optr) = scale;
632635 optr += sizeof (float );
@@ -678,11 +681,22 @@ namespace kiwi
678681
679682 if constexpr (quantized)
680683 {
681- qgemm::invNormU8<arch>(
682- header.contextSize , header.dim ,
683- getContextQuantEmb (0 ), contextEmbStride (),
684- const_cast <float *>(invNormContextPtr)
685- );
684+ if constexpr (arch == ArchType::neon)
685+ {
686+ qgemm::invNormS8<arch>(
687+ header.contextSize , header.dim ,
688+ getContextQuantEmbS8 (0 ), contextEmbStride (),
689+ const_cast <float *>(invNormContextPtr)
690+ );
691+ }
692+ else
693+ {
694+ qgemm::invNormU8<arch>(
695+ header.contextSize , header.dim ,
696+ getContextQuantEmb (0 ), contextEmbStride (),
697+ const_cast <float *>(invNormContextPtr)
698+ );
699+ }
686700 qgemm::invNormS8<arch>(
687701 header.vocabSize , header.dim ,
688702 getOutputQuantEmb (0 ), outputEmbStride (),
@@ -711,7 +725,8 @@ namespace kiwi
711725 if (quantized)
712726 {
713727 float scale;
714- eptr += requantizePackedInts<arch>(optr, scale, eptr, header.dim , header.qbit , header.qgroup , true );
728+ const bool toUint8 = arch != ArchType::neon;
729+ eptr += requantizePackedInts<arch>(optr, scale, eptr, header.dim , header.qbit , header.qgroup , toUint8);
715730 optr += header.dim ;
716731 *reinterpret_cast <float *>(optr) = scale;
717732 optr += sizeof (float );
@@ -771,6 +786,7 @@ namespace kiwi
771786 eptr += sizeof (uint16_t );
772787 }
773788 }
789+
774790 }
775791
776792 template <ArchType arch, class KeyType , class VlKeyType , size_t windowSize, bool quantized>
@@ -856,13 +872,25 @@ namespace kiwi
856872 {
857873 const auto * contextPtr = getContextQuantEmb (unpackedContextId);
858874 const auto * outputPtr = getOutputQuantEmb (next);
859- int32_t acc = qgemm::dotprod<arch>(contextPtr, outputPtr, header.dim );
860- const float contextScale = *reinterpret_cast <const float *>(contextPtr + header.dim ),
861- outputScale = *reinterpret_cast <const float *>(outputPtr + header.dim ),
875+ float contextBias;
876+ if constexpr (arch == ArchType::neon)
877+ {
878+ const auto * contextPtrS8 = getContextQuantEmbS8 (unpackedContextId);
879+ const auto * contextRaw = reinterpret_cast <const uint8_t *>(contextPtrS8);
880+ const float score = qgemm::dotS8S8<arch>(header.dim , contextPtrS8, outputPtr);
881+ contextBias = *reinterpret_cast <const float *>(contextRaw + header.dim + sizeof (float ));
882+ ll = score + contextBias;
883+ }
884+ else
885+ {
886+ int32_t acc = qgemm::dotprod<arch>(contextPtr, outputPtr, header.dim );
887+ const float contextScale = *reinterpret_cast <const float *>(contextPtr + header.dim );
888+ const float outputScale = *reinterpret_cast <const float *>(outputPtr + header.dim );
862889 contextBias = *reinterpret_cast <const float *>(contextPtr + header.dim + sizeof (float ));
863- const int32_t hsum = *reinterpret_cast <const int32_t *>(outputPtr + header.dim + sizeof (float ));
864- acc -= hsum;
865- ll = acc * contextScale * outputScale + contextBias;
890+ const int32_t hsum = *reinterpret_cast <const int32_t *>(outputPtr + header.dim + sizeof (float ));
891+ acc -= hsum;
892+ ll = acc * contextScale * outputScale + contextBias;
893+ }
866894 if (outputEmbBiasPtr) ll += outputEmbBiasPtr[next];
867895 }
868896 else
@@ -2474,11 +2502,24 @@ namespace kiwi
24742502
24752503 if constexpr (quantized)
24762504 {
2477- qgemm::gemvU8U8<arch>(
2478- header.contextSize , header.dim ,
2479- getContextQuantEmb (contextId),
2480- getContextQuantEmb (0 ), contextEmbStride (),
2481- scores);
2505+ if constexpr (arch == ArchType::neon)
2506+ {
2507+ qgemm::gemvS8S8<arch>(
2508+ header.contextSize , header.dim ,
2509+ getContextQuantEmbS8 (contextId),
2510+ getContextQuantEmbS8 (0 ), contextEmbStride (),
2511+ scores
2512+ );
2513+ }
2514+ else
2515+ {
2516+ qgemm::gemvU8U8<arch>(
2517+ header.contextSize , header.dim ,
2518+ getContextQuantEmb (contextId),
2519+ getContextQuantEmb (0 ), contextEmbStride (),
2520+ scores
2521+ );
2522+ }
24822523 }
24832524 else
24842525 {
@@ -2525,10 +2566,20 @@ namespace kiwi
25252566 float result = 0 ;
25262567 if constexpr (quantized)
25272568 {
2528- result = qgemm::dotU8U8<arch>(
2529- header.dim ,
2530- getContextQuantEmb (contextId1), getContextQuantEmb (contextId2)
2531- );
2569+ if constexpr (arch == ArchType::neon)
2570+ {
2571+ result = qgemm::dotS8S8<arch>(
2572+ header.dim ,
2573+ getContextQuantEmbS8 (contextId1), getContextQuantEmbS8 (contextId2)
2574+ );
2575+ }
2576+ else
2577+ {
2578+ result = qgemm::dotU8U8<arch>(
2579+ header.dim ,
2580+ getContextQuantEmb (contextId1), getContextQuantEmb (contextId2)
2581+ );
2582+ }
25322583 }
25332584 else
25342585 {
@@ -2554,12 +2605,24 @@ namespace kiwi
25542605 float * scores = resultBuf.data () + header.vocabSize ;
25552606 if constexpr (quantized)
25562607 {
2557- qgemm::gemv<arch>(
2558- header.vocabSize , header.dim ,
2559- getContextQuantEmb (contextId),
2560- getOutputQuantEmb (0 ), outputEmbStride (),
2561- scores
2562- );
2608+ if constexpr (arch == ArchType::neon)
2609+ {
2610+ qgemm::gemvS8S8<arch>(
2611+ header.vocabSize , header.dim ,
2612+ getContextQuantEmbS8 (contextId),
2613+ getOutputQuantEmb (0 ), outputEmbStride (),
2614+ scores
2615+ );
2616+ }
2617+ else
2618+ {
2619+ qgemm::gemv<arch>(
2620+ header.vocabSize , header.dim ,
2621+ getContextQuantEmb (contextId),
2622+ getOutputQuantEmb (0 ), outputEmbStride (),
2623+ scores
2624+ );
2625+ }
25632626 }
25642627 else
25652628 {
@@ -2606,18 +2669,36 @@ namespace kiwi
26062669 float * scores = resultBuf.data () + header.vocabSize ;
26072670 if constexpr (quantized)
26082671 {
2609- qgemm::gemv<arch>(
2610- header.vocabSize , header.dim ,
2611- getContextQuantEmb (bgContextId),
2612- getOutputQuantEmb (0 ), outputEmbStride (),
2613- resultBuf.data ()
2614- );
2615- qgemm::gemv<arch>(
2616- header.vocabSize , header.dim ,
2617- getContextQuantEmb (contextId),
2618- getOutputQuantEmb (0 ), outputEmbStride (),
2619- scores
2620- );
2672+ if constexpr (arch == ArchType::neon)
2673+ {
2674+ qgemm::gemvS8S8<arch>(
2675+ header.vocabSize , header.dim ,
2676+ getContextQuantEmbS8 (bgContextId),
2677+ getOutputQuantEmb (0 ), outputEmbStride (),
2678+ resultBuf.data ()
2679+ );
2680+ qgemm::gemvS8S8<arch>(
2681+ header.vocabSize , header.dim ,
2682+ getContextQuantEmbS8 (contextId),
2683+ getOutputQuantEmb (0 ), outputEmbStride (),
2684+ scores
2685+ );
2686+ }
2687+ else
2688+ {
2689+ qgemm::gemv<arch>(
2690+ header.vocabSize , header.dim ,
2691+ getContextQuantEmb (bgContextId),
2692+ getOutputQuantEmb (0 ), outputEmbStride (),
2693+ resultBuf.data ()
2694+ );
2695+ qgemm::gemv<arch>(
2696+ header.vocabSize , header.dim ,
2697+ getContextQuantEmb (contextId),
2698+ getOutputQuantEmb (0 ), outputEmbStride (),
2699+ scores
2700+ );
2701+ }
26212702 }
26222703 else
26232704 {
0 commit comments