Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MODEL] Support OpenVLA #1350

Open
billamiable opened this issue Feb 26, 2025 · 5 comments
Open

[MODEL] Support OpenVLA #1350

billamiable opened this issue Feb 26, 2025 · 5 comments

Comments

@billamiable
Copy link

billamiable commented Feb 26, 2025

Hi, I would like to try GPTQ on OpenVLA. I found that most of the examples are using purely language input (i.e., https://github.com/ModelCloud/GPTQModel/blob/main/examples/quantization/transformers_usage.py). Do we support model with visual and language inputs?

I tried with following scripts to quantize:

import torch
from transformers import AutoModelForVision2Seq, AutoProcessor, GPTQConfig
model_id = "./openvla-7b"
load_dtype = torch.bfloat16
dataset = ["gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, local_files_only=True, use_fast=True)
gptq_config = GPTQConfig(bits=4, dataset=dataset, tokenizer=processor.tokenizer, block_name_to_quantize="language_model.model.layers") # If not specify block_name_to_quantize, it will crash
quantized_model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=load_dtype, trust_remote_code=True,
                                                         local_files_only=True, device_map="cpu", quantization_config=gptq_config)
quantized_model.save_pretrained("./openvla-7b-gptq-bf16-language-model-layer-wi-tokenizer")
processor.save_pretrained("./openvla-7b-gptq-bf16-language-model-layer-wi-tokenizer")

Then I tried to inference using the quantized model on Intel A770 with torch 2.5.1 using the following script:

import requests
from PIL import Image
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor, GPTQConfig
import time
    
local_model_path = "./openvla-7b-gptq-bf16-language-model-layer-wi-tokenizer"
device = 'xpu'
processor = AutoProcessor.from_pretrained(local_model_path, trust_remote_code=True, local_files_only=True,
                                            use_fast=True)
dataset = ["gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
gptq_config = GPTQConfig(bits=4, dataset=dataset, tokenizer=processor.tokenizer, block_name_to_quantize="language_model.model.layers")
vla = AutoModelForVision2Seq.from_pretrained(
    local_model_path,
    quantization_config=gptq_config,
    torch_dtype="auto",
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    local_files_only=True,
).to(device)

SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
instruction = 'pick up cup'
prompt = f"{SYSTEM_PROMPT} USER: What action should the robot take to {instruction.lower()}?"
image_id = 30
image_url = f"https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_{image_id}.jpg"
image = Image.open(requests.get(image_url, stream=True).raw).resize((256, 256))


total_time = 0
num_iterations = 10
for i in range(num_iterations):
    inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)
    start_time = time.time()
    action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
    end_time = time.time()
    elapsed_time = end_time - start_time
    total_time += elapsed_time
    print(f"step {i}:{elapsed_time}s")

average_time = total_time / num_iterations
print(f"im_{image_id}.jpg | Final action: {action} | Mean time: {average_time:.2f}s")

but I got following error:

2025-02-26 15:35:19.426445: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-26 15:35:19.427699: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-26 15:35:19.445021: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-26 15:35:19.445036: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-26 15:35:19.445645: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-26 15:35:19.448647: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-26 15:35:19.448779: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-02-26 15:35:19.791722: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/usr/local/lib/python3.10/dist-packages/torchvision/io/image.py:14: UserWarning: Failed to load image Python extension: 'libpng16.so.16: cannot open shared object file: No such file or directory'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
fast_tokenizer_file ./openvla-7b-gptq-bf16-language-model-layer-wi-tokenizer/tokenizer.json
Traceback (most recent call last):
  File "/home/flex-robot-2/openvla-x86/openvla/test/test_gptq_inference.py", line 9, in <module>
    processor = AutoProcessor.from_pretrained(local_model_path, trust_remote_code=True, local_files_only=True,
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/processing_auto.py", line 310, in from_pretrained
    return processor_class.from_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/transformers/processing_utils.py", line 465, in from_pretrained
    args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/processing_utils.py", line 511, in _get_arguments_from_pretrained
    args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/tokenization_auto.py", line 862, in from_pretrained
    return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 2089, in from_pretrained
    return cls._from_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 2311, in _from_pretrained
    tokenizer = cls(*init_inputs, **init_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/tokenization_llama_fast.py", line 124, in __init__
    super().__init__(
  File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_fast.py", line 112, in __init__
    fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
Exception: data did not match any variant of untagged enum ModelWrapper at line 277147 column 3

It seems to be caused by unexpected tokenizer input. Any idea how to fix this? Thanks!

@Qubitium
Copy link
Collaborator

Qubitium commented Feb 26, 2025

@billamiable We support image quantization for Qwen2-VL and Ovis-VL models.

Please check our tests/models folder to see how we test and quantize the VL models in question.

https://github.com/ModelCloud/GPTQModel/blob/main/tests/models/test_qwen2_vl.py

Other VL models require PR, if you are willing to help add the support please submit a PR and we can add the support for others to use. VL model require special processing due to separate tokenization requirement for image pixels.

Also, VL model quantization is only supported when using GPTQModel apis directly and not through transformers quantization api which is only meant for very very simple quantization.

@billamiable
Copy link
Author

Thanks for the timely response! By saying "very very simple quantization" using transformers quantization, does it also quantize the model?

Update from my side: I tried to delete all generated tokenizer related files and directly copy those tokenizer related files from the original unquantized model (see files below).

added_tokens.json
special_tokens_map.json
tokenizer_config.json
tokenizer.json
tokenizer.model

Surprisingly, it is able to run inference. However, the speed is much slower than before. Before, the average speed is ~1.8s/token, however, now it takes ~5.1s/token. I think it might be caused by only quantizing the layers inside llama because I defined block_name_to_quantize="language_model.model.layers". Any suggestion on this?

@Qubitium
Copy link
Collaborator

Qubitium commented Feb 27, 2025

@billamiable You need to provide us with code-snippits on how you quantize and inference your model, inclduing any changes you made to GPTQModel. I do not have enough information to answer your question since your model, code, env is different from ours. I noticed you work for Intel. If this is an internal project with Intel, we can work together to get OpenVLA officialy added. We have active working relationship with Intel AI team in (Shanghai) so working with BeiJing branch just completes the circle. =)

@billamiable
Copy link
Author

I've attached the code above, both for quantization and inference. I made no change to GPTQModel.
Sure, definitely interested in working with you to close the loop.

@Qubitium
Copy link
Collaborator

@billamiable I will reach out to you via Teams from [email protected] handle.

@Qubitium Qubitium changed the title Fail on inference with Visual-Language inputs [MODEL] Support OpenVLA Feb 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants