Skip to content

[Perf] Tunings for SM100 FP8 CUTLASS kernel #18778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented May 27, 2025

I noticed that the FP8 CUTLASS kernel for Blackwell only had one default set of configs. This PR adds new configs for small M < 128.

For Llama 8B on B200, these tunings offer a:

  • 1.7 to 2.5x speedup at M<64
  • 1.1 to 1.3x speedup at 64<=M<128

Kernel benchmarks using #17126

# B200 original tunings
python benchmarks/kernels/bench_fp8_gemm.py --model meta-llama/Llama-3.1-8B-Instruct --tp-sizes 1
meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs GB/s:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0    11.842750               4.384695               4.141685                       5.059773                       4.960429
1         16.0   192.371772              85.477309              79.608508                      96.735811                      96.406114
2         64.0   807.719071             589.708519             516.921588                     758.622920                     752.988363
3        128.0  1493.427964            1164.192322            1022.805103                    1505.913815                    1492.575158
4        256.0  2330.405171            2220.442815            1813.278596                    3008.296875                    2981.101615
5        512.0  2907.547764            2793.742472            2203.694550                    3672.868551                    3653.372327
6       1024.0  3185.960190            3869.274803            2936.256100                    5253.748472                    5213.397951
7       2048.0  3372.689911            4142.692217            3109.779229                    5727.176215                    5689.243846
8       4096.0  3462.188289            4291.145308            3204.745041                    5911.353402                    5899.836283
9       8192.0  3529.249912            4397.600739            3282.537443                    6098.495194                    6085.388654
10     16384.0  3566.596828            4546.730027            3373.615805                    6298.621054                    6287.915519
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs GB/s:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     5.481086               3.994605               3.679018                       4.808255                       4.673826
1         16.0   145.731275              75.527900              68.386568                      91.504401                      90.064710
2         64.0   575.803168             394.026762             345.690693                     507.879320                     504.012236
3        128.0  1139.410454             779.473527             685.449889                    1007.821323                     998.688924
4        256.0  1875.656947            1486.448134            1213.644060                    2014.891023                    1996.817344
5        512.0  2763.397677            2596.513320            1876.577101                    3946.713561                    3910.224984
6       1024.0  3142.721171            3239.147685            2321.719735                    4832.091875                    4807.995344
7       2048.0  3318.816545            3651.365389            2562.171046                    5485.452742                    5446.915454
8       4096.0  3414.255510            3752.133901            2590.631646                    5823.098124                    5807.590771
9       8192.0  3504.274870            3834.227574            2653.123452                    6003.879744                    5993.921400
10     16384.0  4048.244939            3953.554872            2722.017045                    6171.108894                    6164.274055
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs GB/s:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0    12.157825               5.612388               5.539948                       6.181624                       6.000656
1         16.0   172.970068             114.972518             112.622180                     120.732156                     121.298006
2         64.0   674.595646             943.557386             905.512283                    1050.581139                    1062.964463
3        128.0  1300.962379            1858.124579            1761.151736                    2047.242680                    2042.379529
4        256.0  2216.287867            3554.804884            3329.506996                    4014.308414                    3998.976858
5        512.0  2726.096815            4795.019635            4364.503526                    5341.365889                    5324.928601
6       1024.0  2959.186345            5294.689937            4836.956583                    5821.938932                    5814.327508
7       2048.0  3761.815356            5591.323764            5110.608734                    6088.998704                    6083.975578
8       4096.0  3574.236814            5806.798749            5289.014726                    6308.693639                    6300.039740
9       8192.0  4019.744188            5913.920427            5381.996320                    6414.258701                    6403.837270
10     16384.0  4106.402587            5836.563139            5324.548319                    6329.917220                    6323.172853
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs GB/s:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     9.777438               6.095286               5.844354                       6.928244                       6.744669
1         16.0   162.025163             154.025715             144.935378                     185.241624                     184.836429
2         64.0   616.911342             611.287938             576.028417                     740.686773                     738.542393
3        128.0  1204.949288            1216.822830            1146.637308                    1476.530711                    1471.302346
4        256.0  2111.100423            2293.131110            2051.262051                    2952.089686                    2942.423441
5        512.0  2937.260107            3796.200098            3114.371919                    5730.980802                    5704.233235
6       1024.0  3488.413658            4239.281076            3446.170900                    6216.238000                    6212.696169
7       2048.0  3685.025884            4490.287887            3635.015494                    6539.948924                    6529.935247
8       4096.0  3734.385998            4532.309092            3662.371663                    6636.056913                    6629.569313
9       8192.0  3758.236603            4645.041087            3762.776234                    6743.351536                    6718.151235
10     16384.0  4306.204970            4738.751066            3832.668281                    6813.248826                    6805.834591

# B200 new tunings
python benchmarks/kernels/bench_fp8_gemm.py --model meta-llama/Llama-3.1-8B-Instruct --tp-sizes 1
meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs GB/s:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0    11.846480               8.452990               7.572770                      10.641609                      10.534006
1         16.0   192.336576             141.348902             125.896661                     177.382610                     176.327304
2         64.0   809.054397             676.540533             583.294353                     904.338200                     901.570760
3        128.0  1493.297679            1353.990985            1177.501355                    1843.712204                    1838.021954
4        256.0  2350.962645            2220.517499            1815.399353                    3008.279996                    2980.624284
5        512.0  2903.143544            2787.914646            2203.210421                    3672.087705                    3653.324502
6       1024.0  3206.446347            3863.193810            2935.647985                    5250.389648                    5210.453726
7       2048.0  3371.772052            4139.752282            3097.937554                    5695.430751                    5665.540960
8       4096.0  3462.810635            4292.139714            3208.626955                    5912.431026                    5900.400438
9       8192.0  3530.033816            4397.610400            3283.561232                    6095.620584                    6086.099140
10     16384.0  3566.388107            4547.710555            3374.347531                    6297.712858                    6287.613786
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs GB/s:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     5.481237               7.194318               6.265508                       9.530527                       9.511873
1         16.0   145.717362             114.108143              99.566867                     152.901809                     152.985032
2         64.0   575.814872             449.980812             389.597206                     602.248063                     600.669616
3        128.0  1139.590749             905.278908             787.789674                    1234.155324                    1229.686850
4        256.0  1876.291317            1486.630754            1213.623961                    2015.033919                    1996.965312
5        512.0  2764.009823            2591.524744            1879.343830                    3946.809025                    3909.465672
6       1024.0  3142.383890            3230.337927            2321.113332                    4830.144592                    4804.234807
7       2048.0  3318.579790            3651.632276            2557.730213                    5485.786728                    5447.503083
8       4096.0  3414.370536            3749.333637            2590.511600                    5821.179268                    5806.526663
9       8192.0  3504.070655            3834.735907            2654.077070                    6004.740320                    5992.704228
10     16384.0  4047.938574            3953.740013            2722.086526                    6170.869834                    6163.802150
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs GB/s:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0    12.139843              13.499578              13.086876                      14.939551                      14.847452
1         16.0   172.540520             219.474892             210.204189                     239.697254                     230.664293
2         64.0   674.577865            1030.361791             981.054381                    1149.836193                    1148.147840
3        128.0  1301.165223            1872.693659            1806.220627                    2081.156220                    2123.675357
4        256.0  2216.100916            3552.279693            3336.364808                    4014.755649                    3998.710144
5        512.0  2725.949811            4792.585207            4363.568374                    5341.086466                    5326.349901
6       1024.0  2959.125081            5293.626173            4835.906799                    5821.708743                    5813.863106
7       2048.0  3761.956557            5591.420874            5110.979048                    6089.098791                    6083.292157
8       4096.0  3574.307512            5807.307181            5288.928725                    6308.724765                    6300.209702
9       8192.0  4019.744388            5914.312745            5382.379939                    6414.279709                    6404.333069
10     16384.0  4106.374801            5836.666966            5324.534539                    6329.588982                    6323.079665
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs GB/s:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     9.733832              12.054116              11.104702                      15.095716                      15.057972
1         16.0   161.979220             189.099701             175.757232                     238.928259                     238.785025
2         64.0   616.771351             751.830865             698.306492                     956.291017                     954.205153
3        128.0  1204.865625            1483.958057            1379.201876                    1886.236952                    1883.095434
4        256.0  2110.901491            2295.331478            2051.189524                    2952.193198                    2942.234970
5        512.0  2937.621666            3799.227788            3114.414028                    5730.719186                    5704.042751
6       1024.0  3488.194848            4240.428960            3442.777987                    6216.468683                    6211.739077
7       2048.0  3684.893025            4490.729893            3634.670625                    6540.363564                    6530.731937
8       4096.0  3734.498831            4532.994968            3662.073259                    6636.363393                    6630.525610
9       8192.0  3757.974111            4645.556715            3763.901169                    6743.521712                    6721.473664
10     16384.0  4305.834102            4738.611024            3832.780746                    6811.809892                    6806.390517

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _64, _64>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we try Shape<_64, _64, _64>?

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. shall we do some accuracy test?

@houseroad
Copy link
Collaborator

cc: @chenyang78 @drisspg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants