-
Couldn't load subscription status.
- Fork 88
feat(transformers): Transformers 4.54 base #1387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
SamitHuang
merged 99 commits into
mindspore-lab:master
from
wcrzlh:transformers_4.54_base
Oct 27, 2025
Merged
Changes from all commits
Commits
Show all changes
99 commits
Select commit
Hold shift + click to select a range
a8c5c54
upgrade activation_func to transformers v4.54
wcrzlh c05a707
feat(transformers): upgrade attn_mask/rope to 4.54
wcrzlh 84b3ece
feat(transformers): upgrade modeling_layers to 4.54
wcrzlh 6b94c2d
feat(transformers): upgrade cache_utils to 4.54
wcrzlh 2f70121
feat(transformers): upgrade modeling_utils to v4.54
wcrzlh 44ad424
feat(transformers): upgrade generation/utils to v4.54
wcrzlh 17da5c7
feat(transformers): add ernie4.5 for validation
wcrzlh fd769b1
fix get_type_hints problem
wcrzlh 37fe594
fix get_type_hints problem
wcrzlh 0874e77
fix get_type_hints problem
wcrzlh 94fb78b
fix metadata.get keyerror
wcrzlh bf69ef9
fix masking_utils alignment
wcrzlh a1be89c
fix generation/utils logic
wcrzlh b3334ac
fix get_output_embedding override bug
wcrzlh 833419c
fix __init_subclass__ bug
wcrzlh c532a9a
suplement checkpoint_conversion_mapping
wcrzlh 1ac2f72
feat(transformers): upgrade beam search to v4.54
wcrzlh 375e6ab
feat(transformers): upgrade candidate_generator to v4.54
wcrzlh 25033c1
feat(transformers): upgrade logits_process/stopping_criteria to v4.54
wcrzlh 252f4aa
pre-commit
wcrzlh 02834b0
pre-commit
wcrzlh 65e8256
update backbone_utils
wtomin 913cd3c
update generic
wtomin a0c9dc9
remove add_model_info_to_auto_map & update feature_extraction_utils.py
wtomin 2a517d8
remove add_model_info_to_auto_map & update image_processing_base.py
wtomin 8d29722
remove add_model_info_to_auto_map & update processing_utils.py
wtomin 8a90ca6
remove add_model_info_to_auto_map & update video_utils.py
wtomin dcb98ac
tokenization_utils.py update
wtomin 5120977
add_model_info_to_custom_pipelines
wtomin 9a18655
update tokenization_utils_base.py
wtomin 509a308
update image_transforms.py
wtomin 33ed2be
update video_utils.py and image_utils.py
wtomin 0d8142c
update image_utils.py & image_processing_utils_fast.py
wtomin 991a783
update integration sdpa_attention.py
wtomin ccb0897
update mask_utils.py
wtomin 00f2ba3
update modeling_flash_attention_utils.py
wtomin d0b34fb
update modeling_outputs.py
wtomin f75b06a
fix pre-commit errors
wtomin f9ea8ce
fix pre-commit errors
wtomin 03635c5
rebase
wcrzlh 7bed7a1
add modeling_layers.py from cui yushi
wtomin 61b4f5c
fix import in transformers
wtomin 9205ca2
Merge branch 'transformers_4.54_base' into transformer-v4.54.1
wtomin 3e3f452
rm tokenization_utils.py and tokenization_utils_base.py
wtomin 91609b9
resize stacked images one by one
wtomin ffd3377
remove torchvision decoders
wtomin b38bf63
fix get_default_dtype bug
wcrzlh f32b7cb
load module dynamically from mindone/transformers
wtomin 2cb578b
not support FA
wtomin 7ad706d
Merge pull request #2 from wtomin/transformer-v4.54.1
wcrzlh 9457ebc
add video_processing_utils
wcrzlh 32031d0
fix import error/add audio_utils/fix processor bug/attn_implementatio…
wtomin 294d153
fix attn_implementation configuration bug
wcrzlh a44b0f6
Fix attn_implementation
wtomin ba674bc
fix fa bug/key_renaming_mapping bug
wcrzlh 3ab17b0
pre-commit
wcrzlh ee91d87
upgrade modeling_utils/save_pretrained to transformersv4.54
wcrzlh ff82ffb
refactor fa part
wcrzlh 58e07d6
Fix some model's UT
wtomin ab125b4
revert _support_dynamic_input to _support_jit
wcrzlh 226bd0e
fix class name mismatch in generation/utils
wcrzlh d156ca6
fix pa error/delete unused fa part
wcrzlh fe3304b
remove unused part
wcrzlh 934520f
generation/utils ops-->mint
wcrzlh 4aab9fa
copyright/pre-commit
wcrzlh d104c56
fix bugs
wcrzlh ba0a8eb
supplement activation api
wcrzlh 9e36ba8
reformat
wcrzlh 738d9bb
remove losskwargs
wtomin c80e2fd
fix disable_grouping bug in image processing
wcrzlh 10ec00b
fix attn_implementation setting in modeling_utils/from_pretrained
wcrzlh a813cf9
fix attn_implementation setting in modeling_utils/from_pretrained
wcrzlh cdebac0
fix modeling_utils/from_config mindspore_dtype setting, generation/ut…
wcrzlh 7a20fe1
feat(transformers): add qwen3_vl/qwen3_vl_moe model
wcrzlh 4079e6f
fix moe precision bug
wcrzlh c1cde3a
fix qwen3_vl moe memory bugs
wcrzlh 721d0a3
supplement zero3 model weight shard for moe part
wcrzlh e43f3dd
fix qwen3_vl_moe precision bug
wcrzlh 51515b9
fix qwen3_vl_moe precision bug
wcrzlh 25c8110
fix moe part shard bug
wcrzlh 9650f4f
pre-commit
wcrzlh 3771434
reformat
wcrzlh 2d5f9e7
Merge pull request #1310 from wcrzlh/qwen3_vl
vigo999 fed7ffc
fix(transformers): fix typos in qwen3_vl docs
wcrzlh f2b56bf
Merge pull request #1311 from wcrzlh/qwen3_vl
vigo999 3c81df8
feat(transformers): add processor for qwen3_vl (#1326)
wcrzlh 6e6361a
fix(transformers): supplement condition of taking model as processor
wcrzlh e72e032
fix(transformers): reformat generation/utils
wcrzlh dcdec6c
fix(transformers): supplement candidate generator
wcrzlh 47ca032
fix(transformers): supplement logits processor
wcrzlh c6df7fd
feat(transformers): add assisted_generation/dola_generation/contrasiv…
wcrzlh 11ba44c
rebase
wcrzlh 18d35f6
reformat
wcrzlh 8b75291
fix import bug
wcrzlh 9fd221f
fix ut bug
wcrzlh 46639e0
update pyproject.toml
wcrzlh aa0e7b0
pre-commit
wcrzlh b11c421
reformat
wcrzlh bd76d4c
update loss_type
wcrzlh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # Qwen3-VL series | ||
|
|
||
| ## Introduction | ||
| [Qwen3-VL](https://huggingface.co/papers/2502.13923) is a multimodal vision-language model series, encompassing both dense and MoE variants, as well as Instruct and Thinking versions. Building upon its predecessors, Qwen3-VL delivers significant improvements in visual understanding while maintaining strong pure text capabilities. Key architectural advancements include: enhanced MRope with interleaved layout for better spatial-temporal modeling, DeepStack integration to effectively leverage multi-level features from the Vision Transformer (ViT), and improved video understanding through text-based time alignment—evolving from T-RoPE to text timestamp alignment for more precise temporal grounding. These innovations collectively enable Qwen3-VL to achieve superior performance in complex multimodal tasks. | ||
|
|
||
| # Get Started | ||
|
|
||
| ## Requirements: | ||
| | mindspore | ascend driver | firmware | cann tookit/kernel | | ||
| |-----------|----------------|----------------|--------------------| | ||
| | 2.6.0 | 24.1.RC3.b080 | 7.5.T11.0.B088 | 8.1.RC1 | | ||
|
|
||
| ### Installation: | ||
| ``` | ||
| git clone https://github.com/mindspore-lab/mindone.git -b hf-transformers-4.54 | ||
| cd mindone | ||
| pip install -e . | ||
| cd .. | ||
|
|
||
| # compile newest transformers whl because qwen3-vl(transformers v4.57.dev.0) haven't released | ||
| git clone https://github.com/huggingface/transformers.git | ||
| cd transformers | ||
| git reset --hard d0af4269ec260b9c4aeeda24c346a469e44799e1 | ||
| pip install -e . | ||
| cd .. | ||
|
|
||
| cd mindone/examples/transformers/qwen3_vl | ||
| ``` | ||
|
|
||
| ## Quick Start | ||
|
|
||
| Here is a usage example of Qwen3-VL-4B-Instruct. you can use the following command: | ||
|
|
||
| ```bash | ||
| # for Qwen3-VL-4B-Instruct inference | ||
| python generate_qwen3_vl.py | ||
| --model_name "Qwen/Qwen3-VL-4B-Instruct" | ||
| --image "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" | ||
| --prompt "Describe this image." | ||
| ``` | ||
|
|
||
| ```bash | ||
| # for Qwen3-VL-30B-A3B-Instruct inference | ||
| msrun --worker_num=2 --local_worker_num=2 --master_port=8118 \ | ||
| --log_dir=msrun_log --join=True --cluster_time_out=300 \ | ||
| generate_qwen3_vl_moe.py \ | ||
| --model_name "Qwen/Qwen3-VL-30B-A3B-Instruct" \ | ||
| --image "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" \ | ||
| --prompt "Describe this image." \ | ||
| ``` | ||
|
|
||
| Image: | ||
|  | ||
|
|
||
| Prompt: Describe this image. | ||
|
|
||
| Qwen3-VL-4B Outputs: | ||
| ``` | ||
| ['Of course, here is detailed description of the image provided.\n\n | ||
| This is a close-up photograph of a Pallas\'s cat ($Felis$, $manul$), | ||
| an endangered wild feline species native to Central Aisa. | ||
| ... | ||
| **Appearance:** It has a stocky and robust build with short legs | ||
| and a large head relative to its body size. Its fur is thick and dense, | ||
| appearing somewhat fluffy or "matted,", which is characteristic'] | ||
| ``` | ||
|
|
||
| Qwen3-VL-30B Outputs: | ||
| ``` | ||
| ['Of course, here is detailed description of the image provided.\n\n | ||
| This is a dynamic and charming photograph of a Palla's cat (also known as a manul) in a snowy enviroment. | ||
| ... | ||
| "Appearance:" The cat has a very distinctive apperance, characterized by its stocky, low-slung body and exceptionally | ||
| thick, dense fur. This coat is a mix of brownish"] | ||
| ``` | ||
|
|
||
| `model_name` and `image` could be replaced with your local path. Give it a try with various images and prompts🤗🤗. | ||
|
|
||
| ## Inference Speed | ||
| | model name | mindspore version | precision* | cards | attention type | tokens/s | | ||
| |:------------------------------:|:-----------------:|:----------:|:-----:|:--------------:|:----------:| | ||
| | Qwen/Qwen3-VL-4B-Instruct | 2.6.0 | bf16 | 1 | flash_attn | 1.35 | | ||
| | Qwen/Qwen3-VL-30B-A3B-Instruct | 2.6.0 | bf16 | 2 | flash_attn | 0.5 | | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| import argparse | ||
|
|
||
| import numpy as np | ||
|
|
||
| import mindspore as ms | ||
|
|
||
| from mindone.transformers import AutoProcessor, Qwen3VLForConditionalGeneration | ||
|
|
||
|
|
||
| def generate(args): | ||
| model = Qwen3VLForConditionalGeneration.from_pretrained( | ||
| args.model_name, | ||
| mindspore_dtype=ms.bfloat16, | ||
| attn_implementation=args.attn_implementation, | ||
| ) | ||
|
|
||
| processor = AutoProcessor.from_pretrained( | ||
| args.model_name, | ||
| use_fast=False, | ||
| ) | ||
|
|
||
| messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| { | ||
| "type": "image", | ||
| "url": args.image, | ||
| }, | ||
| { | ||
| "type": "text", | ||
| "text": args.prompt, | ||
| }, | ||
| ], | ||
| } | ||
| ] | ||
|
|
||
| inputs = processor.apply_chat_template( | ||
| messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="np" | ||
| ) | ||
|
|
||
| # convert input to Tensor | ||
| for key, value in inputs.items(): | ||
| if isinstance(value, np.ndarray): | ||
| inputs[key] = ms.tensor(value) | ||
| elif isinstance(value, list): | ||
| inputs[key] = ms.Tensor(value) | ||
|
|
||
| generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False) | ||
| generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] | ||
| output_text = processor.batch_decode( | ||
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
| ) | ||
| print(output_text) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="Qwen3VL demo.") | ||
|
|
||
| parser.add_argument("--prompt", type=str, default="Describe this image.") | ||
| parser.add_argument( | ||
| "--image", | ||
| type=str, | ||
| default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", | ||
| ) | ||
| parser.add_argument( | ||
| "--model_name", type=str, default="Qwen/Qwen3-VL-4B-Instruct", help="Path to the pre-trained model." | ||
| ) | ||
| parser.add_argument( | ||
| "--attn_implementation", | ||
| type=str, | ||
| default="flash_attention_2", | ||
| choices=["flash_attention_2", "eager"], | ||
| ) | ||
|
|
||
| # Parse the arguments | ||
| args = parser.parse_args() | ||
|
|
||
| generate(args) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| import argparse | ||
| from functools import partial | ||
|
|
||
| import numpy as np | ||
|
|
||
| import mindspore as ms | ||
| import mindspore.mint.distributed as dist | ||
| from mindspore.communication import GlobalComm | ||
|
|
||
| from mindone.trainers.zero import prepare_network | ||
| from mindone.transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration | ||
|
|
||
|
|
||
| def generate(args): | ||
| model = Qwen3VLMoeForConditionalGeneration.from_pretrained( | ||
| args.model_name, | ||
| mindspore_dtype=ms.bfloat16, | ||
| attn_implementation=args.attn_implementation, | ||
| ) | ||
|
|
||
| # use zero3 parallel | ||
| shard_fn = partial(prepare_network, zero_stage=3, optimizer_parallel_group=GlobalComm.WORLD_COMM_GROUP) | ||
| model = shard_fn(model) | ||
|
|
||
| processor = AutoProcessor.from_pretrained( | ||
| args.model_name, | ||
| use_fast=False, | ||
| ) | ||
|
|
||
| messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| { | ||
| "type": "image", | ||
| "url": args.image, | ||
| }, | ||
| { | ||
| "type": "text", | ||
| "text": args.prompt, | ||
| }, | ||
| ], | ||
| } | ||
| ] | ||
|
|
||
| inputs = processor.apply_chat_template( | ||
| messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="np" | ||
| ) | ||
|
|
||
| # convert input to Tensor | ||
| for key, value in inputs.items(): | ||
| if isinstance(value, np.ndarray): | ||
| inputs[key] = ms.tensor(value) | ||
| elif isinstance(value, list): | ||
| inputs[key] = ms.Tensor(value) | ||
|
|
||
| generated_ids = model.generate(**inputs, max_new_tokens=128) | ||
| generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] | ||
| output_text = processor.batch_decode( | ||
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
| ) | ||
| print(output_text) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="Qwen3VLMoE demo.") | ||
|
|
||
| parser.add_argument("--prompt", type=str, default="Describe this image.") | ||
| parser.add_argument( | ||
| "--image", | ||
| type=str, | ||
| default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", | ||
| ) | ||
| parser.add_argument( | ||
| "--model_name", type=str, default="Qwen/Qwen3-VL-30B-A3B-Instruct", help="Path to the pre-trained model." | ||
| ) | ||
| parser.add_argument( | ||
| "--attn_implementation", | ||
| type=str, | ||
| default="flash_attention_2", | ||
| choices=["flash_attention_2", "eager"], | ||
| ) | ||
|
|
||
| # Parse the arguments | ||
| args = parser.parse_args() | ||
|
|
||
| # set up card communication | ||
| dist.init_process_group(backend="hccl") | ||
| ms.set_auto_parallel_context(parallel_mode="data_parallel") | ||
|
|
||
| generate(args) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| from typing import Literal, Optional | ||
|
|
||
| from mindspore import Tensor | ||
| from mindspore import dtype as mstype | ||
| from mindspore import mint, nn | ||
| from mindspore.communication import get_group_size, get_rank | ||
| from mindspore.communication.management import GlobalComm | ||
| from mindspore.context import ParallelMode | ||
| from mindspore.parallel._utils import _get_parallel_mode | ||
|
|
||
| from .param_wrapper import ZeroParamWrapper | ||
|
|
||
|
|
||
| class MoeTextExperts(nn.Cell): | ||
| def __init__( | ||
| self, | ||
| net: nn.Cell, | ||
| zero_stage: Literal[0, 1, 2, 3] = 0, | ||
| optimizer_parallel_group: str = GlobalComm.WORLD_COMM_GROUP, | ||
| cell_type: Optional[mstype.Type] = None, | ||
| ): | ||
| super().__init__(auto_prefix=False) | ||
| self.net = net | ||
| self.set_param_wrapper(zero_stage, optimizer_parallel_group, cell_type) | ||
|
|
||
| def set_param_wrapper(self, zero_stage, optimizer_parallel_group, cell_type=None): | ||
| self.param_wrapper_gate_up_proj = nn.Identity() | ||
| self.param_wrapper_down_proj = nn.Identity() | ||
| if zero_stage == 3: | ||
| # Init parallel settings | ||
| is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL | ||
| op_group_size = get_group_size(optimizer_parallel_group) if is_parallel else 1 | ||
| op_rank_id = get_rank(optimizer_parallel_group) if is_parallel else 0 | ||
| self.op_group_size = op_group_size | ||
| self.op_rank_id = op_rank_id | ||
| self.param_wrapper_gate_up_proj = ZeroParamWrapper( | ||
| self.net.gate_up_proj, zero_stage, optimizer_parallel_group, cell_type | ||
| ) | ||
| if self.param_wrapper_gate_up_proj.need_rewrite: | ||
| self.net.gate_up_proj.assign_value( | ||
| Tensor.from_numpy( | ||
| self.net.gate_up_proj.numpy().reshape(op_group_size, -1, *self.net.gate_up_proj.shape[1:])[ | ||
| op_rank_id | ||
| ] | ||
| ) | ||
| ) | ||
| self.param_wrapper_down_proj = ZeroParamWrapper( | ||
| self.net.down_proj, zero_stage, optimizer_parallel_group, cell_type | ||
| ) | ||
| if self.param_wrapper_down_proj.need_rewrite: | ||
| self.net.down_proj.assign_value( | ||
| Tensor.from_numpy( | ||
| self.net.down_proj.numpy().reshape(op_group_size, -1, *self.net.down_proj.shape[1:])[op_rank_id] | ||
| ) | ||
| ) | ||
|
|
||
| def construct(self, hidden_states, routing_weights, router_indices): | ||
| batch_size = hidden_states.shape[0] | ||
| hidden_states = hidden_states.reshape(-1, self.net.hidden_size) # (num_tokens, hidden_size) | ||
|
|
||
| hidden_states = hidden_states.repeat(self.net.num_experts, 1) | ||
| hidden_states = hidden_states.view(self.net.num_experts, -1, self.net.hidden_size) | ||
|
|
||
| gate_up = mint.bmm(hidden_states, self.param_wrapper_gate_up_proj(self.net.gate_up_proj)) | ||
| gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors | ||
| next_states = mint.bmm((up * self.net.act_fn(gate)), self.param_wrapper_down_proj(self.net.down_proj)) | ||
| next_states = next_states.reshape(self.net.num_experts, batch_size, -1, self.net.hidden_size) | ||
| next_states = next_states * routing_weights.swapaxes(0, 1).view(self.net.num_experts, batch_size, -1)[..., None] | ||
| next_states = next_states.sum(dim=0) | ||
| return next_states |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.