Skip to content

[Kernel] GGUF MMVQ kernel for multiple input vectors #18754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

SzymonOzog
Copy link
Contributor

@SzymonOzog SzymonOzog commented May 27, 2025

performance_comparison.pdf
Running some kernels in matvec mode for small input sizes is more beneficial than mma. Currently the heuristic is very lazy and safe, happy to take some inputs on how to make the kernel choice better

Signed-off-by: SzymonOzog <[email protected]>
Signed-off-by: SzymonOzog <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link

mergify bot commented May 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @SzymonOzog.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 27, 2025
@mergify mergify bot removed the needs-rebase label May 27, 2025
@Isotr0py Isotr0py self-assigned this May 27, 2025
Copy link
Collaborator

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable to me. Thanks for the optimization!

Signed-off-by: SzymonOzog <[email protected]>
@Isotr0py Isotr0py enabled auto-merge (squash) May 27, 2025 09:28
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label May 27, 2025
@SzymonOzog
Copy link
Contributor Author

Some tests seem to be failing, will take a look

@@ -19,7 +20,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
for (auto i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
const int ibx = row*blocks_per_row + i; // x block index

const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
const int iby = vec*blocks_per_row + i * (qk/QK8_1); // y block index that aligns with ibx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocks_per_row points to weights (vx), not activations (vy).

Even if it pointed to activations, it still wouldn’t work because GGML activations need rows aligned to multiples of 512 elements. So, alignment is required.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nlzy/vllm-gfx906@cea98a7

This commit does the same thing and should work well. Feel free to refer to its content as needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've just noticed that my internal tests had an error, thanks a lot for pointing this out. Will fix this soon but right now I'm running into problems with vLLM compilation #18691

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants