Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 Support #56

Open
endurehero opened this issue Feb 28, 2025 · 0 comments
Open

FP8 Support #56

endurehero opened this issue Feb 28, 2025 · 0 comments

Comments

@endurehero
Copy link

endurehero commented Feb 28, 2025

PR

#54

Intro

Support FP8 WGMMA based on the async pipeline design of FlashMLA. The TransV part draws on the implementation of SmemTranspose64x64 in Fa3.
Currently, Q/K/V only support symmetric PerTensor quantization. Since the maximum value of P does not exceed 1, the f32tofp8_cast is directly used for quantization.

Performance

cuda driver version: 535.183.06
nvcc version: 12.8
torch version: 2.6

On the H20, MLA typically demonstrate a high degree of arithmetic intensity. Consequently, the Memory Floating - point Utilization (MFU) is employed as a performance metric.
image

On the H800, MLA typically encounter memory-bound situations. Consequently, the Memory Bandwidth Utilization (MBU) metric is adopted to evaluate the performance of the kernel. There is still a lot of room for optimization on the H800. Look forward to working together.
image

Reproduction

python3 ./tests/test_flash_mla.py --dtype e4m3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant