Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for W8A8 quantization with CPU weight offloading #1078

Open
NeoChen1024 opened this issue Jan 17, 2025 · 7 comments
Open

Add support for W8A8 quantization with CPU weight offloading #1078

NeoChen1024 opened this issue Jan 17, 2025 · 7 comments
Assignees
Labels
enhancement New feature or request

Comments

@NeoChen1024
Copy link

I always face problems when trying to run oneshot calibration with CPU offloaded tensors, this limits GPUs with 24GiB VRAM from quantizing models larger than 7~8B.

Describe the solution you'd like
Add support for block swap / layer-wise loading

Describe alternatives you've considered
accelerate's offloading (device_map = "auto") and DeepSpeed ZeRO stage 3 doesn't work when I tried it with oneshot.

Additional context
Doing block swap / layer-wise loading with probably slow-down SmoothQuantModifier calibration by many times, but that enables running calibrations with longer context / bigger models.

@NeoChen1024 NeoChen1024 added the enhancement New feature or request label Jan 17, 2025
@dsikka
Copy link
Collaborator

dsikka commented Jan 17, 2025

Hi @NeoChen1024 do you mind providing the code snippet of what you have tried to run with Smoothquant?

@dsikka dsikka self-assigned this Jan 17, 2025
@NeoChen1024
Copy link
Author

My own generalized quantization script is here: https://github.com/NeoChen1024/scripts/blob/master/llm-compressor-quantize.py
Running llm-compressor-quantize.py --model_id cognitivecomputations/Dolphin3.0-Llama3.1-8B --max_sequence_length 2048 --num_calibration_samples 512 will almost OOM on a 24GiB VRAM GPU (like my own 3090Ti). If I increase the max seq len to 4096, it will OOM.

@dsikka
Copy link
Collaborator

dsikka commented Jan 17, 2025

@NeoChen1024 have you tried providing a device map when running GPTQ?
This should enable you to run models much larger than 7b.

E.g.

from llmcompressor.transformers.compression.helpers import calculate_offload_device_map
device_map = calculate_offload_device_map(
    MODEL_ID, reserve_for_hessians=True, num_gpus=NUM_GPUS, torch_dtype="auto"
)

@NeoChen1024
Copy link
Author

NeoChen1024 commented Jan 17, 2025

I tried that before, but apparently oneshot doesn't like meta tensors, it expects all tensors to be in GPU memory. So I can't quantize models larger than 8B on my single 24GiB GPU.

@dsikka
Copy link
Collaborator

dsikka commented Jan 17, 2025

HI @NeoChen1024 can you share what version of the packages you used the device map with?

@NeoChen1024
Copy link
Author

Newest transformers (4.48.0) and newest llm-compressor (0.3.1)

@mhendrey
Copy link

mhendrey commented Feb 4, 2025

If it helps, here's a bit of code I tried to use for Mistral-Small-24B-2501 on my 4090 (24GB of VRAM)

    with init_empty_weights():
        model = AutoModelForCausalLM.from_pretrained(model_id)
        device_map = infer_auto_device_map(model, max_memory={0: "10GiB"})
    # Tried this, but was still getting OOM errors so switched to the above.
    # device_map = calculate_offload_device_map(
    #    model_id, reserve_for_hessians=True, num_gpus=1
    # )
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map=device_map,
        torch_dtype="auto",
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Load dataset and preprocess.
    ds = load_dataset(dataset_id, split=dataset_split)
    ds = ds.shuffle(seed=None).select(range(num_calibration_samples))

    def preprocess(example):
        return {
            "text": tokenizer.apply_chat_template(
                example["messages"],
                tokenize=False,
            )
        }

    ds = ds.map(preprocess)

    # Tokenize inputs.
    def tokenize(sample):
        return tokenizer(
            sample["text"],
            padding=False,
            max_length=max_seq_len,
            truncation=True,
            add_special_tokens=False,
        )

    ds = ds.map(tokenize, remove_columns=ds.column_names)

    # Configure algorithms. In this case, we:
    #   * apply SmoothQuant to make the activations easier to quantize
    #   * quantize the weights to int8 with GPTQ (static per channel)
    #   * quantize the activations to int8 (dynamic per token)
    recipe = [
        SmoothQuantModifier(smoothing_strength=0.8),
        GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
    ]

    # Apply algorithms and save to output_dir
    oneshot(
        model=model,
        dataset=ds,
        recipe=recipe,
        max_seq_length=max_seq_len,
        num_calibration_samples=num_calibration_samples,
    )

Throws:
compressed_tensors/utils/offload.py", line 280, in offload_to_weights_map
raise NotImplementedError(
NotImplementedError: Updating weights_map with disk offloading is not implemented yet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants