Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] 3bit quant and/or inference regression vs AutoGPTQ #1278

Open
sidhantls opened this issue Feb 14, 2025 · 17 comments · Fixed by #1280
Open

[BUG] 3bit quant and/or inference regression vs AutoGPTQ #1278

sidhantls opened this issue Feb 14, 2025 · 17 comments · Fixed by #1280
Labels
bug Something isn't working

Comments

@sidhantls
Copy link

Describe the bug
Quantizing a model to 3 bits using this repo leads to completely deteriorated performance. On MMLU, it gets 22%. However, when I quantize it using https://github.com/AutoGPTQ/AutoGPTQ, (which is where this repo was forked from?), I get 57%. This was using Llama-3.1-8B-Instruct.

Using this repo, for 4bit I get 66% on MMLU, which is in line with what AutoGPTQ gets for 4 bits

Anyone else noticed that 3bit doesn't work here but works in AutoGPTQ ?

Software Info

Operation System/Version + Python Version
python 3.10

To Reproduce
Quantize model to 3 bits:

    calibration_dataset = load_dataset(
        "allenai/c4",
        data_files="en/c4-train.00001-of-01024.json.gz",
        split="train"
    ).select(range(1024//2))["text"]

    # calibration_dataset = [" ".join(item.split()[:20]) for item in calibration_dataset] # speedup

    quantize_config = QuantizeConfig(
        bits=3,
        group_size=64,
    )

    model = GPTQModel.load(args.model_name, quantize_config)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir)

    # increase `batch_size` to match gpu/vram specs to speed up quantization
    model.quantize(calibration_dataset, batch_size=args.batch_size)

Expected behavior

Performance not to break on 3bits and to allign with AutoGPTQ library

@sidhantls sidhantls added the bug Something isn't working label Feb 14, 2025
@Qubitium
Copy link
Collaborator

I am aware of the regression. We recently did a arc test for all bits and 3bits had unexpected lower score than 2bits.

We will backtrack and find out which commit(s) broke 3bit quant or inference.

@Qubitium Qubitium changed the title [BUG] 3bit performance very bad compared to AutoGPTQ [BUG] 3bit quant and/or inference regression vs AutoGPTQ Feb 14, 2025
@Qubitium
Copy link
Collaborator

@sidhantls We have currently bissected the commits and the 3bit ppl/accuracy regression happened between v1.6.0 and v1.8.0 (v1.6.0 had correct 3bit quality). We should have a fix isolated soon.

@CSY-ModelCloud CSY-ModelCloud linked a pull request Feb 15, 2025 that will close this issue
@Qubitium
Copy link
Collaborator

Qubitium commented Feb 15, 2025

@sidhantls Please check main and see if 3 bits quality has returned to normal. It is now passing our 3bit quality ci tests.

@sidhantls
Copy link
Author

@sidhantls Please check main and see if 3 bits quality has returned to normal. It is now passing our 3bit quality ci tests.

Thank you! Will do in the next few days

@sidhantls
Copy link
Author

@Qubitium Hey, I just checked, it still does not seem to work - Does the CI test you used included evaluation on downstream data?

For Llama-3.2-3B, 3bits I get much higher performance on AutoGPTQ than here. For eg, MMLU is 50% using AutoGPTQ but this repo it is 24%

@Qubitium Qubitium reopened this Feb 19, 2025
@Qubitium
Copy link
Collaborator

Qubitium commented Feb 19, 2025

@sidhantls Can you send us the command and/or script code that yo uused to generate mmlu score? We need to replicate both the quant and scoring code for bug fix alignment. Thanks!

On a side note, gptqmodel calibration data handling is very different from autogptq and would inherently create two different quants. By default, we do not concat calibration data together, by default, and even if calibration_dataset_concat_size (concat mode) is used to mimic autogptq, we do token concatenation (not strings as as autogptq). Obviously, we feel our method produces better quants, based on our hundreds of experiements, given same high quality/realstic calibration data. There are also cases where concat dataset works better for native datasets that is not aligned with native model output. But overall, we recommend non-concat mode for the best quality at the cost of slower quants which can be offset with dataset batching.

With that said, the difference is too large to not be a bug. Once we isolate it, we can finally fix it for good. Thanks for your reports!

You can think of GPTQModel as a fork of autogptq in the begining but now vastly different in almost every possible way. With pending 2.0 release, you will be hard to recognize a single contious block code that is shared with autogptq without a microscrope. =)

@Qubitium
Copy link
Collaborator

Also please post your complete mmlu score if possible for the two models as mmlu (sum) is sub-divided into many categories. Currently our ci tests uses faster arc-challenge for downstream validation for many tests but I need to check if bits tests is using ppl or more rigorous lm-eval.

@Qubitium
Copy link
Collaborator

Qubitium commented Feb 19, 2025

@sidhantls Ran out of time today but for tonight we got the following result: Benchmarked using Torch kernel, all 4bits, group_size=128, C4. We will do more testing tomorrow. Please send us your benchmarking script so we can align our test env.

Image

@sidhantls
Copy link
Author

@Qubitium Great, thanks for following up. I'm using the evaluation harness. I can share the script.

Also, I had used group 64. I can test with a group 128 and see the result

@Qubitium
Copy link
Collaborator

Qubitium commented Feb 19, 2025

@sidhantls Please send us the exact same cli command or short script that triggered the lm-eval since depending on executor of the model, kernels used and args, we will get 5% plus different scores between you and I.

I will also check again on gp 64.

@sidhantls
Copy link
Author

sidhantls commented Feb 20, 2025

I ran 3bit quantization with group-size 128 with meta-llama/Llama-3.2-3B, but still results are not aligned with yours or AutoGPTQ. Not sure what the discrepancy is here

Params:

  • Model Name: meta-llama/Llama-3.2-3B
  • Group Size: 128
  • Bits: 3
  • Batch Size: 4

Metrics

(Used the Evaluation Harness):

  • MMLU: 0.23 (overall)
    • mmlu_humanities: 0.248
    • mmlu_social_sciences: 0.22
    • mmlu_stem: 22.2

How to Reproduce:

I have attached a google collab that reproduces these results. The outputs are printed and the notebook should run as well

Notebook

@Qubitium Thank you, appreciate you looking into this

@Qubitium
Copy link
Collaborator

@sidhantls Thanks for the script. Do you have the script for autogptq + eval as well? We need to evaluate both sides exactly.

@sidhantls
Copy link
Author

sidhantls commented Feb 21, 2025

@Qubitium Sure, I can share a script to reproduce AutoGPTQ in some time. Any idea though where my ModelCloud script deviates from yours so as to produce different results? Maybe if you shared your ModelCloud script that generated the above table I can try also seeing what's different

@Qubitium
Copy link
Collaborator

Qubitium commented Feb 21, 2025

@sidhantls We just merged a huge PR into main. Please check the tests/tests_bits_new.py file for example how we tested and evaled the quantized model. main is currently beta quality as the huge PR has completedly reworked the internals. CI tests passing but we need to fix more things and do more testing.

Using tests/tests_bits_new.py
Image

We did get low scores for 3bits but the score is some what consistent with GPTQModel b2 vs b3 vs b4 vs b8 so that's why we needthe autogptq scripts/eval code to cross-ref so see exactly where it deviated for us or if is there an alignment issue with evaluation.

@sidhantls
Copy link
Author

tests/tests_bits_new.py

Thanks for sharing the table. I cannot find tests/tests_bits_new.py. I only see tests/tests_bits.py

@Qubitium
Copy link
Collaborator

tests/tests_bits_new.py

Thanks for sharing the table. I cannot find tests/tests_bits_new.py. I only see tests/tests_bits.py

Try to remove your clone and re-clone. I had to force-push a bad merge yesterday so if you cloned in that 8 hour window, you need to re-clone:

https://github.com/ModelCloud/GPTQModel/blob/main/tests/test_bits_new.py

@Qubitium
Copy link
Collaborator

Good news with lastest main PR merge. You can now lm-eval faster, using GPTQModel.eval() api without having to force backend=BACKEND.TORCH. Just remove it (it will auto select Marlin) or replace it with backend=BACKEND.MARLIN for much faster eval at nearly the same quailty (ci verified) and net you the same quality of output as Torch kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants