Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(store): Cache output stride parameters in registers to reduce global loads #6

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

garethpaul
Copy link

Summary

This PR optimizes the store() function in the FlashMLA kernel by caching frequently used output stride parameters (i.e., o_batch_stride, o_row_stride, and o_head_stride) into local registers. By using the __ldg intrinsic to read these values once per thread, we reduce repetitive global memory accesses and potentially lower memory latency. This improvement should help improve kernel performance without affecting functionality.

Key Changes

  • Added local caching for params.o_batch_stride, params.o_row_stride, and params.o_head_stride in the store() function.
  • Utilizes the __ldg intrinsic to hint to the compiler that these values are read-only.
  • Maintains functional consistency, ensuring that the remainder of the function remains unchanged.

Impact

This change is a targeted performance optimization that minimizes redundant memory loads from global memory, improving efficiency without altering correctness.

@beginlner
Copy link
Collaborator

I don't think the changes make any difference on performance. If so, we should deal with other attributes of params in the same way. Could you please update a benchmark comparison?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants