Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential Issue with off_weight_l Calculation in Quantized Matrix Multiplication Kernel #6

Open
LLSGYN opened this issue Feb 20, 2025 · 0 comments

Comments

@LLSGYN
Copy link

LLSGYN commented Feb 20, 2025

Hello,

I appreciate the effort and time you’ve invested in creating and sharing these exercises and reference solutions. Thank you for supporting the community and providing such valuable resources!

I’ve been working through your exercises. While reviewing the reference solution, I encountered a part of the code that I believe might contain an error in calculating the weight offsets.

Issue Details

In the problem, weights are stored in 4 bits, with FPINT (e.g., 8) weights packed into a single 32-bit integer. The reference solution computes the weight offset as follows:

off_weight_l = l + tl.arange(0, B_MID // FPINT)

However, based on my understanding, since each FPINT weights are packed together, the offset calculation should account for the packing by dividing the index l by FPINT. I believe the correct calculation should be:

off_weight_l = l // FPINT + tl.arange(0, B_MID // FPINT)

This adjustment ensures that the weight indices are correctly mapped to their respective packed positions within the 32-bit integers. Without this division, the offset might incorrectly reference unpacked indices, potentially leading to erroneous weight retrieval and subsequent computation.

Reference Code Snippet

Here is the relevant portion of the reference answer for context:

# load weight
# note: our weight will be stored in 4bits.
off_weight_l = l + tl.arange(0, B_MID // FPINT)
mask_weight_l = off_weight_l < (MID // FPINT)
off_weight = off_j[:, None] * (MID // FPINT) + off_weight_l[None, :]
mask_weight = mask_j[:, None] & mask_weight_l[None, :]
weight = tl.load(weight_ptr + off_weight, mask=mask_weight)

My Concern

By not dividing l by FPINT, the off_weight_l may not correctly represent the indices of the packed weights. Since each packed integer contains FPINT weights, failing to perform this division could result in accessing incorrect memory locations, leading to incorrect weight values being used in the computation.

Request for Clarification

Could you please verify if the offset calculation for off_weight_l in the reference solution is correct?

Thank you for your assistance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant