Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 support #54

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open

fp8 support #54

wants to merge 31 commits into from

Conversation

endurehero
Copy link

@endurehero endurehero commented Feb 28, 2025

Functionality

Support FP8 WGMMA based on the async pipeline design of FlashMLA. The TransV part draws on the implementation of SmemTranspose64x64 in Fa3.
Currently, Q/K/V only support symmetric PerTensor quantization. Since the maximum value of P does not exceed 1, the f32tofp8_cast is directly used for quantization.

Performance

cuda driver version: 535.183.06
nvcc version: 12.8
torch version: 2.6

On the H20, MLA typically demonstrate a high degree of arithmetic intensity. Consequently, the Memory Floating - point Utilization (MFU) is employed as a performance metric.
image

On the H800, MLA typically encounter memory-bound situations. Consequently, the Memory Bandwidth Utilization (MBU) metric is adopted to evaluate the performance of the kernel. There is still a lot of room for optimization on the H800. Look forward to working together.
image

Reproduction

python3 ./tests/test_flash_mla.py --dtype e4m3

@endurehero endurehero closed this Feb 28, 2025
@endurehero endurehero changed the title support fp8 fp8 support Feb 28, 2025
@endurehero endurehero reopened this Feb 28, 2025
@endurehero endurehero mentioned this pull request Feb 28, 2025
@sijiac
Copy link
Contributor

sijiac commented Mar 1, 2025

awesome, did you mind adding a compile flag to save the time when FP8 is not needed? Thanks

@endurehero
Copy link
Author

endurehero commented Mar 1, 2025

awesome, did you mind adding a compile flag to save the time when FP8 is not needed? Thanks

Of course. Already Done

@beginlner
Copy link
Collaborator

beginlner commented Mar 1, 2025

Great work! However, I can’t merge this PR at the moment because, based on our tests, per-sequence kvcache scaling significantly reduces accuracy for MLA.

@endurehero
Copy link
Author

Great work! However, I can’t merge this PR at the moment because, based on our tests, per-sequence kvcache scaling significantly reduces accuracy for MLA.

What about the granularity of PerPageBlock? I can easily adapt it

@beginlner
Copy link
Collaborator

beginlner commented Mar 1, 2025

What about the granularity of PerPageBlock? I can easily adapt it

We think PerPageBlock is neither enough. kv_rope (64) needs to be bf16.

@endurehero
Copy link
Author

What about the granularity of PerPageBlock? I can easily adapt it

We think PerPageBlock is neither enough. kv_rope (64) needs to be bf16.

Got it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants