diff --git a/.bazelrc b/.bazelrc index fac41f007..e91f3544e 100644 --- a/.bazelrc +++ b/.bazelrc @@ -126,7 +126,7 @@ build:arm --host_action_env TF_NEED_CUDA="0" build:arm --crosstool_top=@bazel_tools//tools/cpp:toolchain build:arm --host_crosstool_top=@bazel_tools//tools/cpp:toolchain build:arm --copt="-DUSING_CUDA=0" -build:arm --copt="-D_GLIBCXX_USE_CXX11_ABI=0" +build:arm --copt="-D_GLIBCXX_USE_CXX11_ABI=1" build:arm --define=xft_use_icx=false build:arm --copt=-Wno-tautological-compare build:arm --copt=-Wno-array-bounds # aios diff --git a/open_source/deps/BUILD b/open_source/deps/BUILD index 1c761d3e4..faf07449c 100644 --- a/open_source/deps/BUILD +++ b/open_source/deps/BUILD @@ -53,3 +53,12 @@ compile_pip_requirements( requirements_txt = "requirements_lock_rocm.txt", tags = ["manual"], ) + +compile_pip_requirements( + name = "requirements_cpu_arm", + src = "requirements_cpu_arm.txt", + extra_args = PIP_EXTRA_ARGS, + extra_data = ["//open_source/deps:requirements_base.txt"], + requirements_txt = "requirements_lock_torch_arm.txt", + tags = ["manual"], +) \ No newline at end of file diff --git a/open_source/deps/http.bzl b/open_source/deps/http.bzl index b71bca414..1ddcd0308 100644 --- a/open_source/deps/http.bzl +++ b/open_source/deps/http.bzl @@ -82,9 +82,9 @@ def http_deps(): http_archive( name = "torch_2.3_py310_cpu_aarch64", - sha256 = "bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076", + sha256 = "90832f4d118c566b8652a2196ac695fc1f14cf420db27b5a1b41c7eaaf2141e9", urls = [ - "https://download.pytorch.org/whl/cpu/torch-2.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl#sha256=bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076" + "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl#sha256=90832f4d118c566b8652a2196ac695fc1f14cf420db27b5a1b41c7eaaf2141e9" ], type = "zip", build_file = clean_dep("//:BUILD.pytorch"), diff --git a/open_source/deps/requirements_cpu_arm.txt b/open_source/deps/requirements_cpu_arm.txt index d5cb56188..7606a3303 100644 --- a/open_source/deps/requirements_cpu_arm.txt +++ b/open_source/deps/requirements_cpu_arm.txt @@ -1,2 +1,3 @@ -https://download.pytorch.org/whl/cpu/torch-2.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl#sha256=bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076 ; platform_machine == "aarch64" +https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl#sha256=90832f4d118c566b8652a2196ac695fc1f14cf420db27b5a1b41c7eaaf2141e9 ; platform_machine == "aarch64" +https://download.pytorch.org/whl/cpu/torchvision-0.21.0-cp310-cp310-linux_aarch64.whl#sha256=54815e0a56dde95cc6ec952577f67e0dc151eadd928e8d9f6a7f821d69a4a734 ; platform_machine == "aarch64" -r ../../open_source/deps/requirements_base.txt diff --git a/open_source/deps/requirements_lock_torch_arm.txt b/open_source/deps/requirements_lock_torch_arm.txt index cc5433de2..0e033e100 100644 --- a/open_source/deps/requirements_lock_torch_arm.txt +++ b/open_source/deps/requirements_lock_torch_arm.txt @@ -2,17 +2,14 @@ # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# bazel run //internal_source/deps:requirements_cpu_arm.update +# bazel run //open_source/deps:requirements_cpu_arm.update # ---index-url http://artlab.alibaba-inc.com/1/pypi/rtp_diffusion ---extra-index-url https://artlab.alibaba-inc.com/1/pypi/huiwa_rtp_internal ---extra-index-url https://artlab.alibaba-inc.com/1/PYPI/simple/ ---trusted-host artlab.alibaba-inc.com +--extra-index-url https://mirrors.aliyun.com/pypi/simple/ accelerate==0.25.0 \ --hash=sha256:c7bb817eb974bba0ff3ea1ba0f24d55afb86d50e3d4fe98d6922dc69cf2ccff1 \ --hash=sha256:ecf55b0ab278a1dac8539dde0d276977aff04683f07ede73eaf02478538576a1 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt aiohappyeyeballs==2.4.0 \ --hash=sha256:55a1714f084e63d49639800f95716da97a1f173d46a16dfcfda0016abb93b6b2 \ --hash=sha256:7ce92076e249169a13c2f49320d1967425eaf1f407522d707d59cac7628d62bd @@ -110,7 +107,7 @@ aiohttp==3.10.5 \ --hash=sha256:f8112fb501b1e0567a1251a2fd0747baae60a4ab325a871e975b7bb67e59221f \ --hash=sha256:fd31f176429cecbc1ba499d4aba31aaccfea488f418d60376b911269d3b883c5 # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # dashscope aiosignal==1.3.1 \ --hash=sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc \ @@ -192,7 +189,11 @@ av==13.1.0 \ --hash=sha256:f422360f801a6f878d73aee4d404110ee6bb8f04846bf8815edb218da83bec49 \ --hash=sha256:fa398f0e0579bdeca4f0c31eb46e88c29562988e135e44972f73bb7525d1454e \ --hash=sha256:fc5118f78ee712b2c396f345e4c51e60e61e28f1f606adbd4060c4dc44b0b652 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt +bitsandbytes==0.42.0 \ + --hash=sha256:63798680912cc63bb77b535a2d0860af024e290a52e157f777ad2a52e2585967 \ + --hash=sha256:fc1505f184f0d275766f2a6c663f1a43b734c1409b5c5a406f3a6073d9f329fd + # via -r open_source/deps/../../open_source/deps/requirements_base.txt certifi==2024.8.30 \ --hash=sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8 \ --hash=sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9 @@ -367,6 +368,10 @@ click==8.1.7 \ --hash=sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28 \ --hash=sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de # via uvicorn +concurrent-log-handler==0.9.25 \ + --hash=sha256:157bee12914aa2a72246d1d0641ce07c1aa7a55faa3322bed02f21e60395eb82 \ + --hash=sha256:1e2c6f021414e214d3dac66107894827a3e78db63018304a4f29e55ba549ac22 + # via -r open_source/deps/../../open_source/deps/requirements_base.txt contourpy==1.3.0 \ --hash=sha256:00ccd0dbaad6d804ab259820fa7cb0b8036bda0686ef844d24125d8287178ce0 \ --hash=sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639 \ @@ -436,7 +441,7 @@ contourpy==1.3.0 \ # via matplotlib cpm-kernels==1.0.11 \ --hash=sha256:eab7f211f3b3f6a0686ded4c15cd7d9158393cdf69a931fa5b96a5fbcd366822 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt crcmod==1.7 \ --hash=sha256:dc7051a0db5f2bd48665a990d3ec1cc305a466a77358ca4492826f41f283601e # via oss2 @@ -482,10 +487,10 @@ cycler==0.12.1 \ # via matplotlib dacite==1.8.1 \ --hash=sha256:cc31ad6fdea1f49962ea42db9421772afe01ac5442380d9a99fcf3d188c61afe - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt dashscope==1.20.10 \ --hash=sha256:174df27ea798a7cd01b7cea710e149f1e6ca43f3e1a86598440fab37bcef31b4 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt decorator==5.1.1 \ --hash=sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330 \ --hash=sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186 @@ -497,7 +502,7 @@ distro==1.9.0 \ einops==0.8.0 \ --hash=sha256:63486517fed345712a8385c100cb279108d9d47e6ae59099b07657e983deae85 \ --hash=sha256:9572fb63046264a862693b0a87088af3bdc8c068fde03de63453cbbde245465f - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt exceptiongroup==1.2.2 \ --hash=sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b \ --hash=sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc @@ -505,12 +510,12 @@ exceptiongroup==1.2.2 \ fastapi==0.115.6 \ --hash=sha256:9ec46f7addc14ea472958a96aae5b5de65f39721a46aaf5705c480d9a8b76654 \ --hash=sha256:e9240b29e36fa8f4bb7290316988e90c381e5092e0cbe84e7818cc3713bcf305 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt filelock==3.13.1 \ --hash=sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e \ --hash=sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # huggingface-hub # torch # transformers @@ -701,7 +706,7 @@ grpcio==1.62.0 \ --hash=sha256:fc2836cb829895ee190813446dce63df67e6ed7b9bf76060262c55fcd097d270 \ --hash=sha256:fcc98cff4084467839d0a20d16abc2a76005f3d1b38062464d088c07f500d170 # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # grpcio-tools grpcio-tools==1.57.0 \ --hash=sha256:02d78c034109f46032c7217260066d49d41e6bcaf588fa28fa40fe2f83445347 \ @@ -749,7 +754,7 @@ grpcio-tools==1.57.0 \ --hash=sha256:f64f8ab22d27d4a5693310748d35a696061c3b5c7b8c4fb4ab3b4bc1068b6b56 \ --hash=sha256:f717cce5093e6b6049d9ea6d12fdf3658efdb1a80772f7737db1f8510b876df6 \ --hash=sha256:fb81ff861692111fa81bd85f64584e624cb4013bd66fbce8a209b8893f5ce398 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt h11==0.14.0 \ --hash=sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d \ --hash=sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761 @@ -784,15 +789,15 @@ idna==3.10 \ importlib-metadata==8.5.0 \ --hash=sha256:45e54197d28b7a7f1559e60b95e7c567032b602131fbd588f1497f47880aa68b \ --hash=sha256:71522656f0abace1d072b9e5481a48f07c138e00f079c38c8f883823f9c26bd7 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt jieba==0.42.1 \ --hash=sha256:055ca12f62674fafed09427f176506079bc135638a14e23e25be909131928db2 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt jinja2==3.1.4 \ --hash=sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369 \ --hash=sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # torch jiter==0.5.0 \ --hash=sha256:044f2f1148b5248ad2c8c3afb43430dccf676c5a5834d2f5089a4e6c5bbd64df \ @@ -870,7 +875,7 @@ joblib==1.4.2 \ json5==0.9.25 \ --hash=sha256:34ed7d834b1341a86987ed52f3f76cd8ee184394906b6e22a1e0deb9ab294e8f \ --hash=sha256:548e41b9be043f9426776f05df8635a00fe06104ea51ed24b67f908856e151ae - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt kiwisolver==1.4.7 \ --hash=sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a \ --hash=sha256:08471d4d86cbaec61f86b217dd938a83d85e03785f51121e791a6e6689a3be95 \ @@ -994,7 +999,7 @@ lazy-loader==0.4 \ librosa==0.10.2.post1 \ --hash=sha256:cd99f16717cbcd1e0983e37308d1db46a6f7dfc2e396e5a9e61e6821e44bd2e7 \ --hash=sha256:dc882750e8b577a63039f25661b7e39ec4cfbacc99c1cffba666cd664fb0a7a0 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt llvmlite==0.43.0 \ --hash=sha256:14f0e4bf2fd2d9a75a3534111e8ebeb08eda2f33e9bdd6dfa13282afacdde0ed \ --hash=sha256:18e9953c748b105668487b7c81a3e97b046d8abf95c4ddc0cd3c94f4e4651ae8 \ @@ -1100,7 +1105,7 @@ lru-dict==1.3.0 \ --hash=sha256:f27c078b5d75989952acbf9b77e14c3dadc468a4aafe85174d548afbc5efc38b \ --hash=sha256:f5b88a7c39e307739a3701194993455968fcffe437d1facab93546b1b8a334c1 \ --hash=sha256:f8f7824db5a64581180ab9d09842e6dd9fcdc46aac9cb592a0807cd37ea55680 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt markupsafe==2.1.5 \ --hash=sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf \ --hash=sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff \ @@ -1204,7 +1209,7 @@ matplotlib==3.9.2 \ --hash=sha256:f32c7410c7f246838a77d6d1eff0c0f87f3cb0e7c4247aebea71a6d5a68cab49 \ --hash=sha256:f6ee45bc4245533111ced13f1f2cace1e7f89d1c793390392a80c139d6cf0e6c \ --hash=sha256:f7c0410f181a531ec4e93bbc27692f2c71a15c2da16766f5ba9761e7ae518413 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ --hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c @@ -1374,7 +1379,7 @@ multidict==6.1.0 \ nest-asyncio==1.6.0 \ --hash=sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe \ --hash=sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt networkx==3.3 \ --hash=sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9 \ --hash=sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2 @@ -1432,7 +1437,7 @@ numpy==1.24.1 \ --hash=sha256:ef85cf1f693c88c1fd229ccd1055570cb41cdf4875873b7728b6301f12cd05bf \ --hash=sha256:f1b739841821968798947d3afcefd386fa56da0caf97722a5de53e07c4ccedc7 # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # accelerate # contourpy # librosa @@ -1478,11 +1483,11 @@ onnx==1.16.0 \ --hash=sha256:e5752bbbd5717304a7643643dba383a2fb31e8eb0682f4e7b7d141206328a73b \ --hash=sha256:ec22a43d74eb1f2303373e2fbe7fbcaa45fb225f4eb146edfed1356ada7a9aea \ --hash=sha256:f51179d4af3372b4f3800c558d204b592c61e4b4a18b8f61e0eea7f46211221a - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt openai==1.46.0 \ --hash=sha256:0c5a783530d7cd90e2370dbd52d9239d2d53dc7a0badf9ee1e2e23d3f148969b \ --hash=sha256:8e423690b121d0268c7bb83b552e14f339b0ba250e1d0f70d145c194e79c4e1b - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt orjson==3.10.7 \ --hash=sha256:084e537806b458911137f76097e53ce7bf5806dda33ddf6aaa66a028f8d43a23 \ --hash=sha256:09b2d92fd95ad2402188cf51573acde57eb269eddabaa60f69ea0d733e789fe9 \ @@ -1541,10 +1546,10 @@ orjson==3.10.7 \ --hash=sha256:eef44224729e9525d5261cc8d28d6b11cafc90e6bd0be2157bde69a52ec83024 \ --hash=sha256:f4db56635b58cd1a200b0a23744ff44206ee6aa428185e2b6c4a65b3197abdcd \ --hash=sha256:fdf5197a21dd660cf19dfd2a3ce79574588f8f5e2dbf21bda9ee2d2b46924d84 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt oss2==2.19.0 \ --hash=sha256:9ca54a7921f32f32651a36f2a527bf45e03bb02f3a744877e30f1e842b0f2a0b - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt packaging==24.1 \ --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124 @@ -1637,7 +1642,7 @@ pillow==10.4.0 \ --hash=sha256:ff25afb18123cea58a591ea0244b92eb1e61a1fd497bf6d6384f09bc3262ec3e \ --hash=sha256:ff337c552345e95702c5fde3158acb0625111017d0e5f24bf3acdb9cc16b90d1 # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # matplotlib # pillow-heif # sentence-transformers @@ -1718,7 +1723,7 @@ pillow-avif-plugin==1.4.6 \ --hash=sha256:fdd6ee615d948a2b68fd293f74a1a73d22e9d075f5d714b95a90ec2cb8da8de0 \ --hash=sha256:fe06bdb3ec104f5e1b8c03a0bfb22e3d23b4e94591ae73caec1940ce54eccc67 \ --hash=sha256:fe32db84ba0c9d9b364e2b36a55620d6112133107f82854f19a4fdaa93fce66b - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt pillow-heif==0.20.0 \ --hash=sha256:039f0c82ab3c0b364947979583d53ec9aad42d22159b9497e3c20ddde92c99bd \ --hash=sha256:0919f7738b886ed88367b9d0247132b1cbe5d40411bac5d7536d1876980af23e \ @@ -1766,7 +1771,7 @@ pillow-heif==0.20.0 \ --hash=sha256:ef2ad418f42adc9ef5d5e709547e799fb32141543856cb14f04fa4b22f83bfd7 \ --hash=sha256:f446a78a9d84ef75761638a7e72a477aadeffb282ac70ffe67360a98d54775b1 \ --hash=sha256:f9430a33f69965d067be7e5c15dc70f1e43d5e3c8b5e9dc16c8c8d52179ce1cc - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt platformdirs==4.3.6 \ --hash=sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907 \ --hash=sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb @@ -1775,10 +1780,16 @@ pooch==1.8.2 \ --hash=sha256:3529a57096f7198778a5ceefd5ac3ef0e4d06a6ddaf9fc2d609b806f25302c47 \ --hash=sha256:76561f0de68a01da4df6af38e9955c4c9d1a5c90da73f7e40276a5728ec83d10 # via librosa +portalocker==3.1.1 \ + --hash=sha256:80e984e24de292ff258a5bea0e4f3f778fff84c0ae1275dbaebc4658de4aacb3 \ + --hash=sha256:ec20f6dda2ad9ce89fa399a5f31f4f1495f515958f0cb7ca6543cef7bb5a749e + # via + # -r open_source/deps/../../open_source/deps/requirements_base.txt + # concurrent-log-handler prettytable==3.11.0 \ --hash=sha256:7e23ca1e68bbfd06ba8de98bf553bf3493264c96d5e8a615c0471025deeba722 \ --hash=sha256:aa17083feb6c71da11a68b2c213b04675c4af4ce9c541762632ca3f2cb3546dd - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt protobuf==4.25.0 \ --hash=sha256:1a3ba712877e6d37013cdc3476040ea1e313a6c2e1580836a94f76b3c176d575 \ --hash=sha256:1a53d6f64b00eecf53b65ff4a8c23dc95df1fa1e97bb06b8122e5a64f49fc90a \ @@ -1792,7 +1803,7 @@ protobuf==4.25.0 \ --hash=sha256:cf21faba64cd2c9a3ed92b7a67f226296b10159dbb8fbc5e854fc90657d908e4 \ --hash=sha256:d94a33db8b7ddbd0af7c467475fb9fde0c705fb315a8433c0e2020942b863a1f # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # grpcio-tools # onnx psutil==6.0.0 \ @@ -1814,7 +1825,7 @@ psutil==6.0.0 \ --hash=sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14 \ --hash=sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0 # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # accelerate py-spy==0.3.14 \ --hash=sha256:3e8e48032e71c94c3dd51694c39e762e4bbfec250df5bf514adcdd64e79371e0 \ @@ -1824,7 +1835,7 @@ py-spy==0.3.14 \ --hash=sha256:f59b0b52e56ba9566305236375e6fc68888261d0d36b5addbe3cf85affbefc0e \ --hash=sha256:fd6211fe7f587b3532ba9d300784326d9a6f2b890af7bf6fff21a029ebbc812b \ --hash=sha256:fe7efe6c91f723442259d428bf1f9ddb9c1679828866b353d539345ca40d9dd2 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt pyarrow==17.0.0 \ --hash=sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a \ --hash=sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca \ @@ -1905,7 +1916,7 @@ pydantic==2.7.0 \ --hash=sha256:9dee74a271705f14f9a1567671d144a851c675b072736f0a7b2608fd9e495352 \ --hash=sha256:b5ecdd42262ca2462e2624793551e80911a1e989f462910bb81aef974b4bb383 # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # fastapi # openai pydantic-core==2.18.1 \ @@ -1992,7 +2003,7 @@ pydantic-core==2.18.1 \ pynvml==11.5.3 \ --hash=sha256:183d223ae487e5f00402d8da06c68c978ef8a9295793ee75559839c6ade7b229 \ --hash=sha256:a5fba3ab14febda50d19dbda012ef62ae0aed45b7ccc07af0bc5be79223e450c - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt pyodps==0.11.6.5 \ --hash=sha256:025714a3c8066df2a4b611dd2dc8278d390c05782886c1b2996ba92ef1f10e8c \ --hash=sha256:05786981c72625b59169d2c48eaa81f13c582f19d3313c6eef8fec23308b52a5 \ @@ -2039,11 +2050,11 @@ pyodps==0.11.6.5 \ --hash=sha256:db382098de7a148c3be87a8bdecf243603225edf0dcd413e84864d938dc85c0d \ --hash=sha256:e867818ce3951d8c662af8d05f07231f24d699a893b4c00f583ba08eda4e85bb \ --hash=sha256:eb1a5088a06d3953e39681b9b880cc02ff95113d53a14a86495bbf33c602b191 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt pyopenssl==24.1.0 \ --hash=sha256:17ed5be5936449c5418d1cd269a1a9e9081bc54c17aed272b45856a3d3dc86ad \ --hash=sha256:cabed4bfaa5df9f1a16c0ef64a0cb65318b5cd077a7eda7d6970131ca2f41a6f - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt pyparsing==3.1.4 \ --hash=sha256:a6a7ee4235a3f944aa1fa2249307708f893fe5717dc603503c6c7969c070fb7c \ --hash=sha256:f86ec8d1a83f11977c9a6ea7598e8c27fc5cddfa5b07ea2241edbbde1d7bc032 @@ -2219,7 +2230,6 @@ requests==2.32.3 \ # pooch # pyodps # tiktoken - # torchvision # transformers safetensors==0.4.5 \ --hash=sha256:01c8f00da537af711979e1b42a69a8ec9e1d7112f208e0e9b8a35d2c381085ef \ @@ -2333,7 +2343,7 @@ safetensors==0.4.5 \ --hash=sha256:fd33da8e9407559f8779c82a0448e2133737f922d71f884da27184549416bfed \ --hash=sha256:fdadf66b5a22ceb645d5435a0be7a0292ce59648ca1d46b352f13cff3ea80410 # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # accelerate # timm # transformers @@ -2397,13 +2407,14 @@ scipy==1.14.1 \ --hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \ --hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2 # via + # bitsandbytes # librosa # scikit-learn # sentence-transformers sentence-transformers==2.7.0 \ --hash=sha256:2f7df99d1c021dded471ed2d079e9d1e4fc8e30ecb06f957be060511b36f24ea \ --hash=sha256:6a7276b05a95931581bbfa4ba49d780b2cf6904fa4a171ec7fd66c343f761c98 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt sentencepiece==0.2.0 \ --hash=sha256:0461324897735512a32d222e3d886e24ad6a499761952b6bda2a9ee6e4313ea5 \ --hash=sha256:0993dbc665f4113017892f1b87c3904a44d0640eda510abcacdfb07f74286d36 \ @@ -2458,7 +2469,94 @@ sentencepiece==0.2.0 \ --hash=sha256:f4d158189eb2ecffea3a51edf6d25e110b3678ec47f1a40f2d541eafbd8f6250 \ --hash=sha256:fb89f811e5efd18bab141afc3fea3de141c3f69f3fe9e898f710ae7fe3aab251 \ --hash=sha256:ff88712338b01031910e8e61e7239aff3ce8869ee31a47df63cb38aadd591bea - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt +setproctitle==1.3.5 \ + --hash=sha256:02870e0cb0de7f68a7a8a5b23c2bc0ce63821cab3d9b126f9be80bb6cd674c80 \ + --hash=sha256:0424e1d33232322541cb36fb279ea5242203cd6f20de7b4fb2a11973d8e8c2ce \ + --hash=sha256:0b91e68e6685998e6353f296100ecabc313a6cb3e413d66a03d74b988b61f5ff \ + --hash=sha256:1534d6cd3854d035e40bf4c091984cbdd4d555d7579676d406c53c8f187c006f \ + --hash=sha256:162fd76781f57f42ddf27c475e5fef6a8df4fdd69b28dd554e53e2eb2bfe0f95 \ + --hash=sha256:1b1d2628ac9868f960d7e87b3a9b2bb337104c3644b699e52e01efd7e106e4fe \ + --hash=sha256:1b58d49c32a46c48dcc2812635a89e6bee31139b03818da49a0bbaeaf01edef9 \ + --hash=sha256:1c8dcc250872385f2780a5ea58050b58cbc8b6a7e8444952a5a65c359886c593 \ + --hash=sha256:1e6eaeaf8a734d428a95d8c104643b39af7d247d604f40a7bebcf3960a853c5e \ + --hash=sha256:20b84de1780bbb0adc67560a113a0ea57e6ecfce2325680de8efe6c2a2f781ac \ + --hash=sha256:269d41cd4f085b69821d1ee6599124f02dbbc79962b256e260b6c9021d037994 \ + --hash=sha256:310c7f4ca4c8476a9840b2cd4b22ee602a49a3c902fdcd2dd8284685abd10a9a \ + --hash=sha256:31dc9b330e7cac7685bdef790747c07914081c11ee1066eb0c597303dfb52010 \ + --hash=sha256:322067ef1ffe70d297b00bee8a3862fed96021aa4318e3bce2d7c3bfa7a8d1e7 \ + --hash=sha256:36178b944019ec7fc52bb967ffeee296a11d373734a7be276755bedb3db5c141 \ + --hash=sha256:36b130cf8fe76dc05ad1d48cc9ff3699eb1f0d8edbf6f46a3ce46a7041e49d7b \ + --hash=sha256:3bb6ea3d6e690677619508050bc681d86223723bdf67e4e8a8dffc3d04ca3044 \ + --hash=sha256:4028639b511f5e641d116b3b54ad70c637ebd1b4baac0948283daf11b104119f \ + --hash=sha256:4272295721cf1fd2acf960b674d6dc09bec87f2a1e48995817b4ec4a3d483faf \ + --hash=sha256:4629de80c47155a26e8d87a0a92d9428aa8d79ccfe2c20fd18888580619704e1 \ + --hash=sha256:4969d996bdfbe23bbd023cd0bae6c73a27371615c4ec5296a60cecce268659ef \ + --hash=sha256:50cfbf86b9c63a2c2903f1231f0a58edeb775e651ae1af84eec8430b0571f29b \ + --hash=sha256:523424b9be4dea97d95b8a584b183f35c7bab2d0a3d995b01febf5b8a8de90e4 \ + --hash=sha256:53ce572cdbd43a0bed2aa24299cd823ebf233a7fa720cc7f8634728c213679c0 \ + --hash=sha256:53fc971f7bf7a674f571a23cdec70f2f0ac88152c59c06aa0808d0be6d834046 \ + --hash=sha256:55b278135be742b8901067479626d909f6613bd2d2c4fd0de6bb46f80e07a919 \ + --hash=sha256:5cefc2dbdc48121022c3c05644cd3706f08e0b3c0ce07814d3c04daba0617936 \ + --hash=sha256:62a01c76708daac78b9688ffb95268c57cb57fa90b543043cda01358912fe2db \ + --hash=sha256:6bddef4e27d0ed74e44b58bf050bc3108591bf17d20d461fc59cd141282f849c \ + --hash=sha256:6d8a411e752e794d052434139ca4234ffeceeb8d8d8ddc390a9051d7942b2726 \ + --hash=sha256:707c23d4a88f5e66f1005d93558bf84eb45fc0fb0c4f33480a0c7d0895e8e848 \ + --hash=sha256:755671c39a9e70834eeec6dc6b61e344399c49881d2e7ea3534a1c69669dd9cc \ + --hash=sha256:78288ff5f9c415c56595b2257ad218936dd9fa726b36341b373b31ca958590fe \ + --hash=sha256:7a887582bfdb6dcbc482db0ef9e630ad23ca95875806ef2b444bf6fbd7b7d7ca \ + --hash=sha256:7edd4fbb9fd17ed0e5a7f8bde9fa61c3987a34372084c45bab4eab6a2e554762 \ + --hash=sha256:81f2328ac34c9584e1e5f87eea916c0bc48476a06606a07debae07acdd7ab5ea \ + --hash=sha256:828727d220e46f048b82289018300a64547b46aaed96bf8810c05fe105426b41 \ + --hash=sha256:83b016221cf80028b2947be20630faa14e3e72a403e35f0ba29550b4e856767b \ + --hash=sha256:867af4a5c3d85484fbcc50ea88bcd375acf709cff88a3259575361849c0da351 \ + --hash=sha256:8915d69260ba6a6aaf9a48f6b53dbf9f8e4dc0cb4ae25bc5edb16a1666b6e47c \ + --hash=sha256:8995a1217b52d11d92bafd069961a47c5e13d8751ca976a32b3ecbbd471eaf9b \ + --hash=sha256:8a7fed67ab49f60bd51f3b4cffff3f8d754d1bb0a40e42869911301ec6519b65 \ + --hash=sha256:8ca56e39d10b6758046694a84950e5c5570a034c409ef3337595f64fc2cfa94d \ + --hash=sha256:8ec0a7fe9f1ba90900144489bc93ce7dd4dec3f3df1e7f188c9e58364fe4a4c5 \ + --hash=sha256:95913af603da5b4c7635bf1fb67ecc5df7c18360b6cfb6740fd743bb150a6e17 \ + --hash=sha256:995b3ac1b5fe510f4e1d1c19ebf19f4bceb448f2d6e8d99ea23f33cb6f1a277e \ + --hash=sha256:9996be1d1df399c3cdc6d72ce0064e46bc74fc6e29fe16a328511a303dd4d418 \ + --hash=sha256:9ab52b4c2ce056a1b60d439991a81ca90f019488d4b4f64b2779e6badd3677e6 \ + --hash=sha256:a58f00f35d6038ce1e8a9e5f87cb5ecce13ce118c5977a603566ad1fccc8d2cb \ + --hash=sha256:a5a05e2c3fdfbda32b9c9da72d0506398d1efb5bd2c5981b9e12d3622eb3d4f9 \ + --hash=sha256:a863296a31fb578726c570314cb78ff3a3fddb65963dc01ea33731760f20a92c \ + --hash=sha256:aaee7acba2733a14a886488b7495bfec4a8d6407124c04a0946dbde1684230a3 \ + --hash=sha256:ab3ae11e10d13d514d4a5a15b4f619341142ba3e18da48c40e8614c5a1b5e3c3 \ + --hash=sha256:ae2ce64ea87837c4e3e65a7a232ff80cf09aa7d916e74cb34a245c47fcd87981 \ + --hash=sha256:b63bda3cb4b6526720dc7c6940b891c593f41771d119aeb8763875801ce2296d \ + --hash=sha256:b6ec1d86c1b4d7b5f2bdceadf213310cf24696b82480a2a702194b8a0bfbcb47 \ + --hash=sha256:bc1fda208ae3a2285ad27aeab44c41daf2328abe58fa3270157a739866779199 \ + --hash=sha256:bd2cccd972e4282af4ce2c13cd9ebdf07be157eabafd8ce648fffdc8ae6fbe28 \ + --hash=sha256:bd70c95a94473216e7c7a7a1f7d8ecbaca5b16d4ba93ddbfd32050fc485a8451 \ + --hash=sha256:becc9f3f605936506d2bd63d9cf817b7ee66b10d204184c4a633064dbed579d6 \ + --hash=sha256:c4b299b5bbadf00034978b8d741c85af25173146747eb9dab22596ec805a52d6 \ + --hash=sha256:c64199a73d442a06d372b5286942229a43e86fa41bf36f317dcc60c036aff0bb \ + --hash=sha256:ca82fae9eb4800231dd20229f06e8919787135a5581da245b8b05e864f34cc8b \ + --hash=sha256:cef63879c79a570aabf7c158f453bf8d1285f0fda4b6b9b7a52d64b49c084d40 \ + --hash=sha256:cf4e3ded98027de2596c6cc5bbd3302adfb3ca315c848f56516bb0b7e88de1e9 \ + --hash=sha256:d0b19fd76d46b8096a463724739c3b09cf5ce38317f559f56f424f6ce7158de3 \ + --hash=sha256:d2c371550a2288901a0dcd84192691ebd3197a43c95f3e0b396ed6d1cedf5c6c \ + --hash=sha256:d57e7626329d4fb138da5ce15270b08a91326969956fb19c7a8fec2639066704 \ + --hash=sha256:d880630fd81d1b3bde121c352ca7ea2f2ff507ef40c3c011d0928ed491f912c9 \ + --hash=sha256:dc4f783e100f8b451cd92fcabd3b831edfb1f7cb02be4a79b972f138e0001885 \ + --hash=sha256:dc66b84beb0d5eb03abf0c3140c6d2cbe3d67ae9f0824a09dfa8c6ff164319a6 \ + --hash=sha256:e1d28eb98c91fbebd3e443a45c7da5d84974959851ef304c330eabd654a386f1 \ + --hash=sha256:e9c0d0cfcf715631b10d5950d04a9978f63bc46535724ef7c2eaf1dca9988642 \ + --hash=sha256:ea07f29735d839eaed985990a0ec42c8aecefe8050da89fec35533d146a7826d \ + --hash=sha256:ea6c505264275a43e9b2acd2acfc11ac33caf52bc3167c9fced4418a810f6b1c \ + --hash=sha256:eab441c89f181271ab749077dcc94045a423e51f2fb0b120a1463ef9820a08d0 \ + --hash=sha256:f1af1d310b5b6cda692da52bd862a9833086c0a3f8380fa92505dd23857dcf60 \ + --hash=sha256:f1f13a25fc46731acab518602bb1149bfd8b5fabedf8290a7c0926d61414769d \ + --hash=sha256:f3b5e2eacd572444770026c9dd3ddc7543ce427cdf452d40a408d1e95beefb30 \ + --hash=sha256:f7a8c01ffd013dda2bed6e7d5cb59fbb609e72f805abf3ee98360f38f7758d9b \ + --hash=sha256:f8305b6e6c203222c61318f338f1de08269ec66c247bf251593c215ff1fbeaf9 \ + --hash=sha256:fa912c4d08c66afda30dd5af8f2e9c59065dfc36a51edbd5419c3a7c962875aa \ + --hash=sha256:fb0500e1bc6f00b8ba696c3743ddff14c8679e3c2ca9d292c008ac51488d17cf \ + --hash=sha256:fe3bfd5e51c24349d022e062a96c316a1b8862ea9a0cf5ea2a8b2ae008b77cec \ + --hash=sha256:fec8340ab543144d04a9d805d80a0aad73fdeb54bea6ff94e70d39a676ea4ec0 + # via -r open_source/deps/../../open_source/deps/requirements_base.txt six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 @@ -2510,11 +2608,11 @@ starlette==0.41.3 \ --hash=sha256:0e4ab3d16522a255be6b28260b938eae2482f98ce5cc934cb08dce8dc3ba5835 \ --hash=sha256:44cedb2b7c77a9de33a8b74b2b90e9f50d11fcf25d8270ea525ad71a25374ff7 # via fastapi -sympy==1.13.3 \ - --hash=sha256:54612cf55a62755ee71824ce692986f23c88ffa77207b30c1368eda4a7060f73 \ - --hash=sha256:b27fd2c6530e0ab39e275fc9b683895367e51d5da91baa8d3d64db2565fec4d9 +sympy==1.13.1 \ + --hash=sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f \ + --hash=sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8 # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # torch threadpoolctl==3.5.0 \ --hash=sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107 \ @@ -2522,7 +2620,7 @@ threadpoolctl==3.5.0 \ # via scikit-learn thrift==0.20.0 \ --hash=sha256:4dd662eadf6b8aebe8a41729527bd69adf6ceaa2a8681cbef64d1273b3e8feba - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt tiktoken==0.7.0 \ --hash=sha256:03c6c40ff1db0f48a7b4d2dafeae73a5607aacb472fa11f125e7baf9dce73704 \ --hash=sha256:084cec29713bc9d4189a937f8a35dbdfa785bd1235a34c1124fe2323821ee93f \ @@ -2560,11 +2658,11 @@ tiktoken==0.7.0 \ --hash=sha256:e215292e99cb41fbc96988ef62ea63bb0ce1e15f2c147a61acc319f8b4cbe5bf \ --hash=sha256:e54be9a2cd2f6d6ffa3517b064983fb695c9a9d8aa7d574d1ef3c3f931a99225 \ --hash=sha256:fffdcb319b614cf14f04d02a52e26b1d1ae14a570f90e9b55461a72672f7b13d - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt timm==0.9.12 \ --hash=sha256:2a828afac5b710a80ec66d0f85807e171e342faf5c0703b33102d8aa206f19dc \ --hash=sha256:9121d1cf320f7f32490d893340fd33117bda0a0270eb8282dfd52ae5fd3e1af6 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt tokenizers==0.20.3 \ --hash=sha256:04627b7b502fa6a2a005e1bd446fa4247d89abcb1afaa1b81eb90e21aba9a60f \ --hash=sha256:07d7851a72717321022f3774e84aa9d595a041d643fafa2e87fbc9b18711dac0 \ @@ -2679,36 +2777,19 @@ tokenizers==0.20.3 \ --hash=sha256:fbaf3ea28fedfb2283da60e710aff25492e795a7397cad8a50f1e079b65a5a70 \ --hash=sha256:ff1ef8bd47a02b0dc191688ccb4da53600df5d4c9a05a4b68e1e3de4823e78eb # via transformers -torch @ https://download.pytorch.org/whl/cpu/torch-2.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl#sha256=bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076 ; platform_machine == "aarch64" \ - --hash=sha256:bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076 +torch @ https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp310-cp310-manylinux_2_28_aarch64.whl#sha256=90832f4d118c566b8652a2196ac695fc1f14cf420db27b5a1b41c7eaaf2141e9 \ + --hash=sha256:90832f4d118c566b8652a2196ac695fc1f14cf420db27b5a1b41c7eaaf2141e9 # via - # -r internal_source/deps/requirements_cpu_arm.txt + # -r open_source/deps/requirements_cpu_arm.txt # accelerate # sentence-transformers # timm # torchvision -torchvision==0.16.2 \ - --hash=sha256:335959c43b371c0474af34c1ef2a52efdc7603c45700d29e4475eeb02984170c \ - --hash=sha256:3f4bd5fcbc361476e2e78016636ac7d5509e59d9962521f06eb98e6803898182 \ - --hash=sha256:41dd4fa9f176d563fe9f1b9adef3b7e582cdfb60ce8c9bc51b094a025be687c9 \ - --hash=sha256:4b065143d1a720fe8a9077fd4be35d491f98819ec80b3dbbc3ec64d0b707a906 \ - --hash=sha256:56115268b37f0b75364e3654e47ad9abc66ac34c1f9e5e3dfa89a22d6a40017a \ - --hash=sha256:67b1aaf8b8cb02ce75dd445f291a27c8036a502f8c0aa76e28c37a0faac2e153 \ - --hash=sha256:7fd22d86e08eba321af70cad291020c2cdeac069b00ce88b923ca52e06174769 \ - --hash=sha256:8199acdf8ab066a28b84a5b6f4d97b58976d9e164b1acc3a9d14fccfaf74bb3a \ - --hash=sha256:82805f8445b094f9d1e770390ee6cc86855e89955e08ce34af2e2274fc0e5c45 \ - --hash=sha256:8692ab1e48807e9604046a6f4beeb67b523294cee1b00828654bb0df2cfce2b2 \ - --hash=sha256:96c7583700112a410bdc4e1e4f118c429dab49c29c9a31a2cc3579bc9b08b19d \ - --hash=sha256:9f4032ebb3277fb07ff6a9b818d50a547fb8fcd89d958cfd9e773322454bb688 \ - --hash=sha256:b024bd412df6d3a007dcebf311a894eb3c5c21e1af80d12be382bbcb097a7c3a \ - --hash=sha256:b82732dcf876a37c852772342aa6ee3480c03bb3e2a802ae109fc5f7e28d26e9 \ - --hash=sha256:bc5f274e4ecd1b86062063cdf4fd385a1d39d147a3a2685fbbde9ff08bb720b8 \ - --hash=sha256:bc86f2800cb2c0c1a09c581409cdd6bff66e62f103dc83fc63f73346264c3756 \ - --hash=sha256:bef30d03e1d1c629761f4dca51d3b7d8a0dc0acce6f4068ab2a1634e8e7b64e0 \ - --hash=sha256:e130b08cc9b3cc73a6c59d6edf032394a322f9579bfd21d14bc2e1d0999aa758 \ - --hash=sha256:e59cc7b2bd1ab5c0ce4ae382e4e37be8f1c174e8b5de2f6a23c170de9ae28495 \ - --hash=sha256:e89f10f3c8351972b6e3fda95bc3e479ea8dbfc9dfcfd2c32902dbad4ba5cfc5 - # via timm +torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.21.0-cp310-cp310-linux_aarch64.whl#sha256=54815e0a56dde95cc6ec952577f67e0dc151eadd928e8d9f6a7f821d69a4a734 \ + --hash=sha256:54815e0a56dde95cc6ec952577f67e0dc151eadd928e8d9f6a7f821d69a4a734 + # via + # -r open_source/deps/requirements_cpu_arm.txt + # timm tqdm==4.66.5 \ --hash=sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd \ --hash=sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad @@ -2721,13 +2802,13 @@ transformers==4.46.2 \ --hash=sha256:3d85410881e1c074be767877bf33c83231ec11529f274a6044ecb20c157ba14e \ --hash=sha256:c921f4406b78e6518c97b618c5acd1cf8a4f2315b6b727f4bf9e01496eef849c # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # sentence-transformers typing-extensions==4.12.2 \ --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # anyio # fastapi # huggingface-hub @@ -2745,7 +2826,7 @@ urllib3==2.2.3 \ uvicorn==0.30.0 \ --hash=sha256:78fa0b5f56abb8562024a59041caeb555c86e48d0efdd23c3fe7de7a4075bdab \ --hash=sha256:f678dec4fa3a39706bbf49b9ec5fc40049d42418716cea52b53f07828a60aa37 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt + # via -r open_source/deps/../../open_source/deps/requirements_base.txt wcwidth==0.2.13 \ --hash=sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859 \ --hash=sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5 @@ -2852,108 +2933,11 @@ zipp==3.20.2 \ --hash=sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350 \ --hash=sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29 # via importlib-metadata -setproctitle==1.3.5 \ - --hash=sha256:02870e0cb0de7f68a7a8a5b23c2bc0ce63821cab3d9b126f9be80bb6cd674c80 \ - --hash=sha256:0424e1d33232322541cb36fb279ea5242203cd6f20de7b4fb2a11973d8e8c2ce \ - --hash=sha256:0b91e68e6685998e6353f296100ecabc313a6cb3e413d66a03d74b988b61f5ff \ - --hash=sha256:1534d6cd3854d035e40bf4c091984cbdd4d555d7579676d406c53c8f187c006f \ - --hash=sha256:162fd76781f57f42ddf27c475e5fef6a8df4fdd69b28dd554e53e2eb2bfe0f95 \ - --hash=sha256:1b1d2628ac9868f960d7e87b3a9b2bb337104c3644b699e52e01efd7e106e4fe \ - --hash=sha256:1b58d49c32a46c48dcc2812635a89e6bee31139b03818da49a0bbaeaf01edef9 \ - --hash=sha256:1c8dcc250872385f2780a5ea58050b58cbc8b6a7e8444952a5a65c359886c593 \ - --hash=sha256:1e6eaeaf8a734d428a95d8c104643b39af7d247d604f40a7bebcf3960a853c5e \ - --hash=sha256:20b84de1780bbb0adc67560a113a0ea57e6ecfce2325680de8efe6c2a2f781ac \ - --hash=sha256:269d41cd4f085b69821d1ee6599124f02dbbc79962b256e260b6c9021d037994 \ - --hash=sha256:310c7f4ca4c8476a9840b2cd4b22ee602a49a3c902fdcd2dd8284685abd10a9a \ - --hash=sha256:31dc9b330e7cac7685bdef790747c07914081c11ee1066eb0c597303dfb52010 \ - --hash=sha256:322067ef1ffe70d297b00bee8a3862fed96021aa4318e3bce2d7c3bfa7a8d1e7 \ - --hash=sha256:36178b944019ec7fc52bb967ffeee296a11d373734a7be276755bedb3db5c141 \ - --hash=sha256:36b130cf8fe76dc05ad1d48cc9ff3699eb1f0d8edbf6f46a3ce46a7041e49d7b \ - --hash=sha256:3bb6ea3d6e690677619508050bc681d86223723bdf67e4e8a8dffc3d04ca3044 \ - --hash=sha256:4028639b511f5e641d116b3b54ad70c637ebd1b4baac0948283daf11b104119f \ - --hash=sha256:4272295721cf1fd2acf960b674d6dc09bec87f2a1e48995817b4ec4a3d483faf \ - --hash=sha256:4629de80c47155a26e8d87a0a92d9428aa8d79ccfe2c20fd18888580619704e1 \ - --hash=sha256:4969d996bdfbe23bbd023cd0bae6c73a27371615c4ec5296a60cecce268659ef \ - --hash=sha256:50cfbf86b9c63a2c2903f1231f0a58edeb775e651ae1af84eec8430b0571f29b \ - --hash=sha256:523424b9be4dea97d95b8a584b183f35c7bab2d0a3d995b01febf5b8a8de90e4 \ - --hash=sha256:53ce572cdbd43a0bed2aa24299cd823ebf233a7fa720cc7f8634728c213679c0 \ - --hash=sha256:53fc971f7bf7a674f571a23cdec70f2f0ac88152c59c06aa0808d0be6d834046 \ - --hash=sha256:55b278135be742b8901067479626d909f6613bd2d2c4fd0de6bb46f80e07a919 \ - --hash=sha256:5cefc2dbdc48121022c3c05644cd3706f08e0b3c0ce07814d3c04daba0617936 \ - --hash=sha256:62a01c76708daac78b9688ffb95268c57cb57fa90b543043cda01358912fe2db \ - --hash=sha256:6bddef4e27d0ed74e44b58bf050bc3108591bf17d20d461fc59cd141282f849c \ - --hash=sha256:6d8a411e752e794d052434139ca4234ffeceeb8d8d8ddc390a9051d7942b2726 \ - --hash=sha256:707c23d4a88f5e66f1005d93558bf84eb45fc0fb0c4f33480a0c7d0895e8e848 \ - --hash=sha256:755671c39a9e70834eeec6dc6b61e344399c49881d2e7ea3534a1c69669dd9cc \ - --hash=sha256:78288ff5f9c415c56595b2257ad218936dd9fa726b36341b373b31ca958590fe \ - --hash=sha256:7a887582bfdb6dcbc482db0ef9e630ad23ca95875806ef2b444bf6fbd7b7d7ca \ - --hash=sha256:7edd4fbb9fd17ed0e5a7f8bde9fa61c3987a34372084c45bab4eab6a2e554762 \ - --hash=sha256:81f2328ac34c9584e1e5f87eea916c0bc48476a06606a07debae07acdd7ab5ea \ - --hash=sha256:828727d220e46f048b82289018300a64547b46aaed96bf8810c05fe105426b41 \ - --hash=sha256:83b016221cf80028b2947be20630faa14e3e72a403e35f0ba29550b4e856767b \ - --hash=sha256:867af4a5c3d85484fbcc50ea88bcd375acf709cff88a3259575361849c0da351 \ - --hash=sha256:8915d69260ba6a6aaf9a48f6b53dbf9f8e4dc0cb4ae25bc5edb16a1666b6e47c \ - --hash=sha256:8995a1217b52d11d92bafd069961a47c5e13d8751ca976a32b3ecbbd471eaf9b \ - --hash=sha256:8a7fed67ab49f60bd51f3b4cffff3f8d754d1bb0a40e42869911301ec6519b65 \ - --hash=sha256:8ca56e39d10b6758046694a84950e5c5570a034c409ef3337595f64fc2cfa94d \ - --hash=sha256:8ec0a7fe9f1ba90900144489bc93ce7dd4dec3f3df1e7f188c9e58364fe4a4c5 \ - --hash=sha256:95913af603da5b4c7635bf1fb67ecc5df7c18360b6cfb6740fd743bb150a6e17 \ - --hash=sha256:995b3ac1b5fe510f4e1d1c19ebf19f4bceb448f2d6e8d99ea23f33cb6f1a277e \ - --hash=sha256:9996be1d1df399c3cdc6d72ce0064e46bc74fc6e29fe16a328511a303dd4d418 \ - --hash=sha256:9ab52b4c2ce056a1b60d439991a81ca90f019488d4b4f64b2779e6badd3677e6 \ - --hash=sha256:a58f00f35d6038ce1e8a9e5f87cb5ecce13ce118c5977a603566ad1fccc8d2cb \ - --hash=sha256:a5a05e2c3fdfbda32b9c9da72d0506398d1efb5bd2c5981b9e12d3622eb3d4f9 \ - --hash=sha256:a863296a31fb578726c570314cb78ff3a3fddb65963dc01ea33731760f20a92c \ - --hash=sha256:aaee7acba2733a14a886488b7495bfec4a8d6407124c04a0946dbde1684230a3 \ - --hash=sha256:ab3ae11e10d13d514d4a5a15b4f619341142ba3e18da48c40e8614c5a1b5e3c3 \ - --hash=sha256:ae2ce64ea87837c4e3e65a7a232ff80cf09aa7d916e74cb34a245c47fcd87981 \ - --hash=sha256:b63bda3cb4b6526720dc7c6940b891c593f41771d119aeb8763875801ce2296d \ - --hash=sha256:b6ec1d86c1b4d7b5f2bdceadf213310cf24696b82480a2a702194b8a0bfbcb47 \ - --hash=sha256:bc1fda208ae3a2285ad27aeab44c41daf2328abe58fa3270157a739866779199 \ - --hash=sha256:bd2cccd972e4282af4ce2c13cd9ebdf07be157eabafd8ce648fffdc8ae6fbe28 \ - --hash=sha256:bd70c95a94473216e7c7a7a1f7d8ecbaca5b16d4ba93ddbfd32050fc485a8451 \ - --hash=sha256:becc9f3f605936506d2bd63d9cf817b7ee66b10d204184c4a633064dbed579d6 \ - --hash=sha256:c4b299b5bbadf00034978b8d741c85af25173146747eb9dab22596ec805a52d6 \ - --hash=sha256:c64199a73d442a06d372b5286942229a43e86fa41bf36f317dcc60c036aff0bb \ - --hash=sha256:ca82fae9eb4800231dd20229f06e8919787135a5581da245b8b05e864f34cc8b \ - --hash=sha256:cef63879c79a570aabf7c158f453bf8d1285f0fda4b6b9b7a52d64b49c084d40 \ - --hash=sha256:cf4e3ded98027de2596c6cc5bbd3302adfb3ca315c848f56516bb0b7e88de1e9 \ - --hash=sha256:d0b19fd76d46b8096a463724739c3b09cf5ce38317f559f56f424f6ce7158de3 \ - --hash=sha256:d2c371550a2288901a0dcd84192691ebd3197a43c95f3e0b396ed6d1cedf5c6c \ - --hash=sha256:d57e7626329d4fb138da5ce15270b08a91326969956fb19c7a8fec2639066704 \ - --hash=sha256:d880630fd81d1b3bde121c352ca7ea2f2ff507ef40c3c011d0928ed491f912c9 \ - --hash=sha256:dc4f783e100f8b451cd92fcabd3b831edfb1f7cb02be4a79b972f138e0001885 \ - --hash=sha256:dc66b84beb0d5eb03abf0c3140c6d2cbe3d67ae9f0824a09dfa8c6ff164319a6 \ - --hash=sha256:e1d28eb98c91fbebd3e443a45c7da5d84974959851ef304c330eabd654a386f1 \ - --hash=sha256:e9c0d0cfcf715631b10d5950d04a9978f63bc46535724ef7c2eaf1dca9988642 \ - --hash=sha256:ea07f29735d839eaed985990a0ec42c8aecefe8050da89fec35533d146a7826d \ - --hash=sha256:ea6c505264275a43e9b2acd2acfc11ac33caf52bc3167c9fced4418a810f6b1c \ - --hash=sha256:eab441c89f181271ab749077dcc94045a423e51f2fb0b120a1463ef9820a08d0 \ - --hash=sha256:f1af1d310b5b6cda692da52bd862a9833086c0a3f8380fa92505dd23857dcf60 \ - --hash=sha256:f1f13a25fc46731acab518602bb1149bfd8b5fabedf8290a7c0926d61414769d \ - --hash=sha256:f3b5e2eacd572444770026c9dd3ddc7543ce427cdf452d40a408d1e95beefb30 \ - --hash=sha256:f7a8c01ffd013dda2bed6e7d5cb59fbb609e72f805abf3ee98360f38f7758d9b \ - --hash=sha256:f8305b6e6c203222c61318f338f1de08269ec66c247bf251593c215ff1fbeaf9 \ - --hash=sha256:fa912c4d08c66afda30dd5af8f2e9c59065dfc36a51edbd5419c3a7c962875aa \ - --hash=sha256:fb0500e1bc6f00b8ba696c3743ddff14c8679e3c2ca9d292c008ac51488d17cf \ - --hash=sha256:fe3bfd5e51c24349d022e062a96c316a1b8862ea9a0cf5ea2a8b2ae008b77cec \ - --hash=sha256:fec8340ab543144d04a9d805d80a0aad73fdeb54bea6ff94e70d39a676ea4ec0 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt -concurrent-log-handler==0.9.25 \ - --hash=sha256:157bee12914aa2a72246d1d0641ce07c1aa7a55faa3322bed02f21e60395eb82 \ - --hash=sha256:1e2c6f021414e214d3dac66107894827a3e78db63018304a4f29e55ba549ac22 - # via -r internal_source/deps/../../open_source/deps/requirements_base.txt -portalocker==3.1.1 \ - --hash=sha256:80e984e24de292ff258a5bea0e4f3f778fff84c0ae1275dbaebc4658de4aacb3 \ - --hash=sha256:ec20f6dda2ad9ce89fa399a5f31f4f1495f515958f0cb7ca6543cef7bb5a749e - # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt - # concurrent-log-handler # The following packages are considered to be unsafe in a requirements file: setuptools==60.5.0 \ --hash=sha256:2404879cda71495fc4d5cbc445ed52fdaddf352b36e40be8dcc63147cb4edabe \ --hash=sha256:68eb94073fc486091447fcb0501efd6560a0e5a1839ba249e5ff3c4c93f05f90 # via - # -r internal_source/deps/../../open_source/deps/requirements_base.txt + # -r open_source/deps/../../open_source/deps/requirements_base.txt # grpcio-tools diff --git a/rtp_llm/cpp/cache/CacheManager.cc b/rtp_llm/cpp/cache/CacheManager.cc index 1beea8438..214494618 100644 --- a/rtp_llm/cpp/cache/CacheManager.cc +++ b/rtp_llm/cpp/cache/CacheManager.cc @@ -208,16 +208,18 @@ void CacheManager::initKvCache() { } void CacheManager::initKVCacheScale() { + bool is_cpu = (this->device_->getDeviceProperties().type == DeviceType::ArmCpu); + rtp_llm::MemoryType memory_type = is_cpu ? rtp_llm::MemoryType::MEMORY_CPU : rtp_llm::MemoryType::MEMORY_GPU; if (config_.dtype == rtp_llm::DataType::TYPE_INT8) { kv_cache_.k_scale = - std::make_unique(rtp_llm::MemoryType::MEMORY_GPU, + std::make_unique(memory_type, rtp_llm::DataType::TYPE_FP32, std::vector{(size_t)config_.layer_num, (size_t)config_.block_nums, (size_t)config_.local_head_num_kv, (size_t)config_.seq_size_per_block}, (int8_t*)cache_base_ptr_ + kv_cache_.k_blocks->sizeBytes() * 2); - kv_cache_.v_scale = std::make_unique(rtp_llm::MemoryType::MEMORY_GPU, + kv_cache_.v_scale = std::make_unique(memory_type, rtp_llm::DataType::TYPE_FP32, std::vector{(size_t)config_.layer_num, (size_t)config_.block_nums, @@ -230,7 +232,7 @@ void CacheManager::initKVCacheScale() { #ifdef ENABLE_FP8 else if (config_.dtype == rtp_llm::DataType::TYPE_FP8_E4M3) { kv_cache_.k_scale = std::make_unique( - rtp_llm::MemoryType::MEMORY_GPU, + memory_type, rtp_llm::DataType::TYPE_FP32, std::vector{(size_t)config_.layer_num, (size_t)config_.block_nums, @@ -238,7 +240,7 @@ void CacheManager::initKVCacheScale() { (size_t)config_.seq_size_per_block}, (__nv_fp8_e4m3*)cache_base_ptr_ + kv_cache_.k_blocks->sizeBytes() * 2); kv_cache_.v_scale = std::make_unique( - rtp_llm::MemoryType::MEMORY_GPU, + memory_type, rtp_llm::DataType::TYPE_FP32, std::vector{(size_t)config_.layer_num, (size_t)config_.block_nums, @@ -251,14 +253,16 @@ void CacheManager::initKVCacheScale() { void CacheManager::initKvCacheMla() { RTP_LLM_LOG_INFO("init mla kv cache"); - kv_cache_.k_blocks = std::make_unique(rtp_llm::MemoryType::MEMORY_GPU, + bool is_cpu = (this->device_->getDeviceProperties().type == DeviceType::ArmCpu); + rtp_llm::MemoryType memory_type = is_cpu ? rtp_llm::MemoryType::MEMORY_CPU : rtp_llm::MemoryType::MEMORY_GPU; + kv_cache_.k_blocks = std::make_unique(memory_type, config_.dtype, std::vector{(size_t)config_.layer_num, (size_t)config_.block_nums, (size_t)config_.seq_size_per_block, (size_t)config_.kv_lora_rank}, cache_base_ptr_); - kv_cache_.v_blocks = std::make_unique(rtp_llm::MemoryType::MEMORY_GPU, + kv_cache_.v_blocks = std::make_unique(memory_type, config_.dtype, std::vector{(size_t)config_.layer_num, (size_t)config_.block_nums, @@ -269,7 +273,9 @@ void CacheManager::initKvCacheMla() { void CacheManager::initKvCacheNormal() { RTP_LLM_LOG_INFO("init normal kv cache"); - kv_cache_.k_blocks = std::make_unique(rtp_llm::MemoryType::MEMORY_GPU, + bool is_cpu = (this->device_->getDeviceProperties().type == DeviceType::ArmCpu); + rtp_llm::MemoryType memory_type = is_cpu ? rtp_llm::MemoryType::MEMORY_CPU : rtp_llm::MemoryType::MEMORY_GPU; + kv_cache_.k_blocks = std::make_unique(memory_type, config_.dtype, std::vector{(size_t)config_.layer_num, (size_t)config_.block_nums, @@ -277,7 +283,7 @@ void CacheManager::initKvCacheNormal() { (size_t)config_.seq_size_per_block, (size_t)config_.size_per_head}, cache_base_ptr_); - kv_cache_.v_blocks = std::make_unique(rtp_llm::MemoryType::MEMORY_GPU, + kv_cache_.v_blocks = std::make_unique(memory_type, config_.dtype, std::vector{(size_t)config_.layer_num, (size_t)config_.block_nums, diff --git a/rtp_llm/cpp/devices/DeviceExport.h b/rtp_llm/cpp/devices/DeviceExport.h index 3c234c02a..8a46e7afb 100644 --- a/rtp_llm/cpp/devices/DeviceExport.h +++ b/rtp_llm/cpp/devices/DeviceExport.h @@ -26,7 +26,7 @@ class DeviceExporter { virtual torch::Tensor packInt8TensorToPackedInt4(torch::Tensor weight) = 0; virtual torch::Tensor preprocessWeightsForMixedGemm(torch::Tensor weight, py::object quant_type, const std::string &arch) = 0; virtual std::vector symmetricQuantizeLastAxisOfBatchedMatrix(torch::Tensor weight, py::object quant_type, const std::string &arch) = 0; - virtual torch::Tensor preprocessWeightScale(torch::Tensor weight, torch::Tensor scale) = 0; + virtual torch::Tensor preprocessWeightScale(torch::Tensor weight, torch::Tensor scale, const std::string& key) = 0; protected: rtp_llm::DeviceInitParams device_params_; @@ -55,8 +55,8 @@ class DeviceExporterImpl : public DeviceExporter { const auto dtype = torch::python::detail::py_object_to_dtype(quant_type); return Device::symmetricQuantizeLastAxisOfBatchedMatrix(weight, dtype, arch); } - torch::Tensor preprocessWeightScale(torch::Tensor weight, torch::Tensor scale) { - return Device::preprocessWeightScale(weight, scale); + torch::Tensor preprocessWeightScale(torch::Tensor weight, torch::Tensor scale, const std::string& key) { + return Device::preprocessWeightScale(weight, scale, key); } }; diff --git a/rtp_llm/cpp/devices/DeviceFactory.cc b/rtp_llm/cpp/devices/DeviceFactory.cc index 4c5b8f81c..5f85519d2 100644 --- a/rtp_llm/cpp/devices/DeviceFactory.cc +++ b/rtp_llm/cpp/devices/DeviceFactory.cc @@ -216,7 +216,7 @@ void registerDeviceOps(py::module& m) { py::arg("weight"), py::arg("quant_type"), py::arg("arch")) - .def("preprocess_weight_scale", &DeviceExporter::preprocessWeightScale, py::arg("weight"), py::arg("scale")); + .def("preprocess_weight_scale", &DeviceExporter::preprocessWeightScale, py::arg("weight"), py::arg("scale"), py::arg("key")); m.def("get_device", &DeviceFactory::getDeviceExporter); } diff --git a/rtp_llm/cpp/devices/OpData.h b/rtp_llm/cpp/devices/OpData.h index e6d7716f3..afb4721fc 100644 --- a/rtp_llm/cpp/devices/OpData.h +++ b/rtp_llm/cpp/devices/OpData.h @@ -526,6 +526,7 @@ struct MlaRotaryWriteKVCacheParams { const AttentionLayerWeights& weights; const AttentionConfigs& configs; const QScheme qscheme; + bool is_prefill = false; }; struct MlaAttentionModuleParams { diff --git a/rtp_llm/cpp/devices/arm_impl/ArmActOp.cc b/rtp_llm/cpp/devices/arm_impl/ArmActOp.cc index 4d574084e..7d8689cd7 100644 --- a/rtp_llm/cpp/devices/arm_impl/ArmActOp.cc +++ b/rtp_llm/cpp/devices/arm_impl/ArmActOp.cc @@ -127,14 +127,30 @@ BufferPtr ArmCpuDevice::activation(const ActivationParams& params) { gate = params.gate.value().get().data(); printBufferData(params.gate.value().get(), "ffn activation gate"); if (states->type() == DataType::TYPE_FP16) { + #pragma omp parallel for if (m > 1) for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < n; j++) { + size_t j; + for (j = 0; j <= n - 8; j += 8) { + float16x8_t gate_vec = vld1q_f16((__fp16*)gate + i * n + j); + float16x8_t state_vec = vld1q_f16((__fp16*)states->dataWithOffset(i * n + j)); + state_vec = vmulq_f16(state_vec, gate_vec); + vst1q_f16((__fp16*)states->dataWithOffset(i * n + j), state_vec); + } + for (; j < n; j++) { *(__fp16*)(states->dataWithOffset(i * n + j)) *= ((__fp16*)gate)[i * n + j]; } } } else if (states->type() == DataType::TYPE_FP32) { + #pragma omp parallel for if (m > 1) for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < n; j++) { + size_t j; + for (j = 0; j <= n - 4; j += 4) { + float32x4_t gate_vec = vld1q_f32((float*)gate + i * n + j); + float32x4_t state_vec = vld1q_f32((float*)states->dataWithOffset(i * n + j)); + state_vec = vmulq_f32(state_vec, gate_vec); + vst1q_f32((float*)states->dataWithOffset(i * n + j), state_vec); + } + for (; j < n; j++) { *(float*)(states->dataWithOffset(i * n + j)) *= ((float*)gate)[i * n + j]; } } diff --git a/rtp_llm/cpp/devices/arm_impl/ArmAttentionOp.cc b/rtp_llm/cpp/devices/arm_impl/ArmAttentionOp.cc index 25b89b9e5..032792fb7 100644 --- a/rtp_llm/cpp/devices/arm_impl/ArmAttentionOp.cc +++ b/rtp_llm/cpp/devices/arm_impl/ArmAttentionOp.cc @@ -110,12 +110,156 @@ void updateKVCache(const AttentionModuleParams& params, int batch, size_t step, } } +struct Float2 { + float x; + float y; +}; + +inline Float2 rotary_embedding_transform(const Float2& v, const Float2& coef) { + return { + coef.x * v.x - coef.y * v.y, + coef.x * v.y + coef.y * v.x + }; +} + +inline float rope_inv_freq(const int zid, const int rot_embed_dim, const float base) { + return 1.0f / std::pow(base, zid / static_cast(rot_embed_dim)); +} + +struct YarnRope { + int dim; + int base; + int max_pos; + float beta_slow; + float beta_fast; + float scaling_factor; + float extrapolation_factor; + float mscale; + + static float find_correction_dim(float num_rot, int dim, int base, int max_pos = 2048) { + const float pi = 3.141592654f; + float t0 = dim * std::log(max_pos / (num_rot * 2 * pi)); + float t1 = 2 * std::log((float)base); + return t0 / t1; + } + + static std::pair find_correction_range(float low_rot, float high_rot, int dim, int base, int max_pos = 2048) { + int low = static_cast(std::floor(find_correction_dim(low_rot, dim, base, max_pos))); + int high = static_cast(std::ceil(find_correction_dim(high_rot, dim, base, max_pos))); + return {std::max(0, low), std::min(high, dim - 1)}; + } + + static float linear_ramp_mask(float min_, float max_, int tidx) { + if (min_ == max_) max_ += 0.001f; + float linear = (tidx / 2.0f - min_) / (max_ - min_); + return std::min(1.0f, std::max(0.0f, linear)); + } + + float operator()(float inv_freq, int zid) const { + auto [low, high] = find_correction_range(beta_fast, beta_slow, dim, base, max_pos); + float inv_freq_e = inv_freq; + float inv_freq_i = inv_freq_e / scaling_factor; + float mask = (1.0f - linear_ramp_mask(low, high, zid)) * extrapolation_factor; + return inv_freq_i * (1 - mask) + inv_freq_e * mask; + } + + float sin_cos_scale() const { + return mscale; + } +}; + +struct LinearScaleRope { + float scale = 1.0; + float operator()(float inv_freq, int zid) const { + return inv_freq / scale; + } + + float sin_cos_scale() const { + return 1.0; + } +}; + +template +Float2 rotary_embedding_coefficient( + const int zid, const int rot_embed_dim, const float t_step, + const float base, const RopeInit& rope_init) +{ + float inv_freq = rope_inv_freq(zid, rot_embed_dim, base); + inv_freq = rope_init(inv_freq, zid); + float angle = inv_freq * t_step; + float scale = rope_init.sin_cos_scale(); + return {scale * std::cos(angle), scale * std::sin(angle)}; +} + +template +void apply_rotary_embedding(Float2& v, int zid, int rot_embed_dim, int t_step, + float base, const RopeInit& rope_init) +{ + Float2 coef = rotary_embedding_coefficient(zid, rot_embed_dim, t_step, base, rope_init); + v = rotary_embedding_transform(v, coef); +} + /* Input 'qkv' consists of q & k & v, and each with shape [batch, seq_len, num_heads, head_dim]. * Half RoPE is applied to q & k. * Retrieve pre-calculated Cos/Sin if exists. */ template -void ArmCpuDevice::halfRopeQK(void *qkv, int batch, int seq_len, int num_heads, int kv_num_heads, int head_size, size_t step, size_t base, size_t embed_dim) { +void ArmCpuDevice::halfRopeQK(void *qkv, int batch, int seq_len, int num_heads, int kv_num_heads, int head_size, size_t step, const RopeConfig *rope_config/* size_t base, size_t embed_dim */) { + auto base = rope_config->base; + auto embed_dim = rope_config->dim; + auto offset = rope_config->offset; + if (rope_config->style == RopeStyle::Yarn) { + YarnRope yarn; + yarn.dim = embed_dim; + yarn.base = base; + yarn.max_pos = rope_config->max_pos; + yarn.beta_slow = rope_config->factor1; + yarn.beta_fast = rope_config->factor2; + yarn.scaling_factor = rope_config->scale; + yarn.extrapolation_factor = rope_config->extrapolation_factor; + yarn.mscale = rope_config->mscale; + + size_t inv_freq_size = (embed_dim + 1) / 2; + const int N = batch * seq_len; + + parallel_for(N, [&](int tid) { + int j = tid % seq_len; + T* q_input = (T*)qkv + tid * (num_heads + 2 * kv_num_heads) * head_size + offset; + T* k_input = (T*)qkv + tid * (num_heads + 2 * kv_num_heads) * head_size + num_heads * head_size + offset; + + size_t seq = (j == 0) ? step : j; + + for (int h = 0; h < num_heads; h++) { + for (int d = 0; d < inv_freq_size; d++) { + int rope_idx = d; + if (rope_idx < 0 || rope_idx >= inv_freq_size) + continue; + + // Load q + Float2 q = { + q_input[h * head_size + d], + q_input[h * head_size + d + inv_freq_size] + }; + apply_rotary_embedding(q, rope_idx * 2, embed_dim, seq, base, yarn); + q_input[h * head_size + d] = q.x; + q_input[h * head_size + d + inv_freq_size] = q.y; + + // Load and apply RoPE to k + if (h < kv_num_heads) { + Float2 k = { + k_input[h * head_size + d], + k_input[h * head_size + d + inv_freq_size] + }; + apply_rotary_embedding(k, rope_idx * 2, embed_dim, seq, base, yarn); + k_input[h * head_size + d] = k.x; + k_input[h * head_size + d + inv_freq_size] = k.y; + } + } + } + }); + return; + } + size_t inv_freq_size = (embed_dim + 1) / 2; auto &value = ropeCosSin[base]; @@ -228,13 +372,14 @@ void ArmCpuDevice::runOneBatch(const AttentionModuleParams& params, size_t past_ tStart = std::chrono::steady_clock::now(); if (params.configs.rope_config.style != RopeStyle::No) { - if (params.configs.rope_config.style == RopeStyle::Base) { + if (params.configs.rope_config.style == RopeStyle::Base || + params.configs.rope_config.style == RopeStyle::Yarn) { if (datatype == DataType::TYPE_FP32) { halfRopeQK(qkv, 1, seq_len, head_num, kv_head_num, size_per_head, step, - params.configs.rope_config.base, params.configs.rope_config.dim); + ¶ms.configs.rope_config); } else if (datatype == DataType::TYPE_FP16) { halfRopeQK<__fp16>(qkv, 1, seq_len, head_num, kv_head_num, size_per_head, step, - params.configs.rope_config.base, params.configs.rope_config.dim); + ¶ms.configs.rope_config); } else { throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); } @@ -478,33 +623,21 @@ void assemCacheArray(const AttentionModuleParams& params, BufferPtr k_out, Buffe } } -void ArmCpuDevice::runOneBatchStride(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step) { - auto datatype = params.input.type(); +void ArmCpuDevice::biasAddRopeWriteKVCache(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step) { auto head_num = params.configs.head_num; auto kv_head_num = params.configs.kv_head_num; auto size_per_head = params.configs.size_per_head; std::chrono::steady_clock::time_point tStart, tEnd; std::chrono::microseconds diff; - if (datatype != DataType::TYPE_FP32) { - throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); - } - // if (!params.common.kv_cache.has_value()) { - // throw std::runtime_error("kv cache block pointers can not be null"); - // } - - // Retrieve q/k/v by stride and not to split. auto qkv = params.input.dataWithOffset(past_seq * (head_num + 2 * kv_head_num) * size_per_head); - printBufferData(params.input, "qkv"); tStart = std::chrono::steady_clock::now(); - void *q_array[head_num], *k_array[head_num], *v_array[head_num]; tEnd = std::chrono::steady_clock::now(); diff = std::chrono::duration_cast(tEnd - tStart); logTime(diff, 0); tStart = std::chrono::steady_clock::now(); - //if (params.weights.qkv_weight->bias) { if (params.configs.fuse_qkv_add_bias && params.weights.qkv_weight->bias) { auto bias_data_type = params.weights.qkv_weight->bias->type(); if (bias_data_type == DataType::TYPE_FP32) { @@ -522,12 +655,12 @@ void ArmCpuDevice::runOneBatchStride(const AttentionModuleParams& params, size_t tStart = std::chrono::steady_clock::now(); if (params.configs.rope_config.style != RopeStyle::No) { - if (params.configs.rope_config.style == RopeStyle::Base) { + if (params.configs.rope_config.style == RopeStyle::Base || + params.configs.rope_config.style == RopeStyle::Yarn) { halfRopeQK(qkv, 1, seq_len, head_num, kv_head_num, size_per_head, step, - params.configs.rope_config.base, - params.configs.rope_config.dim); + ¶ms.configs.rope_config); } else { - throw std::runtime_error("SelfAttention RoPE type is not supported"); + throw std::runtime_error("RoPE type is not supported"); } } tEnd = std::chrono::steady_clock::now(); @@ -542,6 +675,26 @@ void ArmCpuDevice::runOneBatchStride(const AttentionModuleParams& params, size_t tEnd = std::chrono::steady_clock::now(); diff = std::chrono::duration_cast(tEnd - tStart); logTime(diff, 8); +} + + +void ArmCpuDevice::runOneBatchStride(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step) { + auto datatype = params.input.type(); + auto head_num = params.configs.head_num; + auto kv_head_num = params.configs.kv_head_num; + auto size_per_head = params.configs.size_per_head; + std::chrono::steady_clock::time_point tStart, tEnd; + std::chrono::microseconds diff; + + if (datatype != DataType::TYPE_FP32) { + throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); + } + + // Retrieve q/k/v by stride and not to split. + auto qkv = params.input.dataWithOffset(past_seq * (head_num + 2 * kv_head_num) * size_per_head); + printBufferData(params.input, "qkv"); + + void *q_array[head_num], *k_array[head_num], *v_array[head_num]; BufferPtr k_buffer, v_buffer; void *k_in, *v_in; @@ -595,7 +748,7 @@ void ArmCpuDevice::runOneBatchStride(const AttentionModuleParams& params, size_t logTime(diff, 4); tStart = std::chrono::steady_clock::now(); - float scale = (1.0f / sqrtf(size_per_head * 1.0f)); + float scale = (1.0f / sqrtf(size_per_head * 1.0f)) * params.configs.softmax_extra_scale; BufferPtr softmax_qk_output; if (seq_len == 1) { @@ -698,8 +851,500 @@ void ArmCpuDevice::runOneBatchStride(const AttentionModuleParams& params, size_t logTime(diff, 7); /* Print profile data at the end of operator unit test. */ - // if (a_cnt_[0] == 24) - // printStat(); + //if (a_cnt_[0] == 24) + // printStat(); +} + +template +static inline void vScaleMask(float* AB, float scale, const MaskType* attnMask, int m, int k, int attnMskStride) { + for (int i = 0; i < m; i++) { + float* buf = AB + i * k; + const MaskType* mbuf = attnMask + i * attnMskStride; + + float32x4_t vscale = vdupq_n_f32(scale); + float32x4_t mask_val = vdupq_n_f32(-10000.0f); + float32x4_t ones = vdupq_n_f32(1.f); + + int j; + for (j = 0; j <= k - 4; j += 4) { + float32x4_t vin = vld1q_f32(buf + j); + float32x4_t vmask; + if constexpr (std::is_same_v) { + vmask = vld1q_f32(mbuf + j); + } else { + vmask = vcvt_f32_f16(vld1_f16(mbuf + j)); + } + vmask = vsubq_f32(ones, vmask); + vmask = vmulq_f32(vmask, mask_val); + float32x4_t vout = vmlaq_f32(vmask, vin, vscale); + vst1q_f32(buf + j, vout); + } + + for (; j < k; j++) { + buf[j] = buf[j] * scale + (1.0f - (float)mbuf[j]) * -10000.0f; + } + } +} + +static inline void vScale(float* AB, float scale, int m, int k) { + for (int i = 0; i < m; i++) { + float* buf = AB + i * k; + float32x4_t vscale = vdupq_n_f32(scale); + int j; + for (j = 0; j <= k - 4; j += 4) { + float32x4_t vx = vld1q_f32(buf + j); + vx = vmulq_f32(vscale, vx); + vst1q_f32(buf + j, vx); + } + for (; j < k; j++) { + buf[j] = buf[j] * scale; + } + } +} + +static inline void vSoftmaxTile(float* AB, float* ABout, float* sum, float* max, int m, int k) { + for (int i = 0; i < m; ++i) { + float* buf = AB + i * k; + float* obuf = ABout + i * k; + + float cur_max = vMax(k, buf); + + cur_max = std::max(cur_max, max[i]); + float merr = std::exp(max[i] - cur_max); + max[i] = cur_max; + float cur_sum = 0; + + float32x4_t vsum = vdupq_n_f32(0.0f); + float32x4_t vmax = vdupq_n_f32(cur_max); + int j; + for (j = 0; j <= k - 4; j += 4) { + float32x4_t vx = vld1q_f32(buf + j); + vx = vexpq_f32(vsubq_f32(vx, vmax)); + vst1q_f32(obuf + j, vx); + vsum = vaddq_f32(vsum, vx); + } + for (; j < k; j++) { + obuf[j] = std::exp(buf[j] - cur_max); + cur_sum += obuf[j]; + } + for (j = 0; j < 4; j++) { + cur_sum += vsum[j]; + } + + sum[i] = sum[i] * merr + cur_sum; + + float sum_mul = 1.0f / sum[i]; + float32x4_t vsum_mul = vdupq_n_f32(sum_mul); + for (j = 0; j <= k - 4; j+= 4) { + float32x4_t vx = vld1q_f32(obuf + j); + vx = vmulq_f32(vx, vsum_mul); + vst1q_f32(obuf + j, vx); + } + for (; j < k; j++) { + obuf[j] *= sum_mul; + } + } +} + +static inline void vUpdateOutTile(float* output, const float* expABC, float* preSum, + float* sum, float* preMax, float* max, int m, int n, + int stride) { + for (int i = 0; i < m; ++i) { + const float* buf = expABC + i * n; + float* outbuf = output + i * stride; + float32x4_t merr = vdupq_n_f32(preMax[i] - max[i]); + merr = vexpq_f32(merr); + float32x4_t vfac = vdupq_n_f32(preSum[i] / sum[i]); + for (int off = 0; off < n; off += 4) { + float32x4_t vout = vld1q_f32(outbuf + off); + float32x4_t vabc = vld1q_f32(buf + off); + float32x4_t vupt = vmlaq_f32(vabc, vout, vmulq_f32(merr, vfac)); + vst1q_f32(outbuf + off, vupt); + } + preSum[i] = sum[i]; + preMax[i] = max[i]; + } +} + +static inline void vReduceSumSplitKVOutput(float* output, const float* o_split_kv, float* lse, + float* lse_logsum, int m, int n, int stride) { + for (int i = 0; i < m; ++i) { + const float* buf = o_split_kv + i * n; + float* outbuf = output + i * stride; + + float lse_scale = std::exp(lse[i] - lse_logsum[i]); + float32x4_t vscale = vdupq_n_f32(lse_scale); + for (int off = 0; off < n; off += 4) { + float32x4_t vo = vld1q_f32(outbuf + off); + float32x4_t vx = vld1q_f32(buf + off); + vo = vmlaq_f32(vo, vx, vscale); + vst1q_f32(outbuf + off, vo); + } + } +} + +template +static inline void vIncrementalTileAttention( + const float* q, const float* k, const float* v, const MaskType* mask, int q_len, + int qk_dim, int v_dim, int kv_len, int mask_stride, float* pre_sum, float* sum, float* pre_max, + float* max, float scale, float* qk, float* qkv, float* output, + int q_stride, int k_stride, int v_stride, int o_stride) { + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, q_len, kv_len, qk_dim, 1.0, q, + q_stride, k, k_stride, 0, qk, kv_len); + + if (mask) { + vScaleMask(qk, scale, mask, q_len, kv_len, mask_stride); + } else { + vScale(qk, scale, q_len, kv_len); + } + + vSoftmaxTile(qk, qk, sum, max, q_len, kv_len); + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, q_len, v_dim, kv_len, 1.0, qk, kv_len, v, + v_stride, 0.0, qkv, v_dim); + + vUpdateOutTile(output, qkv, pre_sum, sum, pre_max, max, q_len, v_dim, o_stride); +} + +void ArmCpuDevice::runOneBatchFlash(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step) { + auto datatype = params.input.type(); + auto head_num = params.configs.head_num; + auto kv_head_num = params.configs.kv_head_num; + auto size_per_head = params.configs.size_per_head; + std::chrono::steady_clock::time_point tStart, tEnd; + std::chrono::microseconds diff; + + if (datatype != DataType::TYPE_FP32) { + throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); + } + if (step != 0) { + throw std::runtime_error("flash attention only used in context prefill"); + } + + // Retrieve q/k/v by stride and not to split. + auto qkv = params.input.dataWithOffset(past_seq * (head_num + 2 * kv_head_num) * size_per_head); + printBufferData(params.input, "qkv"); + + tStart = std::chrono::steady_clock::now(); + + void* k_in = (float *)qkv + head_num * size_per_head; + void* v_in = (float *)qkv + (head_num + kv_head_num) * size_per_head; + int kv_stride = (head_num + 2 * kv_head_num) * size_per_head; + + const int q_stride = (head_num + 2 * kv_head_num) * size_per_head; + const int o_stride = head_num * size_per_head; + float scale = (1.0f / sqrtf(size_per_head * 1.0f)) * params.configs.softmax_extra_scale; + float *output = (float *)params.output.dataWithOffset(past_seq * head_num * size_per_head); + + int num_group = head_num / kv_head_num; + + int q_len = seq_len; + int kv_len = seq_len; + int q_blk = std::min(128, (int)std::pow(2, int(std::log2((q_len + 1) / 2)))); + int kv_blk = std::min(256, kv_len); + + size_t num_thread = omp_get_max_threads(); + + typedef struct { + float* pre_sum; + float* sum; + float* pre_max; + float* max; + float* qk_arr; + float* exp_qkv_arr; + } Ptrs; + + // 4: pre_sum, sum, pre_max, max; kv_blk: exp_qkT; PV_i + size_t arr_stride = (4 + kv_blk + size_per_head) * q_blk; + auto workspace = allocateBuffer({DataType::TYPE_FP32, {num_thread, arr_stride}}); + Ptrs* ptrs = new Ptrs[num_thread]; + + for (int i = 0; i < num_thread; ++i) { + ptrs[i].pre_sum = (float*)workspace->dataWithOffset(i * arr_stride); + ptrs[i].sum = (float*)workspace->dataWithOffset(i * arr_stride) + q_blk; + ptrs[i].pre_max = (float*)workspace->dataWithOffset(i * arr_stride) + q_blk * 2; + ptrs[i].max = (float*)workspace->dataWithOffset(i * arr_stride) + q_blk * 3; + ptrs[i].qk_arr = (float*)workspace->dataWithOffset(i * arr_stride) + q_blk * 4; + ptrs[i].exp_qkv_arr = (float*)workspace->dataWithOffset(i * arr_stride) + q_blk * (4 + kv_blk); + } + + const int q_blk_num = (q_len + q_blk - 1) / q_blk; + const int N = head_num * q_blk_num; + + parallel_for(N, [&](int idx) { + int h = idx / q_blk_num; + int m = (idx % q_blk_num) * q_blk; + int tid = omp_get_thread_num(); + Ptrs ptr = ptrs[tid]; + int q_real_blk = std::min(q_blk, q_len - m); + uint64_t src_off = h * size_per_head; + uint64_t out_off = h * size_per_head; + const float* q_buf = (float *)qkv + src_off + m * q_stride; + float* out = output + out_off + m * o_stride; + + // reset out + float32x4_t zero = vdupq_n_f32(0.0f); + for (int ii = 0; ii < q_real_blk; ++ii) { + for (int jj = 0; jj < size_per_head; jj += 4) { + vst1q_f32(out + ii * o_stride + jj, zero); + } + } + + // reset sum +#pragma omp simd + for (int ii = 0; ii < q_real_blk; ++ii) { + ptr.pre_sum[ii] = 0; + ptr.sum[ii] = 0; + ptr.pre_max[ii] = std::numeric_limits::lowest(); + ptr.max[ii] = std::numeric_limits::lowest(); + } + + uint64_t tgt_off = (h / num_group) * size_per_head; + const float* k = (float *)k_in + tgt_off; + const float* v = (float *)v_in + tgt_off; + + for (int n = 0; n < kv_len; n += kv_blk) { + int kv_real_blk = std::min(kv_blk, kv_len - n); + + // Mask out. Only works for causal mask + if (params.common.attention_mask && (m + q_real_blk - 1 < n)) { + break; + } + + const float* k_blk = k + n * kv_stride; + const float* v_blk = v + n * kv_stride; + + if (params.common.attention_mask && params.common.attention_mask->type() == DataType::TYPE_FP16) { + const __fp16* mask_blk = (__fp16 *)params.common.attention_mask->dataWithOffset(batch * q_len * kv_len) + m * kv_len + n; + vIncrementalTileAttention( + q_buf, k_blk, v_blk, mask_blk, q_real_blk, size_per_head, size_per_head, kv_real_blk, + kv_len, ptr.pre_sum, ptr.sum, ptr.pre_max, ptr.max, scale, + ptr.qk_arr, ptr.exp_qkv_arr, out, q_stride, kv_stride, + kv_stride, o_stride); + } else { + const float* mask_blk; + if (!params.common.attention_mask) { + mask_blk = nullptr; + } else if (params.common.attention_mask && params.common.attention_mask->type() == DataType::TYPE_FP32) { + mask_blk = (float *)params.common.attention_mask->dataWithOffset(batch * q_len * kv_len) + m * kv_len + n; + } else { + throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); + } + vIncrementalTileAttention( + q_buf, k_blk, v_blk, mask_blk, q_real_blk, size_per_head, size_per_head, kv_real_blk, + kv_len, ptr.pre_sum, ptr.sum, ptr.pre_max, ptr.max, scale, + ptr.qk_arr, ptr.exp_qkv_arr, out, q_stride, kv_stride, + kv_stride, o_stride); + } + } + }); + + delete[] ptrs; + tEnd = std::chrono::steady_clock::now(); + diff = std::chrono::duration_cast(tEnd - tStart); + logTime(diff, 4); +} + +void ArmCpuDevice::runOneBatchFlashDecoding(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step) { + bool mla = params.configs.use_mla; + auto datatype = params.input.type(); + auto head_num = params.configs.head_num; + auto kv_head_num = mla ? 1 : params.configs.kv_head_num; + auto size_per_head = mla ? params.configs.kv_lora_rank + params.configs.rope_head_dim : params.configs.size_per_head; + auto v_head_dim = mla ? params.configs.kv_lora_rank : params.configs.size_per_head; + + std::chrono::steady_clock::time_point tStart, tEnd; + std::chrono::microseconds diff; + + if (datatype != DataType::TYPE_FP32) { + throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); + } + if (!params.common.kv_cache.has_value()) { + throw std::runtime_error("kv cache block pointers can not be null"); + } + if (seq_len != 1) { + throw std::runtime_error("flash decoding only used in decode phase"); + } + + // Retrieve q by stride + const int input_stride = mla ? head_num * size_per_head : (head_num + 2 * kv_head_num) * size_per_head; + auto q = params.input.dataWithOffset(past_seq * input_stride); + + tStart = std::chrono::steady_clock::now(); + + const int q_stride = size_per_head; + const int kv_stride = kv_head_num * size_per_head; + const int o_stride = v_head_dim; + float scale = (1.0f / sqrtf(size_per_head * 1.0f)) * params.configs.softmax_extra_scale; + float *output = (float *)params.output.dataWithOffset(past_seq * head_num * v_head_dim); + + int group_size = head_num / kv_head_num; + + int q_len = group_size; + int kv_len = step + 1; + int q_blk = 8; + int kv_blk = params.configs.tokens_per_block; + + // get kv blocks address + const int kv_blk_num = (kv_len + kv_blk - 1) / kv_blk; + float** k_blk_addrs = new float*[kv_blk_num]; + float** v_blk_addrs = new float*[kv_blk_num]; + const KvCacheInfo& kv_cache_info = params.common.kv_cache.value(); + for (int i = 0; i < kv_blk_num; i++) { + getCacheAddrFromIndex(kv_cache_info, batch, i, (void**)(k_blk_addrs + i), (void**)(v_blk_addrs + i)); + } + + size_t num_thread = omp_get_max_threads(); + + // split kv_len + int kv_split_blk = kv_blk * std::min(kv_blk_num, (int)((kv_blk_num * kv_head_num + num_thread - 1) / num_thread)); + int kv_split_num = (kv_len + kv_split_blk - 1) / kv_split_blk; + + const int q_blk_num = (q_len + q_blk - 1) / q_blk; + const size_t N = q_blk_num * kv_head_num * kv_split_num; + + typedef struct { + float* pre_sum; + float* sum; + float* pre_max; + float* max; + float* qk_arr; + float* exp_qkv_arr; + float* o_arr; // output of a kv split + float* lse; // log-sum-exp + } Ptrs; + + // 5: pre_sum, sum, pre_max, max, lse; kv_blk: exp_qkT; PV_i; kv_split_O + size_t arr_stride = (5 + kv_blk + v_head_dim * 2) * q_blk; + auto workspace = allocateBuffer({DataType::TYPE_FP32, {N, arr_stride}}); + Ptrs* ptrs = new Ptrs[N]; + + for (int i = 0; i < N; ++i) { + ptrs[i].pre_sum = (float*)workspace->dataWithOffset(i * arr_stride); + ptrs[i].sum = (float*)workspace->dataWithOffset(i * arr_stride) + q_blk; + ptrs[i].pre_max = (float*)workspace->dataWithOffset(i * arr_stride) + q_blk * 2; + ptrs[i].max = (float*)workspace->dataWithOffset(i * arr_stride) + q_blk * 3; + ptrs[i].qk_arr = (float*)workspace->dataWithOffset(i * arr_stride) + q_blk * 4; + ptrs[i].exp_qkv_arr = (float*)workspace->dataWithOffset(i * arr_stride) + q_blk * (4 + kv_blk); + ptrs[i].o_arr = (float*)workspace->dataWithOffset(i * arr_stride) + q_blk * (4 + kv_blk + v_head_dim); + ptrs[i].lse = (float*)workspace->dataWithOffset(i * arr_stride) + q_blk * (4 + kv_blk + v_head_dim * 2); + } + + parallel_for(N, [&](int idx) { + Ptrs ptr = ptrs[idx]; + int q_blk_idx = idx / (kv_head_num * kv_split_num); + idx = idx % (kv_head_num * kv_split_num); + int kv_h = idx / kv_split_num; + int kv_split_idx = idx % kv_split_num; + int h = kv_h * group_size; + int q_real_blk = std::min(q_blk, q_len - q_blk_idx * q_blk); + uint64_t src_off = h * size_per_head; + const float* q_buf = (float *)q + src_off + q_blk_idx * q_blk * q_stride; + float* out = ptr.o_arr; + + // reset out + float32x4_t zero = vdupq_n_f32(0.0f); + for (int ii = 0; ii < q_real_blk; ++ii) { + for (int jj = 0; jj < v_head_dim; jj += 4) { + vst1q_f32(out + ii * o_stride + jj, zero); + } + } + + // reset sum + for (int ii = 0; ii < q_real_blk; ++ii) { + ptr.pre_sum[ii] = 0.0f; + ptr.sum[ii] = 0.0f; + ptr.pre_max[ii] = std::numeric_limits::lowest(); + ptr.max[ii] = std::numeric_limits::lowest(); + } + + for (int i = 0; i < kv_split_blk; i += kv_blk) { + int n = kv_split_idx * kv_split_blk + i; + int kv_real_blk = std::min(kv_blk, kv_len - n); + if (kv_real_blk <= 0) { + break; + } + + size_t block_idx = n / kv_blk; + const float* k_blk = k_blk_addrs[block_idx] + kv_h * size_per_head; + const float* v_blk = (mla ? k_blk_addrs[block_idx] : v_blk_addrs[block_idx]) + kv_h * size_per_head; + + const float* mask_blk = nullptr; + vIncrementalTileAttention( + q_buf, k_blk, v_blk, mask_blk, q_real_blk, size_per_head, v_head_dim, kv_real_blk, + kv_len, ptr.pre_sum, ptr.sum, ptr.pre_max, ptr.max, scale, + ptr.qk_arr, ptr.exp_qkv_arr, out, q_stride, kv_stride, + kv_stride, o_stride); + } + }); + + tEnd = std::chrono::steady_clock::now(); + diff = std::chrono::duration_cast(tEnd - tStart); + logTime(diff, 4); + + + tStart = std::chrono::steady_clock::now(); + + // reduce sum kv_split_O to Output + parallel_for(q_blk_num * kv_head_num, [&](int idx) { + int q_blk_idx = idx / kv_head_num; + int kv_h = idx % kv_head_num; + int h = kv_h * group_size; + uint64_t out_off = h * v_head_dim; + float* out = output + out_off + q_blk_idx * q_blk * o_stride; + int q_real_blk = std::min(q_blk, q_len - q_blk_idx * q_blk); + + // reset out + float32x4_t zero = vdupq_n_f32(0.0f); + for (int ii = 0; ii < q_real_blk; ++ii) { + for (int jj = 0; jj < v_head_dim; jj += 4) { + vst1q_f32(out + ii * o_stride + jj, zero); + } + } + + float* lse_logsum = new float[q_real_blk]; + for (int ii = 0; ii < q_real_blk; ++ii) { + float lse_sum = 0.0f; + float lse_max = std::numeric_limits::lowest(); + + for (int j = 0; j < kv_split_num; j++) { + Ptrs ptr = ptrs[idx * kv_split_num + j]; + + // lse = max + log(sum) + // lse_max = max(lse) + ptr.lse[ii] = ptr.max[ii] + std::log(ptr.sum[ii]); + lse_max = std::max(lse_max, ptr.lse[ii]); + } + + for (int j = 0; j < kv_split_num; j++) { + Ptrs ptr = ptrs[idx * kv_split_num + j]; + + // lse_sum = sum(exp(lse - lse_max)) + lse_sum += std::exp(ptr.lse[ii] - lse_max); + } + + // lse_logsum = log(lse_sum) + lse_max + lse_logsum[ii] = std::log(lse_sum) + lse_max; + } + + for (int i = 0; i < kv_split_num; i++) { + Ptrs ptr = ptrs[idx * kv_split_num + i]; + + vReduceSumSplitKVOutput(out, ptr.o_arr, ptr.lse, lse_logsum, + q_real_blk, v_head_dim, o_stride); + + } + + delete[] lse_logsum; + }); + + delete[] ptrs; + delete[] k_blk_addrs; + delete[] v_blk_addrs; + tEnd = std::chrono::steady_clock::now(); + diff = std::chrono::duration_cast(tEnd - tStart); + logTime(diff, 5); } AttentionModuleOutput ArmCpuDevice::contextAttention(const AttentionModuleParams& params) { @@ -710,7 +1355,16 @@ AttentionModuleOutput ArmCpuDevice::contextAttention(const AttentionModuleParams if (params.input.type() == DataType::TYPE_FP32) { for (int batch = 0; batch < batch_size; batch++) { size_t context_len = *static_cast(params.common.input_lengths->dataWithOffset(decoder_batch + batch)); - runOneBatchStride(params, past_seq, batch, context_len, 0); + + if (!params.configs.use_mla) { + biasAddRopeWriteKVCache(params, past_seq, batch, context_len, 0); + } + + if (isFAenabled) { + runOneBatchFlash(params, past_seq, batch, context_len, 0); + } else { + runOneBatchStride(params, past_seq, batch, context_len, 0); + } past_seq += context_len; } } else if (params.input.type() == DataType::TYPE_FP16) { @@ -731,7 +1385,16 @@ AttentionModuleOutput ArmCpuDevice::decoderSelfAttention(const AttentionModulePa if (params.input.type() == DataType::TYPE_FP32) { for (int batch = 0; batch < batch_size; batch++) { size_t step = *static_cast(params.common.sequence_lengths->dataWithOffset(batch)); - runOneBatchStride(params, batch, batch, 1, step); + + if (!params.configs.use_mla) { + biasAddRopeWriteKVCache(params, batch, batch, 1, step); + } + + if (isFAenabled || params.configs.use_mla) { + runOneBatchFlashDecoding(params, batch, batch, 1, step); + } else { + runOneBatchStride(params, batch, batch, 1, step); + } } } else if (params.input.type() == DataType::TYPE_FP16) { for (int batch = 0; batch < batch_size; batch++) { diff --git a/rtp_llm/cpp/devices/arm_impl/ArmDevice.cc b/rtp_llm/cpp/devices/arm_impl/ArmDevice.cc index 64331c91a..ebc0ae514 100644 --- a/rtp_llm/cpp/devices/arm_impl/ArmDevice.cc +++ b/rtp_llm/cpp/devices/arm_impl/ArmDevice.cc @@ -3,6 +3,8 @@ #include "rtp_llm/cpp/core/allocator.h" #include "rtp_llm/cpp/core/cpu_allocator.h" #include "rtp_llm/cpp/core/TrackerAllocator.h" +#include "rtp_llm/cpp/core/torch_utils/BufferTorchUtils.h" + #include #include @@ -52,6 +54,11 @@ ArmCpuDevice::ArmCpuDevice(const DeviceInitParams& params): DeviceBase(params) { gemmFunc = &ArmCpuDevice::gemm_kai_bf16; } + if (std::getenv("ARM_FA") == nullptr) { + isFAenabled = false; + } else { + isFAenabled = true; + } } ArmCpuDevice::~ArmCpuDevice() {} @@ -157,6 +164,19 @@ DevicePrepOutput ArmCpuDevice::prepareModelRun(const DevicePrepParams& params) { return output; } +SliceOutput ArmCpuDevice::slice(const SliceParams& params) { + const auto& input = params.input; + const auto& starts = params.start; + const auto& step = params.step; + auto input_t = Buffer2torchTensor(params.input, false); + auto sliceTensor = input_t.slice(params.dim, starts, params.end, step); + auto buffer_shape = torchShapeToBufferShape(sliceTensor.sizes()); + auto out = allocateBuffer({input.type(), buffer_shape}); + auto out_t = Buffer2torchTensor(out, false); + out_t.copy_(sliceTensor, false); + return out; +} + RTP_LLM_REGISTER_DEVICE(ArmCpu); } // namespace rtp_llm diff --git a/rtp_llm/cpp/devices/arm_impl/ArmDevice.h b/rtp_llm/cpp/devices/arm_impl/ArmDevice.h index 2a1a83b94..6f81109a8 100644 --- a/rtp_llm/cpp/devices/arm_impl/ArmDevice.h +++ b/rtp_llm/cpp/devices/arm_impl/ArmDevice.h @@ -23,6 +23,7 @@ class ArmCpuDevice : public DeviceBase { public: void copy(const CopyParams& params) override; LayernormOutput layernorm(const LayernormParams& params) override; + LayernormOutput layernormWithStride(const LayernormWithStrideParams& params) override; BufferPtr gemm(const GemmParams& params) override; BufferPtr gemm_acl(const GemmParams& params); BufferPtr gemm_opt(const GemmParams& params); @@ -36,11 +37,18 @@ class ArmCpuDevice : public DeviceBase { AttentionModuleOutput decoderSelfAttention(const AttentionModuleParams& params) override; GreedyOutput sampleGreedy(const GreedyParams& params) override; void sampleBeamSearch(const BeamSearchParams& params) override; + BufferPtr mlaQKVGemm(const AttentionLayerParams& params) override; + void mlaRotaryWriteKVCache(const MlaRotaryWriteKVCacheParams& params) override; + void prepareMoEGate(const FfnLayerParams& params, BufferPtr gate); + void mlaAbsorbAttention(const MlaAttentionModuleParams& params) override; + void mlaContextAttention(const MlaAttentionModuleParams& params) override; + FfnLayerOutput moeFfnLayer(const FfnLayerParams& params) override; void broadcast(const BroadcastParams& params) override; void allReduceSum(const AllReduceParams& params); DevicePrepOutput prepareModelRun(const DevicePrepParams& params) override; void printStat(); MemoryStatus getDeviceMemoryStatus() override; + SliceOutput slice(const SliceParams& params) override; #ifdef GEMM_DEBUG static void print_time(); #endif @@ -48,16 +56,19 @@ class ArmCpuDevice : public DeviceBase { static torch::Tensor packInt8TensorToPackedInt4(torch::Tensor weight); static torch::Tensor preprocessWeightsForMixedGemm(torch::Tensor row_major_quantized_weight, torch::ScalarType quant_type, const std::string &arch); - static torch::Tensor preprocessWeightScale(torch::Tensor weight, torch::Tensor scale); + static torch::Tensor preprocessWeightScale(torch::Tensor weight, torch::Tensor scale, const std::string& key); private: std::unique_ptr allocator_; arm_compute::DataType getAclDataType(DataType type); void runOneBatch(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step); void runOneBatchStride(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step); + void runOneBatchFlash(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step); + void runOneBatchFlashDecoding(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step); std::unordered_map> ropeCosSin; template - void halfRopeQK(void *qkv, int batch, int seq_len, int num_heads, int kv_num_heads, int head_size, size_t step, size_t base, size_t embed_dim); + void halfRopeQK(void *qkv, int batch, int seq_len, int num_heads, int kv_num_heads, int head_size, size_t step, const RopeConfig* rope_config); + void biasAddRopeWriteKVCache(const AttentionModuleParams& params, size_t past_seq, int batch, size_t seq_len, size_t step); void logTime(std::chrono::microseconds diff, size_t index); uint64_t a_cnt_[16] = {0}; uint64_t a_tmin_[16] = {999999999, 999999999, 999999999, 999999999, 999999999, 999999999, 999999999, 999999999, @@ -66,8 +77,11 @@ class ArmCpuDevice : public DeviceBase { uint64_t a_tave_[16] = {0}; GemmKernel gemm_kernel_; + FfnLayerOutput moe_ffn_a8w4(const BufferPtr expert_indices, const BufferPtr expert_weights, const BufferPtr output, const FfnLayerParams& params); + BufferPtr (ArmCpuDevice::*gemmFunc)(const GemmParams& params); bool isKAIenabled; + bool isFAenabled; #ifdef GEMM_DEBUG static TimerRecorder timer_recorder_; @@ -75,5 +89,6 @@ class ArmCpuDevice : public DeviceBase { }; extern ConstBufferPtr (*armPrepareWeightFunc)(ConstBufferPtr input, bool isTranspose, bool isForceF32Out); - +extern float32x4_t vexpq_f32(float32x4_t x); +extern float vMax(int n, const float* a); } // namespace rtp_llm diff --git a/rtp_llm/cpp/devices/arm_impl/ArmDispatch.h b/rtp_llm/cpp/devices/arm_impl/ArmDispatch.h new file mode 100644 index 000000000..4e73b1e0a --- /dev/null +++ b/rtp_llm/cpp/devices/arm_impl/ArmDispatch.h @@ -0,0 +1,176 @@ + +#pragma once + +#include "rtp_llm/cpp/utils/AssertUtils.h" +#include "rtp_llm/cpp/utils/Logger.h" +#include "rtp_llm/cpp/core/Types.h" + +#include +#include +#include +#include +#include + +namespace rtp_llm { + +template +struct FunctionTraits; + +template +struct FunctionTraits> { +public: + static const size_t nargs = sizeof...(Args); + typedef std::tuple args; +}; + +template +constexpr bool IsCastingVoidPtrToWorkTPtr = + std::is_pointer_v && + std::is_void_v> && + std::is_pointer_v && + std::is_same_v>, WorkT>; + +template +constexpr bool IsCastingFloatToWorkT = + std::is_floating_point_v && + std::is_same_v && + std::is_convertible_v; + + + +template::value, bool> = 0> +inline DstT simpleCast(SrcT src) { + return src; +} + +template) && std::is_convertible_v, bool> = 0> +inline DstT simpleCast(SrcT src) { + return (DstT)src; +} + +template) && + (!std::is_convertible_v) && + (std::is_pointer_v && std::is_pointer_v), bool> = 0> +inline DstT simpleCast(SrcT src) { + return reinterpret_cast(src); +} + +template, bool> = 0> +inline DstT cast(SrcT src) { + return static_cast(src); +} + +template, bool> = 0> +inline DstT cast(SrcT src) { + return (DstT)src; +} + +template) && + (!IsCastingFloatToWorkT), bool> = 0> +inline DstT cast(SrcT src) { + return simpleCast(src); +} + +template +void castTuple(std::tuple &dst, const std::tuple &src, std::index_sequence) { + int unused_expander[] = { 0, + ((void)[&] { + using SrcT = std::tuple_element_t>; + using DstT = std::tuple_element_t>; + std::get(dst) = cast(std::get(src)); + }(), 0) ... }; + (void)unused_expander; +} + +template, bool> = 0> +CastedTuple castArgs(const std::tuple& args) { + auto ret = CastedTuple(); + castTuple(ret, args, std::make_index_sequence>()); + return ret; +} + +template, bool> = 0> +CastedTuple castArgs(const std::tuple& args) { + return move(args); +} + +#define ARG_CASTED_FUNC_CALL(T, func_name, ...) { \ + using target_args_type = FunctionTraits)>>::args; \ + auto typed_args = castArgs(std::make_tuple(__VA_ARGS__)); \ + std::apply(func_name, typed_args); \ +} + +#define ARG_CASTED_FUNC_CALL_TWO_TYPE(T1, T2, func_name, ...) { \ + using target_args_type = FunctionTraits)>>::args; \ + auto typed_args = castArgs>(std::make_tuple(__VA_ARGS__)); \ + std::apply(func_name, typed_args); \ +} + +#define DISPATCH_FOR_EACH_COMPUTE_TYPE(MACRO, ...) \ + MACRO(DataType::TYPE_FP32, float, __VA_ARGS__) \ + MACRO(DataType::TYPE_FP16, __fp16, __VA_ARGS__) \ + MACRO(DataType::TYPE_BF16, hie::bfloat16, __VA_ARGS__) \ + default: \ + RTP_LLM_CHECK(false); + +#define DISPATCH_FOR_EACH_NUMERIC_TYPE(MACRO, ...) \ + MACRO(DataType::TYPE_BYTES, uint8_t, __VA_ARGS__) \ + MACRO(DataType::TYPE_INT8, int8_t, __VA_ARGS__) \ + MACRO(DataType::TYPE_INT32, int32_t, __VA_ARGS__) \ + MACRO(DataType::TYPE_INT64, int64_t, __VA_ARGS__) \ + MACRO(DataType::TYPE_UINT8, uint8_t, __VA_ARGS__) \ + MACRO(DataType::TYPE_UINT32, uint32_t, __VA_ARGS__) \ + MACRO(DataType::TYPE_UINT64, uint64_t, __VA_ARGS__) \ + DISPATCH_FOR_EACH_COMPUTE_TYPE(MACRO, __VA_ARGS__) + +#define DISPATCH_FOR_EACH_QUANT_TYPE(MACRO, T1, ...) \ + MACRO(DataType::TYPE_INT8, T1, int8_t, __VA_ARGS__) \ + ENABLE_FP8_CASE_QUANT(MACRO, T1, __VA_ARGS__) \ + default: \ + RTP_LLM_CHECK_WITH_INFO(false, "unsupport quant type"); + +#define DP_FUNCTION_CALL_CASE(data_type, T, ...) \ + case data_type: { \ + ARG_CASTED_FUNC_CALL(T, __VA_ARGS__); \ + break; \ + } + +#define DISPATCH_ARM_FUNCTION_DATA_TYPE(data_type, function, ...) \ + do { \ + switch (data_type) { \ + DISPATCH_FOR_EACH_COMPUTE_TYPE(DP_FUNCTION_CALL_CASE, function, __VA_ARGS__) \ + } \ + } while (0) + +#define DP_TWO_TYPE_FUNCTION_CALL_CASE(data_type, T1, T2, ...) \ + case data_type: { \ + ARG_CASTED_FUNC_CALL_TWO_TYPE(T1, T2, __VA_ARGS__); \ + break; \ + } + +#define GENERAL_OUTER_TYPE_CASE(dtype1, T1, dtype2, function, ...) \ + case dtype1: { \ + switch (dtype2) { \ + DISPATCH_FOR_EACH_COMPUTE_TYPE(DP_TWO_TYPE_FUNCTION_CALL_CASE, T1, function, __VA_ARGS__); \ + } \ + break; \ + } + +#define DISPATCH_ARM_FUNCTION_TWO_DATA_TYPES(dtype1, dtype2, function, ...) \ + switch (dtype1) { \ + GENERAL_OUTER_TYPE_CASE(DataType::TYPE_FP16, __fp16, dtype2, function, __VA_ARGS__) \ + GENERAL_OUTER_TYPE_CASE(DataType::TYPE_BF16, hie::bfloat16, dtype2, function, __VA_ARGS__) \ + GENERAL_OUTER_TYPE_CASE(DataType::TYPE_FP32, float, dtype2, function, __VA_ARGS__) \ + default: \ + RTP_LLM_CHECK(false); \ + } + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/devices/arm_impl/ArmFfnLayer.cc b/rtp_llm/cpp/devices/arm_impl/ArmFfnLayer.cc new file mode 100644 index 000000000..fbc701ba8 --- /dev/null +++ b/rtp_llm/cpp/devices/arm_impl/ArmFfnLayer.cc @@ -0,0 +1,448 @@ +#include "rtp_llm/cpp/devices/arm_impl/ArmDevice.h" +#include "rtp_llm/cpp/devices/utils/DebugUtils.h" +#include "rtp_llm/cpp/core/BufferHelper.h" + +using namespace std; + +namespace rtp_llm { + +extern size_t get_rhs_packed_size(int n, int k); + +void softmax2(const float* input, float* output, int rows, int cols) { + for (int i = 0; i < rows; ++i) { + float max_val = *std::max_element(input + i * cols, input + (i + 1) * cols); + float sum = 0.0f; + for (int j = 0; j < cols; ++j) { + output[i * cols + j] = std::exp(input[i * cols + j] - max_val); + sum += output[i * cols + j]; + } + for (int j = 0; j < cols; ++j) { + output[i * cols + j] /= sum; + } + } +} + +void sigmoid(const float* input, float* output, int rows, int cols) { + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + output[i * cols + j] = 1.0f / (1.0f + std::exp(-input[i * cols + j])); + } + } +} + +void topKSelection(const float* input, int* output, float* values, int rows, int num_expert, int k) { + for (int i = 0; i < rows; ++i) { + std::vector> scores; + for (int j = 0; j < num_expert; ++j) { + scores.emplace_back(input[i * num_expert + j], j); + } + std::partial_sort(scores.begin(), scores.begin() + k, scores.end(), + [](const std::pair& a, const std::pair& b) { + return a.first > b.first; + }); + for (int j = 0; j < k; ++j) { + output[i * k + j] = scores[j].second; + values[i * k + j] = scores[j].first; + } + } +} + +void topKSelectionNoauxTc(const float* input, int* index, float* scores, const float* bias, int rows, + int num_expert, int k, int n_group, int topk_group) { + int group_size = num_expert / n_group; + for (int i = 0; i < rows; ++i) { + std::vector> scores_for_choice; + for (int j = 0; j < num_expert; ++j) { + float score = input[i * num_expert + j] + bias[j]; + scores_for_choice.emplace_back(score, j); + } + + if (n_group > 1) { + // compute group score (sum of top 2) + std::vector> group_scores; + for (int j = 0; j < n_group; ++j) { + float max1 = -std::numeric_limits::infinity(); + float max2 = -std::numeric_limits::infinity(); + for (int g = 0; g < group_size; ++g) { + float s = scores_for_choice[j * group_size + g].first; + if (s > max1) { + max2 = max1; + max1 = s; + } else if (s > max2) { + max2 = s; + } + } + group_scores.emplace_back(max1 + max2, j); + } + // find topk_group groups + std::partial_sort(group_scores.begin(), group_scores.begin() + topk_group, group_scores.end(), + [](const std::pair& a, const std::pair& b) { + return a.first > b.first; + }); + // mask out scores_for_choice for non topk groups + for (int j = topk_group; j < n_group; ++j) { + int group_idx = group_scores[j].second; + for (int g = 0; g < group_size; ++g) { + int idx = group_idx * group_size + g; + scores_for_choice[idx].first = -std::numeric_limits::infinity(); + } + } + } + + // Sort the scores_for_choice to get the top k experts and their scores + std::partial_sort(scores_for_choice.begin(), scores_for_choice.begin() + k, scores_for_choice.end(), + [](const std::pair& a, const std::pair& b) { + return a.first > b.first; + }); + for (int j = 0; j < k; ++j) { + int expert_idx = scores_for_choice[j].second; + index[i * k + j] = expert_idx; + scores[i * k + j] = input[i * num_expert + expert_idx]; + } + } +} + +void accumulate_output(size_t hidden_dim, float16_t* output, const float16_t* down_output, float16_t weight) { + int h; + for (h = 0; h <= hidden_dim - 8; h += 8) { + float16x8_t out_vec = vld1q_f16(output + h); + float16x8_t down_vec = vld1q_f16(down_output + h); + out_vec = vfmaq_n_f16(out_vec, down_vec, weight); + vst1q_f16(output + h, out_vec); + } + for (; h < hidden_dim; h++) { + output[h] += weight * down_output[h]; + } +} + +void accumulate_output(size_t hidden_dim, float* output, const float* down_output, float weight) { + int h; + for (h = 0; h <= hidden_dim - 4; h += 4) { + float32x4_t out_vec = vld1q_f32(output + h); + float32x4_t down_vec = vld1q_f32(down_output + h); + out_vec = vmlaq_n_f32(out_vec, down_vec, weight); + vst1q_f32(output + h, out_vec); + } + for (; h < hidden_dim; h++) { + output[h] += weight * down_output[h]; + } +} + +extern size_t get_lhs_packed_size_kai_a8w4(int* m_array, int k, int bs, size_t* offsets); +extern void batch_pack_lhs_kai_a8w4(const float16_t* input, const size_t* input_offsets, uint8_t* output, const size_t* output_offsets, int* m_array, int k, int bs); +extern void batch_matmul_kai_a8w4(const uint8_t* input, const size_t* input_offsets, + const uint8_t* weight, const size_t* weight_offsets, + float16_t* output, const size_t* output_offsets, + int* m_array, int k, int n, size_t output_stride, + int bs); + +FfnLayerOutput ArmCpuDevice::moe_ffn_a8w4(const BufferPtr expert_indices, const BufferPtr expert_weights, const BufferPtr output, const FfnLayerParams& params) { + const auto& hidden = params.input; + const auto token_num = hidden.shape()[0]; + const auto hidden_dim = hidden.shape()[1]; + const auto num_expert = params.weights.moe_gating_weight->kernel->shape()[1]; + + const auto& moe_conf = params.configs.moe_configs.value(); + const auto top_k = moe_conf.top_k; + const size_t moe_inter_size = moe_conf.moe_inter_padding_size; + + std::vector m_array; + std::vector activated_experts; + std::vector> expert_to_tokens(num_expert); + std::vector> expert_to_weights(num_expert); + std::vector> token_to_indices(token_num); + std::vector> token_to_weights(token_num); + + // Allocate input buffer for the experts + auto input_buffer = allocateBuffer({DataType::TYPE_FP16, {token_num + token_num * top_k, hidden_dim}}); + std::vector input_offsets; + + for (int i = 0; i < token_num; ++i) { + for (int j = 0; j < top_k; ++j) { + int expert_idx = *(int*)(expert_indices->dataWithOffset(i * top_k + j)); + float weight = *(float*)(expert_weights->dataWithOffset(i * top_k + j)); + expert_to_tokens[expert_idx].push_back(i); + expert_to_weights[expert_idx].push_back((float16_t)weight); + } + } + + // hack: shared expert + m_array.push_back(token_num); + input_offsets.push_back(0); + + int idx = token_num; + for (int i = 0; i < num_expert; ++i) { + if (expert_to_tokens[i].empty()) { + continue; + } + activated_experts.push_back(i); + m_array.push_back(expert_to_tokens[i].size()); + input_offsets.push_back(idx * hidden_dim); + for (int j = 0; j < expert_to_tokens[i].size(); ++j) { + int token_idx = expert_to_tokens[i][j]; + float16_t weight = expert_to_weights[i][j]; + token_to_indices[token_idx].push_back(idx); + token_to_weights[token_idx].push_back(weight); + idx++; + } + } + + #pragma omp parallel for if (token_num > 1) + for (int i = 0; i < token_num; ++i) { + // shared expert + memcpy(input_buffer->dataWithOffset(i * hidden_dim), hidden.dataWithOffset(i * hidden_dim), hidden_dim * sizeof(float16_t)); + + for (int j = 0; j < top_k; ++j) { + int idx = token_to_indices[i][j]; + memcpy(input_buffer->dataWithOffset(idx * hidden_dim), hidden.dataWithOffset(i * hidden_dim), hidden_dim * sizeof(float16_t)); + } + } + + int bs = activated_experts.size(); + + // Calculate offsets for weights and outputs + auto up_gate_weight_packed_size = get_rhs_packed_size(moe_inter_size * 2, hidden_dim); + auto down_weight_packed_size = get_rhs_packed_size(hidden_dim, moe_inter_size); + std::vector up_weight_offsets(1 + bs); + std::vector gate_weight_offsets(1 + bs); + std::vector down_weight_offsets(1 + bs); + std::vector up_gate_output_offsets(1 + bs); + + // hack: shared expert + up_weight_offsets[0] = (uint8_t*)params.weights.shared_expert->up_weight->kernel->data() - (uint8_t*)params.weights.moe_gate_weight->kernel->data(); + gate_weight_offsets[0] = (uint8_t*)params.weights.shared_expert->gate_weight->kernel->data() - (uint8_t*)params.weights.moe_gate_weight->kernel->data(); + down_weight_offsets[0] = (uint8_t*)params.weights.shared_expert->down_weight->kernel->data() - (uint8_t*)params.weights.moe_down_weight->kernel->data(); + up_gate_output_offsets[0] = 0; + size_t offset = m_array[0]; + + for (int i = 0; i < bs; ++i) { + up_weight_offsets[1 + i] = activated_experts[i] * up_gate_weight_packed_size; + gate_weight_offsets[1 + i] = activated_experts[i] * up_gate_weight_packed_size + up_gate_weight_packed_size / 2; + down_weight_offsets[1 + i] = activated_experts[i] * down_weight_packed_size; + up_gate_output_offsets[1 + i] = offset * moe_inter_size; + offset += m_array[1 + i]; + } + + // up gate projection + std::vector packed_input_offsets(1 + bs); + size_t packed_size = get_lhs_packed_size_kai_a8w4(m_array.data(), hidden_dim, 1 + bs, packed_input_offsets.data()); + auto packed_input = allocateBuffer({DataType::TYPE_UINT8, {packed_size}}); + batch_pack_lhs_kai_a8w4((const float16_t*)input_buffer->data(), input_offsets.data(), (uint8_t*)packed_input->data(), packed_input_offsets.data(), m_array.data(), hidden_dim, 1 + bs); + auto up_output = allocateBuffer({DataType::TYPE_FP16, {token_num + token_num * top_k, moe_inter_size}}); + auto gate_output = allocateBuffer({DataType::TYPE_FP16, {token_num + token_num * top_k, moe_inter_size}}); + batch_matmul_kai_a8w4((uint8_t*)packed_input->data(), packed_input_offsets.data(), + (uint8_t*)params.weights.moe_gate_weight->kernel->data(), up_weight_offsets.data(), + (float16_t*)up_output->data(), up_gate_output_offsets.data(), + m_array.data(), hidden_dim, moe_inter_size, moe_inter_size * sizeof(float16_t), 1 + bs); + + batch_matmul_kai_a8w4((uint8_t*)packed_input->data(), packed_input_offsets.data(), + (uint8_t*)params.weights.moe_gate_weight->kernel->data(), gate_weight_offsets.data(), + (float16_t*)gate_output->data(), up_gate_output_offsets.data(), + m_array.data(), hidden_dim, moe_inter_size, moe_inter_size * sizeof(float16_t), 1 + bs); + + // Activation + activation({params.configs.activation_type, + up_output, + mayGetRef(params.weights.moe_gate_weight->bias), + *gate_output, std::nullopt, mayGetRef(params.weights.act_scale)}); + + // Down projection + packed_size = get_lhs_packed_size_kai_a8w4(m_array.data(), moe_inter_size, 1 + bs, packed_input_offsets.data()); + packed_input = allocateBuffer({DataType::TYPE_UINT8, {packed_size}}); + auto down_output = input_buffer; // reuse buffer + batch_pack_lhs_kai_a8w4((float16_t*)up_output->data(), up_gate_output_offsets.data(), (uint8_t*)packed_input->data(), packed_input_offsets.data(), m_array.data(), moe_inter_size, bs + 1); + batch_matmul_kai_a8w4((uint8_t*)packed_input->data(), packed_input_offsets.data(), + (uint8_t*)params.weights.moe_down_weight->kernel->data(), down_weight_offsets.data(), + (float16_t*)down_output->data(), input_offsets.data(), + m_array.data(), moe_inter_size, hidden_dim, hidden_dim * sizeof(float16_t), 1 + bs); + + // Accumulate output + #pragma omp parallel for if (token_num > 1) + for (int i = 0; i < token_num; ++i) { + float16_t* output_ptr = (float16_t*)output->dataWithOffset(i * hidden_dim); + + // shared expert + accumulate_output(hidden_dim, output_ptr, (float16_t*)down_output->dataWithOffset(i * hidden_dim), 1.0f); + + for (int j = 0; j < top_k; ++j) { + int idx = token_to_indices[i][j]; + float16_t w = token_to_weights[i][j]; + accumulate_output(hidden_dim, output_ptr, (float16_t*)down_output->dataWithOffset(idx * hidden_dim), w); + } + } + + return FfnLayerOutput({move(output)}); +} + +FfnLayerOutput ArmCpuDevice::moeFfnLayer(const FfnLayerParams& params) { + RUNTIME_ASSERT_OP_ARG(params.configs.moe_configs, "moe configs not set"); + + const auto& moe_conf = params.configs.moe_configs.value(); + const auto& hidden = params.input; + const auto type = hidden.type(); + const auto weight_type = params.weights.moe_down_weight->kernel->type(); + + const auto token_num = hidden.shape()[0]; + const auto hidden_dim = hidden.shape()[1]; + const size_t moe_inter_size = moe_conf.moe_inter_padding_size; + const auto num_expert = params.weights.moe_gating_weight->kernel->shape()[1]; + const auto top_k = moe_conf.top_k; + // const auto normalize_expert_scale = moe_conf.normalize_expert_scale; + + BufferPtr output = nullptr; + if (params.output) { + output = params.output; + } else { + output = allocateBuffer({type, {token_num, hidden_dim}}); + } + memset(output->data(), 0, output->sizeBytes()); + + printBufferData(*(params.weights.moe_gating_weight->kernel), "moe_gating_weight"); + auto gate_logits = gemm(GemmParams(hidden, *(params.weights.moe_gating_weight->kernel), nullopt, nullptr, DataType::TYPE_FP32)); + printBufferData(*gate_logits, "gate_logits"); + + auto expert_weights = allocateBuffer({DataType::TYPE_FP32, {token_num, top_k}, AllocationType::HOST}); + auto expert_indices = allocateBuffer({DataType::TYPE_INT32, {token_num, top_k}, AllocationType::HOST}); + + std::vector gate_probs(token_num * num_expert); + + //scoring_func 0: softmax, 1: sigmoid + if (moe_conf.scoring_func == 0) { + softmax2(gate_logits->data(), gate_probs.data(), token_num, num_expert); + } else if (moe_conf.scoring_func == 1) { + sigmoid(gate_logits->data(), gate_probs.data(), token_num, num_expert); + } else { + throw std::runtime_error("Unsupported scoring function for moe"); + } + + if (params.weights.e_score_correction_bias) { + float* bias = (float*)params.weights.e_score_correction_bias->data(); + topKSelectionNoauxTc((const float *)gate_probs.data(), (int *)expert_indices->data(), (float *)expert_weights->data(), + bias, token_num, num_expert, top_k, moe_conf.n_group, moe_conf.topk_group); + } else { + topKSelection((const float *)gate_probs.data(), (int *)expert_indices->data(), (float *)expert_weights->data(), token_num, num_expert, top_k); + } + + if (moe_conf.has_moe_norm) { + for (int s = 0; s < token_num; ++s) { + float sum = 0.0f; + float* weight_ptr = (float *)expert_weights->dataWithOffset(s * top_k); + for (int k = 0; k < top_k; ++k) { + sum += weight_ptr[k]; + } + for (int k = 0; k < top_k; ++k) { + weight_ptr[k] /= sum; + } + } + } + + printBufferData(*expert_weights, "expert_weights"); + printBufferData(*expert_indices, "expert_indices"); + + if (type == DataType::TYPE_FP16 && weight_type == DataType::TYPE_QFP8_E4M3) { + return moe_ffn_a8w4(expert_indices, expert_weights, output, params); + } + + std::vector> expert_to_tokens(num_expert); + std::vector> expert_to_weights(num_expert); + // Build mapping + for (int s = 0; s < token_num; ++s) { + for (int k = 0; k < top_k; ++k) { + int expert_idx = *(int*)(expert_indices->dataWithOffset(s * top_k + k)); + float weight = *(float*)(expert_weights->dataWithOffset(s * top_k + k)); + expert_to_tokens[expert_idx].push_back(s); + expert_to_weights[expert_idx].push_back(weight); + } + } + + // Process each expert + for (int expert_idx = 0; expert_idx < num_expert; ++expert_idx) { + const auto& tokens = expert_to_tokens[expert_idx]; + const auto& weights = expert_to_weights[expert_idx]; + + if (tokens.empty()) continue; + + // Gather hidden for this expert + auto input_buffer = allocateBuffer({type, {tokens.size(), hidden_dim}}, {"input"}); + for (int i = 0; i < tokens.size(); ++i) { + // copy hidden[tokens[i]] to input_buffer[i] + memcpy(input_buffer->dataWithOffset(i * hidden_dim), hidden.dataWithOffset(tokens[i] * hidden_dim), hidden_dim * hidden.typeSize()); + } + + // Up projection + auto expert_up_gate_size = params.weights.moe_gate_weight->kernel->size() / num_expert; + + std::vector expert_up_shape = {hidden_dim, moe_inter_size}; + auto up_data = params.weights.moe_gate_weight->kernel->dataWithOffset(expert_idx * expert_up_gate_size); + if (weight_type == DataType::TYPE_QFP8_E4M3) { + // TYPE_QFP8_E4M3 dataWithOffset will return data with no offset + // Manually handle the offset + auto rhs_packed_size = get_rhs_packed_size(moe_inter_size * 2, hidden_dim); + up_data = (char*)params.weights.moe_gate_weight->kernel->data() + expert_idx * rhs_packed_size; + } + auto moe_up_buffer = BufferPtr(new Buffer(MemoryType::MEMORY_CPU, + weight_type, + expert_up_shape, + up_data)); + + auto up_output = gemm(GemmParams(*input_buffer, *moe_up_buffer)); + + // Gate projection + auto expert_up_size = params.weights.moe_gate_weight->kernel->size() / num_expert / 2 ; + auto gate_data = params.weights.moe_gate_weight->kernel->dataWithOffset(expert_idx * expert_up_gate_size + expert_up_size); + if (weight_type == DataType::TYPE_QFP8_E4M3) { + // TYPE_QFP8_E4M3 dataWithOffset will return data with no offset + // Manually handle the offset + auto rhs_packed_size = get_rhs_packed_size(moe_inter_size * 2, hidden_dim); + gate_data = (char*)params.weights.moe_gate_weight->kernel->data() + expert_idx * rhs_packed_size + rhs_packed_size / 2; + } + auto moe_gate_buffer = BufferPtr(new Buffer(MemoryType::MEMORY_CPU, + weight_type, + expert_up_shape, + gate_data)); + + auto gate_output = gemm(GemmParams(*input_buffer, *moe_gate_buffer)); + + // Activation + activation({params.configs.activation_type, + up_output, + mayGetRef(params.weights.moe_gate_weight->bias), + *gate_output, std::nullopt, mayGetRef(params.weights.act_scale)}); + + // Down projection + auto expert_down_size = params.weights.moe_down_weight->kernel->size() / num_expert; + std::vector expert_down_shape = {moe_inter_size, hidden_dim}; + auto down_data = params.weights.moe_down_weight->kernel->dataWithOffset(expert_idx * expert_down_size); + if (weight_type == DataType::TYPE_QFP8_E4M3) { + // TYPE_QFP8_E4M3 dataWithOffset will return data with no offset + // Manually handle the offset + auto rhs_packed_size = get_rhs_packed_size(hidden_dim, moe_inter_size); + down_data = (char*)params.weights.moe_down_weight->kernel->data() + expert_idx * rhs_packed_size; + } + auto moe_down_buffer = BufferPtr(new Buffer(MemoryType::MEMORY_CPU, + weight_type, + expert_down_shape, + down_data)); + + auto down_output = gemm(GemmParams(*up_output, *moe_down_buffer)); + + // Scatter back with weight + for (int i = 0; i < tokens.size(); ++i) { + int token_idx = tokens[i]; + float w = weights[i]; + // accumulate output[token_idx] += w * down_output[i] + if (type == DataType::TYPE_FP32) { + accumulate_output(hidden_dim, (float*)output->dataWithOffset(token_idx * hidden_dim), (float*)down_output->dataWithOffset(i * hidden_dim), w); + } else if (type == DataType::TYPE_FP16) { + accumulate_output(hidden_dim, (float16_t*)output->dataWithOffset(token_idx * hidden_dim), (float16_t*)down_output->dataWithOffset(i * hidden_dim), (float16_t)w); + } else { + throw std::runtime_error("Unsupported data type"); + } + } + } + printBufferData(*output, "moe_ffn_out"); + + return FfnLayerOutput({move(output)}); +} + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/devices/arm_impl/ArmGemmKaiOp.cc b/rtp_llm/cpp/devices/arm_impl/ArmGemmKaiOp.cc index dc6f787a8..cee53ba64 100644 --- a/rtp_llm/cpp/devices/arm_impl/ArmGemmKaiOp.cc +++ b/rtp_llm/cpp/devices/arm_impl/ArmGemmKaiOp.cc @@ -296,11 +296,139 @@ BufferPtr ArmCpuDevice::gemm_kai_bf16(const GemmParams& params) { return output; } +inline size_t get_a8w4_variant(size_t m) { + return (m == 1) ? 0 : 1; +} + +size_t get_lhs_packed_size_kai_a8w4(int* m_array, int k, int bs, size_t* offsets) { + size_t packed_size = 0; + const size_t bl = 32; + for (int i = 0; i < bs; i++) { + int m = m_array[i]; + size_t idx_variant = get_a8w4_variant(m); + size_t mr = fp16_ukernel_variants[idx_variant].ukernel.get_mr(); + size_t kr = fp16_ukernel_variants[idx_variant].ukernel.get_kr(); + size_t sr = fp16_ukernel_variants[idx_variant].ukernel.get_sr(); + + offsets[i] = packed_size; + packed_size += kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f16(m, k, bl, mr, kr, sr); + } + return packed_size; +} + +// batch pack lhs with different m and same k +void batch_pack_lhs_kai_a8w4(const float16_t* input, const size_t* input_offsets, uint8_t* output, const size_t* output_offsets, int* m_array, int k, int bs) { + const size_t bl = 32; + const int m_step = 128; + const int max_m = m_array[0]; + #pragma omp parallel for collapse(2) schedule(dynamic, 1) + for (int b = 0; b < bs; b++) { + for (int m_start = 0; m_start < max_m; m_start += m_step) { + int m = m_array[b]; + if (m_start >= m) { + continue; + } + size_t idx_variant = get_a8w4_variant(m); + size_t mr = fp16_ukernel_variants[idx_variant].ukernel.get_mr(); + size_t kr = fp16_ukernel_variants[idx_variant].ukernel.get_kr(); + size_t sr = fp16_ukernel_variants[idx_variant].ukernel.get_sr(); + + const float16_t* input_ptr = input + input_offsets[b]; + uint8_t* output_ptr = output + output_offsets[b]; + + const size_t lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f16(m_start, k, bl, mr, kr, sr); + int tile_m = (m_start + m_step <= m) ? m_step : m - m_start; + kai_run_lhs_quant_pack_qsi8d32p_f16( + tile_m, k, bl, mr, kr, sr, 0, + input_ptr + m_start * k, + k * sizeof(float16_t), + output_ptr + lhs_packed_offset); + } + } +} + +// batch matmul with same shape packed weights +// input: qsi8d32p weight: qsi4c32p output: fp16 +void batch_matmul_kai_a8w4(const uint8_t* input, const size_t* input_offsets, + const uint8_t* weight, const size_t* weight_offsets, + float16_t* output, const size_t* output_offsets, + int* m_array, int k, int n, size_t output_stride, + int bs) { + const size_t bl = 32; + const int n_step = 256; + #pragma omp parallel for collapse(2) schedule(dynamic, 1) + for (int b = 0; b < bs; ++b) { + for (int n_start = 0; n_start < n; n_start += n_step) { + int m = m_array[b]; + size_t idx_variant = get_a8w4_variant(m); + + const size_t rhs_offset = fp16_ukernel_variants[idx_variant].ukernel.get_rhs_packed_offset(n_start, k, bl); + const size_t dst_offset = fp16_ukernel_variants[idx_variant].ukernel.get_dst_offset(0, n_start, output_stride); + + const void* lhs_ptr = input + input_offsets[b]; + const void* rhs_ptr = weight + weight_offsets[b] + rhs_offset; + float16_t* dst_ptr = output + output_offsets[b] + dst_offset / sizeof(float16_t); + + int tile_n = (n_start + n_step <= n) ? n_step : n - n_start; + + fp16_ukernel_variants[idx_variant].ukernel.run_matmul( + m, tile_n, k, bl, // Dimensions + lhs_ptr, // LHS packed + rhs_ptr, // RHS packed + dst_ptr, // DST + output_stride, // DST stride (row) + sizeof(float16_t), // DST stride (col) + -HALF_FLT_MAX, HALF_FLT_MAX // Min and max for the clamp operation + ); + } + } +} + +// batch matmul with same shape packed weights +// input: fp32 weight: bf16 output: fp32 +void batch_matmul_kai_bf16(const float* input, size_t input_batch_stride, size_t input_row_stride, + const bfloat16_t* weight, + float* output, size_t output_batch_stride, size_t output_row_stride, + int m, int k, int n, int bs) { + const size_t mr = kai_get_mr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(); + const size_t kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(); + const size_t sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(); + + const size_t lhs_packed_size = kai_get_lhs_packed_size_lhs_quant_pack_bf16p8x4_f32_neon(m, k, mr, kr, sr); + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon(n, k); + + #pragma omp parallel for + for (int i = 0; i < bs; i++) { + uint8_t* lhs_packed = new uint8_t[lhs_packed_size]; + const float* lhs_ptr = input + i * input_batch_stride; + const bfloat16_t* rhs_ptr = weight + i * rhs_packed_size / sizeof(bfloat16_t); + float* output_ptr = output + i * output_batch_stride; + kai_run_lhs_quant_pack_bf16p8x4_f32_neon( + m, k, mr, kr, sr, 0, + lhs_ptr, + input_row_stride * sizeof(float), + lhs_packed); + kai_run_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla( + m, n, k, // Dimensions + lhs_packed, // LHS + rhs_ptr, // RHS packed + output_ptr, // DST + output_row_stride * sizeof(float), // DST stride (row) + sizeof(float), // DST stride (col) + -FLT_MAX, FLT_MAX // Min and max for the clamp operation + ); + delete[] lhs_packed; + } +} + BufferPtr ArmCpuDevice::gemm_kai_a8w4(const GemmParams& params) { #ifdef GEMM_DEBUG auto start = std::chrono::high_resolution_clock::now(); #endif - params.check(); + + if (params.B.type() != DataType::TYPE_QFP8_E4M3) { + params.check(); + } std::vector Ashape; std::vector Bshape; diff --git a/rtp_llm/cpp/devices/arm_impl/ArmGemmOptOp.cc b/rtp_llm/cpp/devices/arm_impl/ArmGemmOptOp.cc index 5b925caa2..c99298130 100644 --- a/rtp_llm/cpp/devices/arm_impl/ArmGemmOptOp.cc +++ b/rtp_llm/cpp/devices/arm_impl/ArmGemmOptOp.cc @@ -17,7 +17,7 @@ namespace rtp_llm { /// B [b, ..., k, n] /// C [b, ..., m, n] BufferPtr ArmCpuDevice::gemm(const GemmParams& params) { - if (params.B.type() == DataType::TYPE_QINT4X2) + if (params.B.type() == DataType::TYPE_QINT4X2 || params.B.type() == DataType::TYPE_QFP8_E4M3) return gemm_kai_a8w4(params); return (this->*gemmFunc)(params); } diff --git a/rtp_llm/cpp/devices/arm_impl/ArmLayerNormOp.cc b/rtp_llm/cpp/devices/arm_impl/ArmLayerNormOp.cc index 106a4bf71..b85e0084f 100644 --- a/rtp_llm/cpp/devices/arm_impl/ArmLayerNormOp.cc +++ b/rtp_llm/cpp/devices/arm_impl/ArmLayerNormOp.cc @@ -3,6 +3,8 @@ #include "rtp_llm/cpp/core/allocator.h" #include "rtp_llm/cpp/core/cpu_allocator.h" #include "rtp_llm/cpp/devices/utils/DebugUtils.h" +#include "rtp_llm/cpp/devices/arm_impl/ArmDispatch.h" + #include #include #include //std::all_of @@ -1002,12 +1004,12 @@ LayernormOutput ArmCpuDevice::layernorm(const LayernormParams& params) { int m = input->shape()[0]; int n = input->shape()[1]; const auto data_type = input->type(); - if (!params.is_inplace && params.qscheme == QScheme::NoQuantize) { + if (!params.is_inplace && (params.qscheme == QScheme::NoQuantize || params.qscheme == QScheme::Qfp8PerTokenBlock)) { norm_output = allocateBufferLike(*params.input); } else if (params.qscheme == Qint8PerToken) { throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); } - + int convert_gamma = 0; int convert_beta = 0; int convert_bias = 0; @@ -1463,4 +1465,178 @@ LayernormOutput ArmCpuDevice::layernorm(const LayernormParams& params) { else throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); } +template +inline void layernorm(T* out, + const T* in, + const T* gamma, + const T* beta, + int norm_size, + float eps) { + float mean = 0.0f; + float var = 0.0f; + + for (int i = 0; i < norm_size; ++i) { + mean += static_cast(in[i]); + } + mean /= norm_size; + + for (int i = 0; i < norm_size; ++i) { + float diff = static_cast(in[i]) - mean; + var += diff * diff; + } + var /= norm_size; + float inv_std = 1.0f / std::sqrt(var + eps); + + for (int i = 0; i < norm_size; ++i) { + float x = (static_cast(in[i]) - mean) * inv_std; + float g = static_cast(gamma[i]); + float b = beta ? static_cast(beta[i]) : 0.0f; + out[i] = static_cast(x * g + b); + } +} + +template +inline void rmsnorm(T* out, + const T* in, + const GammaT* gamma, + const T* beta, + int norm_size, + float eps, + bool is_beta) { + float sum_sq = 0.0f; + + for (int i = 0; i < norm_size; ++i) { + float val = static_cast(in[i]); + sum_sq += val * val; + } + + float scale = 1.0f / std::sqrt(sum_sq / norm_size + eps); + for (int i = 0; i < norm_size; ++i) { + float val = static_cast(in[i]) * scale * static_cast(gamma[i]); + if (is_beta) + val += static_cast(beta[i]); + + out[i] = static_cast(val); + } +} + +template +void invokeLayerNormWithStride(T* output, + int out_stride, + const T* input, + int in_stride, + const T* gamma, + const T* beta, + float eps, + int m, + int n, + int norm_size) { + assert(n % norm_size == 0); + const int num_groups = n / norm_size; + + for (int row = 0; row < m; ++row) { + for (int group = 0; group < num_groups; ++group) { + int in_offset = row * in_stride + group * norm_size; + int out_offset = row * out_stride + group * norm_size; + + layernorm(&output[out_offset], + &input[in_offset], + gamma, + beta, + norm_size, + eps); + } + } +} + +template +void invokeRmsNormWithStride(T* output, + int out_stride, + const T* input, + int in_stride, + const GammaT* gamma, + const T* beta, + float eps, + int m, + int n, + int norm_size) { + assert(n % norm_size == 0); + const int num_groups = n / norm_size; + + for (int row = 0; row < m; ++row) { + for (int group = 0; group < num_groups; ++group) { + int in_offset = row * in_stride + group * norm_size; + int out_offset = row * out_stride + group * norm_size; + + rmsnorm(&output[out_offset], + &input[in_offset], + gamma, + beta, + norm_size, + eps, + beta != nullptr); + } + } +} + +LayernormOutput ArmCpuDevice::layernormWithStride(const LayernormWithStrideParams& params) { + RTP_LLM_CHECK_WITH_INFO(params.qscheme == QScheme::NoQuantize, "qscheme must be NoQuantize in layernormWithStride"); + const auto data_type = params.input->type(); + const auto m = params.input->shape()[0]; + const auto in_stride = params.input->shape()[1]; + const auto norm_weight = params.norm_weight; + const auto& gamma = norm_weight ? norm_weight->get().gamma.get()->data() : nullptr; + const auto& beta = (norm_weight && norm_weight->get().beta) ? norm_weight->get().beta.get()->data() : nullptr; + const auto eps = params.eps; + + auto gamma_type = gamma ? norm_weight->get().gamma->type() : DataType::TYPE_FP32; + + int out_stride; + int out_offset; + BufferPtr norm_output; + + // if not in_place, we hope that the output is contiguous + if (params.in_place) { + norm_output = params.input; + out_stride = in_stride; + out_offset = params.offset; + } else { + norm_output = allocateBuffer({data_type, {m, params.norm_group_size}, AllocationType::DEVICE}, {"norm_with_stride_output"}); + out_stride = params.norm_group_size; + out_offset = 0; + } + + if (params.norm_type == NormType::layernorm) { + DISPATCH_ARM_FUNCTION_DATA_TYPE(data_type, + invokeLayerNormWithStride, + norm_output->dataWithOffset(out_offset), + out_stride, + params.input->dataWithOffset(params.offset), + in_stride, + gamma, + beta, + eps, + m, + params.norm_group_size, + norm_weight->get().gamma.get()->shape()[0]); + return LayernormOutput({norm_output, nullptr}); + } else if (params.norm_type == NormType::rmsnorm) { + DISPATCH_ARM_FUNCTION_TWO_DATA_TYPES(gamma_type, data_type, + invokeRmsNormWithStride, + norm_output->dataWithOffset(out_offset), + out_stride, + params.input->dataWithOffset(params.offset), + in_stride, + gamma, + beta, + eps, + m, + params.norm_group_size, + norm_weight->get().gamma.get()->shape()[0]); + return LayernormOutput({norm_output, nullptr}); + } else { + throw std::runtime_error(autil::StringUtil::formatString("unsupported layernorm type for layernormWithStride: %d", int(params.norm_type))); + } +} + } diff --git a/rtp_llm/cpp/devices/arm_impl/ArmMlaAttentionOp.cc b/rtp_llm/cpp/devices/arm_impl/ArmMlaAttentionOp.cc new file mode 100644 index 000000000..9a52cf89d --- /dev/null +++ b/rtp_llm/cpp/devices/arm_impl/ArmMlaAttentionOp.cc @@ -0,0 +1,361 @@ +#include "rtp_llm/cpp/devices/utils/DebugUtils.h" +#include "rtp_llm/cpp/devices/OpData.h" +#include "rtp_llm/cpp/devices/arm_impl/ArmDevice.h" +#include "rtp_llm/cpp/devices/arm_impl/ArmDispatch.h" +#include "rtp_llm/cpp/devices/utils/DevicePerfWrapper.h" +#include + +using namespace std; +using namespace rtp_llm; + +namespace rtp_llm { + +extern void batch_matmul_kai_bf16(const float* input, size_t input_batch_stride, size_t input_row_stride, const bfloat16_t* weight, + float* output, size_t output_batch_stride, size_t output_row_stride, int m, int k, int n, int bs); + +void InputGemmWrapper(const Buffer& q, Buffer& fused_q_input, const MlaAttentionModuleParams& params) { + if (params.weights.kc_weight->kernel->type() != DataType::TYPE_BF16) { + throw std::runtime_error("InputGemmWrapper weight data type is not BF16"); + } + + // q: [token_num, head_num, nope_head_dim + rope_head_dim] + // kc_weight: [head_num, nope_head_dim, ckv_dim] + // fused_q_input: [token_num, head_num, ckv_dim + rope_dim] + + int token_num = q.shape()[0]; + int head_num = params.configs.head_num; + int nope_head_dim = params.configs.nope_head_dim; + int rope_head_dim = params.configs.rope_head_dim; + int ckv_dim = params.configs.kv_lora_rank; + + batch_matmul_kai_bf16( + (float*)q.data(), nope_head_dim + rope_head_dim, head_num * (nope_head_dim + rope_head_dim), + (bfloat16_t*)params.weights.kc_weight->kernel->data(), + (float*)fused_q_input.data(), ckv_dim + rope_head_dim, head_num * (ckv_dim + rope_head_dim), + token_num, nope_head_dim, ckv_dim, head_num); +} + + +void OutputGemmWrapper(const Buffer& attn_out, Buffer& qkv_output, const MlaAttentionModuleParams& params) { + if (params.weights.vc_weight->kernel->type() != DataType::TYPE_BF16) { + throw std::runtime_error("OutputGemmWrapper weight data type is not BF16"); + } + + // attn_out: [token_num, head_num, ckv_dim] + // vc_weight: [head_num, ckv_dim, v_head_dim] + // qkv_output_t: [token_num, head_num, v_head_dim] + + int token_num = params.qkv_output->shape()[0]; + int head_num = params.configs.head_num; + int ckv_dim = params.configs.kv_lora_rank; + int v_head_dim = params.configs.v_head_dim; + + batch_matmul_kai_bf16( + (float*)attn_out.data(), ckv_dim, head_num * ckv_dim, + (bfloat16_t*)params.weights.vc_weight->kernel->data(), + (float*)qkv_output.data(), v_head_dim, head_num * v_head_dim, + token_num, ckv_dim, v_head_dim, head_num); +} + +void ArmCpuDevice::mlaAbsorbAttention(const MlaAttentionModuleParams& params) { + DevicePerfWrapper wrapper(this, "mlaDecoder_layer_%d", params.layer_id); + + auto fused_q_input = allocateBuffer({params.q.type(), {params.q.shape()[0], params.configs.head_num, params.configs.kv_lora_rank + params.configs.rope_head_dim}, AllocationType::DEVICE}); + + mlaRotaryWriteKVCache({params.q, + fused_q_input, + params.fused_qkv, + params.kv_offset, + params.is_prefill ? params.common.prefill_flash_infer_attn : params.common.decode_flash_infer_attn, + params.common, + params.weights, + params.configs, + params.qscheme, + params.is_prefill}); + + if (params.is_prefill) { + writeCacheStore(params); + } + + computeInsertedMoE(); + + InputGemmWrapper(params.q, *fused_q_input, params); + printBufferData(*fused_q_input, "fused_q_input"); + + auto datatype = params.q.type(); + BufferPtr attn_out; + + attn_out = allocateBuffer({datatype, {params.q.shape()[0], params.configs.head_num, params.configs.kv_lora_rank}, AllocationType::DEVICE}); + + if (params.q.type() != DataType::TYPE_FP32) { + throw std::runtime_error("mla absorb attention data type is not supported"); + } + + AttentionModuleParams attn_params = AttentionModuleParams( + {params.layer_id, *fused_q_input, *attn_out, params.common, params.weights, params.configs, params.qscheme}); + decoderSelfAttention(attn_params); + + OutputGemmWrapper(*attn_out, *params.qkv_output, params); +} + +template +void invokeMlaQKVMerge(T* q, + T* k_nope, + T* k_rope, + T* v, + float* qkv, + int token_num, + int head_num, + int nope_head_dim, + int rope_head_dim, + int v_head_dim) { + int nope_rope_dim = nope_head_dim + rope_head_dim; + int hidden_size = head_num * nope_rope_dim; + + parallel_for(token_num, [&](int bs_idx) { + for (int head_idx = 0; head_idx < head_num; ++head_idx) { + int q_offset = bs_idx * head_num * nope_rope_dim + head_idx * nope_rope_dim; + int k_nope_offset = bs_idx * head_num * nope_head_dim + head_idx * nope_head_dim; + int k_rope_offset = bs_idx * rope_head_dim; // broadcast to head_num + int v_offset = bs_idx * head_num * v_head_dim + head_idx * v_head_dim; + int dst_base_offset = bs_idx * 3 * hidden_size + head_idx * nope_rope_dim; + + memcpy(qkv + dst_base_offset, q + q_offset, nope_rope_dim * sizeof(T)); + + memcpy(qkv + dst_base_offset + hidden_size, k_nope + k_nope_offset, nope_head_dim * sizeof(T)); + memcpy(qkv + dst_base_offset + hidden_size + nope_head_dim, k_rope + k_rope_offset, rope_head_dim * sizeof(T)); + + memcpy(qkv + dst_base_offset + 2 * hidden_size, v + v_offset, v_head_dim * sizeof(T)); + memset(qkv + dst_base_offset + 2 * hidden_size + v_head_dim, 0, (nope_rope_dim - v_head_dim) * sizeof(T)); + } + }); +} + +AttentionModuleOutput ArmCpuDevice::mlaContextAttention(const MlaAttentionModuleParams& params) { + DevicePerfWrapper wrapper(this, "mlaContext_layer_%d", params.layer_id); + auto& q = params.q; + auto& fused_qkv = params.fused_qkv; + + auto const token_num = q.shape()[0]; + auto const head_num = params.configs.head_num; + auto const nope_head_dim = params.configs.nope_head_dim; + auto const rope_head_dim = params.configs.rope_head_dim; + auto const v_head_dim = params.configs.v_head_dim; + auto const nope_rope_dim = nope_head_dim + rope_head_dim; + auto const size_per_head = params.configs.size_per_head; + mlaRotaryWriteKVCache({q, + nullptr, + fused_qkv, + params.kv_offset, + params.common.prefill_flash_infer_attn, + params.common, + params.weights, + params.configs, + params.qscheme, + params.is_prefill}); + writeCacheStore(params); + + computeInsertedMoE(); + + auto split_result = split({fused_qkv, {(size_t)params.kv_offset, (size_t)params.configs.kv_lora_rank, (size_t)params.configs.rope_head_dim}, 1}); + auto kv_a = split_result.outputs[1]; + auto k_rope = split_result.outputs[2]; + printBufferData(q, "q_after_rope"); + printBufferData(*kv_a, "kv_a"); + printBufferData(*k_rope, "k_rope"); + + auto datatype = fused_qkv.type(); + auto qkv = + allocateBuffer({datatype, {token_num, head_num * nope_rope_dim * 3}, AllocationType::DEVICE}, {"mla_qkv"}); + + auto k_nope = gemm(GemmParams(*kv_a, *(params.weights.k_nope_weight->kernel))); + auto v = gemm(GemmParams(*kv_a, *(params.weights.v_weight->kernel))); + + printBufferData(*k_nope, "k_nope"); + printBufferData(*v, "v"); + + DISPATCH_ARM_FUNCTION_DATA_TYPE(datatype, + invokeMlaQKVMerge, + q.data(), + k_nope->data(), + k_rope->data(), + v->data(), + qkv->data(), + token_num, + head_num, + nope_head_dim, + rope_head_dim, + v_head_dim); + + printBufferData(*qkv, "mla_qkv"); + auto padded_qkv_output_t = allocateBuffer({datatype, {token_num, head_num * size_per_head}, AllocationType::DEVICE}, {"padded_qkv_output"}); + AttentionModuleParams attn_params = AttentionModuleParams( + {params.layer_id, *qkv, *padded_qkv_output_t, params.common, params.weights, params.configs, params.qscheme}); + // only paged fmha use kv_block_array, mld not use paged fmha + contextAttention(attn_params); + + auto qkv_output_reshaped = padded_qkv_output_t->reshape({token_num, params.configs.head_num, size_per_head}); + auto sliced_buffer = slice({qkv_output_reshaped, -1, 0, (int64_t)v_head_dim}); + copy({*params.qkv_output, *sliced_buffer}); +} + +template +void ApplyRope(int token_num, int *position_ids, int num_heads, T *src, int src_stride, int src_head_size, int src_head_offset, T *dst, int dst_stride, int dst_head_size, int dst_head_offset, int rope_dim, T* rope_cos_sin_cache) { + if (src_head_size != src_head_offset + rope_dim) { + throw std::runtime_error("Rope src wrong dim"); + } + + if (dst_head_size != dst_head_offset + rope_dim) { + throw std::runtime_error("Rope dst wrong dim"); + } + + size_t inv_freq_size = (rope_dim + 1) / 2; + + auto rope = [&](int i) { + T* input = src + i * src_stride + src_head_offset; + T* output = dst + i * dst_stride + dst_head_offset; + + int pos = position_ids[i]; + + for (int h = 0; h < num_heads; h++) { + for (int d = 0; d < inv_freq_size; d++) { + float fcr = rope_cos_sin_cache[pos * rope_dim + d]; + float fci = rope_cos_sin_cache[pos * rope_dim + inv_freq_size + d]; + float x = input[h * src_head_size + d]; + float y = input[h * src_head_size + d + inv_freq_size]; + output[h * dst_head_size + d] = x * fcr - y * fci; + output[h * dst_head_size + d + inv_freq_size] = x * fci + y * fcr; + } + } + }; + + if (token_num == 1) { // fast path for decode + rope(0); + } else { + parallel_for(token_num, rope); + } +} + +// apply rope to a [token_num, num_head, head_size] buffer +template +void ApplyRope(int token_num, int *position_ids, int num_heads, T *buffer, int stride, int head_size, int head_offset, int rope_dim, T* rope_cos_sin_cache) { + ApplyRope(token_num, position_ids, num_heads, buffer, stride, head_size, head_offset, buffer, stride, head_size, head_offset, rope_dim, rope_cos_sin_cache); +} + +void getCacheAddrFromIndex(const KvCacheInfo& kv_cache, size_t batch, size_t block_idx, void **k_addr) { + const auto& kv_blocks_offset = *(kv_cache.kv_cache_block_id); + const auto& k_cache = *(kv_cache.k_cache_buffer); + const auto max_blocks_per_batch = kv_blocks_offset.shape()[1]; + size_t block_size = k_cache[0].sizeBytes(); + int *index = (int *)kv_blocks_offset.data(); + + *k_addr = (char*)k_cache.data() + index[batch * max_blocks_per_batch + block_idx] * block_size; +} + +// update KV cache, from fused_qkv to k_cache_buffer, [token_num, dim] +// dim = ckv_dim + rope_dim +template +void updateKVCacheStride(const MlaRotaryWriteKVCacheParams& params, T* fused_qkv, int stride, int dim, int batch, size_t seq_len, size_t step) { + auto block_tokens = params.configs.tokens_per_block; + size_t block_offset = step / block_tokens; + void *k_block_addr; + + // fast path for decode + if (seq_len == 1) { + getCacheAddrFromIndex(params.common.kv_cache.value(), batch, block_offset, &k_block_addr); + memcpy((T*)k_block_addr + step % block_tokens * dim, fused_qkv, dim * sizeof(T)); + return; + } + + size_t block_num = (seq_len + block_tokens - 1) / block_tokens; + size_t copied_len = 0; + + for (int i = 0; i < block_num; i++) { + size_t len = std::min(block_tokens, seq_len - copied_len); + getCacheAddrFromIndex(params.common.kv_cache.value(), batch, i + block_offset, &k_block_addr); + + T* input = fused_qkv + (i * block_tokens) * stride; + parallel_for(len, [&](int tid) { + memcpy((T*)k_block_addr + (step % block_tokens + tid) * dim, input + tid * stride, dim * sizeof(T)); + }); + + copied_len += len; + } +} + +void ArmCpuDevice::mlaRotaryWriteKVCache(const MlaRotaryWriteKVCacheParams& params) { + DevicePerfWrapper wrapper(this, "mlaRotaryWriteKVCache"); + float* q = (float*)params.q.data(); + float* fused_qkv = (float*)params.fused_qkv.data(); + float* rope_cos_sin_cache = (float*)params.weights.rope_cos_sin_cache->data(); + + // fused_qkv: q/cq + ckv + k_rope + // [token_num, kv_offset + ckv_dim + rope_dim] + // q: [token_num, head_num, nope_dim + rope_dim] + // fuse_dest_q: [token_num, head_num, ckv_dim + rope_dim] + + // rope_cos_sin_cache: [position_id, rope_dim] + + int batch = params.is_prefill ? params.common.context_batch_size : params.common.decoder_batch_size; + int token_num = params.q.shape()[0]; + + auto position_buf = allocateBuffer({DataType::TYPE_INT32, {(size_t)token_num}}); + int* position_ids = (int*)position_buf->data(); + if (params.is_prefill) { + int *inp_len = (int*)(params.common.input_lengths->dataWithOffset(params.common.decoder_batch_size)); + int offset = 0; + for (int b = 0; b < batch; b++) { + for (int i = 0; i < inp_len[b]; i++) { + position_ids[offset + i] = i; + } + offset += inp_len[b]; + } + } else { + int *seq_len = (int*)(params.common.sequence_lengths->data()); + for (int b = 0; b < batch; b++) { + position_ids[b] = seq_len[b]; + } + } + + int head_num = params.configs.head_num; + int nope_dim = params.configs.nope_head_dim; + int rope_dim = params.configs.rope_head_dim; + int ckv_dim = params.configs.kv_lora_rank; + int kv_offset = params.kv_offset; + // Q rope + if (params.fused_dest_q) { + float* dst_q = (float*)params.fused_dest_q->data(); + ApplyRope(token_num, position_ids, head_num, + q, head_num * (nope_dim + rope_dim), nope_dim + rope_dim, nope_dim, + dst_q, head_num * (ckv_dim + rope_dim), ckv_dim + rope_dim, ckv_dim, + rope_dim, rope_cos_sin_cache); + } else { + ApplyRope(token_num, position_ids, head_num, + q, head_num * (nope_dim + rope_dim), nope_dim + rope_dim, nope_dim, + rope_dim, rope_cos_sin_cache); + } + // K rope + ApplyRope(token_num, position_ids, 1, + fused_qkv + kv_offset + ckv_dim, kv_offset + ckv_dim + rope_dim, rope_dim, 0, + rope_dim, rope_cos_sin_cache); + + if (params.common.kv_cache.has_value()) { + int offset = 0; + for (int b = 0; b < batch; b++) { + int seq_len, step; + if (params.is_prefill) { + int *inp_len = (int*)(params.common.input_lengths->dataWithOffset(params.common.decoder_batch_size)); + step = 0; + seq_len = inp_len[b]; + } else { + int *steps = (int*)(params.common.sequence_lengths->data()); + seq_len = 1; + step = steps[b]; + } + updateKVCacheStride(params, fused_qkv + (kv_offset + ckv_dim + rope_dim) * offset + kv_offset, kv_offset + ckv_dim + rope_dim, ckv_dim + rope_dim, b, seq_len, step); + offset += seq_len; + } + } +} + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/devices/arm_impl/ArmMlaQKVGemm.cc b/rtp_llm/cpp/devices/arm_impl/ArmMlaQKVGemm.cc new file mode 100644 index 000000000..75959a8f0 --- /dev/null +++ b/rtp_llm/cpp/devices/arm_impl/ArmMlaQKVGemm.cc @@ -0,0 +1,146 @@ +#include "rtp_llm/cpp/devices/arm_impl/ArmDevice.h" +#include "rtp_llm/cpp/devices/utils/DebugUtils.h" +#include "rtp_llm/cpp/core/BufferHelper.h" + +namespace rtp_llm { + +template +void mla_merge_transpose_cpu(T* q, + T* k_nope, + T* k_rope, + T* v, + float* qkv, + int token_num, + int head_num, + int nope_head_dim, + int rope_head_dim, + int v_head_dim) { + int nope_rope_dim = nope_head_dim + rope_head_dim; + int hidden_size = head_num * nope_rope_dim; + + for (int bs_idx = 0; bs_idx < token_num; ++bs_idx) { + for (int head_idx = 0; head_idx < head_num; ++head_idx) { + for (int tidx = 0; tidx < nope_rope_dim; ++tidx) { + int q_offset = bs_idx * head_num * nope_rope_dim + head_idx * nope_rope_dim + tidx; + int k_nope_offset = bs_idx * head_num * nope_head_dim + head_idx * nope_head_dim + tidx; + int rope_idx = tidx - nope_head_dim; + int k_rope_offset = bs_idx * rope_head_dim + rope_idx; // broadcast to head_num + int v_offset = bs_idx * head_num * v_head_dim + head_idx * v_head_dim + tidx; + int dst_base_offset = bs_idx * 3 * hidden_size + head_idx * nope_rope_dim + tidx; + + if (tidx < nope_head_dim) { + qkv[dst_base_offset] = q[q_offset]; + qkv[dst_base_offset + hidden_size] = k_nope[k_nope_offset]; + } else { + int trans_idx = rope_idx / 2; + int trans_offset = trans_idx + (rope_idx % 2 ? 1 : 0) * (rope_head_dim / 2) - tidx + nope_head_dim; + int q_dst = dst_base_offset + trans_offset; + int k_dst = q_dst + hidden_size; + qkv[q_dst] = q[q_offset]; + qkv[k_dst] = k_rope[k_rope_offset]; + } + + if (tidx < v_head_dim) { + qkv[dst_base_offset + 2 * hidden_size] = v[v_offset]; + } else { + qkv[dst_base_offset + 2 * hidden_size] = 0; + } + } + } + } +} +BufferPtr ArmCpuDevice::mlaQKVGemm(const AttentionLayerParams& params) { + auto datatype = params.input.type(); + const auto& input = params.input; + + auto token_num = params.input.shape()[0]; + auto head_num = params.configs.head_num; + auto nope_head_dim = params.configs.nope_head_dim; + auto rope_head_dim = params.configs.rope_head_dim; + auto v_head_dim = params.configs.v_head_dim; + auto nope_rope_dim = nope_head_dim + rope_head_dim; + + auto qkv = + allocateBuffer({DataType::TYPE_FP32, {token_num, head_num * nope_rope_dim * 3}, AllocationType::DEVICE}, {"mla_qkv"}); + + // Q_a = input * W_qa + // Q_a = normalize(Q_a) + // Q = Q_a * W_qb + BufferPtr fused_qkv = nullptr; + BufferPtr q = nullptr; + int64_t kv_offset = 0; + if (params.weights.fusedqkrope_weight != nullptr) { + fused_qkv = gemm(GemmParams(input, *(params.weights.fusedqkrope_weight->kernel))); + kv_offset = params.configs.q_lora_rank; + layernorm(LayernormParams(fused_qkv, + fused_qkv, + mayGetRef(params.weights.q_a_norm_weight), + std::nullopt, + std::nullopt, + std::nullopt, + 1.0f, + params.ln_params.eps, + true, + false, + params.ln_params.norm_type)); + q = gemm(GemmParams(*fused_qkv, *(params.weights.q_b_weight->kernel))); + } else { + fused_qkv = gemm(GemmParams(input, *(params.weights.fusedqkrope_no_lora_weight->kernel))); + kv_offset = params.configs.head_num * params.configs.size_per_head; + printf("@mlaQKVGemm kv_offset %ld\n", kv_offset); + q = slice(SliceParams({*fused_qkv, -1, 0, (int64_t)(params.configs.head_num * params.configs.size_per_head)})); + } + + // kv_a = input * W_kva + // kv_a = normalize(kv_a) + // knope = kv_a * W_knope + // v = kv_a * W_v + auto kv_a = gemm(GemmParams(input, *(params.weights.kv_a_weight->kernel))); + layernorm(LayernormParams(kv_a, + kv_a, + mayGetRef(params.weights.kv_a_norm_weight), + std::nullopt, + std::nullopt, + std::nullopt, + 1.0f, + params.ln_params.eps, + true, + false, + params.ln_params.norm_type)); + auto k_nope = gemm(GemmParams(*kv_a, *(params.weights.k_nope_weight->kernel))); + auto v = gemm(GemmParams(*kv_a, *(params.weights.v_weight->kernel))); + + // k_rope = input * W_krope + auto k_rope = gemm(GemmParams(input, *(params.weights.k_rope_weight->kernel))); + + if (datatype == DataType::TYPE_FP16) { + mla_merge_transpose_cpu<__fp16>( + (__fp16 *)q->data(), + (__fp16 *)k_nope->data(), + (__fp16 *)k_rope->data(), + (__fp16 *)v->data(), + (float *)qkv->data(), + token_num, + head_num, + nope_head_dim, + rope_head_dim, + v_head_dim); + } else if (datatype == DataType::TYPE_FP16) { + mla_merge_transpose_cpu( + (float *)q->data(), + (float *)k_nope->data(), + (float *)k_rope->data(), + (float *)v->data(), + (float *)qkv->data(), + token_num, + head_num, + nope_head_dim, + rope_head_dim, + v_head_dim); + } else { + throw std::runtime_error("mla_merge_transpose_cpu type is not supported"); + } + printBufferData(*qkv, "MLA QKV Gemm output"); + return qkv; +} +} // namespace rtp_llm \ No newline at end of file diff --git a/rtp_llm/cpp/devices/arm_impl/ArmWeights.cc b/rtp_llm/cpp/devices/arm_impl/ArmWeights.cc index 22df8b4be..023025dcb 100644 --- a/rtp_llm/cpp/devices/arm_impl/ArmWeights.cc +++ b/rtp_llm/cpp/devices/arm_impl/ArmWeights.cc @@ -48,10 +48,10 @@ torch::Tensor ArmCpuDevice::preprocessWeightsForMixedGemm(torch::Tensor row_majo return row_major_quantized_weight; } -torch::Tensor ArmCpuDevice::preprocessWeightScale(torch::Tensor qweight, torch::Tensor scales) { +torch::Tensor ArmCpuDevice::preprocessWeightScale(torch::Tensor qweight, torch::Tensor scales, const std::string& key) { auto qweightBuffer = torchTensor2Buffer(qweight); auto scaleBuffer = torchTensor2Buffer(scales); - auto retBuffer = prepareGemmOptForGPTQInt4(qweightBuffer, scaleBuffer, ""); + auto retBuffer = prepareGemmOptForGPTQInt4(qweightBuffer, scaleBuffer, key); return Buffer2torchTensor(*retBuffer, false); } diff --git a/rtp_llm/cpp/devices/arm_impl/gemm_opt/ArmGemmPacking.cc b/rtp_llm/cpp/devices/arm_impl/gemm_opt/ArmGemmPacking.cc index 679899d61..bbed91568 100644 --- a/rtp_llm/cpp/devices/arm_impl/gemm_opt/ArmGemmPacking.cc +++ b/rtp_llm/cpp/devices/arm_impl/gemm_opt/ArmGemmPacking.cc @@ -27,7 +27,7 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" - +#include #define GPTQ_COMPUTE_AS_DI_BF16 0 namespace rtp_llm { @@ -121,34 +121,53 @@ static void quant_qs4c32_f32(size_t n, size_t k, size_t bl, const float* rhs_f32 } } -ConstBufferPtr prepareGemmWeight(const std::string& key, ConstBufferPtr input) { - if (armPrepareWeightFunc == nullptr) { - if (std::getenv("ARM_GEMM_USE_KAI") == nullptr) { - armPrepareWeightFunc = prepareGemmOptWeight; - } else { - RTP_LLM_LOG_INFO("KleidiAI enabled.\n"); - armPrepareWeightFunc = prepareKaiWeightBf16; - } - } - // Transpose and reorder - if (key == W::lm_head) { - return armPrepareWeightFunc(transposeWeight(input), true, true); - } +size_t get_rhs_packed_size(int n, int k) { + const size_t bl = 32; + const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod(); + const size_t kr = kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod(); + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(n, k, nr, kr, bl); +} - // // Reorder RHS weight matrics for better GEMM performance - if (key == W::attn_qkv_w) { - return armPrepareWeightFunc(input, false, true); - } - if (key == W::attn_o_w || - key == W::ffn_w1 || - key == W::ffn_w2 || - key == W::ffn_w3) { - return armPrepareWeightFunc(input, false, false); - } +void transposeWeightByAcl(size_t k, size_t n, DataType data_type, void *input_data) { - return input; -} + arm_compute::NETranspose transB; + arm_compute::Tensor wei_tran_tensor; + arm_compute::TensorInfo wei_data_info; + arm_compute::TensorInfo wei_tran_info; + arm_compute::Tensor wei_tensor; + arm_compute::DataType acl_data_type; + if (data_type == DataType::TYPE_FP16) + acl_data_type = arm_compute::DataType::F16; + else if (data_type == DataType::TYPE_FP32) + acl_data_type = arm_compute::DataType::F32; + else if (data_type == DataType::TYPE_FP8_E4M3) + acl_data_type = arm_compute::DataType::U8; + else + throw std::runtime_error("transpose data type is not supported"); + + wei_data_info = arm_compute::TensorInfo(arm_compute::TensorShape(n, k), 1, acl_data_type); + wei_tran_info = arm_compute::TensorInfo(arm_compute::TensorShape(k, n), 1, acl_data_type); + + size_t element_num = k * n; + size_t data_size = data_type == DataType::TYPE_FP32 ? sizeof(float) : sizeof(float16_t); + + size_t transposed_size = element_num * data_size; + void *transposed_data = malloc(transposed_size); + + wei_tensor.allocator()->init(wei_data_info); + wei_tran_tensor.allocator()->init(wei_tran_info); + wei_tensor.allocator()->import_memory(input_data); + + wei_tran_tensor.allocator()->import_memory(transposed_data); + + transB.configure(&wei_tensor, &wei_tran_tensor); + transB.run(); + + // Update input buffer with transposed data, reduce memory usage + memcpy(input_data, transposed_data, transposed_size); + free(transposed_data); +} BufferPtr transposeWeight(ConstBufferPtr input) { @@ -185,33 +204,10 @@ BufferPtr transposeWeight(ConstBufferPtr input) { std::vector weight_workspace_shape = std::vector(Bshape.begin(), Bshape.end() - 2); weight_workspace_shape.insert(weight_workspace_shape.end(), {n, k}); - size_t element_num = k * n; - size_t data_size = data_type == DataType::TYPE_FP32 ? sizeof(float) : sizeof(float16_t); - //const void *data = malloc(element_num * data_size); - //output = BufferPtr(new Buffer(MemoryType::MEMORY_CPU, - // data_type, - // weight_workspace_shape, - // data)), - size_t transposed_size = element_num * data_size; - void *transposed_data = malloc(transposed_size); - - wei_tensor.allocator()->init(wei_data_info); - wei_tran_tensor.allocator()->init(wei_tran_info); - wei_tensor.allocator()->import_memory(input->data()); - //wei_tran_tensor.allocator()->import_memory(output->data()); - wei_tran_tensor.allocator()->import_memory(transposed_data); - - transB.configure(&wei_tensor, &wei_tran_tensor); - transB.run(); - - //return output; - // Update input buffer with transposed data, reduce memory usage - RTP_LLM_CHECK_WITH_INFO(input->sizeBytes() >= transposed_size, "transpose dst size < src size"); - memcpy(input->data(), transposed_data, transposed_size); - free(transposed_data); + transposeWeightByAcl(k, n, input->type(), input->data()); auto packedBuffer = BufferPtr(new Buffer(MemoryType::MEMORY_CPU, - data_type, + input->type(), weight_workspace_shape, input->data())); return packedBuffer; @@ -350,6 +346,87 @@ ConstBufferPtr prepareKaiWeightBf16(ConstBufferPtr input, bool isTranspose, bool return output; } +ConstBufferPtr prepareKaiWeightKcVc(ConstBufferPtr input) { + if (input->type() != DataType::TYPE_FP16) { + throw std::runtime_error("prepareKaiWeightKcVc only supports fp16 weight type"); + } + + ConstBufferPtr output = input; + + size_t bs = input->shape()[0]; + size_t k = input->shape()[1]; + size_t n = input->shape()[2]; + + const size_t nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(); + const size_t kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(); + const size_t sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla(); + + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon(n, k); + + uint8_t* rhs_packed = new uint8_t[bs * rhs_packed_size]; + + std::vector weight_workspace_shape = {bs, rhs_packed_size / sizeof(bfloat16_t)}; + + output = BufferPtr(new Buffer(MemoryType::MEMORY_CPU, + DataType::TYPE_BF16, + weight_workspace_shape, + rhs_packed)); + + const size_t rhs_stride = n * sizeof(float16_t); + float16_t* rhs = (float16_t* )input->data(); + + #pragma omp parallel for + for (int b = 0; b < bs; ++b) { + kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon( + 1, n, k, nr, kr, sr, // Packing arguments + rhs_stride, // RHS stride + rhs + b * k * n, // RHS + NULL, // Bias + NULL, // Scale + rhs_packed + b * rhs_packed_size, // RHS packed + 0, NULL); + } + return output; +} + +BufferPtr transposeWeightMoe(ConstBufferPtr input, bool isMerged) { + + std::vector Bshape = input->shape(); + auto data_type = input->type(); + auto dim = input->dim(); + size_t k; + size_t n; + + + k = Bshape[dim - 2]; + n = Bshape[dim - 1]; + + std::vector transposedShape = std::vector(Bshape.begin(), Bshape.end() - 2); + transposedShape.insert(transposedShape.end(), {n, k}); + + size_t experts_num = input->size() / (k * n); + + if (isMerged) { + k /= 2; + } + + for (size_t i = 0; i < experts_num; i++) { + // pack weight + if (!isMerged) { + transposeWeightByAcl(k, n, data_type, input->dataWithOffset(k * n * i)); + } else { + transposeWeightByAcl(k, n, data_type, input->dataWithOffset(2 * i * k * n)); + transposeWeightByAcl(k, n, data_type, input->dataWithOffset((2 * i + 1) * k * n)); + } + } + + auto packedBuffer = BufferPtr(new Buffer(MemoryType::MEMORY_CPU, + data_type, + transposedShape, + input->data())); + return packedBuffer; +} + ConstBufferPtr prepareGemmOptWeight(ConstBufferPtr input, bool isTranspose, bool unused) { ConstBufferPtr weight_workspace = input; @@ -397,7 +474,7 @@ ConstBufferPtr prepareGemmOptWeight(ConstBufferPtr input, bool isTranspose, bool } } - // Update original buffer with packed data to save memory usage + // Update original buffer with packed data to save memory usage //RTP_LLM_CHECK_WITH_INFO(input->sizeBytes() >= weight_workspace->sizeBytes(), "gemm pack dst size < src size"); //memcpy(input->data(), weight_workspace->data(), weight_workspace->sizeBytes()); //free(weight_workspace->data()); @@ -412,50 +489,238 @@ ConstBufferPtr prepareGemmOptWeight(ConstBufferPtr input, bool isTranspose, bool return weight_workspace; } -//ConstBufferPtr prepareGemmWeight(const std::string& key, ConstBufferPtr input) { -// // Transpose and reorder -// if (key == W::lm_head) { -// return prepareGemmOptWeight(transposeWeight(input), true); -// } -// -// // Reorder RHS weight matrics for better GEMM performance -// if (key == W::attn_qkv_w || +BufferPtr prepareGemmOptWeightMoe(ConstBufferPtr input, bool isMerged) { + if (input->type() != DataType::TYPE_FP32 && input->type() != DataType::TYPE_FP16) { + throw std::runtime_error("prepareGemmOptWeightMoe type is not supported"); + } + + BufferPtr weight_workspace; + GemmKernel gemm_kernel; + std::vector Bshape = input->shape(); + auto dim = input->dim(); + + size_t k; + size_t n; + + k = Bshape[dim - 2]; + n = Bshape[dim - 1]; + size_t experts_num = input->size() / (k * n); + + if (isMerged) { + n /= 2; + } + + size_t weight_k_pack = std::ceil(k / 8.0) * 8; + size_t width = weight_k_pack * 2; + size_t height = n / 2 + n % 2; + + void *data = malloc(input->sizeBytes()); + memset(data, 0, input->sizeBytes()); + + // allocate a temp workspace to pack weight fp32->bf16 + std::vector weight_workspace_shape = std::vector(Bshape.begin(), Bshape.end() - 2); + weight_workspace_shape.insert(weight_workspace_shape.end(), {n, k}); + + weight_workspace = BufferPtr(new Buffer(MemoryType::MEMORY_CPU, + DataType::TYPE_BF16, + weight_workspace_shape, + data)); + + for (size_t i = 0; i < experts_num; i++) { + // pack weight + if (!isMerged) { + hie::bfloat16* weight_workspace_cur_ptr = reinterpret_cast(weight_workspace->dataWithOffset(height * width * i)); + if (input->type() == DataType::TYPE_FP32) { + float* B_fp32_ptr = reinterpret_cast(input->dataWithOffset(k * n * i)); + gemm_kernel.gemm_pack_weight_FP32toBF16_arm(n, k, weight_k_pack, B_fp32_ptr, weight_workspace_cur_ptr); + } else { // if(params.B.type() == DataType::TYPE_FP16) + float16_t* B_fp16_ptr = reinterpret_cast(input->dataWithOffset(k * n * i)); + gemm_kernel.gemm_pack_weight_FP16toBF16_arm(n, k, weight_k_pack, B_fp16_ptr, weight_workspace_cur_ptr); + } + } else { + hie::bfloat16* weight_workspace_cur_ptr = reinterpret_cast(weight_workspace->dataWithOffset(2 * height * width * i)); + if (input->type() == DataType::TYPE_FP32) { + float* B_fp32_ptr = reinterpret_cast(input->dataWithOffset(2 * k * n * i)); + gemm_kernel.gemm_pack_weight_FP32toBF16_arm(n, k, weight_k_pack, B_fp32_ptr, weight_workspace_cur_ptr); + } else { // if(params.B.type() == DataType::TYPE_FP16) + float16_t* B_fp16_ptr = reinterpret_cast(input->dataWithOffset(2 * k * n * i)); + gemm_kernel.gemm_pack_weight_FP16toBF16_arm(n, k, weight_k_pack, B_fp16_ptr, weight_workspace_cur_ptr); + } + + weight_workspace_cur_ptr = reinterpret_cast(weight_workspace->dataWithOffset(2 * height * width * i + height * width)); + if (input->type() == DataType::TYPE_FP32) { + float* B_fp32_ptr = reinterpret_cast(input->dataWithOffset(2 * k * n * i + k * n)); + gemm_kernel.gemm_pack_weight_FP32toBF16_arm(n, k, weight_k_pack, B_fp32_ptr, weight_workspace_cur_ptr); + } else { // if(params.B.type() == DataType::TYPE_FP16) + float16_t* B_fp16_ptr = reinterpret_cast(input->dataWithOffset(2 * k * n * i + k * n)); + gemm_kernel.gemm_pack_weight_FP16toBF16_arm(n, k, weight_k_pack, B_fp16_ptr, weight_workspace_cur_ptr); + } + } + } + // Update original buffer with packed data to save memory usage + RTP_LLM_CHECK_WITH_INFO(input->sizeBytes() >= weight_workspace->sizeBytes(), "gemm pack dst size < src size"); + memcpy(input->data(), weight_workspace->data(), input->sizeBytes()); + free(weight_workspace->data()); + + auto packedBuffer = BufferPtr(new Buffer(MemoryType::MEMORY_CPU, + DataType::TYPE_BF16, + Bshape, + input->data())); + return packedBuffer; +} + +ConstBufferPtr prepareGemmWeight(const std::string& key, ConstBufferPtr input) { + if (armPrepareWeightFunc == nullptr) { + if (std::getenv("ARM_GEMM_USE_KAI") == nullptr) { + armPrepareWeightFunc = prepareGemmOptWeight; + } else { + RTP_LLM_LOG_INFO("KleidiAI enabled.\n"); + armPrepareWeightFunc = prepareKaiWeightBf16; + } + } + + // Transpose and reorder + if (key == W::lm_head) { + return armPrepareWeightFunc(transposeWeight(input), true, true); + } + + if (key == W::mla_kc || key == W::mla_vc) { + return prepareKaiWeightKcVc(input); + } + + if (key == W::moe_w1) { + if (std::getenv("ARM_GEMM_USE_KAI") != nullptr) { + throw std::runtime_error("prepareKAIWeightBf16Moe is not implemented."); + } + return prepareGemmOptWeightMoe(transposeWeightMoe(input, /*isMerged*/true), /*isMerged*/true); + } + + if (key == W::moe_w2) { + return prepareGemmOptWeightMoe(transposeWeightMoe(input, /*isMerged*/false), /*isMerged*/false); + } + + // Reorder RHS weight matrics for better GEMM performance + if (key == W::attn_qkv_w || + key == W::attn_q_b || + key == W::attn_k_nope || + key == W::attn_v || + key == W::mla_fusedqkrope_no_lora || + key == W::mla_fusedqkrope || + key == W::moe_gate) { + return armPrepareWeightFunc(input, false, /*isForceF32Out*/true); + } + if (key == W::attn_o_w || + key == W::ffn_w1 || + key == W::ffn_w2 || + key == W::ffn_w3) { + return armPrepareWeightFunc(input, false, /*isForceF32Out*/false); + } + + return input; +} + torch::Tensor ArmCpuDevice::preprocessGemmWeightByKey(const std::string& key, torch::Tensor weight) { + if (c10::isFloat8Type(weight.dtype().toScalarType()) || key.find("weight_only_quant_scale") != std::string::npos) { + return weight; + } + auto buffer = torchTensor2Buffer(weight); auto retBuffer = prepareGemmWeight(key, buffer); // Repacked buffer size may not match with shape size * element size, // should use buffer pointer instead of copying data. if ((key == W::attn_qkv_w || + key == W::mla_fusedqkrope_no_lora || + key == W::mla_fusedqkrope || + key == W::mla_kc || + key == W::mla_vc || key == W::attn_o_w || key == W::ffn_w1 || key == W::ffn_w2 || - // key == W::ffn_w3) { - //return prepareGemmOptWeight(input, false); key == W::ffn_w3 || + key == W::moe_gate || key == W::lm_head) && retBuffer->type() == DataType::TYPE_BF16) { return Buffer2torchTensor(*retBuffer, false); } - if ((key == W::attn_qkv_w || - key == W::attn_o_w || - key == W::ffn_w1 || - key == W::ffn_w2 || - key == W::ffn_w3) && retBuffer->type() == DataType::TYPE_UINT8) { - return Buffer2torchTensor(*retBuffer, false); - } - return Buffer2torchTensor(*retBuffer); } +inline float fp8_to_fp32_e4m3(uint8_t fp8_value) { + // Extract sign, exponent, and mantissa + uint8_t sign = (fp8_value & 0x80) >> 7; // 1 bit for sign + uint8_t exponent = (fp8_value & 0x78) >> 3; // 4 bits for exponent + uint8_t mantissa = (fp8_value & 0x07); // 3 bits for mantissa + + if (exponent == 0) { + // Subnormal number + float subnormal = mantissa / 512.0f; + return sign ? -subnormal : subnormal; + } else if (exponent == 0xF && mantissa == 0x7) { + // NaN + return sign ? -NAN : NAN; + } else { + // Normalized number + float normalized = (8.0f + mantissa) * (1 << exponent) / (1 << 10); + return sign ? -normalized : normalized; + } +} + +static void quant_qs4c32_f8(size_t n, size_t k, size_t scale_n, size_t scale_k, size_t bl, const uint8_t* qweight, const float* qscales, uint8_t* rhs_qs4c32) { + if (bl != 32) { + throw std::runtime_error("bl should be 32"); + } + const size_t num_blocks_row = num_blocks_per_row(k, bl); + const size_t num_bytes_block = num_bytes_per_block_qs4c32(bl); + const size_t dst_stride = num_blocks_row * num_bytes_block; + #pragma omp parallel for + for (size_t row_idx = 0; row_idx < n; ++row_idx) { + uint8_t* dst_ptr = (uint8_t*)rhs_qs4c32 + row_idx * dst_stride; + + for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { + float src[32]; + float max = 0.0f; + float amax = 0.0f; + float scale_0 = qscales[row_idx / 128 * scale_k + block_idx * bl / 128 ]; + for (size_t i = 0; i < bl; ++i) { + uint8_t qint8 = qweight[row_idx * k + block_idx * bl + i]; + const float x0 = scale_0 * fp8_to_fp32_e4m3(qint8); + src[i] = x0; + const float ax0 = fabsf(x0); + if (amax < ax0) { + amax = ax0; + max = x0; + } + } + + const float scale = max / -8.0; + const float recip_scale = scale ? 1.0f / scale : 0.0f; + + // Store the scale at the beginning of the block + *((uint16_t*)dst_ptr) = kai_cast_f16_f32(scale); + dst_ptr += sizeof(uint16_t); + + for (size_t i = 0; i < bl / 2; ++i) { + float v0_f32 = src[i]; + float v1_f32 = src[i + bl / 2]; + + v0_f32 *= recip_scale; + v1_f32 *= recip_scale; + + const uint8_t v0_u8 = (uint8_t)std::min((int8_t)15, (int8_t)(v0_f32 + 8.5f)); + const uint8_t v1_u8 = (uint8_t)std::min((int8_t)15, (int8_t)(v1_f32 + 8.5f)); -//torch::Tensor ArmCpuDevice::preprocessGemmWeightByKey(const std::string& key, torch::Tensor weight) { -// auto buffer = torchTensor2Buffer(weight); -// auto retBuffer = prepareGemmWeight(key, buffer); -// return Buffer2torchTensor(*retBuffer); -//} + const uint8_t rhs_v0 = (v1_u8 << 4) | v0_u8; + + dst_ptr[0] = rhs_v0; + dst_ptr += sizeof(uint8_t); + } + } + } +} ConstBufferPtr prepareGemmOptForGPTQInt4(ConstBufferPtr kernel, ConstBufferPtr scales, const std::string& key) { + ConstBufferPtr weight_workspace = kernel; std::vector Bshape = kernel->shape(); @@ -517,6 +782,7 @@ ConstBufferPtr prepareGemmOptForGPTQInt4(ConstBufferPtr kernel, ConstBufferPtr s gemm_kernel.gemm_pack_weight_FP16toBF16_arm(n, k, weight_k_pack, unpacked_weight, weight_workspace_cur_ptr); free(unpacked_weight); return weight_workspace; + } #else if (kernel->type() == DataType::TYPE_INT8 && scales->type() == DataType::TYPE_FP16) { int8_t* qweight = (int8_t*)kernel->data(); @@ -572,7 +838,8 @@ ConstBufferPtr prepareGemmOptForGPTQInt4(ConstBufferPtr kernel, ConstBufferPtr s std::vector weight_workspace_shape = std::vector(Bshape.begin(), Bshape.end() - 2); - weight_workspace_shape.insert(weight_workspace_shape.end(), {k, n / 2}); + weight_workspace_shape.insert(weight_workspace_shape.end(), {k, n / 2}); + BufferPtr output = BufferPtr(new Buffer(MemoryType::MEMORY_CPU, DataType::TYPE_UINT8, weight_workspace_shape, @@ -580,8 +847,7 @@ ConstBufferPtr prepareGemmOptForGPTQInt4(ConstBufferPtr kernel, ConstBufferPtr s uint8_t* rhs_native_mtx_qs4c32 = new uint8_t[rhs_native_size_qs4c32]; - quant_qs4c32_f32( - n, k, bl, (const float*)transposedWeight->data(), (uint8_t*)rhs_native_mtx_qs4c32); + quant_qs4c32_f32(n, k, bl, (const float*)transposedWeight->data(), (uint8_t*)rhs_native_mtx_qs4c32); struct kai_rhs_pack_qs4cxs1s0_param kai_rhs_params; kai_rhs_params.lhs_zero_point = 1; @@ -611,8 +877,101 @@ ConstBufferPtr prepareGemmOptForGPTQInt4(ConstBufferPtr kernel, ConstBufferPtr s delete[] rhs_native_mtx_qs4c32; free(unpacked_weight); return output; -#endif + } else if (kernel->type() == DataType::TYPE_FP8_E4M3 && scales->type() == DataType::TYPE_FP32) { + uint8_t* qweight = (uint8_t*)kernel->data(); + auto qscales = (float*)scales->data(); + n /= 2; + + if (key == W::moe_w1 || key == W::moe_w2) { + size_t tmp = k; + k = n; + n = tmp; + } + size_t scale_k = k / 128; + size_t scale_n = n / 128; + + size_t batch_size = std::accumulate(Bshape.begin(), Bshape.end() - 2, (size_t)1, std::multiplies()); + + //float* unpacked_weight = (float*)malloc(k * n * sizeof(float)); + + const size_t bl = 32; + const size_t num_blocks = k / bl; + const size_t num_bytes_per_block_qs4c32 = (bl / 2) + sizeof(int16_t); + const size_t rhs_native_size_qs4c32 = n * num_blocks * num_bytes_per_block_qs4c32; + uint8_t* rhs_native_mtx_qs4c32 = new uint8_t[rhs_native_size_qs4c32]; + + const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod(); + const size_t kr = kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod(); + const size_t sr = kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod(); + + // In a single row, we pack nr bias values followed by K rows of nr RHS values + const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(n, k, nr, kr, bl); + + uint8_t* rhs_packed_mtx_qs4c32 = new uint8_t[rhs_packed_size * batch_size]; + + std::vector weight_workspace_shape = std::vector(Bshape.begin(), Bshape.end() - 2); + + // workaround to save/load converted weights + // set buffer size same as actual packed buffer size + // dim n is correct and is used in gemm compute + // dim k is wrong and should not be used + if (rhs_packed_size % n != 0) { + throw std::runtime_error("rhs_packed_size is not multiple of n"); + } + weight_workspace_shape.insert(weight_workspace_shape.end(), {rhs_packed_size / n, n}); + //weight_workspace_shape.insert(weight_workspace_shape.end(), {k, n}); + + BufferPtr output = BufferPtr(new Buffer(MemoryType::MEMORY_CPU, + DataType::TYPE_UINT8, + weight_workspace_shape, + rhs_packed_mtx_qs4c32)); + + for (int b = 0; b < batch_size; b++) { + // qweight/qscales are transposed [n, k] + // #pragma omp parallel for collapse(2) + // for (int i = 0; i < n; i++) { + // for (int j = 0; j < k; j++) { + // uint8_t qint8 = qweight[b * k * n + i * k + j]; + // float scale_0 = qscales[b * scale_k * scale_n + i / 128 * scale_k + j / 128 ]; + // auto x0 = scale_0 * fp8_to_fp32_e4m3(qint8); + // unpacked_weight[i * k + j ] = x0; + // } + // } + + //quant_qs4c32_f32(n, k, bl, unpacked_weight, rhs_native_mtx_qs4c32); + + quant_qs4c32_f8(n, k, scale_n, scale_k, bl, qweight + b * k * n, qscales + b * scale_k * scale_n, rhs_native_mtx_qs4c32); + + struct kai_rhs_pack_qs4cxs1s0_param kai_rhs_params; + kai_rhs_params.lhs_zero_point = 1; + kai_rhs_params.rhs_zero_point = 8; + + // Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant. + int n_step = 32; + size_t rhs_stride = kai_rhs_stride(k, bl); + + #pragma omp parallel for + for (int n_start = 0; n_start < n; n_start += n_step) { + const size_t rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(n_start, rhs_stride); + const size_t packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(n_start, k, nr, kr, bl); + + int tile_n = (n_start + n_step <= n) ? n_step : n - n_start; + kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( + 1, tile_n, k, // Dimensions + nr, kr, sr, // Packing arguments + bl, // Block length + (const uint8_t*)(rhs_native_mtx_qs4c32 + rhs_offset), // RHS + NULL, // Bias + ((uint8_t*)rhs_packed_mtx_qs4c32 + b * rhs_packed_size + packed_offset), // RHS packed + 0, &kai_rhs_params + ); + } + } + delete[] rhs_native_mtx_qs4c32; + //free(unpacked_weight); + return output; } +#endif return weight_workspace; } diff --git a/rtp_llm/cpp/devices/arm_impl/test/BUILD b/rtp_llm/cpp/devices/arm_impl/test/BUILD index 1cc3f686e..bf155653a 100644 --- a/rtp_llm/cpp/devices/arm_impl/test/BUILD +++ b/rtp_llm/cpp/devices/arm_impl/test/BUILD @@ -139,3 +139,16 @@ cc_test( env = test_envs, tags = test_tags, ) + +cc_test( + name = "arm_ffn_opt_op_test", + srcs = [ + "ops/ArmFfnOpTest.cc", + ], + data = [], + copts = test_copts, + linkopts = test_linkopts, + deps = test_deps, + env = test_envs, + tags = test_tags, +) diff --git a/rtp_llm/cpp/devices/arm_impl/test/ops/ArmFfnOpTest.cc b/rtp_llm/cpp/devices/arm_impl/test/ops/ArmFfnOpTest.cc new file mode 100644 index 000000000..0c1157ab6 --- /dev/null +++ b/rtp_llm/cpp/devices/arm_impl/test/ops/ArmFfnOpTest.cc @@ -0,0 +1,29 @@ +#include "maga_transformer/cpp/devices/base_tests/FfnLayerTest.hpp" +#include "maga_transformer/cpp/devices/arm_impl/ArmDevice.h" + + +class ArmFfnLayerTest: public FfnLayerTest {}; + +TEST_F(ArmFfnLayerTest, Gate_Fp16_FfnOpTest) { + FfnOpTest(4, 2048, 128, ActivationType::Swiglu, DataType::TYPE_FP32); + // FfnOpTest(4, 2048, 4096, ActivationType::Swiglu, DataType::TYPE_FP32); + // FfnOpTest(128, 2048, 128, ActivationType::Swiglu, DataType::TYPE_FP32); + // FfnOpTest(1000, 2048, 128, ActivationType::Swiglu, DataType::TYPE_FP32); + // FfnOpTest(1, 2, 4096, ActivationType::Swiglu, DataType::TYPE_FP32); + // FfnOpTest(1000, 2048, 128, ActivationType::Swiglu, DataType::TYPE_FP32); +} + +// TEST_F(ArmFfnLayerTest, NoGate_Fp16_FfnOpTest) { +// FfnOpTest(4, 2048, 128, ActivationType::Geglu, DataType::TYPE_FP32); +// FfnOpTest(4, 2048, 4096, ActivationType::Geglu, DataType::TYPE_FP32); +// FfnOpTest(128, 2048, 128, ActivationType::Geglu, DataType::TYPE_FP32); +// FfnOpTest(1000, 2048, 128, ActivationType::Geglu, DataType::TYPE_FP32); +// FfnOpTest(1, 2, 4096, ActivationType::Geglu, DataType::TYPE_FP32); +// FfnOpTest(1000, 2048, 128, ActivationType::Geglu, DataType::TYPE_FP32); +// } + +TEST_F(MoELayerTest, Gate_Fp16_MoEOpTest) { + // MoEOpTest(4, 3584, 2560, 64, 8, ActivationType::Silu, DataType::TYPE_FP16); + MoEOpTest(10, 448, 320, 8, 2, ActivationType::Swiglu, DataType::TYPE_FP32); // FUll divide 8 + // MoEOpTest(10, 448, 320, 8, 2, ActivationType::Silu, DataType::TYPE_FP16); // FUll divide 8 +} diff --git a/rtp_llm/cpp/devices/arm_impl/test/ops/LayerNormOpTest.cc b/rtp_llm/cpp/devices/arm_impl/test/ops/LayerNormOpTest.cc index 0745f9116..3986fab55 100644 --- a/rtp_llm/cpp/devices/arm_impl/test/ops/LayerNormOpTest.cc +++ b/rtp_llm/cpp/devices/arm_impl/test/ops/LayerNormOpTest.cc @@ -1,5 +1,6 @@ #include "rtp_llm/cpp/devices/testing/TestBase.h" #include "rtp_llm/cpp/devices/arm_impl/ArmDevice.h" +#include "rtp_llm/cpp/devices/base_tests/LayerNormTest.hpp" #include using namespace std; @@ -82,4 +83,26 @@ TEST_F(ArmLayerNormOpsTest, testSimpleLayernorm) { testGeneralLayernorm(DataType::TYPE_FP32, m, n); } } +} + +TEST_F(LayerNormTest, testSimpleLayernormStride) { + const auto test_m = vector({1, 2, 4, 8, 10, 20}); + const auto test_n = vector({128, 256, 1024}); + for (const auto& m: test_m) { + for (const auto& n: test_n) { + printf("testing m = %d, n = %d \n", m, n); + testGeneralLayernormStride(DataType::TYPE_FP16, NormType::layernorm, m, n, n); + testGeneralLayernormStride(DataType::TYPE_BF16, NormType::layernorm, m, n, n); + testGeneralLayernormStride(DataType::TYPE_FP32, NormType::layernorm, m, n, n); + testGeneralLayernormStride(DataType::TYPE_FP16, NormType::rmsnorm, m, n, n); + testGeneralLayernormStride(DataType::TYPE_BF16, NormType::rmsnorm, m, n, n); + testGeneralLayernormStride(DataType::TYPE_FP32, NormType::rmsnorm, m, n, n); + testGeneralLayernormStride(DataType::TYPE_FP16, NormType::layernorm, m, n, n / 2); + testGeneralLayernormStride(DataType::TYPE_BF16, NormType::layernorm, m, n, n / 2); + testGeneralLayernormStride(DataType::TYPE_FP32, NormType::layernorm, m, n, n / 2); + testGeneralLayernormStride(DataType::TYPE_FP16, NormType::rmsnorm, m, n, n / 2); + testGeneralLayernormStride(DataType::TYPE_BF16, NormType::rmsnorm, m, n, n / 2); + testGeneralLayernormStride(DataType::TYPE_FP32, NormType::rmsnorm, m, n, n / 2); + } + } } \ No newline at end of file diff --git a/rtp_llm/cpp/devices/arm_impl/type_bf16/bfloat16_impl.hpp b/rtp_llm/cpp/devices/arm_impl/type_bf16/bfloat16_impl.hpp index ec6306c7d..8938d8861 100644 --- a/rtp_llm/cpp/devices/arm_impl/type_bf16/bfloat16_impl.hpp +++ b/rtp_llm/cpp/devices/arm_impl/type_bf16/bfloat16_impl.hpp @@ -130,9 +130,11 @@ struct HIE_ALIGN(2) __Bf16Impl { // from bf16 to float static float bfloat162float(__Bf16Impl v) { - std::uint32_t val = static_cast(v.__x) << 16; - const float* vptr = reinterpret_cast(&val); - return *vptr; + std::uint32_t tmp = static_cast(v.__x); + std::uint32_t val = tmp << 16; + float result; + std::memcpy(&result, &val, sizeof(result)); + return result; } static double bfloat162double(__Bf16Impl v) { diff --git a/rtp_llm/cpp/devices/base_impl/FfnLayer.cc b/rtp_llm/cpp/devices/base_impl/FfnLayer.cc index d87b0934c..11c2488c9 100644 --- a/rtp_llm/cpp/devices/base_impl/FfnLayer.cc +++ b/rtp_llm/cpp/devices/base_impl/FfnLayer.cc @@ -20,6 +20,13 @@ FfnLayerOutput DeviceBase::ffnLayer(const FfnLayerParams& params) { auto moe_output = moeFfnLayer(params); output = moe_output.hidden_states; +#if defined(__aarch64__) + if (params.input.type() == DataType::TYPE_FP16 && + params.weights.moe_down_weight->kernel->type() == DataType::TYPE_QFP8_E4M3) { + return FfnLayerOutput({std::move(output)}); + } +#endif + auto shared_expert_output = moeSharedExpert(params).hidden_states; // for deep ep ll, the gather should be defered afater shared expert. @@ -32,9 +39,16 @@ FfnLayerOutput DeviceBase::ffnLayer(const FfnLayerParams& params) { printBufferData(*output, "moe_out_after_barrier"); if (shared_expert_output) { // just add bias to output - layernorm({ +#if defined(__aarch64__) + shared_expert_output = layernorm({ + output, nullptr, nullopt, mayGetRef(shared_expert_output), + nullopt, nullopt, 1.0f, 1e-5, true, false, NormType::rmsnorm + }).output; +#else + shared_expert_output = layernorm({ output, nullptr, nullopt, mayGetRef(shared_expert_output) }).output; +#endif } } else { BufferPtr up_output; diff --git a/rtp_llm/cpp/devices/base_impl/MlaAttentionLayer.cc b/rtp_llm/cpp/devices/base_impl/MlaAttentionLayer.cc index 460ab2636..94ec233b0 100644 --- a/rtp_llm/cpp/devices/base_impl/MlaAttentionLayer.cc +++ b/rtp_llm/cpp/devices/base_impl/MlaAttentionLayer.cc @@ -60,7 +60,11 @@ AttentionLayerOutput DeviceBase::mlaAttentionLayer(const AttentionLayerParams& p DevicePerfWrapper pre_mla_wrapper(this, "pre_mla_layer"); if (params.weights.fusedqkrope_weight != nullptr) { // auto q_output_size = params.configs.nope_head_dim; +#if defined(__aarch64__) + fused_qkv = gemm(GemmParams(input, *(params.weights.fusedqkrope_weight->kernel), std::nullopt, nullptr, DataType::TYPE_FP32)); +#else fused_qkv = gemm(GemmParams(input, *(params.weights.fusedqkrope_weight->kernel))); +#endif kv_offset = params.configs.q_lora_rank; auto norm_output = layernormWithStride(LayernormWithStrideParams( {fused_qkv, @@ -73,7 +77,11 @@ AttentionLayerOutput DeviceBase::mlaAttentionLayer(const AttentionLayerParams& p false})); q = gemm(GemmParams(*norm_output.output, *(params.weights.q_b_weight->kernel))); } else { +#if defined(__aarch64__) + fused_qkv = gemm(GemmParams(input, *(params.weights.fusedqkrope_no_lora_weight->kernel), std::nullopt, nullptr, DataType::TYPE_FP32)); +#else fused_qkv = gemm(GemmParams(input, *(params.weights.fusedqkrope_no_lora_weight->kernel))); +#endif kv_offset = params.configs.head_num * params.configs.size_per_head; q = slice(SliceParams({*fused_qkv, -1, 0, (int64_t)(params.configs.head_num * params.configs.size_per_head)})); } @@ -88,7 +96,11 @@ AttentionLayerOutput DeviceBase::mlaAttentionLayer(const AttentionLayerParams& p true})); pre_mla_wrapper.stop(); auto dtype = input.type(); +#if defined(__aarch64__) + auto qkv_output = allocateBuffer({DataType::TYPE_FP32, {h_token_num, params.configs.head_num * params.configs.v_head_dim}}, {"qkv_output"}); +#else auto qkv_output = allocateBuffer({dtype, {h_token_num, params.configs.head_num * params.configs.v_head_dim}}, {"qkv_output"}); +#endif if (generate_batch_size) { RTP_LLM_LOG_DEBUG("absorb decode mla attention"); RTP_LLM_CHECK_WITH_INFO(layer_kv_cache.has_value(), "kv cache can not be null for mla attention layer"); @@ -148,12 +160,17 @@ AttentionLayerOutput DeviceBase::mlaAttentionLayer(const AttentionLayerParams& p params.common, params.weights, params.configs, - params.qscheme}); + params.qscheme, + true}); } } printBufferData(*qkv_output, "attent_proj_input"); +#if defined(__aarch64__) + auto output_gemm_params = GemmParams(*qkv_output, *(params.weights.output_weight->kernel), std::nullopt, nullptr, dtype); +#else auto output_gemm_params = GemmParams(*qkv_output, *(params.weights.output_weight->kernel)); +#endif auto attention_out = loraLinear(LoraLinearParams(output_gemm_params, params.common.lora_input.out_lora_input)).output; printBufferData(*attention_out, "attention_out"); return {std::move(attention_out)}; diff --git a/rtp_llm/cpp/devices/base_tests/FfnLayerTest.hpp b/rtp_llm/cpp/devices/base_tests/FfnLayerTest.hpp index 0e3928eaf..6625646f0 100644 --- a/rtp_llm/cpp/devices/base_tests/FfnLayerTest.hpp +++ b/rtp_llm/cpp/devices/base_tests/FfnLayerTest.hpp @@ -1,6 +1,12 @@ #pragma once + #include "rtp_llm/cpp/devices/torch_impl/FfnLayer.h" #include "rtp_llm/cpp/devices/testing/TestBase.h" + +#if defined(__aarch64__) +#include "rtp_llm/cpp/devices/arm_impl/gemm_opt/ArmGemmKernel.h" +#endif + #include #include #include @@ -48,10 +54,20 @@ class FfnLayerTest : public DeviceTestBase { auto up_proj = tensorToBuffer(params.up_proj, alloc_type); auto down_proj = tensorToBuffer(params.down_proj, alloc_type); +#if defined(__aarch64__) + auto gate_packed = prepareGemmOptWeight(gate_proj); + auto up_packed = prepareGemmOptWeight(up_proj); + auto down_packed = prepareGemmOptWeight(down_proj); + FfnLayerWeights weights; + weights.up_weight = std::make_unique(DenseWeights(up_packed)); + weights.down_weight = std::make_unique(DenseWeights(down_packed)); + weights.gate_weight = std::make_unique(DenseWeights(gate_packed)); +#else FfnLayerWeights weights; weights.up_weight = std::make_unique(DenseWeights(up_proj)); weights.down_weight = std::make_unique(DenseWeights(down_proj)); weights.gate_weight = std::make_unique(DenseWeights(gate_proj)); +#endif FfnConfigs ffn_configs({Atype}); FfnLayerParams Opparams(*input, diff --git a/rtp_llm/device/device_impl.py b/rtp_llm/device/device_impl.py index b477d8b29..7efcb31a0 100644 --- a/rtp_llm/device/device_impl.py +++ b/rtp_llm/device/device_impl.py @@ -67,7 +67,7 @@ def preprocess_groupwise_weight_params(self, qweight_int32, qzeros_int32, scales qweight = qweight.to(torch.int8) if not is_int8: qweight = packer(qweight) - qweight_interleaved = preprocess_weight_scale(qweight, scales_fp16) + qweight_interleaved = preprocess_weight_scale(qweight, scales_fp16, "") # zero = 0 if qzeros_int32 = -2004318072 torch.int32 for awq # zero = 0 if qzeros_int32 = 2004318071 torch.int32 for gptq @@ -84,6 +84,9 @@ def preprocess_groupwise_weight_params(self, qweight_int32, qzeros_int32, scales # return processed interleaved weight, original scales and zeros * scales return qweight_interleaved.contiguous().to(device), zeros_x_scales_fp16.contiguous().to(device), scales_fp16.contiguous().to(device) + def shuffle_moe_weight(self, x: torch.Tensor, datatype: torch.dtype, name: str) -> torch.Tensor: + return x + class GpuImpl(DeviceBase): def __init__(self, exported_device: DeviceExporter): super().__init__(exported_device) diff --git a/rtp_llm/model_loader/per_block_fp8_quant_weight.py b/rtp_llm/model_loader/per_block_fp8_quant_weight.py index e06f8898a..0ef14ca0e 100644 --- a/rtp_llm/model_loader/per_block_fp8_quant_weight.py +++ b/rtp_llm/model_loader/per_block_fp8_quant_weight.py @@ -370,6 +370,11 @@ def _postprocess(self, tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], dev scale_weight = processed_res[self.scale.name] scale_weight = scale_weight.reshape(scale_weight.shape[-1], -1) if scale_weight.dim() == 2 else scale_weight processed_res[self.scale.name] = scale_weight + from rtp_llm.ops import get_device + exported_device: DeviceExporter = get_device() + preprocess_weight_scale = exported_device.preprocess_weight_scale + qweight_interleaved = preprocess_weight_scale(kernel_weight, scale_weight, self.kernel.name) + processed_res[self.kernel.name] = qweight_interleaved return processed_res diff --git a/rtp_llm/models/deepseek_dequant.py b/rtp_llm/models/deepseek_dequant.py index 18b600fe1..e57233ed0 100644 --- a/rtp_llm/models/deepseek_dequant.py +++ b/rtp_llm/models/deepseek_dequant.py @@ -1,7 +1,4 @@ -import triton import torch -import triton.language as tl -from triton import Config from rtp_llm.utils.util import check_with_info diff --git a/rtp_llm/ops/libth_transformer.pyi b/rtp_llm/ops/libth_transformer.pyi index f3310b41d..5c7f76c56 100644 --- a/rtp_llm/ops/libth_transformer.pyi +++ b/rtp_llm/ops/libth_transformer.pyi @@ -11,7 +11,7 @@ class DeviceExporter: ... def preprocess_gemm_weight_by_key(self, key: str, weight: torch.Tensor) -> torch.Tensor: ... - def preprocess_weight_scale(self, weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + def preprocess_weight_scale(self, weight: torch.Tensor, scale: torch.Tensor, key: str) -> torch.Tensor: ... def preprocess_weights_for_mixed_gemm(self, weight: torch.Tensor, quant_type: typing.Any, arch: str) -> torch.Tensor: ...