Skip to content

[Bug]: InternVL2 support for AWQ quantizationΒ #1929

@marvinzh

Description

@marvinzh

βš™οΈ Your current environment

Environment Information

Operating System: Linux-5.15.0-72-generic-x86_64-with-glibc2.35
Python Version: 3.10.18 (main, Jun 5 2025, 13:14:17) [GCC 11.2.0]
llm-compressor Version: 0.8.1
compressed-tensors Version: 0.12.2
transformers Version: 4.56.2
torch Version: 2.8.0
CUDA Devices: ['NVIDIA A100-SXM4-80GB']
AMD Devices: None

πŸ› Describe the bug

Hi team,
we are trying to use AWQ to quantize InternVL2 model https://huggingface.co/OpenGVLab/InternVL2-8B

at first, we met some torch.fx issue, like the int() cast in following code snippet is not supported by torch.fx

   (self.create_arg(fn(*args)),),
  File "InternVLChatModel_8756586350349_autowrapped", line 20, in forward
    input_embeds = input_embeds.reshape(B * N, C)
  File "/home/xxx/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_internvl_chat.py", line 198, in extract_feature
    h = w = int(vit_embeds.shape[1] ** 0.5)
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'HFProxy'
def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
                          'which results in a transposed image.')
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

as the cast is not necessary as these variable are fixed once the model is trained, we manually rewrite these expression to remove the int cast.

the process can preceed after the rewrite, however, we have the following issue in the 2nd propagation step

Preparing cache: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [00:02<00:00, 37.15it/s]
(1/2): Calibrating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [00:06<00:00, 14.79it/s]
Smoothing: 0it [00:00, ?it/s]
(1/2): Propagating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [00:05<00:00, 18.77it/s]
(2/2): Calibrating: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [00:00<00:00, 130.18it/s]
Smoothing: 0it [00:00, ?it/s]
(2/2): Propagating:   0%|                                                                                                             | 0/100 [00:00<?, ?it/s]

Traceback (most recent call last):
  File "/mnt/xxxx/workspace/xxxxxx/quant/test_awq.py", line 97, in <module>
    oneshot(
  File "/mnt/xxxxx/xxxxxx/miniconda/envs/quant/lib/python3.10/site-packages/llmcompressor/entrypoints/oneshot.py", line 330, in oneshot
    one_shot()
  File "/mnt/xxxxx/xxxxxx/miniconda/envs/quant/lib/python3.10/site-packages/llmcompressor/entrypoints/oneshot.py", line 158, in __call__
    self.apply_recipe_modifiers(
  File "/mnt/xxxxx/xxxxxx/miniconda/envs/quant/lib/python3.10/site-packages/llmcompressor/entrypoints/oneshot.py", line 201, in apply_recipe_modifiers
    pipeline(
  File "/mnt/xxxxx/xxxxxx/miniconda/envs/quant/lib/python3.10/site-packages/llmcompressor/pipelines/independent/pipeline.py", line 45, in __call__
    pipeline(model, dataloader, dataset_args)
  File "/mnt/xxxxx/xxxxxx/miniconda/envs/quant/lib/python3.10/site-packages/llmcompressor/pipelines/sequential/pipeline.py", line 112, in __call__
    inputs = activations.fetch(batch_idx, subgraph.input_names)
  File "/mnt/xxxxx/xxxxxx/miniconda/envs/quant/lib/python3.10/site-packages/llmcompressor/pipelines/cache.py", line 104, in fetch
    return {
  File "/mnt/xxxxx/xxxxxx/miniconda/envs/quant/lib/python3.10/site-packages/llmcompressor/pipelines/cache.py", line 105, in <dictcomp>
    key: self._onload_value(subgraph_input)
  File "/mnt/xxxxx/xxxxxx/miniconda/envs/quant/lib/python3.10/site-packages/llmcompressor/pipelines/cache.py", line 210, in _onload_value
    raise e
  File "/mnt/xxxxx/xxxxxx/miniconda/envs/quant/lib/python3.10/site-packages/llmcompressor/pipelines/cache.py", line 205, in _onload_value
    setattr(value, field.name, self._onload_value(v))
  File "/mnt/xxxxx/xxxxxx/miniconda/envs/quant/lib/python3.10/site-packages/llmcompressor/pipelines/cache.py", line 195, in _onload_value
    value = intermediate.value
AttributeError: 'NoneType' object has no attribute 'value'

following the stacktrace, we print the variable that trigger the exception in cache.py using following code snippet

        if is_dataclass(value):
            for field in fields(value):  # `asdict` is recursive, not applicable here
                v = getattr(value, field.name)
                try: 
                    setattr(value, field.name, self._onload_value(v))
                except Exception as e:
                    print("value:", value)
                    print("field:", field)
                    print("v:", v)
                    raise e
value: CausalLMOutputWithPast(loss=None, logits=tensor([[[-6.5625, -5.8438, -5.3125,  ..., -4.2500, -5.4062, -5.5625],
         [ 3.1562,  4.3125,  1.0391,  ...,  5.6875,  5.4688,  5.0625],
         [ 4.5000,  4.5625,  3.0625,  ...,  6.9375,  7.0312,  6.5312],
         ...,
         [ 8.0625,  9.3750,  6.9688,  ...,  9.6250,  9.5000,  9.1875],
         [ 4.8750,  5.7500,  5.8438,  ...,  6.2500,  6.1562,  6.0000],
         [ 2.1562,  3.3906,  0.6875,  ...,  3.4688,  3.1406,  3.2812]]],
       device='cuda:0'), past_key_values=None, hidden_states=None, attentions=None)
field: Field(name='loss',type=typing.Optional[torch.FloatTensor],default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x7f60485226b0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD)
v: None

looks like the loss field is None which is unexpected. could you please help to guide what to do next to resolve this issue?

πŸ› οΈ Steps to reproduce

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier, AWQMapping

from utils import load_image

# Load model.
model_id = "OpenGVLab/InternVL2-8B" # better to download and use following code snippet to replace `forward` method
model = AutoModel.from_pretrained(model_id, torch_dtype="auto", trust_remote_code=True)
print(model)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
NUM_CALIBRATION_SAMPLES = 100
DATASET = "xxx.jsonl"

def preprocess(example):
    messages = []
    for turn in example["conversations"]:
        if turn["from"] == "human":
            messages.append({
                "role": "user",
                "content": turn["value"]
            })
        elif turn["from"] == "gpt":
            pass
        else:
            raise ValueError
            
    prompt_ids = tokenizer.apply_chat_template(messages)
    example["input_ids"] = prompt_ids
    return example

# Load dataset and preprocess.
ds = load_dataset('json', data_files=DATASET, split='train')

ds = ds.map(preprocess)
print(ds[0])

def data_collator(batch):
    assert len(batch) == 1
    item = {key: value for key, value in batch[0].items()}
    item["pixel_values"] = torch.concat([load_image(x) for x in item["image"]])
    item["labels"] = torch.LongTensor([item["input_ids"]])
    item["input_ids"] = torch.LongTensor([item["input_ids"]])
    return item


# Recipe
recipe = [
    AWQModifier(
        ignore=["re:.*lm_head", "re:mlp1.*", "re:.*vision_model.*"],
        scheme="W4A16", 
        targets=["Linear"],
        offload_device=torch.device("cpu"),
        mappings=[
                # AWQMapping(
                #     "re:.*input_layernorm",
                #     ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
                # ),
                AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
                # AWQMapping(
                #     "re:.*post_attention_layernorm",
                #     ["re:.*gate_proj", "re:.*up_proj"],
                # ),
                # AWQMapping(
                #     "re:.*up_proj",
                #     ["re:.*down_proj"],
                # ),
            ]
    ),
]


# Perform oneshot
oneshot(
    model=model,
    tokenizer=model_id,
    dataset=ds,
    recipe=recipe,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    trust_remote_code_model=True,
    data_collator=data_collator,
    sequential_targets=["InternLM2ForCausalLM"]
)

replace the forward method in modeling_internvl_chat.py for fast reproduce

    def forward(
            self,
            pixel_values: torch.FloatTensor,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            image_flags: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # image_flags = image_flags.squeeze(-1)
        input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()

        # vit_embeds = self.extract_feature(pixel_values)
        # vit_embeds = vit_embeds[image_flags == 1]
        # vit_batch_size = pixel_values.shape[0]

        # B, N, C = input_embeds.shape
        # input_embeds = input_embeds.reshape(B * N, C)

        # if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
        #     print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')

        # input_ids = input_ids.reshape(B * N)
        # selected = (input_ids == self.img_context_token_id)
        # try:
        #     input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
        # except Exception as e:
        #     vit_embeds = vit_embeds.reshape(-1, C)
        #     print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
        #           f'vit_embeds.shape={vit_embeds.shape}')
        #     n_token = selected.sum()
        #     input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]

        # input_embeds = input_embeds.reshape(B, N, C)

        outputs = self.language_model(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        logits = outputs.logits

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

Metadata

Metadata

Labels

awqFor any issue / PR related to AWQ supportbugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions