Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

where are the outliers stored in LLM.int8 quantization for inference suing transformers library on AMD GPU? #1320

Closed
vbayanag opened this issue Aug 15, 2024 · 4 comments
Labels
question Further information is requested

Comments

@vbayanag
Copy link

Hi, I'm using BitsAndBytesConfig on HF's Transformers library to quantize facebook/opt-66B model. But when I print the dtype of weights of varoius layers, all of them turn out to be of int8.

Capture

This makes me wonder where are the outliers stored? Since LLM.int8 algorithm requires outliers to be stored in fp16/fp32 as part of matrix decomposition algo. Can someone kindly clarify?

Moreover, Figure 2 of the paper shows that FP16 inputs are converted to int8 after detecting outliers, but in our case the model is already converted/quantized.

llmint8

@Xing-Zhuang
Copy link

Xing-Zhuang commented Aug 18, 2024

When linear8bit forward , it will find activation outlier cols and the corresponding weight matrix rows will be dequantized.
You can see bitsandbytes/nn/modules for more details.

@ZxAndJb
Copy link

ZxAndJb commented Oct 12, 2024

I am not sure whether I get to the point correctly. Based on my check, the whole weight metrix still needs to be squeezed to int8, including the rows containing outliers. And during forward process, the corresponding rows will be converted to float16 once the outliers are detected. So the difference is the computation of outliers need to be dequantized first and then doing metrix multiplication, but the parts without outliers will be computed in Int8 (then int32) and then doing dequantization. As mentioned in the paper, quantizing outliers will harm performance greatly. I am confused why the actual implementation still introduces this error. Maybe I am wrong, waiting someone to discuss.

@sidhantls
Copy link

sidhantls commented Dec 28, 2024

@ZxAndJb Using weight.CB and weight.SCB worked well for dequantizing the weights. however, what if there are outiers?

qlayer = model_8bit.model.decoder.layers[1].self_attn.k_proj
res = (qlayer.weight.CB * qlayer.weight.SCB) / 127 # dequantized weights

However, what if there is an outlier column, accoding to bitsandbytes, those columns are saved in fp16, not int8. do you know where those are saved ?

@matthewdouglas
Copy link
Member

@sidhantls Please see the answer here: #1400 (comment)

In short, outliers in the activations are kept in fp16. The corresponding channels in the weights are dequantized (with some error) from int8 to fp16.

@matthewdouglas matthewdouglas added the question Further information is requested label Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants