diff --git a/examples/transformers/internvl/generate.py b/examples/transformers/internvl/generate.py new file mode 100644 index 0000000000..184166d27d --- /dev/null +++ b/examples/transformers/internvl/generate.py @@ -0,0 +1,65 @@ +import time + +from PIL import Image +from transformers import GotOcr2ImageProcessor, InternVLProcessor + +import mindspore as ms + +from mindone.transformers import InternVLForConditionalGeneration + +MODEL_HUB = "OpenGVLab/InternVL3-1B-hf" +image = "demo.jpeg" + +# Load processor +start = time.time() +processor = InternVLProcessor.from_pretrained(MODEL_HUB) +# GotOcr2ImageProcessorFast does not support return_tensors="np", use GotOcr2ImageProcessor instead +image_processor = GotOcr2ImageProcessor.from_pretrained(MODEL_HUB) +processor.image_processor = image_processor +print(f"Loaded InternVLProcessor in {time.time()-start:.4f}s") + +# Load model with bfloat16 and eager attention +start = time.time() +model = InternVLForConditionalGeneration.from_pretrained( + MODEL_HUB, + mindspore_dtype=ms.bfloat16, + attn_implementation="eager", +) +print(f"Loaded model in {time.time()-start:.4f}s") + +# load image +image = Image.open(image) + +messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": image, + }, + {"type": "text", "text": "Describe this image."}, + ], + } +] +prompt = processor.apply_chat_template(messages, add_generation_prompt=True) + +# Tokenize + encode +inputs = processor(text=prompt, images=[image], return_tensors="np") + +for k, v in inputs.items(): + tensor = ms.Tensor(v) + if tensor.dtype == ms.int64: + tensor = tensor.astype(ms.int32) + else: + tensor = tensor.astype(model.dtype) + inputs[k] = tensor + +# Generate +start = time.time() +generated_ids = model.generate(**inputs, max_new_tokens=500) +print(f"Inference in {time.time()-start:.4f}s") + +# Decode +texts = processor.batch_decode(generated_ids, skip_special_tokens=True) +print(texts) diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index dbe07851b9..1a35c3b207 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -693,6 +693,15 @@ Glm4PreTrainedModel, ) +if version.parse(transformers.__version__) >= version.parse("4.52.0"): + from .models.internvl import ( + InternVLForConditionalGeneration, + InternVLModel, + InternVLPreTrainedModel, + InternVLVisionModel, + InternVLVisionPreTrainedModel, + ) + if version.parse(transformers.__version__) >= version.parse("4.53.0"): from .models.glm4v import ( Glm4vForConditionalGeneration, diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index 8731ed7a42..38b83aa0f2 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -651,6 +651,8 @@ class PreTrainedModel(nn.Cell, ModuleUtilsMixin, GenerationMixin, PushToHubMixin main_input_name = "input_ids" model_tags = None + _checkpoint_conversion_mapping = {} + _auto_class = None _no_split_modules = None _skip_keys_device_placement = None @@ -1920,6 +1922,7 @@ def from_pretrained( use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) adapter_kwargs = kwargs.pop("adapter_kwargs", {}) adapter_name = kwargs.pop("adapter_name", "default") + key_mapping = kwargs.pop("key_mapping", None) if use_auth_token is not None: warnings.warn( @@ -2391,6 +2394,11 @@ def from_pretrained( sharded_metadata=sharded_metadata, dtype=mindspore_dtype, keep_in_fp32_modules=keep_in_fp32_modules, + key_mapping=( + key_mapping + if key_mapping is not None + else (getattr(cls, "_checkpoint_conversion_mapping", None) or None) + ), ) if _adapter_model_path is not None: @@ -2442,6 +2450,62 @@ def from_pretrained( return model + @staticmethod + def _fix_state_dict_key_on_load(key: str) -> tuple[str, bool]: + return key, False + + def _get_key_renaming_mapping( # NEW (HF parity) + self, + checkpoint_keys: list[str], + key_mapping: dict[str, str] | None = None, + loading_base_model_from_task_state_dict: bool = False, + loading_task_model_from_base_state_dict: bool = False, + ) -> dict[str, str]: + prefix = self.base_model_prefix + _prefix = f"{prefix}." + + renamed_keys_for_log: dict[str, tuple[str, str]] = {} + key_renaming_mapping: dict[str, str] = {} + + for key in checkpoint_keys: + new_key, has_changed = self._fix_state_dict_key_on_load(key) + + # 2) optional regex key mapping + if key_mapping is not None: + for pattern, replacement in key_mapping.items(): + updated_key, n_replace = re.subn(pattern, replacement, new_key) + if n_replace > 0: + has_changed = True + new_key = updated_key + + if loading_task_model_from_base_state_dict: + new_key = ".".join([prefix, new_key]) + elif loading_base_model_from_task_state_dict: + if not new_key.startswith(_prefix): + continue + new_key = new_key[len(_prefix) :] + + key_renaming_mapping[key] = new_key + + if has_changed: + if key.endswith("LayerNorm.gamma"): + renamed_keys_for_log["LayerNorm.gamma"] = (key, new_key) + elif key.endswith("LayerNorm.beta"): + renamed_keys_for_log["LayerNorm.beta"] = (key, new_key) + + if renamed_keys_for_log: + msg = ( + f"A pretrained model of type `{self.__class__.__name__}` " + "contains parameters that have been renamed internally (a few are listed):\n" + ) + for old_key, new_key in renamed_keys_for_log.values(): + msg += f"* `{old_key}` -> `{new_key}`\n" + # optional: encourage upstream PRs as HF does + msg += "If you loaded from the Hub, consider submitting a PR to adjust these weights." + logger.info(msg) + + return key_renaming_mapping + @classmethod def _load_pretrained_model( cls, @@ -2454,6 +2518,7 @@ def _load_pretrained_model( sharded_metadata=None, dtype=None, keep_in_fp32_modules=None, + key_mapping: dict[str, str] | None = None, ): model.tie_weights() @@ -2470,6 +2535,17 @@ def _load_pretrained_model( has_prefix_module = False expects_prefix_module = False + loading_task_model_from_base_state_dict = (not has_prefix_module) and expects_prefix_module + loading_base_model_from_task_state_dict = has_prefix_module and (not expects_prefix_module) + + key_renaming_mapping = model._get_key_renaming_mapping( + original_loaded_keys, + key_mapping=key_mapping, + loading_base_model_from_task_state_dict=loading_base_model_from_task_state_dict, + loading_task_model_from_base_state_dict=loading_task_model_from_base_state_dict, + ) + loaded_keys = list(key_renaming_mapping.values()) + # Mapping loaded_keys from pt to ms pt2ms_mappings = _get_pt2ms_mappings(model) loaded_keys = _get_pt2ms_mapped_k(pt2ms_mappings, has_prefix_module, expects_prefix_module, loaded_keys, prefix) @@ -2563,6 +2639,9 @@ def _find_mismatched_keys( # Whole checkpoint state_dict = _convert_state_dict(model, state_dict, prefix) + if key_renaming_mapping: + state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} + mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, @@ -2590,6 +2669,11 @@ def _find_mismatched_keys( state_dict = load_state_dict(shard_file) state_dict = _convert_state_dict(model, state_dict, prefix) + if key_renaming_mapping: + state_dict = { + key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping + } + # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys += _find_mismatched_keys( diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py index cf5d62cda8..185a9fbcab 100644 --- a/mindone/transformers/models/__init__.py +++ b/mindone/transformers/models/__init__.py @@ -106,5 +106,8 @@ if version.parse(transformers.__version__) >= version.parse("4.51.3"): from . import glm4 +if version.parse(transformers.__version__) >= version.parse("4.52.0"): + from . import internvl + if version.parse(transformers.__version__) >= version.parse("4.53.0"): from . import glm4v, minimax, vjepa2 diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py index 547533e126..93382f89ae 100644 --- a/mindone/transformers/models/auto/configuration_auto.py +++ b/mindone/transformers/models/auto/configuration_auto.py @@ -298,6 +298,13 @@ CONFIG_MAPPING_NAMES.update({"glm4": "Glm4Config"}) MODEL_NAMES_MAPPING.update({"glm4": "glm4"}) +if version.parse(transformers.__version__) >= version.parse("4.52.0"): + CONFIG_MAPPING_NAMES.update({"internvl": "InternVLConfig"}) + CONFIG_MAPPING_NAMES.update({"internvl_vision": "InternVLVisionConfig"}) + MODEL_NAMES_MAPPING.update({"internvl": "InternVLModel"}) # TODO: InternVL + MODEL_NAMES_MAPPING.update({"internvl_vision": "InternVLVisionModel"}) # TODO: InternVLVision + SPECIAL_MODEL_TYPE_TO_MODULE_NAME.update({"internvl_vision": "internvl"}) + if version.parse(transformers.__version__) >= version.parse("4.53.0"): CONFIG_MAPPING_NAMES.update({"minimax": "MiniMaxConfig", "vjepa2": "VJEPA2Model"}) MODEL_NAMES_MAPPING.update({"minimax": "MiniMax", "vjepa2": "VJEPA2Model"}) diff --git a/mindone/transformers/models/auto/modeling_auto.py b/mindone/transformers/models/auto/modeling_auto.py index e3adfbad03..692136f45a 100644 --- a/mindone/transformers/models/auto/modeling_auto.py +++ b/mindone/transformers/models/auto/modeling_auto.py @@ -680,6 +680,11 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.update({"glm4": "Glm4ForSequenceClassification"}) MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES.update({"glm4": "Glm4ForTokenClassification"}) +if version.parse(transformers.__version__) >= version.parse("4.52.0"): + MODEL_MAPPING_NAMES.update({"internvl": "InternVLModel"}) + MODEL_MAPPING_NAMES.update({"internvl_vision": "InternVLVisionModel"}) + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.update({"internvl": "InternVLForConditionalGeneration"}) + if version.parse(transformers.__version__) >= version.parse("4.53.0"): MODEL_MAPPING_NAMES.update({"minimax": "MiniMaxModel", "vjepa2": "VJEPA2Model"}) MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update({"minimax": "MiniMaxForCausalLM"}) diff --git a/mindone/transformers/models/internvl/__init__.py b/mindone/transformers/models/internvl/__init__.py new file mode 100644 index 0000000000..de640c080b --- /dev/null +++ b/mindone/transformers/models/internvl/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_internvl import ( + InternVLForConditionalGeneration, + InternVLModel, + InternVLPreTrainedModel, + InternVLVisionModel, + InternVLVisionPreTrainedModel, +) diff --git a/mindone/transformers/models/internvl/modeling_internvl.py b/mindone/transformers/models/internvl/modeling_internvl.py new file mode 100644 index 0000000000..c1bae816a3 --- /dev/null +++ b/mindone/transformers/models/internvl/modeling_internvl.py @@ -0,0 +1,1016 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/internvl/modular_internvl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_internvl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections.abc +from dataclasses import dataclass +from typing import Callable, Optional, Union + +from transformers import InternVLConfig, InternVLVisionConfig +from transformers.utils import ModelOutput, auto_docstring, can_return_tuple + +import mindspore as ms +import mindspore.mint.nn.functional as F +from mindspore import mint, nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, MSPreTrainedModel +from ...processing_utils import Unpack +from ..qwen2 import Qwen2Model + + +class InternVLVisionRMSNorm(nn.Cell): + def __init__(self, hidden_size, eps=1e-6): + """ + InternVLVisionRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = ms.Parameter(mint.ones(hidden_size)) + self.variance_epsilon = eps + + def construct(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(ms.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * mint.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def eager_attention_forward( + module: nn.Cell, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attention_mask: Optional[ms.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = key + value_states = value + + attn_weights = mint.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # No upcasting of the attention weights to float32 in this implementation + attn_weights = F.softmax(attn_weights, dim=-1) + attn_weights = F.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mint.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class InternVLVisionAttention(nn.Cell): + """Attention Class for InternVL Vision Encoder""" + + def __init__(self, config: InternVLVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + proj_dropout = config.projection_dropout + qk_norm = config.use_qk_norm + + # Needed for flash attention + self.is_causal = False + + self.q_proj = mint.nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = mint.nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = mint.nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) + self.projection_layer = mint.nn.Linear(self.embed_dim, self.embed_dim) + self.projection_dropout = mint.nn.Dropout(proj_dropout) if proj_dropout > 0 else mint.nn.Identity() + + self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else mint.nn.Identity() + self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else mint.nn.Identity() + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + output_attentions: Optional[ms.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ): + batch_size, seq_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scale, + is_causal=False, + **kwargs, + ) + attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim) + + output = self.projection_layer(attn_output) + output = self.projection_dropout(output) + + outputs = (output, attn_weights) if output_attentions else (output, None) + return outputs + + +@auto_docstring +class InternVLVisionPreTrainedModel(MSPreTrainedModel): + config_class = InternVLVisionConfig + base_model_prefix = "internvl_vision" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["InternVLVisionLayer"] + _supports_sdpa = True + _supports_flash_attn = True + _supports_flex_attn = True + _supports_attention_backend = True + + def _init_weights(self, module): + """Initialize the weights""" + super()._init_weights(module) + if isinstance(module, InternVLVisionEmbeddings): + module.cls_token.data.zero_() + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, InternVLVisionLayer): + module.lambda_1.data.fill_(self.config.layer_scale_init_value) + module.lambda_2.data.fill_(self.config.layer_scale_init_value) + + +@dataclass +@auto_docstring( + custom_intro=""" + Class for outputs of [`InternVLVisionModel`]. + """ +) +class InternVLVisionModelOutputWithPooling(BaseModelOutputWithPooling): + r""" + pooler_output (`ms.Tensor` of shape `(batch_size, hidden_size)`): + Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if + *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token + will be returned. + """ + + +class InternVLVisionPatchEmbeddings(nn.Cell): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, tuple(config.patch_size) + num_channels, hidden_size = config.num_channels, config.hidden_size + + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.patch_shape = patch_shape + + self.projection = mint.nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def construct(self, pixel_values: ms.Tensor) -> ms.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + + embeddings = self.projection(pixel_values) + patch_height, patch_width = embeddings.shape[2], embeddings.shape[3] + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, (patch_height, patch_width) + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +class InternVLVisionEmbeddings(nn.Cell): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + + """ + + def __init__(self, config: InternVLVisionConfig) -> None: + super().__init__() + self.cls_token = ms.Parameter(mint.zeros((1, 1, config.hidden_size))) + if config.use_mask_token: + self.mask_token = ms.Parameter(mint.zeros((1, 1, config.hidden_size))) + else: + self.mask_token = None + self.patch_embeddings = InternVLVisionPatchEmbeddings(config) + self.patch_size = config.patch_size + self.image_size = ( + config.image_size + if isinstance(config.image_size, collections.abc.Iterable) + else (config.image_size, config.image_size) + ) + num_patches = self.patch_embeddings.num_patches + if config.use_absolute_position_embeddings: + self.position_embeddings = ms.Parameter(mint.zeros((1, num_patches + 1, config.hidden_size))) + else: + self.position_embeddings = None + self.dropout = mint.nn.Dropout(config.hidden_dropout_prob) + + def interpolate_pos_encoding(self, embeddings: ms.Tensor, height: int, width: int) -> ms.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size[0] + new_width = width // self.patch_size[1] + + sqrt_num_positions = int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = F.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return mint.cat((class_pos_embed, patch_pos_embed), dim=1) + + def construct( + self, + pixel_values: ms.Tensor, + bool_masked_pos: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + _, _, height, width = pixel_values.shape + embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values) + batch_size, seq_len, _ = embeddings.shape + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand((batch_size, seq_len, -1)) + # replace the masked visual tokens by mask_tokens + w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1 - w) + mask_tokens * w + + cls_tokens = self.cls_token.expand((batch_size, -1, -1)) + embeddings = mint.cat((cls_tokens, embeddings), dim=1) + + if self.position_embeddings is not None: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + embeddings = self.dropout(embeddings) + + return embeddings, (patch_height, patch_width) + + +class InternVLVisionMLP(nn.Cell): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = mint.nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = mint.nn.Linear(config.intermediate_size, config.hidden_size) + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +NORM2FN = {"layer_norm": mint.nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm} + + +class InternVLVisionLayer(nn.Cell): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: InternVLVisionConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = InternVLVisionAttention(config) + self.mlp = InternVLVisionMLP(config) + # InternVL uses different layernorm implementations for different models + self.layernorm_before = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = NORM2FN[config.norm_type](config.hidden_size, eps=config.layer_norm_eps) + + init_values = config.layer_scale_init_value + self.lambda_1 = ms.Parameter(init_values * mint.ones(config.hidden_size), requires_grad=True) + self.lambda_2 = ms.Parameter(init_values * mint.ones(config.hidden_size), requires_grad=True) + self.dropout = mint.nn.Dropout(config.hidden_dropout_prob) + + def construct( + self, + hidden_states: ms.Tensor, + output_attentions: bool = False, + ) -> Union[tuple[ms.Tensor], tuple[ms.Tensor, ms.Tensor]]: + attention_output, attention_weights = self.attention( + self.layernorm_before(hidden_states), # in InternVLVision, layernorm is applied before self-attention + output_attentions=output_attentions, + ) + + attention_output = self.lambda_1 * attention_output + + # first residual connection + hidden_states = attention_output + hidden_states + + # in InternVLVision, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + + layer_output = self.mlp(layer_output) + layer_output = self.dropout(layer_output) + + if self.lambda_2 is not None: + layer_output = self.lambda_2 * layer_output + + # second residual connection + layer_output = layer_output + hidden_states + + return layer_output, attention_weights + + +class InternVLVisionEncoder(nn.Cell): + def __init__(self, config: InternVLVisionConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.CellList([InternVLVisionLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def construct( + self, + hidden_states: ms.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module(hidden_states, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring +class InternVLVisionModel(InternVLVisionPreTrainedModel): + def __init__(self, config: InternVLVisionConfig) -> None: + super().__init__(config) + self.config = config + + self.embeddings = InternVLVisionEmbeddings(config) + self.encoder = InternVLVisionEncoder(config) + + self.layernorm = ( + mint.nn.Identity() + if config.use_mean_pooling + else mint.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + @can_return_tuple + @auto_docstring + def construct( + self, + pixel_values: ms.Tensor, + bool_masked_pos: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Union[tuple, InternVLVisionModelOutputWithPooling]: + r""" + bool_masked_pos (`ms.Tensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + + encoder_outputs = self.encoder( + embedding_output, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + return InternVLVisionModelOutputWithPooling( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class InternVLPreTrainedModel(MSPreTrainedModel): + config_class = InternVLConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_flex_attn = True + _supports_attention_backend = True + + +class InternVLMultiModalProjector(nn.Cell): + def __init__(self, config: InternVLConfig): + super().__init__() + self.layer_norm = mint.nn.LayerNorm(config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2) + self.linear_1 = mint.nn.Linear( + config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2, config.text_config.hidden_size + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = mint.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size) + + def construct(self, image_features): + hidden_states = self.layer_norm(image_features) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for InternVL outputs, with hidden states and attentions. + """ +) +class InternVLModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`ms.Tensor`, *optional*): + A `ms.Tensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[ms.Tensor] = None + + +@auto_docstring( + custom_intro=""" + The InternVL model which consists of a vision backbone and a language model, without a language modeling head. + """ +) +class InternVLModel(InternVLPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + + def __init__(self, config: InternVLConfig): + super().__init__(config) + self.vision_tower = InternVLVisionModel(config.vision_config) + self.multi_modal_projector = InternVLMultiModalProjector(config) + self.language_model = Qwen2Model(config.text_config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_image_features( + self, + pixel_values: ms.Tensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`ms.Tensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layer (`int` or `list[int]`): + Layer index or list of layer indices to extract features from. + Returns: + vision_features (`ms.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`. + """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + downsample_ratio = self.config.downsample_ratio + if vision_feature_layer == -1: + vision_features = self.vision_tower(pixel_values=pixel_values).last_hidden_state + else: + vision_features = self.vision_tower(pixel_values=pixel_values).hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + vision_features = vision_features[:, 1:, :] + + # Calculate dimensions based on vision features + channels = vision_features.shape[1] + feature_size = int(channels**0.5) + batch_size = vision_features.shape[0] + + # Reshape tensor to spatial dimensions + vision_features = vision_features.reshape(batch_size, feature_size, feature_size, -1) + + # Apply downsampling using pixel shuffle + vision_features = self.pixel_shuffle(vision_features, scale_factor=downsample_ratio) + + # Reshape tensor to prepare for projection + vision_features = vision_features.reshape(batch_size, -1, vision_features.shape[-1]) + + # Project features through multi-modal projector + vision_features = self.multi_modal_projector(vision_features) + return vision_features + + @auto_docstring + def construct( + self, + input_ids: ms.Tensor = None, + pixel_values: ms.Tensor = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[ms.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, InternVLModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + ms.Tensor(self.config.image_token_id, dtype=ms.int64) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds) + + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.dtype) + # masked_scatter does not support bfloat16 + inputs_embeds_dtype = inputs_embeds.dtype + inputs_embeds = inputs_embeds.to(ms.float32) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features.to(ms.float32)) + inputs_embeds = inputs_embeds.to(inputs_embeds_dtype) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return InternVLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + def pixel_shuffle(self, vision_features: ms.Tensor, scale_factor: float = 0.5): + """Perform pixel shuffle downsampling on vision features. + + Args: + vision_features (`ms.Tensor`): + Input tensor of shape (batch_size, width, height, channels). + scale_factor (`float`, *optional*, defaults to `0.5`): + Factor by which to downsample. Default is 0.5, which halves the dimensions. + + Returns: + vision_features (`ms.Tensor`): + Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)). + """ + batch_size, width, height, channels = vision_features.shape + + if height % scale_factor != 0 or width % scale_factor != 0: + raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.") + + # Reshape to allow downsampling + vision_features = vision_features.view( + batch_size, width, int(height * scale_factor), int(channels / scale_factor) + ) + # Permute dimensions to align downsampled axis correctly + vision_features = vision_features.permute(0, 2, 1, 3).contiguous() + + # Reshape to achieve final downsampled dimensions + vision_features = vision_features.view( + batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2)) + ) + + # Swap height and width back for proper orientation + vision_features = vision_features.permute(0, 2, 1, 3).contiguous() + + return vision_features + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for InternVL causal language model (or autoregressive) outputs. + """ +) +class InternVLCausalLMOutputWithPast(ModelOutput): + r""" + loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`ms.Tensor`, *optional*): + A `ms.Tensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[ms.Tensor] = None + logits: Optional[ms.Tensor] = None + past_key_values: Optional[list[ms.Tensor]] = None + hidden_states: Optional[tuple[ms.Tensor]] = None + attentions: Optional[tuple[ms.Tensor]] = None + image_hidden_states: Optional[ms.Tensor] = None + + +@auto_docstring( + custom_intro=""" + The INTERNVL model which consists of a vision backbone and a language model. + """ +) +class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "embed_tokens.weight$": "embed_tokens.embedding_table", + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: InternVLConfig): + super().__init__(config) + self.model = InternVLModel(config) + self.lm_head = mint.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Cell: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features( + self, + pixel_values: ms.Tensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + **kwargs, + ): + return self.model.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + **kwargs, + ) + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @auto_docstring + def construct( + self, + input_ids: ms.Tensor = None, + pixel_values: ms.Tensor = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[ms.Tensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[ms.Tensor] = None, + logits_to_keep: Optional[Union[int, ms.Tensor]] = None, + image_sizes: Optional[ms.Tensor] = None, + ) -> Union[tuple, InternVLCausalLMOutputWithPast]: + r""" + Example: + + ```python + >>> import mindspore as ms + >>> from transformers import InternVLProcessor, GotOcr2ImageProcessor + >>> from mindone.transformers import InternVLForConditionalGeneration + >>> from PIL import Image + >>> + >>> MODEL_HUB = "OpenGVLab/InternVL3-1B-hf" + >>> image_path = "demo.jpeg" + >>> + >>> processor = InternVLProcessor.from_pretrained(MODEL_HUB) + >>> # GotOcr2ImageProcessorFast does not support return_tensors="np"; use GotOcr2ImageProcessor instead + >>> image_processor = GotOcr2ImageProcessor.from_pretrained(MODEL_HUB) + >>> processor.image_processor = image_processor + >>> + >>> model = InternVLForConditionalGeneration.from_pretrained( + ... MODEL_HUB, + ... mindspore_dtype=ms.bfloat16, + ... attn_implementation="eager", + ... ) + >>> + >>> image = Image.open(image_path) + >>> + >>> messages = [ + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image", "image": image}, + ... {"type": "text", "text": "Describe this image."}, + ... ], + ... }, + ... ] + >>> + >>> prompt = processor.apply_chat_template(messages, add_generation_prompt=True) + >>> + >>> inputs = processor(text=prompt, images=[image], return_tensors="np") + >>> + >>> for k, v in list(inputs.items()): + ... t = ms.Tensor(v) + ... if t.dtype == ms.int64: + ... t = t.astype(ms.int32) + ... else: + ... t = t.astype(model.dtype) + ... inputs[k] = t + >>> start = time.time() + >>> generated_ids = model.generate(**inputs, max_new_tokens=500) + >>> + >>> # Decode + >>> texts = processor.batch_decode(generated_ids, skip_special_tokens=True) + >>> print(texts) + A woman and a dog are sitting on a beach during sunset. + The woman is smiling and giving a treat to the dog, which is wearing a harness. + The dog is sitting patiently, and the ocean waves are visible in the background. + The scene is warm and serene. + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + image_sizes=image_sizes, + ) + + hidden_states = outputs[0] + if logits_to_keep is None: + logits_to_keep = 1 + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return InternVLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + return model_inputs + + +__all__ = [ + "InternVLVisionPreTrainedModel", + "InternVLVisionModel", + "InternVLPreTrainedModel", + "InternVLModel", + "InternVLForConditionalGeneration", +] diff --git a/tests/transformers_tests/models/internvl/__init__.py b/tests/transformers_tests/models/internvl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/internvl/test_modeling_internvl.py b/tests/transformers_tests/models/internvl/test_modeling_internvl.py new file mode 100644 index 0000000000..55098e5e22 --- /dev/null +++ b/tests/transformers_tests/models/internvl/test_modeling_internvl.py @@ -0,0 +1,236 @@ +import inspect + +import numpy as np +import pytest +import torch +from transformers import InternVLConfig, InternVLVisionConfig, Qwen2Config + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import ids_numpy + +DTYPE_AND_THRESHOLDS = {"fp32": 5e-2, "fp16": 5e-2, "bf16": 5e-2} +MODES = [1] + + +class InternVLModelTester: + def __init__( + self, + batch_size=1, + seq_length=7, + # common + is_training=False, + use_attention_mask=True, + use_cache=False, + output_attentions=False, + # text model + vocab_size=99, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=32, + hidden_act="silu", + max_position_embeddings=512, + # vision model + image_size=(32, 32), # 32x32 with 16x16 patches -> 2x2 patches -> 4 tokens + patch_size=(16, 16), # removing CLS => 4 -> reshape (2,2) + downsample_ratio=0.5, # pixel_shuffle(0.5) => (2,2) -> (1,1) and channels x4 => exactly 1 image feature vector + # run-time impl + attn_implementation="eager", + torch_dtype="float32", + image_token_id=5, + ): + self.batch_size = batch_size + self.seq_length = seq_length + + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_cache = use_cache + self.output_attentions = output_attentions + + self.vocab_size = vocab_size + + # shared dims + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + + # vision dims + self.image_size = image_size + self.patch_size = patch_size + self.downsample_ratio = downsample_ratio + + # impl & dtype + self.attn_implementation = attn_implementation + self.torch_dtype = torch_dtype + self.image_token_id = image_token_id + + def get_config(self): + text_config = Qwen2Config( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + max_position_embeddings=self.max_position_embeddings, + use_cache=self.use_cache, + attn_implementation=self.attn_implementation, + torch_dtype=self.torch_dtype, + ) + + vision_config = InternVLVisionConfig( + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_channels=3, + image_size=list(self.image_size), + patch_size=list(self.patch_size), + attn_implementation=self.attn_implementation, + torch_dtype=self.torch_dtype, + ) + + config = InternVLConfig( + use_cache=self.use_cache, + vision_config=vision_config, + text_config=text_config, + image_token_id=self.image_token_id, + attn_implementation=self.attn_implementation, + downsample_ratio=self.downsample_ratio, + torch_dtype=self.torch_dtype, + ) + return config + + def prepare_config_and_inputs(self): + config = self.get_config() + input_ids = ids_numpy([self.batch_size, self.seq_length], self.vocab_size) + + # place exactly one token per sample so it matches the 1 image feature vector produced + # (with config above, image_features per sample is 1) + image_pos = self.seq_length // 2 + input_ids[:, image_pos] = config.image_token_id + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_numpy([self.batch_size, self.seq_length], vocab_size=2) + + # pixel_values in [-1, 1] (consistency with your Idefics3) + num_channels, height, width = 3, self.image_size[0], self.image_size[1] + pixel_values = ids_numpy([self.batch_size, num_channels, height, width], vocab_size=256) + pixel_values = (pixel_values.astype(np.float32) / 255.0) * 2.0 - 1.0 + + return (config, input_ids, attention_mask, pixel_values) + + +model_tester = InternVLModelTester() +config, input_ids, attention_mask, pixel_values = model_tester.prepare_config_and_inputs() + + +TEST_CASES = [ + [ # text Q&A + "InternVLModel", + "transformers.InternVLModel", + "mindone.transformers.InternVLModel", + (config,), + {}, + (), + { + "input_ids": input_ids, + "attention_mask": attention_mask, + }, + { + "last_hidden_state": 0, # text_model, i.e., Qwen2Model + }, + ], + [ # VQA (multimodal) + "InternVLModel", + "transformers.InternVLModel", + "mindone.transformers.InternVLModel", + (config,), + {}, + (), + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, # (B, C, H, W) for InternVL + }, + { + "last_hidden_state": 0, # Qwen2Model + "image_hidden_states": -1, # Vision Transformer + }, + ], +] + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [case + [dtype] + [mode] for case in TEST_CASES for dtype in DTYPE_AND_THRESHOLDS.keys() for mode in MODES], +) +def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, +): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + )