From 25e8b2ce03c0e44c42f470a2edaee6def2fd03cd Mon Sep 17 00:00:00 2001 From: Leon Seidel Date: Fri, 24 Jan 2025 13:35:17 +0100 Subject: [PATCH 1/4] Add Idefics3/SmolVLM --- .../multimodal_vision/idefics3_example.py | 111 +++++ .../transformers/tracing/__init__.py | 6 +- .../transformers/tracing/idefics3.py | 423 ++++++++++++++++++ 3 files changed, 539 insertions(+), 1 deletion(-) create mode 100644 examples/multimodal_vision/idefics3_example.py create mode 100644 src/llmcompressor/transformers/tracing/idefics3.py diff --git a/examples/multimodal_vision/idefics3_example.py b/examples/multimodal_vision/idefics3_example.py new file mode 100644 index 000000000..c23b8577d --- /dev/null +++ b/examples/multimodal_vision/idefics3_example.py @@ -0,0 +1,111 @@ +import requests +from PIL import Image +from transformers import AutoProcessor +import torch +from datasets import load_dataset +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot +from llmcompressor.transformers.tracing import TraceableIdefics3ForConditionalGeneration + +# Load model. +model_id = "HuggingFaceM4/Idefics3-8B-Llama3" # or "HuggingFaceTB/SmolVLM-Instruct" +model = TraceableIdefics3ForConditionalGeneration.from_pretrained( + model_id, device_map="auto", torch_dtype="auto" +) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +# Oneshot arguments +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = "test[:512]" +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 4096 # Seems to be required here + +# Define a oneshot data collator for multimodal inputs. +def data_collator(batch): + assert len(batch) == 1 + return {key: torch.tensor(value) for key, value in batch[0].items()} + +# Recipe +recipe = [ + GPTQModifier( + targets="Linear", + scheme="W4A16", + sequential_targets=["LlamaDecoderLayer"], + ignore=["re:.*lm_head", "re:model.vision_model.*", "re:model.connector.*"], + ), +] + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + +# Apply chat template +def preprocess(example): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What does the image show?"}, + {"type": "image"}, + ] + } + ] + return { + "text": processor.apply_chat_template( + messages, + add_generation_prompt=True, + ), + "images": example["image"], + } + +ds = ds.map(preprocess) + +# Tokenize inputs. +def tokenize(sample): + return processor( + text=sample["text"], + images=sample["images"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + ) + +# long data lengths produced by the phi3_vision processor +# can lead to integer overflows when mapping, avoid with writer_batch_size +ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names) + +# Perform oneshot +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + data_collator=data_collator, +) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Please describe the animal in this image\n"}, + {"type": "image"}, + ], + }, +] +prompt = processor.apply_chat_template(messages, add_generation_prompt=True) +image_url = "http://images.cocodataset.org/train2017/000000231895.jpg" +raw_image = Image.open(requests.get(image_url, stream=True).raw) + +inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda") +output = model.generate(**inputs, max_new_tokens=100) +print(processor.decode(output[0], skip_special_tokens=True)) +print("==========================================") + +# Save to disk compressed. +SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index fae57dbb1..1e4a341eb 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -7,9 +7,13 @@ from .qwen2_vl import ( Qwen2VLForConditionalGeneration as TraceableQwen2VLForConditionalGeneration, ) +from .idefics3 import ( + Idefics3ForConditionalGeneration as TraceableIdefics3ForConditionalGeneration +) __all__ = [ "TraceableLlavaForConditionalGeneration", "TraceableMllamaForConditionalGeneration", "TraceableQwen2VLForConditionalGeneration", -] + "TraceableIdefics3ForConditionalGeneration" +] \ No newline at end of file diff --git a/src/llmcompressor/transformers/tracing/idefics3.py b/src/llmcompressor/transformers/tracing/idefics3.py new file mode 100644 index 000000000..29add82ea --- /dev/null +++ b/src/llmcompressor/transformers/tracing/idefics3.py @@ -0,0 +1,423 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# vllm-project: no copyright +"""PyTorch Idefics3 model.""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.cache_utils import Cache, DynamicCache +from transformers.utils import logging +# from transformers.models.auto import AutoModel +from transformers.models.idefics3.configuration_idefics3 import Idefics3Config +from transformers.models.idefics3.modeling_idefics3 import ( + Idefics3Model, + Idefics3ForConditionalGeneration, + Idefics3VisionTransformer, + Idefics3Connector, + Idefics3BaseModelOutputWithPast +) +from transformers.models.llama.modeling_llama import ( + LlamaModel, + LlamaAttention, + LlamaDecoderLayer, + LlamaConfig, + LlamaRMSNorm, + LlamaMLP, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward + ) + +from typing import Callable +from transformers.processing_utils import Unpack +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.cache_utils import StaticCache + +logger = logging.get_logger(__name__) + + +# TRACING: cannot condition on mask shape +@torch.fx.wrap +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class LlamaAttention(LlamaAttention): + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + # TRACING: Use input_shape[0], input_shape[1] instead of *input_shape + hidden_shape = (input_shape[0], input_shape[1], -1, self.head_dim) #(*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(input_shape[0], input_shape[1], -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class LlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.hidden_size = config.hidden_size + # TRACING: Use custom LlamaAttention + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class LlamaModel(LlamaModel): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + # TRACING: Use custom LlamaDecoderLayer + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + # TRACING: Use wrapped _prepare_4d_causal_attention_mask_with_cache_position + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class Idefics3Model(Idefics3Model): + def __init__(self, config: Idefics3Config): + super().__init__(config) + self.padding_idx = self.config.text_config.pad_token_id + self.vocab_size = self.config.text_config.vocab_size + + self.vision_model = Idefics3VisionTransformer._from_config(config.vision_config) + self.connector = Idefics3Connector(config) + # TRACING: Use traceable LlamaModel + self.text_model = LlamaModel(config.text_config) # AutoModel.from_config(config.text_config) + + self.image_seq_len = int( + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) + ) + self.image_token_id = self.config.image_token_id + + self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" + + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Idefics3BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.training and self.text_model.gradient_checkpointing and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # retrieve input_ids and inputs_embeds + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_seen_tokens = 0 + if use_cache: + if past_key_values is None: + past_key_values = DynamicCache() + past_seen_tokens = past_key_values.get_seq_length() + + if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: + raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") + + if inputs_embeds is None: + inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device) + + # START VISUAL INPUTS INTEGRATION + if pixel_values is not None and image_hidden_states is not None: + raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") + elif pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility + # TRACING: Use pixel_values.shape[2], pixel_values.shape[3], pixel_values.shape[4] instead of *pixel_values.shape[2:] + pixel_values = pixel_values.view(batch_size * num_images, pixel_values.shape[2], pixel_values.shape[3], pixel_values.shape[4]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask + # TRACING: Use pixel_attention_mask.shape[2], pixel_attention_mask.shape[3] instead of *pixel_attention_mask.shape[2:] + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, pixel_attention_mask.shape[2], pixel_attention_mask.shape[3] + ) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) + patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ).last_hidden_state + + # Modality projection & resampling + image_hidden_states = self.connector(image_hidden_states) + + elif image_hidden_states is not None: + image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) + + if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self.inputs_merger( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_hidden_states=image_hidden_states, + ) + + outputs = self.text_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return tuple(v for v in [*outputs, image_hidden_states] if v is not None) + + return Idefics3BaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_hidden_states, + ) + + +class Idefics3ForConditionalGeneration(Idefics3ForConditionalGeneration): + # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3 + def __init__(self, config): + super().__init__(config) + # TRACING: Use custom Idefics3Model + self.model = Idefics3Model(config) + self.image_token_id = self.config.image_token_id + + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.vocab_size = config.text_config.vocab_size + + # Initialize weights and apply final processing + self.post_init() From 4ffff657730bb392b20d4e318ea50b7c4f9a89ae Mon Sep 17 00:00:00 2001 From: leon-seidel <83984854+leon-seidel@users.noreply.github.com> Date: Fri, 24 Jan 2025 16:32:02 +0100 Subject: [PATCH 2/4] Update src/llmcompressor/transformers/tracing/idefics3.py Co-authored-by: Kyle Sayers --- src/llmcompressor/transformers/tracing/idefics3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llmcompressor/transformers/tracing/idefics3.py b/src/llmcompressor/transformers/tracing/idefics3.py index 29add82ea..8c61ba45e 100644 --- a/src/llmcompressor/transformers/tracing/idefics3.py +++ b/src/llmcompressor/transformers/tracing/idefics3.py @@ -1,3 +1,4 @@ +# flake8: noqa # coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # From c174c25b10ee9489bf6cc7a33365efd402393892 Mon Sep 17 00:00:00 2001 From: Leon Seidel Date: Sat, 25 Jan 2025 00:10:52 +0100 Subject: [PATCH 3/4] Style and Quality Signed-off-by: Leon Seidel --- .../multimodal_vision/idefics3_example.py | 21 ++++++++++++------- .../transformers/tracing/__init__.py | 2 +- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/examples/multimodal_vision/idefics3_example.py b/examples/multimodal_vision/idefics3_example.py index c23b8577d..10402ffb9 100644 --- a/examples/multimodal_vision/idefics3_example.py +++ b/examples/multimodal_vision/idefics3_example.py @@ -1,14 +1,15 @@ import requests -from PIL import Image -from transformers import AutoProcessor import torch from datasets import load_dataset +from PIL import Image +from transformers import AutoProcessor + from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot from llmcompressor.transformers.tracing import TraceableIdefics3ForConditionalGeneration # Load model. -model_id = "HuggingFaceM4/Idefics3-8B-Llama3" # or "HuggingFaceTB/SmolVLM-Instruct" +model_id = "HuggingFaceM4/Idefics3-8B-Llama3" # or "HuggingFaceTB/SmolVLM-Instruct" model = TraceableIdefics3ForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype="auto" ) @@ -18,13 +19,15 @@ DATASET_ID = "lmms-lab/flickr30k" DATASET_SPLIT = "test[:512]" NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 4096 # Seems to be required here +MAX_SEQUENCE_LENGTH = 4096 # Seems to be required here + # Define a oneshot data collator for multimodal inputs. def data_collator(batch): assert len(batch) == 1 return {key: torch.tensor(value) for key, value in batch[0].items()} + # Recipe recipe = [ GPTQModifier( @@ -39,6 +42,7 @@ def data_collator(batch): ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + # Apply chat template def preprocess(example): messages = [ @@ -47,9 +51,9 @@ def preprocess(example): "content": [ {"type": "text", "text": "What does the image show?"}, {"type": "image"}, - ] - } - ] + ], + } + ] return { "text": processor.apply_chat_template( messages, @@ -58,8 +62,10 @@ def preprocess(example): "images": example["image"], } + ds = ds.map(preprocess) + # Tokenize inputs. def tokenize(sample): return processor( @@ -70,6 +76,7 @@ def tokenize(sample): truncation=True, ) + # long data lengths produced by the phi3_vision processor # can lead to integer overflows when mapping, avoid with writer_batch_size ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names) diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index 1e4a341eb..d8df42c93 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -16,4 +16,4 @@ "TraceableMllamaForConditionalGeneration", "TraceableQwen2VLForConditionalGeneration", "TraceableIdefics3ForConditionalGeneration" -] \ No newline at end of file +] From ebc6c0d24306aa3502c83d2f5dee12d09d8b5a48 Mon Sep 17 00:00:00 2001 From: leon-seidel <83984854+leon-seidel@users.noreply.github.com> Date: Sat, 25 Jan 2025 18:46:40 +0100 Subject: [PATCH 4/4] Remove reference to phi3_vision --- examples/multimodal_vision/idefics3_example.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/multimodal_vision/idefics3_example.py b/examples/multimodal_vision/idefics3_example.py index 10402ffb9..2a3934d15 100644 --- a/examples/multimodal_vision/idefics3_example.py +++ b/examples/multimodal_vision/idefics3_example.py @@ -77,8 +77,7 @@ def tokenize(sample): ) -# long data lengths produced by the phi3_vision processor -# can lead to integer overflows when mapping, avoid with writer_batch_size +# avoid errors with writer_batch_size ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names) # Perform oneshot