Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add int8 ops for CPU #1178

Conversation

Xia-Weiwen
Copy link

@Xia-Weiwen Xia-Weiwen commented Apr 12, 2024

Waiting for #1173 landed first.

This PR adds int8 ops for CPU.
UTs are added for the following CPU ops:

  • double_quant
  • transform
  • igemmlt
  • mm_dequant
  • extract_outliers

You can run tests by pytest tests/test_functional.py -k <op name>

I have removed the example for CPU in the first commit 13ad630, because it depends on changes to transformers. If you want to run the example, the code is here: https://gist.github.com/Xia-Weiwen/1fdc3c3933d615460994c94e29a0c889. You will need to bypass CUDA checks in transformers by changing its source code and build from source.


CC @jianan-gu @jiqing-feng @jgong5

@Xia-Weiwen Xia-Weiwen changed the title Add int8 ops for Intel CPU & XPU [WIP] Add int8 ops for Intel CPU & XPU Apr 12, 2024
Copy link

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add UTs to cover the new device support?

@jiqing-feng
Copy link
Contributor

jiqing-feng commented Apr 15, 2024

After some minor fix and bypassing CUDA checks in transformers, I got the following performance in Intel CPU and would like to share with you:

model: meta-llama/Llama-2-7b-hf
input (length=8): 'Hamburg is in which country?\n'
generation config: max_length=64
pytorch version: 2.4.0
image

with BNB
int8 output (length=57): Hamburg is in which country?
 февр. 14, 2020
Hamburg is in which country?
The answer is: Germany
Which country: Germany
What is the answer to this question: Hamburg is in which country?

without BNB
bf16 output (length=64): Hamburg is in which country?
Hamburg is in Germany.
Is hamburg in germany?
Hamburg is in Germany.
What country is hamburg in?
Hamburg is in Germany.
Is hamburg in Germany?
Is hamburg in eng

@Xia-Weiwen Xia-Weiwen changed the title [WIP] Add int8 ops for Intel CPU & XPU Add int8 ops for CPU Apr 15, 2024
@matthewdouglas
Copy link
Member

@jiqing-feng Looks great! I'm curious what CPU model that is?

@jiqing-feng
Copy link
Contributor

@jiqing-feng Looks great! I'm curious what CPU model that is?

The model is Llama-2-7b-chat-hf
The CPU is Intel 4th Gen Xeon (SPR):
image

@Titus-von-Koeller
Copy link
Collaborator

Dear @Xia-Weiwen @jianan-gu @jiqing-feng @jgong5,

This is looking great 💪🏻 ! Also the preliminary tests you ran look really promising.

Let's try to get this merged this week. I'll have a deeper look by tmr evening.

You've marked this as draft: what do you think is needed for this to be ready to merge or do you think that's already the case now?

In our iterative approach, merging onto multi-backend-refactor the individual PRs don't need to be absolutely final, it's more important to merge quickly, so that follow up PRs all have a single source of truth (the multi-backend-refactor branch) to address and that we can iterate quickly.

@Xia-Weiwen
Copy link
Author

@Titus-von-Koeller Thanks for your comments. Looks like we need to wait for #1173 landed first. However, if you think it's ok whichever goes the first, we are happy to have this merged :)

@Xia-Weiwen Xia-Weiwen marked this pull request as ready for review April 17, 2024 08:17
@@ -312,13 +314,16 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
state.outlier_pool = GlobalOutlierPooler.get_instance()

# Cast A to fp16
if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
A_dtype = torch.float16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A_dtype = torch.float16
if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
A_dtype = torch.float16

Tensors which are already in in fp16 do not need to be set again

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abhilash1910 Thanks for the comment. Here we are considering other dtypes like bfloat16 for CPU.

A_dtype = torch.float16
if A.device == torch.device('cpu'):
    A_dtype = torch.bfloat16
if A.dtype != A_dtype:
    warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes correct but if tensor already in fp16 then no need to convert right? the condition only applies if bf16 or other precision applies, then it goes in the condition (logic remains same I think ). Let me know your thoughts. Looks ok eitherway.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conversion is done afterwards. Here is just to print a warning.
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(A_dtype), threshold=state.threshold)
And in fact, if tensor is already in A_dtype, no action will be taken.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question that might be related here. Do we need to consider any changes (e.g. fall back to fp32) for users with a CPU that does not have AVX512-BF16 or AMX? Or is that something handled by torch?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will fall back to fp32 automatically. It's handled by torch.

@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16], ids=id_formatter("dtype"))
def test_coo_double_quant(dim1, dim2, device, dtype):
if device == "cuda" and dtype == torch.bfloat16:
pytest.skip("BFloat16 not supported on CUDA")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With bf16 supported since Ampere it might be best to make this more clear and say something like "bfloat16 is not implemented for this operation on CUDA backend"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Updated.

@jiqing-feng
Copy link
Contributor

jiqing-feng commented Apr 22, 2024

Hi @Titus-von-Koeller . We would like to get your review : )

Besides, after this PR merged, we can add CPU building path to the installation readme. It would be great if we can also have a pip wheel.

Dear @Xia-Weiwen @jianan-gu @jiqing-feng @jgong5,

This is looking great 💪🏻 ! Also the preliminary tests you ran look really promising.

Let's try to get this merged this week. I'll have a deeper look by tmr evening.

You've marked this as draft: what do you think is needed for this to be ready to merge or do you think that's already the case now?

In our iterative approach, merging onto multi-backend-refactor the individual PRs don't need to be absolutely final, it's more important to merge quickly, so that follow up PRs all have a single source of truth (the multi-backend-refactor branch) to address and that we can iterate quickly.

Thanks for your support! Please sync with us if you have any concerns.

Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Titus-von-Koeller
Copy link
Collaborator

Hey!

I am taking the week off to rest up a bit. I'll be back next Monday, the 29th, and this is at the very top of my list. My plan is to first merge #1173, likely with a few modifications/additions, and then merge this one, after merge conflicts are resolved.

Thanks for your work so far, it looks really great!

@jiqing-feng
Copy link
Contributor

jiqing-feng commented May 6, 2024

Hi @Titus-von-Koeller . I would like to share my tests for the newest changes with you.

For inference

input (length=5): 'I am happy because'
generation config: max_length=64, do_sample=False, num_beams=1
pytorch version: 2.4.0.dev20240417+cpu
disabled compile
There are 2 data types:
torch_dtype: 
Users can assign by model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch_dtype)
bnb input data A dtype: A_dtype fixed by codes in bnb
torch_dtype = torch.bfloat16
config1: transformers full bf16
config2: bf16 + bnb int8 w/ A_dtype=bf16
config2: bf16 + bnb int8 w/ A_dtype=fp32
model = meta-llama/Llama-2-7b-hf
image

output:
config 1: I am happy because I have been able to help my friend who is going through a difficult time. I have been there for her and listened to her when she needed someone to talk to. I have also been able to offer her support and encouragement when she felt like giving up. Seeing her smile and knowing
config 2: I am happy because I have been able to help my friend who is going through a difficult time. I have been able to listen to her, offer support and advice, and help her find resources that can help her. I feel good because I know that I am making a positive impact in her life, and that
config 3: I am happy because I have been able to help my friend who is going through a difficult time. I have been able to listen to her, offer support and advice, and help her find resources that can help her. I feel good because I know that I am making a positive impact in her life, and that

For fine-tune

script: https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_opt_bnb_peft.ipynb
only changes in dataset:
data = data.map(lambda samples: tokenizer(samples["quote"], padding="max_length", max_length=64, truncation=True), batched=True)
CPU: stock pytorch with single socket on SPR (56 cores)

CPU: HF with BNB int8
bf16=True
loss:
{'loss': 2.2501, 'grad_norm': 1.544872522354126, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.01}
{'loss': 1.7135, 'grad_norm': 1.0868957042694092, 'learning_rate': 0.00015800000000000002, 'epoch': 0.5}
{'loss': 2.2395, 'grad_norm': 0.8754072189331055, 'learning_rate': 8.6e-05, 'epoch': 1.0}
{''train_loss': 1.9395700389146804, 'epoch': 1.28}

CPU: HF without BNB (amp bf16)
bf16=True
loss:
{'loss': 2.2408, 'grad_norm': 1.5752161741256714, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.01}
{'loss': 1.706, 'grad_norm': 1.1324536800384521, 'learning_rate': 0.00015800000000000002, 'epoch': 0.5}
{'loss': 2.2366, 'grad_norm': 0.8937948346138, 'learning_rate': 8.6e-05, 'epoch': 1.0}
{ 'train_loss': 1.9383774387836457, 'epoch': 1.28}

image

@Xia-Weiwen
Copy link
Author

Hi @Titus-von-Koeller I have rebased and updated this PR. Please review again. Thanks.

@Titus-von-Koeller
Copy link
Collaborator

Hey @Xia-Weiwen, thanks for the updates and your ongoing work, really appreciated! It's super cool to see what you accomplished so far and I'm excited to have your work be integrated soon into BNB.

On my side, I'm working on a slight rework of the multi-backend abstraction in #1195.

The idea is for me to work that PR this week (other than Thu) and receive feedback from the community. We would then integrate the feedback, merge and ask you to adapt your PR to the changed abstraction.

I think it makes most sense in that order, do you agree?

@Xia-Weiwen
Copy link
Author

Hi @Titus-von-Koeller Thanks for the info. However, we are wondering why this PR depends on #1195. Looks like there is no conflict. Can we land this PR first and do some refactoring afterwards if needed? cc @jiqing-feng

@jiqing-feng
Copy link
Contributor

Hi @Titus-von-Koeller Thanks for the info. However, we are wondering why this PR depends on #1195. Looks like there is no conflict. Can we land this PR first and do some refactoring afterwards if needed? cc @jiqing-feng

Agreed.

Hi @Titus-von-Koeller . Thanks for your support. We believe there is no functional change, just some refactoring. I see your PR #1195 just opened for a day, and this PR has been open for at least a month and passed a few rounds of reviews. I think it would be better to finish this PR so we can start the next step; as you know, many works rely on this PR, and we need to discuss it with other libs like transformers and peft to enable the CPU bnb support.

Besides, refactoring is not short-term work; we need to keep changing by communicating with the community, so it is more reasonable to enable function while refactoring instead of waiting for refactoring ready. WDYT? I hope you can understand. We would like to hear your feedback and consider your opinion as the first option. Thanks sincerely.

@matthewdouglas
Copy link
Member

Hi all,

I would agree that this could, and probably should, be merged ahead of #1195. Any changes needed here as a result of #1195 could happen in that separate PR.

@Titus-von-Koeller
Copy link
Collaborator

Hey all,

thanks for your input. Agreed, the idea of the multi-backend-refactor target branch was to take an iterative approach and it makes sense to take your work and use it as the basis for further iterations, instead of blocking progress.

The other PR branch is currently not conflicting, because it doesn't contain the full work yet. The reason why I was hesitant is that I know the refactor will break a lot of things. Another aspect is that we currently have no way of testing the Intel backend with our personal machines or CI, meaning I have no visibility of what I'm breaking.

Who from the Intel side could we contact regarding collaborating on making an Intel runner available to us?

Regarding potential breaking changes by the refactor: It would be very important for me to have your buy-in and support in future follow-up PRs to make sure the Intel parts will be adjusted for the changes in the device abstraction / backend API, lazy initialization and dispatching that we still have to do. Do I have that?

On the topic of moving forward and the time it took so far: We're still in quite a special situation with BNB right now. The owner of BNB and only one with true full ownership of the code base and especially it's lower-level components has been almost completely unavailable for the last 6 months. This is exactly the duration since I officially took over maintenance with the support of HF and Tim's absence became only apparent right after my start date. Despite the gracious support by HF, BNB is still an independent and very small FOSS project and I'm the only one fully tasked with it. We've had a lot of technical debt from before we took over maintenance and community concerns had been unaddressed to a large degree for an extended period before we took over.

We have (had) many high-impact, fundamental things to handle with BNB in order to facilitate this new stage of the BNB story (where we move from FOSS academic project with a single main contributor to fully-maintained critical infrastructure that is truly community-driven), many of which needed to be prioritized over the multi-backend refactor. Some other important topics related to the refactor, for example setting up an org to move BNB to, have been temporarily halted as we need consent from Tim. This is for example needed to set up a full-scale CI/CD infrastructure that would be super helpful with the refactor and the maintaining multiple backends over time. Formulating a plan forward that aligns with the interests of all stakeholders took a bunch of effort and time, yet we can't yet finalize it, because Tim is not available for the time being.

On top of that I wasn't feeling very well in the last weeks.

I was doubtful if I should share these internal details of my and the projects process, but I sympathize with your situation and can understand it might be a bit frustrating at times to have to keep on waiting, especially given that you have such a skilled and motivated team working on your end of the bargain. To also gain some sympathy from you, I thought it might be helpful to share a bit from my side, too.

Thanks for your continued effort as well as understanding of our circumstances and let me re-affirm my commitment to enabling this effort.

I already reviewed your PR multiple times over the last week and I think it's really good work and agree it's ready to merge. My concerns were more around the conflicts the following refactor would cause, but now I think we agree on the way forward in this regard as well.

Thanks for this important contribution and helping in pushing forward the democratization of AI through free and open source AI!

@Titus-von-Koeller
Copy link
Collaborator

P.S. the failing docs build was caused by a change from where we're pulling the image. I fixed that in a separate PR.

@Titus-von-Koeller Titus-von-Koeller merged commit 8561f09 into bitsandbytes-foundation:multi-backend-refactor May 7, 2024
1 of 2 checks passed
@Titus-von-Koeller
Copy link
Collaborator

P.P.S. I'll also work on enabling a nightly build for the multi-backend-refactor branch. That will help you in coordinating the integrations with other libraries.

@jiqing-feng
Copy link
Contributor

jiqing-feng commented May 8, 2024

Hi, @Titus-von-Koeller. Thanks for your support! I know it's not easy to move forward by refactoring while developing. Feel free to contact us if you need any help or if anything requires our changes. We will keep an eye on your refactoring work because we have 4-bit features that need to be developed. Thanks very much!

@Xia-Weiwen
Copy link
Author

Xia-Weiwen commented May 8, 2024

Hi @Titus-von-Koeller Viele Dank fuer Ihre Hilfe. And thanks for sharing your thoughts.

Another aspect is that we currently have no way of testing the Intel backend with our personal machines or CI, meaning I have no visibility of what I'm breaking. Who from the Intel side could we contact regarding collaborating on making an Intel runner available to us?
Regarding potential breaking changes by the refactor: It would be very important for me to have your buy-in and support in future follow-up PRs to make sure the Intel parts will be adjusted for the changes in the device abstraction / backend API, lazy initialization and dispatching that we still have to do. Do I have that?

We fully understand the situation you are facing. And we are glad to take on the development and maintenance work of the CPU part. This PR has enabled tests on CPU in test_functional.py. You may run the tests by pytest tests/test_functional.py -k <op name>. And <op name> can be double_quant, transform, igemmlt, mm_dequant and extract_outliers. You may also run the example that I mentioned in the PR description at the top (changes to transformers are needed).
You may run on your PC or on any AWS instance with 3rd or 4th generation of Intel(R) Xeon(R) processors. And we are also looking for available machines on our side.
And BTW, we are working on 4-bit operators on CPU backend. The PR will be ready in one or two weeks.
If you need more help on the aspects you mentioned, please feel free to contact @jiqing-feng @jianan-gu and me, or @jgong5 who is more senior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants