1
1
import requests
2
- from PIL import Image
3
- from transformers import AutoProcessor
4
2
import torch
5
3
from datasets import load_dataset
4
+ from PIL import Image
5
+ from transformers import AutoProcessor
6
+
6
7
from llmcompressor .modifiers .quantization import GPTQModifier
7
8
from llmcompressor .transformers import oneshot
8
9
from llmcompressor .transformers .tracing import TraceableIdefics3ForConditionalGeneration
9
10
10
11
# Load model.
11
- model_id = "HuggingFaceM4/Idefics3-8B-Llama3" # or "HuggingFaceTB/SmolVLM-Instruct"
12
+ model_id = "HuggingFaceM4/Idefics3-8B-Llama3" # or "HuggingFaceTB/SmolVLM-Instruct"
12
13
model = TraceableIdefics3ForConditionalGeneration .from_pretrained (
13
14
model_id , device_map = "auto" , torch_dtype = "auto"
14
15
)
18
19
DATASET_ID = "lmms-lab/flickr30k"
19
20
DATASET_SPLIT = "test[:512]"
20
21
NUM_CALIBRATION_SAMPLES = 512
21
- MAX_SEQUENCE_LENGTH = 4096 # Seems to be required here
22
+ MAX_SEQUENCE_LENGTH = 4096 # Seems to be required here
23
+
22
24
23
25
# Define a oneshot data collator for multimodal inputs.
24
26
def data_collator (batch ):
25
27
assert len (batch ) == 1
26
28
return {key : torch .tensor (value ) for key , value in batch [0 ].items ()}
27
29
30
+
28
31
# Recipe
29
32
recipe = [
30
33
GPTQModifier (
@@ -39,6 +42,7 @@ def data_collator(batch):
39
42
ds = load_dataset (DATASET_ID , split = DATASET_SPLIT )
40
43
ds = ds .shuffle (seed = 42 ).select (range (NUM_CALIBRATION_SAMPLES ))
41
44
45
+
42
46
# Apply chat template
43
47
def preprocess (example ):
44
48
messages = [
@@ -47,9 +51,9 @@ def preprocess(example):
47
51
"content" : [
48
52
{"type" : "text" , "text" : "What does the image show?" },
49
53
{"type" : "image" },
50
- ]
51
- }
52
- ]
54
+ ],
55
+ }
56
+ ]
53
57
return {
54
58
"text" : processor .apply_chat_template (
55
59
messages ,
@@ -58,8 +62,10 @@ def preprocess(example):
58
62
"images" : example ["image" ],
59
63
}
60
64
65
+
61
66
ds = ds .map (preprocess )
62
67
68
+
63
69
# Tokenize inputs.
64
70
def tokenize (sample ):
65
71
return processor (
@@ -70,6 +76,7 @@ def tokenize(sample):
70
76
truncation = True ,
71
77
)
72
78
79
+
73
80
# long data lengths produced by the phi3_vision processor
74
81
# can lead to integer overflows when mapping, avoid with writer_batch_size
75
82
ds = ds .map (tokenize , writer_batch_size = 1 , remove_columns = ds .column_names )
0 commit comments