Skip to content

Commit c174c25

Browse files
committed
Style and Quality
Signed-off-by: Leon Seidel <[email protected]>
1 parent bdf0c18 commit c174c25

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

examples/multimodal_vision/idefics3_example.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import requests
2-
from PIL import Image
3-
from transformers import AutoProcessor
42
import torch
53
from datasets import load_dataset
4+
from PIL import Image
5+
from transformers import AutoProcessor
6+
67
from llmcompressor.modifiers.quantization import GPTQModifier
78
from llmcompressor.transformers import oneshot
89
from llmcompressor.transformers.tracing import TraceableIdefics3ForConditionalGeneration
910

1011
# Load model.
11-
model_id = "HuggingFaceM4/Idefics3-8B-Llama3" # or "HuggingFaceTB/SmolVLM-Instruct"
12+
model_id = "HuggingFaceM4/Idefics3-8B-Llama3" # or "HuggingFaceTB/SmolVLM-Instruct"
1213
model = TraceableIdefics3ForConditionalGeneration.from_pretrained(
1314
model_id, device_map="auto", torch_dtype="auto"
1415
)
@@ -18,13 +19,15 @@
1819
DATASET_ID = "lmms-lab/flickr30k"
1920
DATASET_SPLIT = "test[:512]"
2021
NUM_CALIBRATION_SAMPLES = 512
21-
MAX_SEQUENCE_LENGTH = 4096 # Seems to be required here
22+
MAX_SEQUENCE_LENGTH = 4096 # Seems to be required here
23+
2224

2325
# Define a oneshot data collator for multimodal inputs.
2426
def data_collator(batch):
2527
assert len(batch) == 1
2628
return {key: torch.tensor(value) for key, value in batch[0].items()}
2729

30+
2831
# Recipe
2932
recipe = [
3033
GPTQModifier(
@@ -39,6 +42,7 @@ def data_collator(batch):
3942
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
4043
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
4144

45+
4246
# Apply chat template
4347
def preprocess(example):
4448
messages = [
@@ -47,9 +51,9 @@ def preprocess(example):
4751
"content": [
4852
{"type": "text", "text": "What does the image show?"},
4953
{"type": "image"},
50-
]
51-
}
52-
]
54+
],
55+
}
56+
]
5357
return {
5458
"text": processor.apply_chat_template(
5559
messages,
@@ -58,8 +62,10 @@ def preprocess(example):
5862
"images": example["image"],
5963
}
6064

65+
6166
ds = ds.map(preprocess)
6267

68+
6369
# Tokenize inputs.
6470
def tokenize(sample):
6571
return processor(
@@ -70,6 +76,7 @@ def tokenize(sample):
7076
truncation=True,
7177
)
7278

79+
7380
# long data lengths produced by the phi3_vision processor
7481
# can lead to integer overflows when mapping, avoid with writer_batch_size
7582
ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names)

src/llmcompressor/transformers/tracing/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
"TraceableMllamaForConditionalGeneration",
1717
"TraceableQwen2VLForConditionalGeneration",
1818
"TraceableIdefics3ForConditionalGeneration"
19-
]
19+
]

0 commit comments

Comments
 (0)